diff --git a/README.md b/README.md index 1e2f22907..8804ec306 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,34 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot ## Documentation - Dependency Security: `docs/dependency-security.md` -- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md` + +--- + +## Codex CLI WebSocket v2 Example + +To enable OpenAI WebSocket Mode v2 in Codex CLI with Sub2API, add the following to `~/.codex/config.toml`: + +```toml +model_provider = "aicodx2api" +model = "gpt-5.3-codex" +review_model = "gpt-5.3-codex" +model_reasoning_effort = "xhigh" +disable_response_storage = true +network_access = "enabled" +windows_wsl_setup_acknowledged = true + +[model_providers.aicodx2api] +name = "aicodx2api" +base_url = "https://api.sub2api.ai" +wire_api = "responses" +supports_websockets = true +requires_openai_auth = true + +[features] +responses_websockets_v2 = true +``` + +After updating the config, restart Codex CLI. --- diff --git a/README_CN.md b/README_CN.md index 9da089b74..22a772b5a 100644 --- a/README_CN.md +++ b/README_CN.md @@ -62,6 +62,34 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( - 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。 - 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。 +## Codex CLI 开启 OpenAI WebSocket Mode v2 示例配置 + +如需在 Codex CLI 中通过 Sub2API 启用 OpenAI WebSocket Mode v2,可将以下配置写入 `~/.codex/config.toml`: + +```toml +model_provider = "aicodx2api" +model = "gpt-5.3-codex" +review_model = "gpt-5.3-codex" +model_reasoning_effort = "xhigh" +disable_response_storage = true +network_access = "enabled" +windows_wsl_setup_acknowledged = true + +[model_providers.aicodx2api] +name = "aicodx2api" +base_url = "https://api.sub2api.ai" +wire_api = "responses" +supports_websockets = true +requires_openai_auth = true + +[features] +responses_websockets_v2 = true +``` + +配置更新后,重启 Codex CLI 使其生效。 + +--- + ## 部署方式 ### 方式一:脚本安装(推荐) diff --git a/backend/.golangci.yml b/backend/.golangci.yml index 68b76751f..3ec692a84 100644 --- a/backend/.golangci.yml +++ b/backend/.golangci.yml @@ -5,7 +5,6 @@ linters: enable: - depguard - errcheck - - gosec - govet - ineffassign - staticcheck @@ -43,22 +42,6 @@ linters: desc: "handler must not import gorm" - pkg: github.com/redis/go-redis/v9 desc: "handler must not import redis" - gosec: - excludes: - - G101 - - G103 - - G104 - - G109 - - G115 - - G201 - - G202 - - G301 - - G302 - - G304 - - G306 - - G404 - severity: high - confidence: high errcheck: # Report about not checking of errors in type assertions: `a := b.(MyStruct)`. # Such cases aren't reported by default. diff --git a/backend/.gosec.json b/backend/.gosec.json new file mode 100644 index 000000000..7a8ccb6a1 --- /dev/null +++ b/backend/.gosec.json @@ -0,0 +1,5 @@ +{ + "global": { + "exclude": "G704,G101,G103,G104,G109,G115,G201,G202,G301,G302,G304,G306,G404" + } +} diff --git a/backend/Dockerfile b/backend/Dockerfile index aeb20fdb6..6db2b1756 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -1,4 +1,4 @@ -FROM golang:1.25.7-alpine +FROM registry-1.docker.io/library/golang:1.25.7-alpine WORKDIR /app diff --git a/backend/THIRD_PARTY_NOTICES.md b/backend/THIRD_PARTY_NOTICES.md new file mode 100644 index 000000000..0524259ad --- /dev/null +++ b/backend/THIRD_PARTY_NOTICES.md @@ -0,0 +1,26 @@ +# Third-Party Notices + +## caddyserver/caddy + +- Project: https://github.com/caddyserver/caddy +- License: Apache License 2.0 +- Copyright: + Copyright 2015 Matthew Holt and The Caddy Authors + +### Usage in this repository + +OpenAI WS v2 passthrough relay adopts the Caddy reverse proxy streaming architecture +(bidirectional tunnel + convergence shutdown) and is adapted with minimal changes for +`coder/websocket` frame-based forwarding. + +- Referenced commit: + `f283062d37c50627d53ca682ebae2ce219b35515` +- Referenced upstream files: + - `modules/caddyhttp/reverseproxy/streaming.go` + - `modules/caddyhttp/reverseproxy/reverseproxy.go` +- Local adaptation files: + - `backend/internal/service/openai_ws_v2/caddy_adapter.go` + - `backend/internal/service/openai_ws_v2/passthrough_relay.go` + +The adaptation preserves Apache-2.0 license obligations. + diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index bc0016939..2ff7358b1 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 32844913e..6ba552b88 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.88 \ No newline at end of file +0.1.87.4 diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index cbf89ba3b..5044f7ee0 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -210,9 +210,9 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, - {"OpenAIWSPool", func() error { + {"OpenAIWSCtxPool", func() error { if openAIGateway != nil { - openAIGateway.CloseOpenAIWSPool() + openAIGateway.CloseOpenAIWSCtxPool() } return nil }}, diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index cbeb9a693..5bc0d35c3 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -48,8 +48,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { redisClient := repository.ProvideRedis(configConfig) refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) - groupRepository := repository.NewGroupRepository(client, db) - settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig) + settingService := service.NewSettingService(settingRepository, configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) turnstileVerifier := repository.NewTurnstileVerifier() @@ -60,15 +59,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) apiKeyRepository := repository.NewAPIKeyRepository(client, db) billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig) + groupRepository := repository.NewGroupRepository(client, db) userGroupRateRepository := repository.NewUserGroupRateRepository(db) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) - apiKeyService.SetRateLimitCacheInvalidator(billingCache) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) - subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) - authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService) + authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache) + subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig) redeemCache := repository.NewRedeemCache(redisClient) redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator) secretEncryptor, err := repository.NewAESEncryptor(configConfig) @@ -104,7 +103,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) @@ -139,8 +138,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) - rpmCache := repository.NewRPMCache(redisClient) - accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator) + accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) dataManagementService := service.NewDataManagementService() dataManagementHandler := admin.NewDataManagementHandler(dataManagementService) @@ -162,7 +160,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) digestSessionStore := service.NewDigestSessionStore() - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) @@ -197,9 +195,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) - userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient) - userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, nil, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) @@ -396,9 +392,9 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, - {"OpenAIWSPool", func() error { + {"OpenAIWSCtxPool", func() error { if openAIGateway != nil { - openAIGateway.CloseOpenAIWSPool() + openAIGateway.CloseOpenAIWSCtxPool() } return nil }}, diff --git a/backend/go.mod b/backend/go.mod index ab76258a2..08c4e26f0 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -109,7 +109,6 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect @@ -178,7 +177,6 @@ require ( golang.org/x/mod v0.32.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect - golang.org/x/tools v0.41.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.67.6 // indirect diff --git a/backend/go.sum b/backend/go.sum index 32e389a76..98914a83e 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -182,8 +182,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= -github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index c1f54ab69..365cb0f7c 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -273,13 +273,8 @@ type CSPConfig struct { } type ProxyFallbackConfig struct { - // AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。 - // 仅影响以下非 AI 账号连接的辅助服务: - // - GitHub Release 更新检查 - // - 定价数据拉取 - // 不影响 AI 账号网关连接(Claude/OpenAI/Gemini/Antigravity), - // 这些关键路径的代理失败始终返回错误,不会回退直连。 - // 默认 false:避免因代理配置错误导致服务器真实 IP 泄露。 + // AllowDirectOnError 当代理初始化失败时是否允许回退直连。 + // 默认 false:避免因代理配置错误导致 IP 泄露/关联。 AllowDirectOnError bool `mapstructure:"allow_direct_on_error"` } @@ -379,6 +374,8 @@ type GatewayConfig struct { OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` // OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP) OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"` + // OpenAIHTTP2: OpenAI HTTP 上游协议策略(默认启用 HTTP/2,可按代理能力回退 HTTP/1.1) + OpenAIHTTP2 GatewayOpenAIHTTP2Config `mapstructure:"openai_http2"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -511,12 +508,27 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string { return "" } +// GatewayOpenAIHTTP2Config OpenAI HTTP 上游协议配置。 +// 默认启用 HTTP/2,多路复用提升并发效率;在部分代理不兼容时按策略回退 HTTP/1.1。 +type GatewayOpenAIHTTP2Config struct { + // Enabled: 是否启用 OpenAI HTTP/2 优先策略 + Enabled bool `mapstructure:"enabled"` + // AllowProxyFallbackToHTTP1: 代理不兼容 HTTP/2 时是否允许回退 HTTP/1.1 + AllowProxyFallbackToHTTP1 bool `mapstructure:"allow_proxy_fallback_to_http1"` + // FallbackErrorThreshold: 在窗口期内触发回退所需的连续错误次数 + FallbackErrorThreshold int `mapstructure:"fallback_error_threshold"` + // FallbackWindowSeconds: 连续错误计数窗口(秒) + FallbackWindowSeconds int `mapstructure:"fallback_window_seconds"` + // FallbackTTLSeconds: 进入 HTTP/1.1 回退态后的持续时间(秒) + FallbackTTLSeconds int `mapstructure:"fallback_ttl_seconds"` +} + // GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 // 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。 type GatewayOpenAIWSConfig struct { - // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为) + // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 true;关闭时保持 legacy 行为) ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"` - // IngressModeDefault: ingress 默认模式(off/shared/dedicated) + // IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough) IngressModeDefault string `mapstructure:"ingress_mode_default"` // Enabled: 全局总开关(默认 true) Enabled bool `mapstructure:"enabled"` @@ -554,12 +566,13 @@ type GatewayOpenAIWSConfig struct { // OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor)) OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"` // APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor)) - APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"` - DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` - ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` - WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` - PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"` - QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"` + APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"` + DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` + ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` + ClientReadIdleTimeoutSeconds int `mapstructure:"client_read_idle_timeout_seconds"` + WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` + PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"` + QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"` // EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数) EventFlushBatchSize int `mapstructure:"event_flush_batch_size"` // EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发 @@ -579,6 +592,11 @@ type GatewayOpenAIWSConfig struct { // PayloadLogSampleRate: payload_schema 日志采样率(0-1) PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"` + // UpstreamConnMaxAgeSeconds: 上游 WebSocket 连接最大存活时间(秒)。 + // OpenAI 在 60 分钟后强制断开连接,此参数控制主动轮换阈值。 + // 默认 3300(55 分钟);设为 0 则禁用超龄轮换。 + UpstreamConnMaxAgeSeconds int `mapstructure:"upstream_conn_max_age_seconds"` + // 账号调度与粘连参数 LBTopK int `mapstructure:"lb_top_k"` // StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL @@ -595,6 +613,32 @@ type GatewayOpenAIWSConfig struct { StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"` SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"` + + // SchedulerP2CEnabled: 启用 P2C(Power-of-Two-Choices)选择算法替代 Top-K 加权采样 + SchedulerP2CEnabled bool `mapstructure:"scheduler_p2c_enabled"` + + // Softmax 温度采样:替代线性平移的概率选择策略 + SchedulerSoftmaxEnabled bool `mapstructure:"scheduler_softmax_enabled"` + SchedulerSoftmaxTemperature float64 `mapstructure:"scheduler_softmax_temperature"` + + // 账号级熔断器 + SchedulerCircuitBreakerEnabled bool `mapstructure:"scheduler_circuit_breaker_enabled"` + SchedulerCircuitBreakerFailThreshold int `mapstructure:"scheduler_circuit_breaker_fail_threshold"` + SchedulerCircuitBreakerCooldownSec int `mapstructure:"scheduler_circuit_breaker_cooldown_sec"` + SchedulerCircuitBreakerHalfOpenMax int `mapstructure:"scheduler_circuit_breaker_half_open_max"` + + // 条件性 Sticky Session 释放:当粘连账号不健康时主动释放,回退到负载均衡 + StickyReleaseEnabled bool `mapstructure:"sticky_release_enabled"` + StickyReleaseErrorThreshold float64 `mapstructure:"sticky_release_error_threshold"` + + // Per-model TTFT tracking + SchedulerPerModelTTFTEnabled bool `mapstructure:"scheduler_per_model_ttft_enabled"` + SchedulerPerModelTTFTMaxModels int `mapstructure:"scheduler_per_model_ttft_max_models"` + + // SchedulerTrendEnabled: 启用负载趋势预测(线性回归外推),在打分时对 loadFactor 施加趋势修正 + SchedulerTrendEnabled bool `mapstructure:"scheduler_trend_enabled"` + // SchedulerTrendMaxSlope: 趋势斜率归一化上限(每秒负载百分比变化率);0 或负数使用默认值 5.0 + SchedulerTrendMaxSlope float64 `mapstructure:"scheduler_trend_max_slope"` } // GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。 @@ -1172,9 +1216,6 @@ func setDefaults() { viper.SetDefault("security.csp.policy", DefaultCSPPolicy) viper.SetDefault("security.proxy_probe.insecure_skip_verify", false) - // Security - disable direct fallback on proxy error - viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) - // Billing viper.SetDefault("billing.circuit_breaker.enabled", true) viper.SetDefault("billing.circuit_breaker.failure_threshold", 5) @@ -1332,8 +1373,8 @@ func setDefaults() { viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) // OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚) viper.SetDefault("gateway.openai_ws.enabled", true) - viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false) - viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared") + viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", true) + viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool") viper.SetDefault("gateway.openai_ws.oauth_enabled", true) viper.SetDefault("gateway.openai_ws.apikey_enabled", true) viper.SetDefault("gateway.openai_ws.force_http", false) @@ -1364,7 +1405,8 @@ func setDefaults() { viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2) viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000) viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2) - viper.SetDefault("gateway.openai_ws.lb_top_k", 7) + viper.SetDefault("gateway.openai_ws.upstream_conn_max_age_seconds", 3300) + viper.SetDefault("gateway.openai_ws.lb_top_k", 999) viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600) viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true) viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true) @@ -1376,6 +1418,12 @@ func setDefaults() { viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7) viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8) viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5) + // OpenAI HTTP upstream protocol strategy + viper.SetDefault("gateway.openai_http2.enabled", true) + viper.SetDefault("gateway.openai_http2.allow_proxy_fallback_to_http1", true) + viper.SetDefault("gateway.openai_http2.fallback_error_threshold", 2) + viper.SetDefault("gateway.openai_http2.fallback_window_seconds", 60) + viper.SetDefault("gateway.openai_http2.fallback_ttl_seconds", 600) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_extra_retries", 10) viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) @@ -1420,7 +1468,7 @@ func setDefaults() { viper.SetDefault("gateway.usage_record.worker_count", 128) viper.SetDefault("gateway.usage_record.queue_size", 16384) viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5) - viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample) + viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySync) viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10) viper.SetDefault("gateway.usage_record.auto_scale_enabled", true) viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128) @@ -1493,6 +1541,9 @@ func setDefaults() { viper.SetDefault("gemini.oauth.scopes", "") viper.SetDefault("gemini.quota.policy", "") + // Security - proxy fallback + viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) + // Subscription Maintenance (bounded queue + worker pool) viper.SetDefault("subscription_maintenance.worker_count", 2) viper.SetDefault("subscription_maintenance.queue_size", 1024) @@ -1971,6 +2022,15 @@ func (c *Config) Validate() error { (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") } + if c.Gateway.OpenAIHTTP2.FallbackErrorThreshold < 0 { + return fmt.Errorf("gateway.openai_http2.fallback_error_threshold must be non-negative") + } + if c.Gateway.OpenAIHTTP2.FallbackWindowSeconds < 0 { + return fmt.Errorf("gateway.openai_http2.fallback_window_seconds must be non-negative") + } + if c.Gateway.OpenAIHTTP2.FallbackTTLSeconds < 0 { + return fmt.Errorf("gateway.openai_http2.fallback_ttl_seconds must be non-negative") + } // 兼容旧键 sticky_previous_response_ttl_seconds if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds @@ -2039,11 +2099,16 @@ func (c *Config) Validate() error { if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 { return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative") } + if c.Gateway.OpenAIWS.ResponsesWebsockets && !c.Gateway.OpenAIWS.ResponsesWebsocketsV2 { + return fmt.Errorf("gateway.openai_ws.responses_websockets (v1) is not supported; enable gateway.openai_ws.responses_websockets_v2") + } if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" { switch mode { - case "off", "shared", "dedicated": + case "off", "ctx_pool", "passthrough": + case "shared", "dedicated": + slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode) default: - return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated") + return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough") } } if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" { @@ -2056,6 +2121,9 @@ func (c *Config) Validate() error { if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 { return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]") } + if c.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds < 0 { + return fmt.Errorf("gateway.openai_ws.upstream_conn_max_age_seconds must be non-negative") + } if c.Gateway.OpenAIWS.LBTopK <= 0 { return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive") } @@ -2083,6 +2151,22 @@ func (c *Config) Validate() error { if weightSum <= 0 { return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero") } + // Validate new scheduler/sticky-release config ranges. + if c.Gateway.OpenAIWS.SchedulerSoftmaxTemperature < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_softmax_temperature must be non-negative") + } + if c.Gateway.OpenAIWS.StickyReleaseErrorThreshold < 0 || c.Gateway.OpenAIWS.StickyReleaseErrorThreshold > 1 { + return fmt.Errorf("gateway.openai_ws.sticky_release_error_threshold must be within [0,1]") + } + if c.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_fail_threshold must be non-negative") + } + if c.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_cooldown_sec must be non-negative") + } + if c.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_half_open_max must be non-negative") + } if c.Gateway.MaxLineSize < 0 { return fmt.Errorf("gateway.max_line_size must be non-negative") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index e3b592e2c..480a17c16 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -150,11 +150,11 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) { if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" { t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict") } - if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled { - t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false") + if !cfg.Gateway.OpenAIWS.ModeRouterV2Enabled { + t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = false, want true") } - if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" { - t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared") + if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" { + t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool") } } @@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) { wantErr: "gateway.openai_ws.store_disabled_conn_mode", }, { - name: "ingress_mode_default 必须为 off|shared|dedicated", + name: "ingress_mode_default 必须为 off|ctx_pool|passthrough", mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" }, wantErr: "gateway.openai_ws.ingress_mode_default", }, @@ -1387,6 +1387,14 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) { mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 }, wantErr: "gateway.openai_ws.retry_total_budget_ms", }, + { + name: "responses_websockets v1-only 配置不允许", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.ResponsesWebsockets = true + c.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + }, + wantErr: "gateway.openai_ws.responses_websockets", + }, { name: "lb_top_k 必须为正数", mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 }, @@ -1637,8 +1645,8 @@ func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) { if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 { t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds) } - if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample { - t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample) + if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySync { + t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySync) } if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 { t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent) diff --git a/backend/internal/handler/admin/account_data_handler_test.go b/backend/internal/handler/admin/account_data_handler_test.go index 285033a17..c8b04c2ae 100644 --- a/backend/internal/handler/admin/account_data_handler_test.go +++ b/backend/internal/handler/admin/account_data_handler_test.go @@ -64,7 +64,6 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) { nil, nil, nil, - nil, ) router.GET("/api/v1/admin/accounts/data", h.ExportData) diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 98ead2841..e4a69032f 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -53,7 +53,6 @@ type AccountHandler struct { concurrencyService *service.ConcurrencyService crsSyncService *service.CRSSyncService sessionLimitCache service.SessionLimitCache - rpmCache service.RPMCache tokenCacheInvalidator service.TokenCacheInvalidator } @@ -70,7 +69,6 @@ func NewAccountHandler( concurrencyService *service.ConcurrencyService, crsSyncService *service.CRSSyncService, sessionLimitCache service.SessionLimitCache, - rpmCache service.RPMCache, tokenCacheInvalidator service.TokenCacheInvalidator, ) *AccountHandler { return &AccountHandler{ @@ -85,7 +83,6 @@ func NewAccountHandler( concurrencyService: concurrencyService, crsSyncService: crsSyncService, sessionLimitCache: sessionLimitCache, - rpmCache: rpmCache, tokenCacheInvalidator: tokenCacheInvalidator, } } @@ -137,6 +134,7 @@ type BulkUpdateAccountsRequest struct { RateMultiplier *float64 `json:"rate_multiplier"` Status string `json:"status" binding:"omitempty,oneof=active inactive error"` Schedulable *bool `json:"schedulable"` + AutoPauseOnExpired *bool `json:"auto_pause_on_expired"` GroupIDs *[]int64 `json:"group_ids"` Credentials map[string]any `json:"credentials"` Extra map[string]any `json:"extra"` @@ -157,7 +155,6 @@ type AccountWithConcurrency struct { // 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回 CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用 ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数 - CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数 } func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency { @@ -193,12 +190,6 @@ func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, ac } } } - - if h.rpmCache != nil && account.GetBaseRPM() > 0 { - if rpm, err := h.rpmCache.GetRPM(ctx, account.ID); err == nil { - item.CurrentRPM = &rpm - } - } } return item @@ -241,10 +232,9 @@ func (h *AccountHandler) List(c *gin.Context) { concurrencyCounts = make(map[int64]int) } - // 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能) + // 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能) windowCostAccountIDs := make([]int64, 0) sessionLimitAccountIDs := make([]int64, 0) - rpmAccountIDs := make([]int64, 0) sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置 for i := range accounts { acc := &accounts[i] @@ -256,24 +246,12 @@ func (h *AccountHandler) List(c *gin.Context) { sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID) sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute } - if acc.GetBaseRPM() > 0 { - rpmAccountIDs = append(rpmAccountIDs, acc.ID) - } } } - // 并行获取窗口费用、活跃会话数和 RPM 计数 + // 并行获取窗口费用和活跃会话数 var windowCosts map[int64]float64 var activeSessions map[int64]int - var rpmCounts map[int64]int - - // 获取 RPM 计数(批量查询) - if len(rpmAccountIDs) > 0 && h.rpmCache != nil { - rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs) - if rpmCounts == nil { - rpmCounts = make(map[int64]int) - } - } // 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置) if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil { @@ -334,13 +312,6 @@ func (h *AccountHandler) List(c *gin.Context) { } } - // 添加 RPM 计数(仅当启用时) - if rpmCounts != nil { - if rpm, ok := rpmCounts[acc.ID]; ok { - item.CurrentRPM = &rpm - } - } - result[i] = item } @@ -483,8 +454,6 @@ func (h *AccountHandler) Create(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } - // base_rpm 输入校验:负值归零,超过 10000 截断 - sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -554,8 +523,6 @@ func (h *AccountHandler) Update(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } - // base_rpm 输入校验:负值归零,超过 10000 截断 - sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -938,9 +905,6 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) { continue } - // base_rpm 输入校验:负值归零,超过 10000 截断 - sanitizeExtraBaseRPM(item.Extra) - skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{ @@ -1085,8 +1049,6 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { response.BadRequest(c, "rate_multiplier must be >= 0") return } - // base_rpm 输入校验:负值归零,超过 10000 截断 - sanitizeExtraBaseRPM(req.Extra) // 确定是否跳过混合渠道检查 skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk @@ -1098,6 +1060,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { req.RateMultiplier != nil || req.Status != "" || req.Schedulable != nil || + req.AutoPauseOnExpired != nil || req.GroupIDs != nil || len(req.Credentials) > 0 || len(req.Extra) > 0 @@ -1116,20 +1079,13 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) { RateMultiplier: req.RateMultiplier, Status: req.Status, Schedulable: req.Schedulable, + AutoPauseOnExpired: req.AutoPauseOnExpired, GroupIDs: req.GroupIDs, Credentials: req.Credentials, Extra: req.Extra, SkipMixedChannelCheck: skipCheck, }) if err != nil { - var mixedErr *service.MixedChannelError - if errors.As(err, &mixedErr) { - c.JSON(409, gin.H{ - "error": "mixed_channel_warning", - "message": mixedErr.Error(), - }) - return - } response.ErrorFrom(c, err) return } @@ -1753,22 +1709,3 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) { func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) { response.Success(c, domain.DefaultAntigravityModelMapping) } - -// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。 -// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。 -func sanitizeExtraBaseRPM(extra map[string]any) { - if extra == nil { - return - } - raw, ok := extra["base_rpm"] - if !ok { - return - } - v := service.ParseExtraInt(raw) - if v < 0 { - v = 0 - } else if v > 10000 { - v = 10000 - } - extra["base_rpm"] = v -} diff --git a/backend/internal/handler/admin/account_handler_bulk_update_test.go b/backend/internal/handler/admin/account_handler_bulk_update_test.go new file mode 100644 index 000000000..c2dfdf746 --- /dev/null +++ b/backend/internal/handler/admin/account_handler_bulk_update_test.go @@ -0,0 +1,62 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupAccountBulkUpdateRouter(adminSvc *stubAdminService) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate) + return router +} + +func TestAccountHandlerBulkUpdate_ForwardsAutoPauseOnExpired(t *testing.T) { + adminSvc := newStubAdminService() + router := setupAccountBulkUpdateRouter(adminSvc) + + body, err := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + "auto_pause_on_expired": true, + }) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, adminSvc.lastBulkUpdateInput) + require.NotNil(t, adminSvc.lastBulkUpdateInput.AutoPauseOnExpired) + require.True(t, *adminSvc.lastBulkUpdateInput.AutoPauseOnExpired) +} + +func TestAccountHandlerBulkUpdate_RejectsEmptyUpdates(t *testing.T) { + adminSvc := newStubAdminService() + router := setupAccountBulkUpdateRouter(adminSvc) + + body, err := json.Marshal(map[string]any{ + "account_ids": []int64{1}, + }) + require.NoError(t, err) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Contains(t, resp["message"], "No updates provided") +} diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go index 24ec5bcfe..ad004844d 100644 --- a/backend/internal/handler/admin/account_handler_mixed_channel_test.go +++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go @@ -15,11 +15,10 @@ import ( func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine { gin.SetMode(gin.TestMode) router := gin.New() - accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel) router.POST("/api/v1/admin/accounts", accountHandler.Create) router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update) - router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate) return router } @@ -146,53 +145,3 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T require.False(t, hasDetails) require.False(t, hasRequireConfirmation) } - -func TestAccountHandlerBulkUpdateMixedChannelConflict(t *testing.T) { - adminSvc := newStubAdminService() - adminSvc.bulkUpdateAccountErr = &service.MixedChannelError{ - GroupID: 27, - GroupName: "claude-max", - CurrentPlatform: "Antigravity", - OtherPlatform: "Anthropic", - } - router := setupAccountMixedChannelRouter(adminSvc) - - body, _ := json.Marshal(map[string]any{ - "account_ids": []int64{1, 2, 3}, - "group_ids": []int64{27}, - }) - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - router.ServeHTTP(rec, req) - - require.Equal(t, http.StatusConflict, rec.Code) - var resp map[string]any - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) - require.Equal(t, "mixed_channel_warning", resp["error"]) - require.Contains(t, resp["message"], "claude-max") -} - -func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) { - adminSvc := newStubAdminService() - router := setupAccountMixedChannelRouter(adminSvc) - - body, _ := json.Marshal(map[string]any{ - "account_ids": []int64{1, 2}, - "group_ids": []int64{27}, - "confirm_mixed_channel_risk": true, - }) - rec := httptest.NewRecorder() - req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body)) - req.Header.Set("Content-Type", "application/json") - router.ServeHTTP(rec, req) - - require.Equal(t, http.StatusOK, rec.Code) - var resp map[string]any - require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) - require.Equal(t, float64(0), resp["code"]) - data, ok := resp["data"].(map[string]any) - require.True(t, ok) - require.Equal(t, float64(2), data["success"]) - require.Equal(t, float64(0), data["failed"]) -} diff --git a/backend/internal/handler/admin/account_handler_passthrough_test.go b/backend/internal/handler/admin/account_handler_passthrough_test.go index d86501c04..d09cccd6d 100644 --- a/backend/internal/handler/admin/account_handler_passthrough_test.go +++ b/backend/internal/handler/admin/account_handler_passthrough_test.go @@ -28,7 +28,6 @@ func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testi nil, nil, nil, - nil, ) router := gin.New() diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index f3b99ddbe..a84988617 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -31,7 +31,8 @@ type stubAdminService struct { platform string groupIDs []int64 } - mu sync.Mutex + lastBulkUpdateInput *service.BulkUpdateAccountsInput + mu sync.Mutex } func newStubAdminService() *stubAdminService { @@ -239,6 +240,9 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic if s.bulkUpdateAccountErr != nil { return nil, s.bulkUpdateAccountErr } + s.mu.Lock() + s.lastBulkUpdateInput = input + s.mu.Unlock() return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil } diff --git a/backend/internal/handler/admin/apikey_handler.go b/backend/internal/handler/admin/apikey_handler.go index 8dd245a43..6b7b25158 100644 --- a/backend/internal/handler/admin/apikey_handler.go +++ b/backend/internal/handler/admin/apikey_handler.go @@ -35,6 +35,10 @@ func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) { response.BadRequest(c, "Invalid API key ID") return } + if keyID <= 0 { + response.BadRequest(c, "Invalid API key ID") + return + } var req AdminUpdateAPIKeyGroupRequest if err := c.ShouldBindJSON(&req); err != nil { diff --git a/backend/internal/handler/admin/apikey_handler_test.go b/backend/internal/handler/admin/apikey_handler_test.go index bf128b18a..be8ba03b5 100644 --- a/backend/internal/handler/admin/apikey_handler_test.go +++ b/backend/internal/handler/admin/apikey_handler_test.go @@ -36,6 +36,19 @@ func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) { require.Contains(t, rec.Body.String(), "Invalid API key ID") } +func TestAdminAPIKeyHandler_UpdateGroup_InvalidNegativeID(t *testing.T) { + router := setupAPIKeyHandler(newStubAdminService()) + body := `{"group_id": 2}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/-1", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "Invalid API key ID") +} + func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) { router := setupAPIKeyHandler(newStubAdminService()) diff --git a/backend/internal/handler/admin/batch_update_credentials_test.go b/backend/internal/handler/admin/batch_update_credentials_test.go index 0b1b66917..c8185735c 100644 --- a/backend/internal/handler/admin/batch_update_credentials_test.go +++ b/backend/internal/handler/admin/batch_update_credentials_test.go @@ -36,7 +36,7 @@ func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) { gin.SetMode(gin.TestMode) router := gin.New() - handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials) return router, handler } diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 1d48c653c..ea2572fa4 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -333,71 +333,15 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { }) } -// GetGroupStats handles getting group usage statistics +// GetGroupStats handles getting group usage statistics. // GET /api/v1/admin/dashboard/groups -// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type +// +// NOTE: Group-level aggregation pipeline is not available in this branch baseline. +// Keep this endpoint for API compatibility and return an empty dataset. func (h *DashboardHandler) GetGroupStats(c *gin.Context) { startTime, endTime := parseTimeRange(c) - - var userID, apiKeyID, accountID, groupID int64 - var requestType *int16 - var stream *bool - var billingType *int8 - - if userIDStr := c.Query("user_id"); userIDStr != "" { - if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil { - userID = id - } - } - if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" { - if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil { - apiKeyID = id - } - } - if accountIDStr := c.Query("account_id"); accountIDStr != "" { - if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil { - accountID = id - } - } - if groupIDStr := c.Query("group_id"); groupIDStr != "" { - if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil { - groupID = id - } - } - if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { - parsed, err := service.ParseUsageRequestType(requestTypeStr) - if err != nil { - response.BadRequest(c, err.Error()) - return - } - value := int16(parsed) - requestType = &value - } else if streamStr := c.Query("stream"); streamStr != "" { - if streamVal, err := strconv.ParseBool(streamStr); err == nil { - stream = &streamVal - } else { - response.BadRequest(c, "Invalid stream value, use true or false") - return - } - } - if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { - if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil { - bt := int8(v) - billingType = &bt - } else { - response.BadRequest(c, "Invalid billing_type") - return - } - } - - stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) - if err != nil { - response.Error(c, 500, "Failed to get group statistics") - return - } - response.Success(c, gin.H{ - "groups": stats, + "groups": []any{}, "start_date": startTime.Format("2006-01-02"), "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"), }) diff --git a/backend/internal/handler/admin/data_management_handler.go b/backend/internal/handler/admin/data_management_handler.go index 02fc766f9..2b3ceae72 100644 --- a/backend/internal/handler/admin/data_management_handler.go +++ b/backend/internal/handler/admin/data_management_handler.go @@ -1,7 +1,6 @@ package admin import ( - "context" "strconv" "strings" @@ -14,34 +13,13 @@ import ( ) type DataManagementHandler struct { - dataManagementService dataManagementService + dataManagementService *service.DataManagementService } func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler { return &DataManagementHandler{dataManagementService: dataManagementService} } -type dataManagementService interface { - GetConfig(ctx context.Context) (service.DataManagementConfig, error) - UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error) - ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error) - CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error) - ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error) - CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error) - UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error) - DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error - SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error) - ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error) - CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error) - UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error) - DeleteS3Profile(ctx context.Context, profileID string) error - SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error) - ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error) - GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error) - EnsureAgentEnabled(ctx context.Context) error - GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth -} - type TestS3ConnectionRequest struct { Endpoint string `json:"endpoint"` Region string `json:"region" binding:"required"` @@ -123,12 +101,8 @@ func (h *DataManagementHandler) GetConfig(c *gin.Context) { if !h.requireAgentEnabled(c) { return } - cfg, err := h.dataManagementService.GetConfig(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, cfg) + _, err := h.dataManagementService.GetConfig(c.Request.Context()) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) UpdateConfig(c *gin.Context) { @@ -141,12 +115,8 @@ func (h *DataManagementHandler) UpdateConfig(c *gin.Context) { if !h.requireAgentEnabled(c) { return } - cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, cfg) + _, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) TestS3(c *gin.Context) { @@ -159,7 +129,7 @@ func (h *DataManagementHandler) TestS3(c *gin.Context) { if !h.requireAgentEnabled(c) { return } - result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{ + _, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{ Enabled: true, Endpoint: req.Endpoint, Region: req.Region, @@ -170,11 +140,7 @@ func (h *DataManagementHandler) TestS3(c *gin.Context) { ForcePathStyle: req.ForcePathStyle, UseSSL: req.UseSSL, }) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, gin.H{"ok": result.OK, "message": result.Message}) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) { @@ -193,7 +159,7 @@ func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) { if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10) } - job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{ + _, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{ BackupType: req.BackupType, UploadToS3: req.UploadToS3, S3ProfileID: req.S3ProfileID, @@ -202,11 +168,7 @@ func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) { TriggeredBy: triggeredBy, IdempotencyKey: req.IdempotencyKey, }) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status}) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) { @@ -223,12 +185,8 @@ func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) { if !h.requireAgentEnabled(c) { return } - items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, gin.H{"items": items}) + _, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) { @@ -247,18 +205,14 @@ func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) { if !h.requireAgentEnabled(c) { return } - profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{ + _, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{ SourceType: sourceType, ProfileID: req.ProfileID, Name: req.Name, Config: req.Config, SetActive: req.SetActive, }) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, profile) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) { @@ -282,17 +236,13 @@ func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) { if !h.requireAgentEnabled(c) { return } - profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{ + _, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{ SourceType: sourceType, ProfileID: profileID, Name: req.Name, Config: req.Config, }) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, profile) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) { @@ -310,11 +260,8 @@ func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) { if !h.requireAgentEnabled(c) { return } - if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, gin.H{"deleted": true}) + err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) { @@ -332,12 +279,8 @@ func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) { if !h.requireAgentEnabled(c) { return } - profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, profile) + _, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) { @@ -345,12 +288,8 @@ func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) { return } - items, err := h.dataManagementService.ListS3Profiles(c.Request.Context()) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, gin.H{"items": items}) + _, err := h.dataManagementService.ListS3Profiles(c.Request.Context()) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) { @@ -364,7 +303,7 @@ func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) { return } - profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{ + _, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{ ProfileID: req.ProfileID, Name: req.Name, SetActive: req.SetActive, @@ -380,11 +319,7 @@ func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) { UseSSL: req.UseSSL, }, }) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, profile) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) { @@ -404,7 +339,7 @@ func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) { return } - profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{ + _, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{ ProfileID: profileID, Name: req.Name, S3: service.DataManagementS3Config{ @@ -419,11 +354,7 @@ func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) { UseSSL: req.UseSSL, }, }) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, profile) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) { @@ -436,11 +367,8 @@ func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) { if !h.requireAgentEnabled(c) { return } - if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, gin.H{"deleted": true}) + err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) { @@ -453,12 +381,8 @@ func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) { if !h.requireAgentEnabled(c) { return } - profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, profile) + _, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) { @@ -476,17 +400,13 @@ func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) { pageSize = int32(v) } - result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{ + _, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{ PageSize: pageSize, PageToken: c.Query("page_token"), Status: c.Query("status"), BackupType: c.Query("backup_type"), }) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, result) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) GetBackupJob(c *gin.Context) { @@ -499,12 +419,8 @@ func (h *DataManagementHandler) GetBackupJob(c *gin.Context) { if !h.requireAgentEnabled(c) { return } - job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID) - if err != nil { - response.ErrorFrom(c, err) - return - } - response.Success(c, job) + _, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID) + response.ErrorFrom(c, err) } func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool { @@ -517,12 +433,9 @@ func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool { return false } - if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil { - response.ErrorFrom(c, err) - return false - } - - return true + err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()) + response.ErrorFrom(c, err) + return false } func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth { diff --git a/backend/internal/handler/admin/ops_openai_ws_v2_handler.go b/backend/internal/handler/admin/ops_openai_ws_v2_handler.go new file mode 100644 index 000000000..a77df4b0a --- /dev/null +++ b/backend/internal/handler/admin/ops_openai_ws_v2_handler.go @@ -0,0 +1,28 @@ +package admin + +import ( + "net/http" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" + "github.com/gin-gonic/gin" +) + +// GetOpenAIWSV2PassthroughMetrics returns OpenAI WS v2 passthrough runtime metrics. +// GET /api/v1/admin/ops/openai-ws-v2/passthrough-metrics +func (h *OpsHandler) GetOpenAIWSV2PassthroughMetrics(c *gin.Context) { + if h.opsService == nil { + response.Error(c, http.StatusServiceUnavailable, "Ops service not available") + return + } + if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "passthrough": openaiwsv2.SnapshotMetrics(), + "timestamp": time.Now().UTC(), + }) +} diff --git a/backend/internal/handler/admin/ops_openai_ws_v2_handler_test.go b/backend/internal/handler/admin/ops_openai_ws_v2_handler_test.go new file mode 100644 index 000000000..89b235466 --- /dev/null +++ b/backend/internal/handler/admin/ops_openai_ws_v2_handler_test.go @@ -0,0 +1,64 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func newOpsOpenAIWSV2TestRouter(handler *OpsHandler) *gin.Engine { + gin.SetMode(gin.TestMode) + r := gin.New() + r.GET("/metrics", handler.GetOpenAIWSV2PassthroughMetrics) + return r +} + +func TestOpsOpenAIWSV2Handler_GetPassthroughMetrics_ServiceUnavailable(t *testing.T) { + r := newOpsOpenAIWSV2TestRouter(NewOpsHandler(nil)) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusServiceUnavailable { + t.Fatalf("status=%d, want %d", w.Code, http.StatusServiceUnavailable) + } +} + +func TestOpsOpenAIWSV2Handler_GetPassthroughMetrics_Success(t *testing.T) { + r := newOpsOpenAIWSV2TestRouter(NewOpsHandler(newRuntimeOpsService(t))) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) + r.ServeHTTP(w, req) + if w.Code != http.StatusOK { + t.Fatalf("status=%d, want %d, body=%s", w.Code, http.StatusOK, w.Body.String()) + } + + var payload map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &payload); err != nil { + t.Fatalf("unmarshal body: %v", err) + } + if code, _ := payload["code"].(float64); int(code) != 0 { + t.Fatalf("code=%v, want 0", payload["code"]) + } + data, ok := payload["data"].(map[string]any) + if !ok { + t.Fatalf("missing data field: %v", payload) + } + passthrough, ok := data["passthrough"].(map[string]any) + if !ok { + t.Fatalf("missing passthrough field: %v", data) + } + if _, ok := passthrough["semantic_mutation_total"].(float64); !ok { + t.Fatalf("missing semantic_mutation_total: %v", passthrough) + } + if _, ok := passthrough["usage_parse_failure_total"].(float64); !ok { + t.Fatalf("missing usage_parse_failure_total: %v", passthrough) + } + if _, ok := data["timestamp"].(string); !ok { + t.Fatalf("missing timestamp: %v", data) + } +} diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go index e8ae0ce2d..9fd187fc5 100644 --- a/backend/internal/handler/admin/proxy_handler.go +++ b/backend/internal/handler/admin/proxy_handler.go @@ -64,9 +64,9 @@ func (h *ProxyHandler) List(c *gin.Context) { return } - out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies)) + out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) for i := range proxies { - out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i])) + out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) } response.Paginated(c, out, total, page, pageSize) } @@ -83,9 +83,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) { response.ErrorFrom(c, err) return } - out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies)) + out := make([]dto.ProxyWithAccountCount, 0, len(proxies)) for i := range proxies { - out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i])) + out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i])) } response.Success(c, out) return @@ -97,9 +97,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) { return } - out := make([]dto.AdminProxy, 0, len(proxies)) + out := make([]dto.Proxy, 0, len(proxies)) for i := range proxies { - out = append(out, *dto.ProxyFromServiceAdmin(&proxies[i])) + out = append(out, *dto.ProxyFromService(&proxies[i])) } response.Success(c, out) } @@ -119,7 +119,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) { return } - response.Success(c, dto.ProxyFromServiceAdmin(proxy)) + response.Success(c, dto.ProxyFromService(proxy)) } // Create handles creating a new proxy @@ -143,7 +143,7 @@ func (h *ProxyHandler) Create(c *gin.Context) { if err != nil { return nil, err } - return dto.ProxyFromServiceAdmin(proxy), nil + return dto.ProxyFromService(proxy), nil }) } @@ -176,7 +176,7 @@ func (h *ProxyHandler) Update(c *gin.Context) { return } - response.Success(c, dto.ProxyFromServiceAdmin(proxy)) + response.Success(c, dto.ProxyFromService(proxy)) } // Delete handles deleting a proxy diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 0a932ee98..12706ee62 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -43,7 +43,7 @@ type GenerateRedeemCodesRequest struct { // CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user. type CreateAndRedeemCodeRequest struct { Code string `json:"code" binding:"required,min=3,max=128"` - Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"` + Type string `json:"type" binding:"required,oneof=balance concurrency"` Value float64 `json:"value" binding:"required,gt=0"` UserID int64 `json:"user_id" binding:"required,gt=0"` Notes string `json:"notes"` @@ -136,6 +136,10 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) { return } req.Code = strings.TrimSpace(req.Code) + if len(req.Code) < 3 || len(req.Code) > 128 { + response.BadRequest(c, "Invalid request: code length must be between 3 and 128") + return + } executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { existing, err := h.redeemService.GetByCode(ctx, req.Code) diff --git a/backend/internal/handler/admin/redeem_handler_create_and_redeem_test.go b/backend/internal/handler/admin/redeem_handler_create_and_redeem_test.go new file mode 100644 index 000000000..1c06a366d --- /dev/null +++ b/backend/internal/handler/admin/redeem_handler_create_and_redeem_test.go @@ -0,0 +1,46 @@ +package admin + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupCreateAndRedeemRouter() *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + h := NewRedeemHandler(newStubAdminService(), &service.RedeemService{}) + router.POST("/api/v1/admin/redeem-codes/create-and-redeem", h.CreateAndRedeem) + return router +} + +func TestCreateAndRedeem_RejectsUnsupportedType(t *testing.T) { + router := setupCreateAndRedeemRouter() + body := `{"code":"ORDER-123","type":"subscription","value":100,"user_id":1}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "Invalid request") +} + +func TestCreateAndRedeem_RejectsTrimmedEmptyCode(t *testing.T) { + router := setupCreateAndRedeemRouter() + body := `{"code":" ","type":"balance","value":100,"user_id":1}` + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewBufferString(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "code length must be between 3 and 128") +} diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index e32c142f4..c7b93497b 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1,13 +1,8 @@ package admin import ( - "crypto/rand" - "encoding/hex" - "encoding/json" "fmt" "log" - "net/http" - "regexp" "strings" "time" @@ -20,21 +15,6 @@ import ( "github.com/gin-gonic/gin" ) -// semverPattern 预编译 semver 格式校验正则 -var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`) - -// menuItemIDPattern validates custom menu item IDs: alphanumeric, hyphens, underscores only. -var menuItemIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`) - -// generateMenuItemID generates a short random hex ID for a custom menu item. -func generateMenuItemID() (string, error) { - b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - return "", fmt.Errorf("generate menu item ID: %w", err) - } - return hex.EncodeToString(b), nil -} - // SettingHandler 系统设置处理器 type SettingHandler struct { settingService *service.SettingService @@ -66,13 +46,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { // Check if ops monitoring is enabled (respects config.ops.enabled) opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context()) - defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions)) - for _, sub := range settings.DefaultSubscriptions { - defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{ - GroupID: sub.GroupID, - ValidityDays: sub.ValidityDays, - }) - } response.Success(c, dto.SystemSettings{ RegistrationEnabled: settings.RegistrationEnabled, @@ -107,10 +80,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, - CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, - DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: settings.EnableModelFallback, FallbackModelAnthropic: settings.FallbackModelAnthropic, FallbackModelOpenAI: settings.FallbackModelOpenAI, @@ -122,7 +93,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled, OpsQueryModeDefault: settings.OpsQueryModeDefault, OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds, - MinClaudeCodeVersion: settings.MinClaudeCodeVersion, }) } @@ -157,23 +127,21 @@ type UpdateSettingsRequest struct { LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` // OEM设置 - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` - SoraClientEnabled bool `json:"sora_client_enabled"` - CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` // 默认配置 - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` - DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -191,8 +159,6 @@ type UpdateSettingsRequest struct { OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"` OpsQueryModeDefault *string `json:"ops_query_mode_default"` OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"` - - MinClaudeCodeVersion string `json:"min_claude_code_version"` } // UpdateSettings 更新系统设置 @@ -220,7 +186,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { if req.SMTPPort <= 0 { req.SMTPPort = 587 } - req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions) // Turnstile 参数验证 if req.TurnstileEnabled { @@ -316,84 +281,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } } - // 自定义菜单项验证 - const ( - maxCustomMenuItems = 20 - maxMenuItemLabelLen = 50 - maxMenuItemURLLen = 2048 - maxMenuItemIconSVGLen = 10 * 1024 // 10KB - maxMenuItemIDLen = 32 - ) - - customMenuJSON := previousSettings.CustomMenuItems - if req.CustomMenuItems != nil { - items := *req.CustomMenuItems - if len(items) > maxCustomMenuItems { - response.BadRequest(c, "Too many custom menu items (max 20)") - return - } - for i, item := range items { - if strings.TrimSpace(item.Label) == "" { - response.BadRequest(c, "Custom menu item label is required") - return - } - if len(item.Label) > maxMenuItemLabelLen { - response.BadRequest(c, "Custom menu item label is too long (max 50 characters)") - return - } - if strings.TrimSpace(item.URL) == "" { - response.BadRequest(c, "Custom menu item URL is required") - return - } - if len(item.URL) > maxMenuItemURLLen { - response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)") - return - } - if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil { - response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL") - return - } - if item.Visibility != "user" && item.Visibility != "admin" { - response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'") - return - } - if len(item.IconSVG) > maxMenuItemIconSVGLen { - response.BadRequest(c, "Custom menu item icon SVG is too large (max 10KB)") - return - } - // Auto-generate ID if missing - if strings.TrimSpace(item.ID) == "" { - id, err := generateMenuItemID() - if err != nil { - response.Error(c, http.StatusInternalServerError, "Failed to generate menu item ID") - return - } - items[i].ID = id - } else if len(item.ID) > maxMenuItemIDLen { - response.BadRequest(c, "Custom menu item ID is too long (max 32 characters)") - return - } else if !menuItemIDPattern.MatchString(item.ID) { - response.BadRequest(c, "Custom menu item ID contains invalid characters (only a-z, A-Z, 0-9, - and _ are allowed)") - return - } - } - // ID uniqueness check - seen := make(map[string]struct{}, len(items)) - for _, item := range items { - if _, exists := seen[item.ID]; exists { - response.BadRequest(c, "Duplicate custom menu item ID: "+item.ID) - return - } - seen[item.ID] = struct{}{} - } - menuBytes, err := json.Marshal(items) - if err != nil { - response.BadRequest(c, "Failed to serialize custom menu items") - return - } - customMenuJSON = string(menuBytes) - } - // Ops metrics collector interval validation (seconds). if req.OpsMetricsIntervalSeconds != nil { v := *req.OpsMetricsIntervalSeconds @@ -405,21 +292,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { } req.OpsMetricsIntervalSeconds = &v } - defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions)) - for _, sub := range req.DefaultSubscriptions { - defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{ - GroupID: sub.GroupID, - ValidityDays: sub.ValidityDays, - }) - } - - // 验证最低版本号格式(空字符串=禁用,或合法 semver) - if req.MinClaudeCodeVersion != "" { - if !semverPattern.MatchString(req.MinClaudeCodeVersion) { - response.Error(c, http.StatusBadRequest, "min_claude_code_version must be empty or a valid semver (e.g. 2.1.63)") - return - } - } settings := &service.SystemSettings{ RegistrationEnabled: req.RegistrationEnabled, @@ -453,10 +325,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionEnabled: purchaseEnabled, PurchaseSubscriptionURL: purchaseURL, SoraClientEnabled: req.SoraClientEnabled, - CustomMenuItems: customMenuJSON, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, - DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: req.EnableModelFallback, FallbackModelAnthropic: req.FallbackModelAnthropic, FallbackModelOpenAI: req.FallbackModelOpenAI, @@ -464,7 +334,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { FallbackModelAntigravity: req.FallbackModelAntigravity, EnableIdentityPatch: req.EnableIdentityPatch, IdentityPatchPrompt: req.IdentityPatchPrompt, - MinClaudeCodeVersion: req.MinClaudeCodeVersion, OpsMonitoringEnabled: func() bool { if req.OpsMonitoringEnabled != nil { return *req.OpsMonitoringEnabled @@ -504,13 +373,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { response.ErrorFrom(c, err) return } - updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions)) - for _, sub := range updatedSettings.DefaultSubscriptions { - updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{ - GroupID: sub.GroupID, - ValidityDays: sub.ValidityDays, - }) - } response.Success(c, dto.SystemSettings{ RegistrationEnabled: updatedSettings.RegistrationEnabled, @@ -545,10 +407,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, SoraClientEnabled: updatedSettings.SoraClientEnabled, - CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, - DefaultSubscriptions: updatedDefaultSubscriptions, EnableModelFallback: updatedSettings.EnableModelFallback, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI, @@ -560,7 +420,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled, OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault, OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds, - MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion, }) } @@ -670,9 +529,6 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.DefaultBalance != after.DefaultBalance { changed = append(changed, "default_balance") } - if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) { - changed = append(changed, "default_subscriptions") - } if before.EnableModelFallback != after.EnableModelFallback { changed = append(changed, "enable_model_fallback") } @@ -706,50 +562,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds { changed = append(changed, "ops_metrics_interval_seconds") } - if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion { - changed = append(changed, "min_claude_code_version") - } - if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled { - changed = append(changed, "purchase_subscription_enabled") - } - if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL { - changed = append(changed, "purchase_subscription_url") - } - if before.CustomMenuItems != after.CustomMenuItems { - changed = append(changed, "custom_menu_items") - } return changed } -func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting { - if len(input) == 0 { - return nil - } - normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input)) - for _, item := range input { - if item.GroupID <= 0 || item.ValidityDays <= 0 { - continue - } - if item.ValidityDays > service.MaxValidityDays { - item.ValidityDays = service.MaxValidityDays - } - normalized = append(normalized, item) - } - return normalized -} - -func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays { - return false - } - } - return true -} - // TestSMTPRequest 测试SMTP连接请求 type TestSMTPRequest struct { SMTPHost string `json:"smtp_host" binding:"required"` diff --git a/backend/internal/handler/admin/setting_handler_bulk_edit_template.go b/backend/internal/handler/admin/setting_handler_bulk_edit_template.go new file mode 100644 index 000000000..c63ed9af2 --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_bulk_edit_template.go @@ -0,0 +1,228 @@ +package admin + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +type UpsertBulkEditTemplateRequest struct { + ID string `json:"id"` + Name string `json:"name"` + ScopePlatform string `json:"scope_platform"` + ScopeType string `json:"scope_type"` + ShareScope string `json:"share_scope"` + GroupIDs []int64 `json:"group_ids"` + State map[string]any `json:"state"` +} + +type RollbackBulkEditTemplateRequest struct { + VersionID string `json:"version_id"` +} + +// ListBulkEditTemplates 获取批量编辑模板列表 +// GET /api/v1/admin/settings/bulk-edit-templates +func (h *SettingHandler) ListBulkEditTemplates(c *gin.Context) { + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids")) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + items, listErr := h.settingService.ListBulkEditTemplates(c.Request.Context(), service.BulkEditTemplateQuery{ + ScopePlatform: c.Query("scope_platform"), + ScopeType: c.Query("scope_type"), + ScopeGroupIDs: scopeGroupIDs, + RequesterUserID: subject.UserID, + }) + if listErr != nil { + response.ErrorFrom(c, listErr) + return + } + + response.Success(c, gin.H{"items": items}) +} + +// UpsertBulkEditTemplate 创建/更新批量编辑模板 +// POST /api/v1/admin/settings/bulk-edit-templates +func (h *SettingHandler) UpsertBulkEditTemplate(c *gin.Context) { + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + var req UpsertBulkEditTemplateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + item, upsertErr := h.settingService.UpsertBulkEditTemplate( + c.Request.Context(), + service.BulkEditTemplateUpsertInput{ + ID: req.ID, + Name: req.Name, + ScopePlatform: req.ScopePlatform, + ScopeType: req.ScopeType, + ShareScope: req.ShareScope, + GroupIDs: req.GroupIDs, + State: req.State, + RequesterUserID: subject.UserID, + }, + ) + if upsertErr != nil { + response.ErrorFrom(c, upsertErr) + return + } + + response.Success(c, item) +} + +// DeleteBulkEditTemplate 删除批量编辑模板 +// DELETE /api/v1/admin/settings/bulk-edit-templates/:template_id +func (h *SettingHandler) DeleteBulkEditTemplate(c *gin.Context) { + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + templateID := strings.TrimSpace(c.Param("template_id")) + if templateID == "" { + response.BadRequest(c, "template_id is required") + return + } + + if err := h.settingService.DeleteBulkEditTemplate(c.Request.Context(), templateID, subject.UserID); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"deleted": true}) +} + +// ListBulkEditTemplateVersions 获取模板版本历史 +// GET /api/v1/admin/settings/bulk-edit-templates/:template_id/versions +func (h *SettingHandler) ListBulkEditTemplateVersions(c *gin.Context) { + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + templateID := strings.TrimSpace(c.Param("template_id")) + if templateID == "" { + response.BadRequest(c, "template_id is required") + return + } + + scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids")) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + items, listErr := h.settingService.ListBulkEditTemplateVersions( + c.Request.Context(), + service.BulkEditTemplateVersionQuery{ + TemplateID: templateID, + ScopeGroupIDs: scopeGroupIDs, + RequesterUserID: subject.UserID, + }, + ) + if listErr != nil { + response.ErrorFrom(c, listErr) + return + } + + response.Success(c, gin.H{"items": items}) +} + +// RollbackBulkEditTemplate 回滚模板到指定版本 +// POST /api/v1/admin/settings/bulk-edit-templates/:template_id/rollback +func (h *SettingHandler) RollbackBulkEditTemplate(c *gin.Context) { + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + templateID := strings.TrimSpace(c.Param("template_id")) + if templateID == "" { + response.BadRequest(c, "template_id is required") + return + } + + scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids")) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + + var req RollbackBulkEditTemplateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + item, rollbackErr := h.settingService.RollbackBulkEditTemplate( + c.Request.Context(), + service.BulkEditTemplateRollbackInput{ + TemplateID: templateID, + VersionID: req.VersionID, + ScopeGroupIDs: scopeGroupIDs, + RequesterUserID: subject.UserID, + }, + ) + if rollbackErr != nil { + response.ErrorFrom(c, rollbackErr) + return + } + + response.Success(c, item) +} + +func parseScopeGroupIDs(raw string) ([]int64, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return nil, nil + } + + parts := strings.Split(trimmed, ",") + if len(parts) == 0 { + return nil, nil + } + + seen := make(map[int64]struct{}, len(parts)) + groupIDs := make([]int64, 0, len(parts)) + for _, part := range parts { + candidate := strings.TrimSpace(part) + if candidate == "" { + continue + } + + groupID, err := strconv.ParseInt(candidate, 10, 64) + if err != nil || groupID <= 0 { + return nil, fmt.Errorf("scope_group_ids must be comma-separated positive integers") + } + if _, exists := seen[groupID]; exists { + continue + } + seen[groupID] = struct{}{} + groupIDs = append(groupIDs, groupID) + } + + return groupIDs, nil +} diff --git a/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go b/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go new file mode 100644 index 000000000..c72d412bf --- /dev/null +++ b/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go @@ -0,0 +1,603 @@ +package admin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strconv" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type settingHandlerTemplateRepoStub struct { + values map[string]string +} + +func newSettingHandlerTemplateRepoStub() *settingHandlerTemplateRepoStub { + return &settingHandlerTemplateRepoStub{values: map[string]string{}} +} + +func (s *settingHandlerTemplateRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &service.Setting{Key: key, Value: value}, nil +} + +func (s *settingHandlerTemplateRepoStub) GetValue(ctx context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", service.ErrSettingNotFound + } + return value, nil +} + +func (s *settingHandlerTemplateRepoStub) Set(ctx context.Context, key, value string) error { + s.values[key] = value + return nil +} + +func (s *settingHandlerTemplateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *settingHandlerTemplateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + for key, value := range settings { + s.values[key] = value + } + return nil +} + +func (s *settingHandlerTemplateRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *settingHandlerTemplateRepoStub) Delete(ctx context.Context, key string) error { + delete(s.values, key) + return nil +} + +type failingSettingRepoStub struct{} + +func (s *failingSettingRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) { + return nil, errors.New("boom") +} +func (s *failingSettingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + return "", errors.New("boom") +} +func (s *failingSettingRepoStub) Set(ctx context.Context, key, value string) error { + return errors.New("boom") +} +func (s *failingSettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + return nil, errors.New("boom") +} +func (s *failingSettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + return errors.New("boom") +} +func (s *failingSettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + return nil, errors.New("boom") +} +func (s *failingSettingRepoStub) Delete(ctx context.Context, key string) error { + return errors.New("boom") +} + +func setupBulkEditTemplateRouter() *gin.Engine { + gin.SetMode(gin.TestMode) + repo := newSettingHandlerTemplateRepoStub() + settingService := service.NewSettingService(repo, nil) + handler := NewSettingHandler(settingService, nil, nil, nil, nil) + + router := gin.New() + router.Use(func(c *gin.Context) { + uid := int64(1) + if header := c.GetHeader("X-User-ID"); header != "" { + if parsed, err := strconv.ParseInt(header, 10, 64); err == nil && parsed > 0 { + uid = parsed + } + } + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: uid}) + c.Next() + }) + + router.GET("/api/v1/admin/settings/bulk-edit-templates", handler.ListBulkEditTemplates) + router.POST("/api/v1/admin/settings/bulk-edit-templates", handler.UpsertBulkEditTemplate) + router.DELETE("/api/v1/admin/settings/bulk-edit-templates/:template_id", handler.DeleteBulkEditTemplate) + router.GET("/api/v1/admin/settings/bulk-edit-templates/:template_id/versions", handler.ListBulkEditTemplateVersions) + router.POST("/api/v1/admin/settings/bulk-edit-templates/:template_id/rollback", handler.RollbackBulkEditTemplate) + + return router +} + +func decodeResponseDataMap(t *testing.T, body []byte) map[string]any { + t.Helper() + var payload response.Response + require.NoError(t, json.Unmarshal(body, &payload)) + if payload.Data == nil { + return map[string]any{} + } + asMap, ok := payload.Data.(map[string]any) + require.True(t, ok) + return asMap +} + +func TestSettingHandlerBulkEditTemplate_CRUDFlow(t *testing.T) { + router := setupBulkEditTemplateRouter() + + createBody := map[string]any{ + "name": "OpenAI OAuth Baseline", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "team", + "state": map[string]any{ + "enableOpenAIPassthrough": true, + }, + } + raw, err := json.Marshal(createBody) + require.NoError(t, err) + + createRec := httptest.NewRecorder() + createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw)) + createReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(createRec, createReq) + require.Equal(t, http.StatusOK, createRec.Code) + + createData := decodeResponseDataMap(t, createRec.Body.Bytes()) + templateID, ok := createData["id"].(string) + require.True(t, ok) + require.NotEmpty(t, templateID) + + listRec := httptest.NewRecorder() + listReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth", + nil, + ) + router.ServeHTTP(listRec, listReq) + require.Equal(t, http.StatusOK, listRec.Code) + + listData := decodeResponseDataMap(t, listRec.Body.Bytes()) + items, ok := listData["items"].([]any) + require.True(t, ok) + require.Len(t, items, 1) + + deleteRec := httptest.NewRecorder() + deleteReq := httptest.NewRequest( + http.MethodDelete, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID, + nil, + ) + router.ServeHTTP(deleteRec, deleteReq) + require.Equal(t, http.StatusOK, deleteRec.Code) + + listAfterDeleteRec := httptest.NewRecorder() + listAfterDeleteReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth", + nil, + ) + router.ServeHTTP(listAfterDeleteRec, listAfterDeleteReq) + require.Equal(t, http.StatusOK, listAfterDeleteRec.Code) + + listAfterDeleteData := decodeResponseDataMap(t, listAfterDeleteRec.Body.Bytes()) + itemsAfterDelete, ok := listAfterDeleteData["items"].([]any) + require.True(t, ok) + require.Len(t, itemsAfterDelete, 0) +} + +func TestSettingHandlerBulkEditTemplate_VersionsAndRollback(t *testing.T) { + router := setupBulkEditTemplateRouter() + + createBody := map[string]any{ + "name": "Rollback Target", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "groups", + "group_ids": []int64{2}, + "state": map[string]any{ + "enableOpenAIWSMode": true, + }, + } + createRaw, err := json.Marshal(createBody) + require.NoError(t, err) + + createRec := httptest.NewRecorder() + createReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates", + bytes.NewReader(createRaw), + ) + createReq.Header.Set("Content-Type", "application/json") + createReq.Header.Set("X-User-ID", "9") + router.ServeHTTP(createRec, createReq) + require.Equal(t, http.StatusOK, createRec.Code) + createData := decodeResponseDataMap(t, createRec.Body.Bytes()) + templateID, ok := createData["id"].(string) + require.True(t, ok) + + updateBody := map[string]any{ + "id": templateID, + "name": "Rollback Target", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "team", + "group_ids": []int64{}, + "state": map[string]any{ + "enableOpenAIWSMode": false, + }, + } + updateRaw, err := json.Marshal(updateBody) + require.NoError(t, err) + + updateRec := httptest.NewRecorder() + updateReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates", + bytes.NewReader(updateRaw), + ) + updateReq.Header.Set("Content-Type", "application/json") + updateReq.Header.Set("X-User-ID", "9") + router.ServeHTTP(updateRec, updateReq) + require.Equal(t, http.StatusOK, updateRec.Code) + + versionsRec := httptest.NewRecorder() + versionsReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/versions?scope_group_ids=2", + nil, + ) + versionsReq.Header.Set("X-User-ID", "9") + router.ServeHTTP(versionsRec, versionsReq) + require.Equal(t, http.StatusOK, versionsRec.Code) + versionsData := decodeResponseDataMap(t, versionsRec.Body.Bytes()) + versions, ok := versionsData["items"].([]any) + require.True(t, ok) + require.Len(t, versions, 1) + versionData, ok := versions[0].(map[string]any) + require.True(t, ok) + versionID, ok := versionData["version_id"].(string) + require.True(t, ok) + + rollbackBody := map[string]any{"version_id": versionID} + rollbackRaw, err := json.Marshal(rollbackBody) + require.NoError(t, err) + + rollbackRec := httptest.NewRecorder() + rollbackReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/rollback?scope_group_ids=2", + bytes.NewReader(rollbackRaw), + ) + rollbackReq.Header.Set("Content-Type", "application/json") + rollbackReq.Header.Set("X-User-ID", "9") + router.ServeHTTP(rollbackRec, rollbackReq) + require.Equal(t, http.StatusOK, rollbackRec.Code) + rollbackData := decodeResponseDataMap(t, rollbackRec.Body.Bytes()) + require.Equal(t, "groups", rollbackData["share_scope"]) + groupIDs, ok := rollbackData["group_ids"].([]any) + require.True(t, ok) + require.Equal(t, []any{float64(2)}, groupIDs) + state, ok := rollbackData["state"].(map[string]any) + require.True(t, ok) + require.Equal(t, true, state["enableOpenAIWSMode"]) + + versionsAfterRollbackRec := httptest.NewRecorder() + versionsAfterRollbackReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/versions?scope_group_ids=2", + nil, + ) + versionsAfterRollbackReq.Header.Set("X-User-ID", "9") + router.ServeHTTP(versionsAfterRollbackRec, versionsAfterRollbackReq) + require.Equal(t, http.StatusOK, versionsAfterRollbackRec.Code) + versionsAfterRollbackData := decodeResponseDataMap(t, versionsAfterRollbackRec.Body.Bytes()) + versionsAfterRollback, ok := versionsAfterRollbackData["items"].([]any) + require.True(t, ok) + require.Len(t, versionsAfterRollback, 2) +} + +func TestSettingHandlerBulkEditTemplate_Validation(t *testing.T) { + router := setupBulkEditTemplateRouter() + + invalidCreateBody := map[string]any{ + "name": "Groups Template", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "groups", + "group_ids": []int64{}, + "state": map[string]any{}, + } + raw, err := json.Marshal(invalidCreateBody) + require.NoError(t, err) + + invalidCreateRec := httptest.NewRecorder() + invalidCreateReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw)) + invalidCreateReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(invalidCreateRec, invalidCreateReq) + require.Equal(t, http.StatusBadRequest, invalidCreateRec.Code) + + invalidListRec := httptest.NewRecorder() + invalidListReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_group_ids=abc", + nil, + ) + router.ServeHTTP(invalidListRec, invalidListReq) + require.Equal(t, http.StatusBadRequest, invalidListRec.Code) +} + +func TestSettingHandlerBulkEditTemplate_PrivateVisibilityAndDeletePermission(t *testing.T) { + router := setupBulkEditTemplateRouter() + + createBody := map[string]any{ + "name": "Private Template", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "private", + "state": map[string]any{ + "enableBaseUrl": true, + }, + } + raw, err := json.Marshal(createBody) + require.NoError(t, err) + + createRec := httptest.NewRecorder() + createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw)) + createReq.Header.Set("Content-Type", "application/json") + createReq.Header.Set("X-User-ID", "100") + router.ServeHTTP(createRec, createReq) + require.Equal(t, http.StatusOK, createRec.Code) + + createData := decodeResponseDataMap(t, createRec.Body.Bytes()) + templateID, ok := createData["id"].(string) + require.True(t, ok) + + listByOtherRec := httptest.NewRecorder() + listByOtherReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth", + nil, + ) + listByOtherReq.Header.Set("X-User-ID", "200") + router.ServeHTTP(listByOtherRec, listByOtherReq) + require.Equal(t, http.StatusOK, listByOtherRec.Code) + + listByOtherData := decodeResponseDataMap(t, listByOtherRec.Body.Bytes()) + items, ok := listByOtherData["items"].([]any) + require.True(t, ok) + require.Len(t, items, 0) + + deleteByOtherRec := httptest.NewRecorder() + deleteByOtherReq := httptest.NewRequest( + http.MethodDelete, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID, + nil, + ) + deleteByOtherReq.Header.Set("X-User-ID", "200") + router.ServeHTTP(deleteByOtherRec, deleteByOtherReq) + require.Equal(t, http.StatusForbidden, deleteByOtherRec.Code) + + deleteByOwnerRec := httptest.NewRecorder() + deleteByOwnerReq := httptest.NewRequest( + http.MethodDelete, + "/api/v1/admin/settings/bulk-edit-templates/"+templateID, + nil, + ) + deleteByOwnerReq.Header.Set("X-User-ID", "100") + router.ServeHTTP(deleteByOwnerRec, deleteByOwnerReq) + require.Equal(t, http.StatusOK, deleteByOwnerRec.Code) +} + +func TestSettingHandlerBulkEditTemplate_GroupsVisibilityByScopeGroupIDs(t *testing.T) { + router := setupBulkEditTemplateRouter() + + createBody := map[string]any{ + "name": "Group Shared", + "scope_platform": "openai", + "scope_type": "oauth", + "share_scope": "groups", + "group_ids": []int64{3, 8}, + "state": map[string]any{"enableOpenAIWSMode": true}, + } + raw, err := json.Marshal(createBody) + require.NoError(t, err) + + createRec := httptest.NewRecorder() + createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw)) + createReq.Header.Set("Content-Type", "application/json") + createReq.Header.Set("X-User-ID", "1") + router.ServeHTTP(createRec, createReq) + require.Equal(t, http.StatusOK, createRec.Code) + + invisibleRec := httptest.NewRecorder() + invisibleReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth&scope_group_ids=9", + nil, + ) + invisibleReq.Header.Set("X-User-ID", "2") + router.ServeHTTP(invisibleRec, invisibleReq) + require.Equal(t, http.StatusOK, invisibleRec.Code) + invisibleData := decodeResponseDataMap(t, invisibleRec.Body.Bytes()) + invisibleItems, ok := invisibleData["items"].([]any) + require.True(t, ok) + require.Len(t, invisibleItems, 0) + + visibleRec := httptest.NewRecorder() + visibleReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth&scope_group_ids=8", + nil, + ) + visibleReq.Header.Set("X-User-ID", "2") + router.ServeHTTP(visibleRec, visibleReq) + require.Equal(t, http.StatusOK, visibleRec.Code) + visibleData := decodeResponseDataMap(t, visibleRec.Body.Bytes()) + visibleItems, ok := visibleData["items"].([]any) + require.True(t, ok) + require.Len(t, visibleItems, 1) +} + +func TestSettingHandlerBulkEditTemplate_UnauthorizedAndInvalidRequests(t *testing.T) { + gin.SetMode(gin.TestMode) + repo := newSettingHandlerTemplateRepoStub() + settingService := service.NewSettingService(repo, nil) + handler := NewSettingHandler(settingService, nil, nil, nil, nil) + + router := gin.New() + router.GET("/list", handler.ListBulkEditTemplates) + router.GET("/versions/:template_id", handler.ListBulkEditTemplateVersions) + router.POST("/rollback/:template_id", handler.RollbackBulkEditTemplate) + router.POST("/upsert", handler.UpsertBulkEditTemplate) + router.DELETE("/delete/:template_id", handler.DeleteBulkEditTemplate) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/list", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/upsert", bytes.NewBufferString("{bad-json")) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/delete/%20", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/versions/abc", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/rollback/abc", bytes.NewBufferString(`{"version_id":"v1"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestParseScopeGroupIDs(t *testing.T) { + ids, err := parseScopeGroupIDs("") + require.NoError(t, err) + require.Nil(t, ids) + + ids, err = parseScopeGroupIDs("1, 2,2,3") + require.NoError(t, err) + require.Equal(t, []int64{1, 2, 3}, ids) + + _, err = parseScopeGroupIDs("x,2") + require.Error(t, err) +} + +func TestSettingHandlerBulkEditTemplate_BindErrorAndMissingTemplateID(t *testing.T) { + router := setupBulkEditTemplateRouter() + + bindErrRec := httptest.NewRecorder() + bindErrReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates", + bytes.NewBufferString("{bad-json"), + ) + bindErrReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(bindErrRec, bindErrReq) + require.Equal(t, http.StatusBadRequest, bindErrRec.Code) + + missingIDRec := httptest.NewRecorder() + missingIDReq := httptest.NewRequest( + http.MethodDelete, + "/api/v1/admin/settings/bulk-edit-templates/%20", + nil, + ) + router.ServeHTTP(missingIDRec, missingIDReq) + require.Equal(t, http.StatusBadRequest, missingIDRec.Code) + + invalidScopeRec := httptest.NewRecorder() + invalidScopeReq := httptest.NewRequest( + http.MethodGet, + "/api/v1/admin/settings/bulk-edit-templates/abc/versions?scope_group_ids=bad", + nil, + ) + router.ServeHTTP(invalidScopeRec, invalidScopeReq) + require.Equal(t, http.StatusBadRequest, invalidScopeRec.Code) + + rollbackMissingIDRec := httptest.NewRecorder() + rollbackMissingIDReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates/%20/rollback", + bytes.NewBufferString(`{"version_id":"v1"}`), + ) + rollbackMissingIDReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rollbackMissingIDRec, rollbackMissingIDReq) + require.Equal(t, http.StatusBadRequest, rollbackMissingIDRec.Code) + + rollbackInvalidScopeRec := httptest.NewRecorder() + rollbackInvalidScopeReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates/abc/rollback?scope_group_ids=bad", + bytes.NewBufferString(`{"version_id":"v1"}`), + ) + rollbackInvalidScopeReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rollbackInvalidScopeRec, rollbackInvalidScopeReq) + require.Equal(t, http.StatusBadRequest, rollbackInvalidScopeRec.Code) + + rollbackBindErrRec := httptest.NewRecorder() + rollbackBindErrReq := httptest.NewRequest( + http.MethodPost, + "/api/v1/admin/settings/bulk-edit-templates/abc/rollback", + bytes.NewBufferString("{bad-json"), + ) + rollbackBindErrReq.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rollbackBindErrRec, rollbackBindErrReq) + require.Equal(t, http.StatusBadRequest, rollbackBindErrRec.Code) +} + +func TestSettingHandlerBulkEditTemplate_ListErrorFromService(t *testing.T) { + gin.SetMode(gin.TestMode) + settingService := service.NewSettingService(&failingSettingRepoStub{}, nil) + handler := NewSettingHandler(settingService, nil, nil, nil, nil) + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 1}) + c.Next() + }) + router.GET("/list", handler.ListBulkEditTemplates) + router.GET("/versions/:template_id", handler.ListBulkEditTemplateVersions) + router.POST("/rollback/:template_id", handler.RollbackBulkEditTemplate) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/list", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusInternalServerError, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/versions/tpl-1", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusInternalServerError, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/rollback/tpl-1", bytes.NewBufferString(`{"version_id":"v1"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index fe2a1d773..49c74522a 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -72,31 +72,22 @@ func APIKeyFromService(k *service.APIKey) *APIKey { return nil } return &APIKey{ - ID: k.ID, - UserID: k.UserID, - Key: k.Key, - Name: k.Name, - GroupID: k.GroupID, - Status: k.Status, - IPWhitelist: k.IPWhitelist, - IPBlacklist: k.IPBlacklist, - LastUsedAt: k.LastUsedAt, - Quota: k.Quota, - QuotaUsed: k.QuotaUsed, - ExpiresAt: k.ExpiresAt, - CreatedAt: k.CreatedAt, - UpdatedAt: k.UpdatedAt, - RateLimit5h: k.RateLimit5h, - RateLimit1d: k.RateLimit1d, - RateLimit7d: k.RateLimit7d, - Usage5h: k.Usage5h, - Usage1d: k.Usage1d, - Usage7d: k.Usage7d, - Window5hStart: k.Window5hStart, - Window1dStart: k.Window1dStart, - Window7dStart: k.Window7dStart, - User: UserFromServiceShallow(k.User), - Group: GroupFromServiceShallow(k.Group), + ID: k.ID, + UserID: k.UserID, + Key: k.Key, + Name: k.Name, + GroupID: k.GroupID, + Status: k.Status, + IPWhitelist: k.IPWhitelist, + IPBlacklist: k.IPBlacklist, + LastUsedAt: k.LastUsedAt, + Quota: k.Quota, + QuotaUsed: k.QuotaUsed, + ExpiresAt: k.ExpiresAt, + CreatedAt: k.CreatedAt, + UpdatedAt: k.UpdatedAt, + User: UserFromServiceShallow(k.User), + Group: GroupFromServiceShallow(k.Group), } } @@ -218,17 +209,6 @@ func AccountFromServiceShallow(a *service.Account) *Account { if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 { out.SessionIdleTimeoutMin = &idleTimeout } - if rpm := a.GetBaseRPM(); rpm > 0 { - out.BaseRPM = &rpm - strategy := a.GetRPMStrategy() - out.RPMStrategy = &strategy - buffer := a.GetRPMStickyBuffer() - out.RPMStickyBuffer = &buffer - } - // 用户消息队列模式 - if mode := a.GetUserMsgQueueMode(); mode != "" { - out.UserMsgQueueMode = &mode - } // TLS指纹伪装开关 if a.IsTLSFingerprintEnabled() { enabled := true @@ -306,6 +286,7 @@ func ProxyFromService(p *service.Proxy) *Proxy { Host: p.Host, Port: p.Port, Username: p.Username, + Password: p.Password, Status: p.Status, CreatedAt: p.CreatedAt, UpdatedAt: p.UpdatedAt, @@ -335,51 +316,6 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi } } -// ProxyFromServiceAdmin converts a service Proxy to AdminProxy DTO for admin users. -// It includes the password field - user-facing endpoints must not use this. -func ProxyFromServiceAdmin(p *service.Proxy) *AdminProxy { - if p == nil { - return nil - } - base := ProxyFromService(p) - if base == nil { - return nil - } - return &AdminProxy{ - Proxy: *base, - Password: p.Password, - } -} - -// ProxyWithAccountCountFromServiceAdmin converts a service ProxyWithAccountCount to AdminProxyWithAccountCount DTO. -// It includes the password field - user-facing endpoints must not use this. -func ProxyWithAccountCountFromServiceAdmin(p *service.ProxyWithAccountCount) *AdminProxyWithAccountCount { - if p == nil { - return nil - } - admin := ProxyFromServiceAdmin(&p.Proxy) - if admin == nil { - return nil - } - return &AdminProxyWithAccountCount{ - AdminProxy: *admin, - AccountCount: p.AccountCount, - LatencyMs: p.LatencyMs, - LatencyStatus: p.LatencyStatus, - LatencyMessage: p.LatencyMessage, - IPAddress: p.IPAddress, - Country: p.Country, - CountryCode: p.CountryCode, - Region: p.Region, - City: p.City, - QualityStatus: p.QualityStatus, - QualityScore: p.QualityScore, - QualityGrade: p.QualityGrade, - QualitySummary: p.QualitySummary, - QualityChecked: p.QualityChecked, - } -} - func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary { if a == nil { return nil diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index beb03e679..41676b831 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -55,9 +55,8 @@ type SystemSettings struct { SoraClientEnabled bool `json:"sora_client_enabled"` CustomMenuItems []CustomMenuItem `json:"custom_menu_items"` - DefaultConcurrency int `json:"default_concurrency"` - DefaultBalance float64 `json:"default_balance"` - DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` + DefaultConcurrency int `json:"default_concurrency"` + DefaultBalance float64 `json:"default_balance"` // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -75,13 +74,6 @@ type SystemSettings struct { OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"` OpsQueryModeDefault string `json:"ops_query_mode_default"` OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"` - - MinClaudeCodeVersion string `json:"min_claude_code_version"` -} - -type DefaultSubscriptionSetting struct { - GroupID int64 `json:"group_id"` - ValidityDays int `json:"validity_days"` } type PublicSettings struct { diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 920615f70..732433975 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -47,17 +47,6 @@ type APIKey struct { CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` - // Rate limit fields - RateLimit5h float64 `json:"rate_limit_5h"` - RateLimit1d float64 `json:"rate_limit_1d"` - RateLimit7d float64 `json:"rate_limit_7d"` - Usage5h float64 `json:"usage_5h"` - Usage1d float64 `json:"usage_1d"` - Usage7d float64 `json:"usage_7d"` - Window5hStart *time.Time `json:"window_5h_start"` - Window1dStart *time.Time `json:"window_1d_start"` - Window7dStart *time.Time `json:"window_7d_start"` - User *User `json:"user,omitempty"` Group *Group `json:"group,omitempty"` } @@ -164,13 +153,6 @@ type Account struct { MaxSessions *int `json:"max_sessions,omitempty"` SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"` - // RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效) - // 从 extra 字段提取,方便前端显示和编辑 - BaseRPM *int `json:"base_rpm,omitempty"` - RPMStrategy *string `json:"rpm_strategy,omitempty"` - RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"` - UserMsgQueueMode *string `json:"user_msg_queue_mode,omitempty"` - // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) // 从 extra 字段提取,方便前端显示和编辑 EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"` @@ -233,32 +215,6 @@ type ProxyWithAccountCount struct { QualityChecked *int64 `json:"quality_checked,omitempty"` } -// AdminProxy 是管理员接口使用的 proxy DTO(包含密码等敏感字段)。 -// 注意:普通接口不得使用此 DTO。 -type AdminProxy struct { - Proxy - Password string `json:"password,omitempty"` -} - -// AdminProxyWithAccountCount 是管理员接口使用的带账号统计的 proxy DTO。 -type AdminProxyWithAccountCount struct { - AdminProxy - AccountCount int64 `json:"account_count"` - LatencyMs *int64 `json:"latency_ms,omitempty"` - LatencyStatus string `json:"latency_status,omitempty"` - LatencyMessage string `json:"latency_message,omitempty"` - IPAddress string `json:"ip_address,omitempty"` - Country string `json:"country,omitempty"` - CountryCode string `json:"country_code,omitempty"` - Region string `json:"region,omitempty"` - City string `json:"city,omitempty"` - QualityStatus string `json:"quality_status,omitempty"` - QualityScore *int `json:"quality_score,omitempty"` - QualityGrade string `json:"quality_grade,omitempty"` - QualitySummary string `json:"quality_summary,omitempty"` - QualityChecked *int64 `json:"quality_checked,omitempty"` -} - type ProxyAccountSummary struct { ID int64 `json:"id"` Name string `json:"name"` diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index c47e66df3..191d1d0c9 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -49,7 +49,6 @@ type GatewayHandler struct { maxAccountSwitches int maxAccountSwitchesGemini int cfg *config.Config - settingService *service.SettingService } // NewGatewayHandler creates a new GatewayHandler @@ -66,7 +65,6 @@ func NewGatewayHandler( errorPassthroughService *service.ErrorPassthroughService, userMsgQueueService *service.UserMessageQueueService, cfg *config.Config, - settingService *service.SettingService, ) *GatewayHandler { pingInterval := time.Duration(0) maxAccountSwitches := 10 @@ -102,7 +100,6 @@ func NewGatewayHandler( maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, cfg: cfg, - settingService: settingService, } } @@ -168,11 +165,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { SetClaudeCodeClientContext(c, body, parsedReq) isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context()) - // 版本检查:仅对 Claude Code 客户端,拒绝低于最低版本的请求 - if !h.checkClaudeCodeVersion(c) { - return - } - // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) @@ -421,15 +413,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // RPM 计数递增(Forward 成功后) - // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 - // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 - if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { - if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { - reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } - } - // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) @@ -682,7 +665,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, status, code, message, streamStarted) return } - // 兜底重试按"直接请求兜底分组"处理:清除强制平台,允许按分组平台调度 + // 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度 ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "") c.Request = c.Request.WithContext(ctx) currentAPIKey = fallbackAPIKey @@ -716,15 +699,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) { return } - // RPM 计数递增(Forward 成功后) - // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。 - // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。 - if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 { - if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil { - reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } - } - // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) @@ -1081,41 +1055,6 @@ func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarte return true } -// checkClaudeCodeVersion 检查 Claude Code 客户端版本是否满足最低要求 -// 仅对已识别的 Claude Code 客户端执行,count_tokens 路径除外 -func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool { - ctx := c.Request.Context() - if !service.IsClaudeCodeClient(ctx) { - return true - } - - // 排除 count_tokens 子路径 - if strings.HasSuffix(c.Request.URL.Path, "/count_tokens") { - return true - } - - minVersion := h.settingService.GetMinClaudeCodeVersion(ctx) - if minVersion == "" { - return true // 未设置,不检查 - } - - clientVersion := service.GetClaudeCodeVersion(ctx) - if clientVersion == "" { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", - "Unable to determine Claude Code version. Please update Claude Code: npm update -g @anthropic-ai/claude-code") - return false - } - - if service.CompareVersions(clientVersion, minVersion) < 0 { - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", - fmt.Sprintf("Your Claude Code version (%s) is below the minimum required version (%s). Please update: npm update -g @anthropic-ai/claude-code", - clientVersion, minVersion)) - return false - } - - return true -} - // errorResponse 返回Claude API格式的错误响应 func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ @@ -1493,25 +1432,12 @@ func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger } func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { - if task == nil { - return - } - if h.usageRecordWorkerPool != nil { - h.usageRecordWorkerPool.Submit(task) - return - } - // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - defer func() { - if recovered := recover(); recovered != nil { - logger.L().With( - zap.String("component", "handler.gateway.messages"), - zap.Any("panic", recovered), - ).Error("gateway.usage_record_task_panic_recovered") - } - }() - task(ctx) + submitUsageRecordTaskWithFallback( + "handler.gateway.messages", + h.usageRecordWorkerPool, + h.cfg, + task, + ) } // getUserMsgQueueMode 获取当前请求的 UMQ 模式 diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index c07c568d3..74f0861a9 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -153,7 +153,6 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi nil, // deferredService nil, // claudeTokenProvider nil, // sessionLimitCache - nil, // rpmCache nil, // digestStore ) diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 09e6c09ba..ea8a5f1a9 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -29,10 +29,8 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service. if parsedReq != nil { c.Set(claudeCodeParsedRequestContextKey, parsedReq) } - - ua := c.GetHeader("User-Agent") // Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。 - if !claudeCodeValidator.ValidateUserAgent(ua) { + if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) { ctx := service.SetClaudeCodeClient(c.Request.Context(), false) c.Request = c.Request.WithContext(ctx) return @@ -56,14 +54,6 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service. // 更新 request context ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode) - - // 仅在确认为 Claude Code 客户端时提取版本号写入 context - if isClaudeCode { - if version := claudeCodeValidator.ExtractVersion(ua); version != "" { - ctx = service.SetClaudeCodeVersion(ctx, version) - } - } - c.Request = c.Request.WithContext(ctx) } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 4bbd17bae..770d8ca5b 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -9,6 +9,7 @@ import ( "runtime/debug" "strconv" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -32,7 +33,31 @@ type OpenAIGatewayHandler struct { usageRecordWorkerPool *service.UsageRecordWorkerPool errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper + cfg *config.Config maxAccountSwitches int + + // Test hooks for websocket ingress path. Production keeps them nil. + wsSelectAccountWithSchedulerFn func( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport service.OpenAIUpstreamTransport, + ) (*service.AccountSelectionResult, service.OpenAIAccountScheduleDecision, error) + wsBindStickySessionFn func(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error + wsGetAccessTokenFn func(ctx context.Context, account *service.Account) (string, string, error) + wsProxyResponsesWSFn func( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *service.Account, + token string, + firstClientMessageType coderws.MessageType, + firstClientMessage []byte, + hooks *service.OpenAIWSIngressHooks, + ) error } // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler @@ -60,6 +85,7 @@ func NewOpenAIGatewayHandler( usageRecordWorkerPool: usageRecordWorkerPool, errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + cfg: cfg, maxAccountSwitches: maxAccountSwitches, } } @@ -113,30 +139,29 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } - setOpsRequestContext(c, "", false, body) - - // 校验请求体 JSON 合法性 + // 校验请求体 JSON 合法性,避免畸形 JSON 被 gjson 部分解析后继续下游处理。 if !gjson.ValidBytes(body) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } - // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal - modelResult := gjson.GetBytes(body, "model") + // 使用 GetManyBytes 一次扫描提取所有字段,避免多次遍历大请求体 + results := gjson.GetManyBytes(body, "model", "stream", "previous_response_id") + modelResult := results[0] if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required") return } reqModel := modelResult.String() - streamResult := gjson.GetBytes(body, "stream") + streamResult := results[1] if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type") return } reqStream := streamResult.Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) - previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()) + previousResponseID := strings.TrimSpace(results[2].String()) if previousResponseID != "" { previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) reqLog = reqLog.With( @@ -155,6 +180,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) + // 缓存已提取的 meta 到 context,供 Service 层复用,避免重复解析请求体。 + // prompt_cache_key 也在此处提前提取,确保 meta 设置后只读不写,避免并发竞态。 + promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + c.Set(service.OpenAIRequestMetaKey, &service.OpenAIRequestMeta{ + Model: reqModel, + Stream: reqStream, + PromptCacheKey: promptCacheKey, + }) + // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 if !h.validateFunctionCallOutputRequest(c, body, reqLog) { return @@ -269,7 +303,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0) h.gatewayService.RecordOpenAIAccountSwitch() failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr @@ -286,7 +320,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ) continue } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0) wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) fields := []zap.Field{ zap.Int64("account_id", account.ID), @@ -301,26 +335,32 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } if result != nil { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + var ttftMsVal float64 + if result.FirstTokenMs != nil && *result.FirstTokenMs > 0 { + ttftMsVal = float64(*result.FirstTokenMs) + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs, reqModel, ttftMsVal) } else { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil, reqModel, 0) } // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") clientIP := ip.GetClientIP(c) + requestID := strings.TrimSpace(c.GetHeader("X-Request-ID")) // 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。 h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + FallbackRequestID: requestID, + APIKeyService: h.apiKeyService, }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), @@ -550,9 +590,10 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { reqLog.Info("openai.websocket_ingress_started") clientIP := ip.GetClientIP(c) userAgent := strings.TrimSpace(c.GetHeader("User-Agent")) + requestID := strings.TrimSpace(c.GetHeader("X-Request-ID")) wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, + CompressionMode: coderws.CompressionNoContextTakeover, }) if err != nil { reqLog.Warn("openai.websocket_accept_failed", @@ -603,16 +644,21 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { } previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String()) previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) - if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID { - closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id") - return - } + service.SetOpenAIWSFirstMessageMeta(c, reqModel, previousResponseID, previousResponseIDKind) reqLog = reqLog.With( zap.Bool("ws_ingress", true), zap.String("model", reqModel), zap.Bool("has_previous_response_id", previousResponseID != ""), zap.String("previous_response_id_kind", previousResponseIDKind), ) + if h.shouldRejectWSMessageIDPreviousResponseIDEarly(previousResponseIDKind) { + reqLog.Warn("openai.websocket_request_validation_failed", + zap.String("reason", "previous_response_id_looks_like_message_id"), + zap.String("openai_ws_ingress_mode_default", h.cfg.Gateway.OpenAIWS.IngressModeDefault), + ) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id") + return + } setOpsRequestContext(c, reqModel, true, firstMessage) var currentUserRelease func() @@ -654,7 +700,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { firstMessage, openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID), ) - selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + selection, scheduleDecision, err := h.wsSelectAccountWithScheduler( ctx, apiKey.GroupID, previousResponseID, @@ -674,6 +720,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { } account := selection.Account + wsIngressMode, wsModeRouterV2Enabled := h.resolveOpenAIWSIngressMode(account) + reqLog = reqLog.With( + zap.Bool("openai_ws_mode_router_v2_enabled", wsModeRouterV2Enabled), + zap.String("openai_ws_ingress_mode", wsIngressMode), + ) accountMaxConcurrency := account.Concurrency if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 { accountMaxConcurrency = selection.WaitPlan.MaxConcurrency @@ -701,11 +752,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { accountReleaseFunc = fastReleaseFunc } currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) - if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil { + if err := h.wsBindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil { reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) } - token, _, err := h.gatewayService.GetAccessToken(ctx, account) + token, _, err := h.wsGetAccessToken(ctx, account) if err != nil { reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") @@ -717,13 +768,112 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { zap.String("account_name", account.Name), zap.String("schedule_layer", scheduleDecision.Layer), zap.Int("candidate_count", scheduleDecision.CandidateCount), + zap.Bool("openai_ws_mode_router_v2_enabled", wsModeRouterV2Enabled), + zap.String("openai_ws_ingress_mode", wsIngressMode), ) + var turnScheduleReported atomic.Bool + afterTurn := func(releaseSlots bool) func(turn int, result *service.OpenAIForwardResult, turnErr error) { + return func(turn int, result *service.OpenAIForwardResult, turnErr error) { + if releaseSlots { + releaseTurnSlots() + } + if turnErr != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0) + turnScheduleReported.Store(true) + if partialResult, ok := service.OpenAIWSIngressTurnPartialResult(turnErr); ok && partialResult != nil { + h.submitUsageRecordTask(func(taskCtx context.Context) { + if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ + Result: partialResult, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + FallbackRequestID: requestID, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("openai.websocket_record_partial_usage_failed", + zap.Int64("account_id", account.ID), + zap.String("request_id", partialResult.RequestID), + zap.Error(err), + ) + } + }) + } + return + } + if result == nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0) + turnScheduleReported.Store(true) + return + } + var turnTTFTMs float64 + if result.FirstTokenMs != nil && *result.FirstTokenMs > 0 { + turnTTFTMs = float64(*result.FirstTokenMs) + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs, reqModel, turnTTFTMs) + turnScheduleReported.Store(true) + h.submitUsageRecordTask(func(taskCtx context.Context) { + if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + FallbackRequestID: requestID, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("openai.websocket_record_usage_failed", + zap.Int64("account_id", account.ID), + zap.String("request_id", result.RequestID), + zap.Error(err), + ) + } + }) + } + } + handleProxyError := func(proxyErr error) { + if !turnScheduleReported.Load() { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0) + } + closeStatus, closeReason := summarizeWSCloseErrorForLog(proxyErr) + reqLog.Warn("openai.websocket_proxy_failed", + zap.Int64("account_id", account.ID), + zap.Error(proxyErr), + zap.String("close_status", closeStatus), + zap.String("close_reason", closeReason), + ) + var closeErr *service.OpenAIWSClientCloseError + if errors.As(proxyErr, &closeErr) { + closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) + return + } + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed") + } + if wsIngressMode == service.OpenAIWSIngressModePassthrough { + passthroughHooks := &service.OpenAIWSIngressHooks{ + AfterTurn: afterTurn(false), + } + if err := h.wsProxyResponses(ctx, c, wsConn, account, token, msgType, firstMessage, passthroughHooks); err != nil { + handleProxyError(err) + return + } + reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) + return + } + hooks := &service.OpenAIWSIngressHooks{ BeforeTurn: func(turn int) error { if turn == 1 { return nil } + if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "billing check failed", err) + } // 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。 releaseTurnSlots() // 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。 @@ -751,48 +901,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) return nil }, - AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { - releaseTurnSlots() - if turnErr != nil || result == nil { - return - } - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) - h.submitUsageRecordTask(func(taskCtx context.Context) { - if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ - Result: result, - APIKey: apiKey, - User: apiKey.User, - Account: account, - Subscription: subscription, - UserAgent: userAgent, - IPAddress: clientIP, - APIKeyService: h.apiKeyService, - }); err != nil { - reqLog.Error("openai.websocket_record_usage_failed", - zap.Int64("account_id", account.ID), - zap.String("request_id", result.RequestID), - zap.Error(err), - ) - } - }) - }, + AfterTurn: afterTurn(true), } - if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil { - h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) - closeStatus, closeReason := summarizeWSCloseErrorForLog(err) - reqLog.Warn("openai.websocket_proxy_failed", - zap.Int64("account_id", account.ID), - zap.Error(err), - zap.String("close_status", closeStatus), - zap.String("close_reason", closeReason), - ) - var closeErr *service.OpenAIWSClientCloseError - if errors.As(err, &closeErr) { - closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) - return - } - closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed") + if err := h.wsProxyResponses(ctx, c, wsConn, account, token, msgType, firstMessage, hooks); err != nil { + handleProxyError(err) return } reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) @@ -882,25 +995,12 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) { } func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { - if task == nil { - return - } - if h.usageRecordWorkerPool != nil { - h.usageRecordWorkerPool.Submit(task) - return - } - // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - defer func() { - if recovered := recover(); recovered != nil { - logger.L().With( - zap.String("component", "handler.openai_gateway.responses"), - zap.Any("panic", recovered), - ).Error("openai.usage_record_task_panic_recovered") - } - }() - task(ctx) + submitUsageRecordTaskWithFallback( + "handler.openai_gateway.responses", + h.usageRecordWorkerPool, + h.cfg, + task, + ) } // handleConcurrencyError handles concurrency-related errors with proper 429 response @@ -1030,6 +1130,124 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID) } +func (h *OpenAIGatewayHandler) wsSelectAccountWithScheduler( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport service.OpenAIUpstreamTransport, +) (*service.AccountSelectionResult, service.OpenAIAccountScheduleDecision, error) { + if h != nil && h.wsSelectAccountWithSchedulerFn != nil { + return h.wsSelectAccountWithSchedulerFn( + ctx, + groupID, + previousResponseID, + sessionHash, + requestedModel, + excludedIDs, + requiredTransport, + ) + } + return h.gatewayService.SelectAccountWithScheduler( + ctx, + groupID, + previousResponseID, + sessionHash, + requestedModel, + excludedIDs, + requiredTransport, + ) +} + +func (h *OpenAIGatewayHandler) wsBindStickySession( + ctx context.Context, + groupID *int64, + sessionHash string, + accountID int64, +) error { + if h != nil && h.wsBindStickySessionFn != nil { + return h.wsBindStickySessionFn(ctx, groupID, sessionHash, accountID) + } + return h.gatewayService.BindStickySession(ctx, groupID, sessionHash, accountID) +} + +func (h *OpenAIGatewayHandler) wsGetAccessToken( + ctx context.Context, + account *service.Account, +) (string, string, error) { + if h != nil && h.wsGetAccessTokenFn != nil { + return h.wsGetAccessTokenFn(ctx, account) + } + return h.gatewayService.GetAccessToken(ctx, account) +} + +func (h *OpenAIGatewayHandler) wsProxyResponses( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *service.Account, + token string, + firstClientMessageType coderws.MessageType, + firstClientMessage []byte, + hooks *service.OpenAIWSIngressHooks, +) error { + if h != nil && h.wsProxyResponsesWSFn != nil { + return h.wsProxyResponsesWSFn( + ctx, + c, + clientConn, + account, + token, + firstClientMessageType, + firstClientMessage, + hooks, + ) + } + return h.gatewayService.ProxyResponsesWebSocketFromClient( + ctx, + c, + clientConn, + account, + token, + firstClientMessageType, + firstClientMessage, + hooks, + ) +} + +func (h *OpenAIGatewayHandler) resolveOpenAIWSIngressMode(account *service.Account) (mode string, modeRouterV2Enabled bool) { + if account == nil { + return "account_missing", false + } + if h == nil || h.cfg == nil { + return "config_missing", false + } + modeRouterV2Enabled = h.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled + if !modeRouterV2Enabled { + return "legacy", false + } + resolvedMode := account.ResolveOpenAIResponsesWebSocketV2Mode(h.cfg.Gateway.OpenAIWS.IngressModeDefault) + if resolvedMode == "" { + resolvedMode = service.OpenAIWSIngressModeOff + } + return resolvedMode, true +} + +func (h *OpenAIGatewayHandler) shouldRejectWSMessageIDPreviousResponseIDEarly(previousResponseIDKind string) bool { + if previousResponseIDKind != service.OpenAIPreviousResponseIDKindMessageID { + return false + } + if h == nil || h.cfg == nil { + return false + } + if !h.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled { + return false + } + return strings.TrimSpace(h.cfg.Gateway.OpenAIWS.IngressModeDefault) == service.OpenAIWSIngressModeCtxPool +} + func isOpenAIWSUpgradeRequest(r *http.Request) bool { if r == nil { return false diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index a26b3a0c3..9ea62044e 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -7,9 +7,11 @@ import ( "net/http" "net/http/httptest" "strings" + "sync/atomic" "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" @@ -431,6 +433,34 @@ func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) { require.Contains(t, w.Body.String(), "previous_response_id must be a response.id") } +func TestOpenAIResponses_InvalidJSONBodyReturnsBadRequest(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader( + `{"model":"gpt-5.1","stream":false,invalid}`, + )) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 201, + GroupID: &groupID, + User: &service.User{ID: 1}, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + h.Responses(c) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "Failed to parse request body") +} + func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) { gin.SetMode(gin.TestMode) @@ -461,10 +491,86 @@ func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c)) } -func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) { +func TestOpenAIResponsesWebSocket_DoesNotEarlyRejectMessageIDPreviousResponseID(t *testing.T) { gin.SetMode(gin.TestMode) - h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return false, nil + }, + } + h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache) + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx) + cancelRead() + require.Error(t, err) + var closeErr coderws.CloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusTryAgainLater, closeErr.Code) + require.Contains(t, strings.ToLower(closeErr.Reason), "too many concurrent requests") +} + +func TestOpenAIResponsesWebSocket_CtxPoolRejectsMessageIDBeforeScheduling(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{RunMode: config.RunModeSimple} + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = service.OpenAIWSIngressModeCtxPool + + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second), + cfg: cfg, + } + + var scheduleCalls atomic.Int32 + var stickyBindCalls atomic.Int32 + h.wsSelectAccountWithSchedulerFn = func( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport service.OpenAIUpstreamTransport, + ) (*service.AccountSelectionResult, service.OpenAIAccountScheduleDecision, error) { + scheduleCalls.Add(1) + return nil, service.OpenAIAccountScheduleDecision{}, errors.New("should not be called") + } + h.wsBindStickySessionFn = func(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { + stickyBindCalls.Add(1) + return nil + } + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) defer wsServer.Close() @@ -487,10 +593,244 @@ func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testin _, _, err = clientConn.Read(readCtx) cancelRead() require.Error(t, err) + var closeErr coderws.CloseError require.ErrorAs(t, err, &closeErr) require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) - require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id") + require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id must be a response.id") + require.Equal(t, int32(0), scheduleCalls.Load()) + require.Equal(t, int32(0), stickyBindCalls.Load()) +} + +func TestOpenAIResponsesWebSocket_RejectsEmptyModelInFirstPayload(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","stream":false,"input":[]}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx) + cancelRead() + require.Error(t, err) + var closeErr coderws.CloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) + require.Contains(t, strings.ToLower(closeErr.Reason), "model is required") +} + +func TestOpenAIResponsesWebSocket_RejectsEmptyModelWithoutCallingProxy(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + var proxyCalls atomic.Int32 + h.wsProxyResponsesWSFn = func( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *service.Account, + token string, + firstClientMessageType coderws.MessageType, + firstClientMessage []byte, + hooks *service.OpenAIWSIngressHooks, + ) error { + proxyCalls.Add(1) + return nil + } + + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","stream":false,"input":[]}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx) + cancelRead() + require.Error(t, err) + var closeErr coderws.CloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) + require.Contains(t, strings.ToLower(closeErr.Reason), "model is required") + require.Equal(t, int32(0), proxyCalls.Load()) +} + +func TestOpenAIResponsesWebSocket_PassthroughAndCtxPoolShareSchedulerInputsAndStickyBind(t *testing.T) { + gin.SetMode(gin.TestMode) + + type wsScheduleCapture struct { + groupID int64 + previousResponseID string + sessionHash string + model string + transport service.OpenAIUpstreamTransport + stickyBindCount int + stickyGroupID int64 + stickySessionHash string + stickyAccountID int64 + } + runCase := func(t *testing.T, mode string) wsScheduleCapture { + t.Helper() + + cfg := &config.Config{RunMode: config.RunModeSimple} + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = mode + billingSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) + t.Cleanup(func() { + billingSvc.Stop() + }) + + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: billingSvc, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second), + cfg: cfg, + } + + account := &service.Account{ + ID: 901, + Name: "ws-mode-test", + Platform: service.PlatformOpenAI, + Type: service.AccountTypeAPIKey, + Status: service.StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{"api_key": "sk-test"}, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_mode": mode, + }, + } + + var capture wsScheduleCapture + h.wsSelectAccountWithSchedulerFn = func( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport service.OpenAIUpstreamTransport, + ) (*service.AccountSelectionResult, service.OpenAIAccountScheduleDecision, error) { + if groupID != nil { + capture.groupID = *groupID + } + capture.previousResponseID = previousResponseID + capture.sessionHash = sessionHash + capture.model = requestedModel + capture.transport = requiredTransport + return &service.AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: func() {}, + }, service.OpenAIAccountScheduleDecision{Layer: "unit"}, nil + } + h.wsBindStickySessionFn = func(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error { + capture.stickyBindCount++ + if groupID != nil { + capture.stickyGroupID = *groupID + } + capture.stickySessionHash = sessionHash + capture.stickyAccountID = accountID + return nil + } + h.wsGetAccessTokenFn = func(ctx context.Context, account *service.Account) (string, string, error) { + return "sk-test", "apikey", nil + } + h.wsProxyResponsesWSFn = func( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *service.Account, + token string, + firstClientMessageType coderws.MessageType, + firstClientMessage []byte, + hooks *service.OpenAIWSIngressHooks, + ) error { + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(1, nil, nil) + } + return nil + } + + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + t.Cleanup(wsServer.Close) + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", &coderws.DialOptions{ + HTTPHeader: http.Header{ + "Session_ID": []string{"session-fixed-123"}, + }, + }) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_fixed"}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, _ = clientConn.Read(readCtx) + cancelRead() + return capture + } + + passthroughCapture := runCase(t, service.OpenAIWSIngressModePassthrough) + ctxPoolCapture := runCase(t, service.OpenAIWSIngressModeCtxPool) + + require.Equal(t, passthroughCapture.groupID, ctxPoolCapture.groupID) + require.Equal(t, passthroughCapture.previousResponseID, ctxPoolCapture.previousResponseID) + require.Equal(t, passthroughCapture.sessionHash, ctxPoolCapture.sessionHash) + require.Equal(t, passthroughCapture.model, ctxPoolCapture.model) + require.Equal(t, service.OpenAIUpstreamTransportResponsesWebsocketV2, passthroughCapture.transport) + require.Equal(t, passthroughCapture.transport, ctxPoolCapture.transport) + + require.Equal(t, 1, passthroughCapture.stickyBindCount) + require.Equal(t, 1, ctxPoolCapture.stickyBindCount) + require.Equal(t, passthroughCapture.stickyGroupID, ctxPoolCapture.stickyGroupID) + require.Equal(t, passthroughCapture.stickySessionHash, ctxPoolCapture.stickySessionHash) + require.Equal(t, passthroughCapture.stickyAccountID, ctxPoolCapture.stickyAccountID) } func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) { diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 2f53d655e..6fbf79527 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -662,10 +662,8 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { requestID = c.Writer.Header().Get("x-request-id") } - normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code) - - phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code) - isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message) + phase := classifyOpsPhase(parsed.ErrorType, parsed.Message, parsed.Code) + isBusinessLimited := classifyOpsIsBusinessLimited(parsed.ErrorType, phase, parsed.Code, status, parsed.Message) errorOwner := classifyOpsErrorOwner(phase, parsed.Message) errorSource := classifyOpsErrorSource(phase, parsed.Message) @@ -687,8 +685,8 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { UserAgent: c.GetHeader("User-Agent"), ErrorPhase: phase, - ErrorType: normalizedType, - Severity: classifyOpsSeverity(normalizedType, status), + ErrorType: normalizeOpsErrorType(parsed.ErrorType, parsed.Code), + Severity: classifyOpsSeverity(parsed.ErrorType, status), StatusCode: status, IsBusinessLimited: isBusinessLimited, IsCountTokens: isCountTokensRequest(c), @@ -700,7 +698,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { ErrorSource: errorSource, ErrorOwner: errorOwner, - IsRetryable: classifyOpsIsRetryable(normalizedType, status), + IsRetryable: classifyOpsIsRetryable(parsed.ErrorType, status), RetryCount: 0, CreatedAt: time.Now(), } @@ -941,29 +939,8 @@ func guessPlatformFromPath(path string) string { } } -// isKnownOpsErrorType returns true if t is a recognized error type used by the -// ops classification pipeline. Upstream proxies sometimes return garbage values -// (e.g. the Go-serialized literal "") which would pollute phase/severity -// classification if accepted blindly. -func isKnownOpsErrorType(t string) bool { - switch t { - case "invalid_request_error", - "authentication_error", - "rate_limit_error", - "billing_error", - "subscription_error", - "upstream_error", - "overloaded_error", - "api_error", - "not_found_error", - "forbidden_error": - return true - } - return false -} - func normalizeOpsErrorType(errType string, code string) string { - if errType != "" && isKnownOpsErrorType(errType) { + if errType != "" { return errType } switch strings.TrimSpace(code) { diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go index 679dd4cef..731b36ab9 100644 --- a/backend/internal/handler/ops_error_logger_test.go +++ b/backend/internal/handler/ops_error_logger_test.go @@ -214,63 +214,3 @@ func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) { }) require.Equal(t, http.StatusNoContent, rec.Code) } - -func TestIsKnownOpsErrorType(t *testing.T) { - known := []string{ - "invalid_request_error", - "authentication_error", - "rate_limit_error", - "billing_error", - "subscription_error", - "upstream_error", - "overloaded_error", - "api_error", - "not_found_error", - "forbidden_error", - } - for _, k := range known { - require.True(t, isKnownOpsErrorType(k), "expected known: %s", k) - } - - unknown := []string{"", "null", "", "random_error", "some_new_type", "\u003e"} - for _, u := range unknown { - require.False(t, isKnownOpsErrorType(u), "expected unknown: %q", u) - } -} - -func TestNormalizeOpsErrorType(t *testing.T) { - tests := []struct { - name string - errType string - code string - want string - }{ - // Known types pass through. - {"known invalid_request_error", "invalid_request_error", "", "invalid_request_error"}, - {"known rate_limit_error", "rate_limit_error", "", "rate_limit_error"}, - {"known upstream_error", "upstream_error", "", "upstream_error"}, - - // Unknown/garbage types are rejected and fall through to code-based or default. - {"nil literal from upstream", "", "", "api_error"}, - {"null string", "null", "", "api_error"}, - {"random string", "something_weird", "", "api_error"}, - - // Unknown type but known code still maps correctly. - {"nil with INSUFFICIENT_BALANCE code", "", "INSUFFICIENT_BALANCE", "billing_error"}, - {"nil with USAGE_LIMIT_EXCEEDED code", "", "USAGE_LIMIT_EXCEEDED", "subscription_error"}, - - // Empty type falls through to code-based mapping. - {"empty type with balance code", "", "INSUFFICIENT_BALANCE", "billing_error"}, - {"empty type with subscription code", "", "SUBSCRIPTION_NOT_FOUND", "subscription_error"}, - {"empty type no code", "", "", "api_error"}, - - // Known type overrides conflicting code-based mapping. - {"known type overrides conflicting code", "rate_limit_error", "INSUFFICIENT_BALANCE", "rate_limit_error"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := normalizeOpsErrorType(tt.errType, tt.code) - require.Equal(t, tt.want, got) - }) - } -} diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go index d933abd7d..890a523db 100644 --- a/backend/internal/handler/sora_client_handler_test.go +++ b/backend/internal/handler/sora_client_handler_test.go @@ -942,12 +942,12 @@ func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, e func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil } -func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } -func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } -func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } +func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } +func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } // ==================== NewSoraClientHandler ==================== @@ -2059,13 +2059,19 @@ func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination. return nil, nil, nil } func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { - return nil, nil + return r.accounts, nil } func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) { return nil, nil } -func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) { - return nil, nil +func (r *stubAccountRepoForHandler) ListByPlatform(_ context.Context, platform string) ([]service.Account, error) { + filtered := make([]service.Account, 0, len(r.accounts)) + for _, account := range r.accounts { + if account.Platform == platform { + filtered = append(filtered, account) + } + } + return filtered, nil } func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil } func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { @@ -2199,7 +2205,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { return service.NewGatewayService( accountRepo, nil, nil, nil, nil, nil, nil, nil, - nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, ) } @@ -2233,7 +2239,6 @@ func TestProcessGeneration_SelectAccountError(t *testing.T) { } func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) @@ -2253,7 +2258,6 @@ func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) { } func TestProcessGeneration_ForwardError(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) @@ -2312,7 +2316,6 @@ func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) { } func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) @@ -2374,7 +2377,6 @@ func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) { } func TestProcessGeneration_FullSuccessUpstream(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} genService := service.NewSoraGenerationService(repo, nil, nil) @@ -2405,7 +2407,6 @@ func TestProcessGeneration_FullSuccessUpstream(t *testing.T) { } func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") sourceServer := newFakeSourceServer() defer sourceServer.Close() fakeS3 := newFakeS3Server("ok") @@ -2453,7 +2454,6 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { } func TestProcessGeneration_MarkCompletedFails(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") repo := newStubSoraGenRepo() repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败 @@ -2621,7 +2621,6 @@ func TestDeleteGeneration_DeleteError(t *testing.T) { // ==================== fetchUpstreamModels 测试 ==================== func TestFetchUpstreamModels_NilGateway(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") h := &SoraClientHandler{} _, err := h.fetchUpstreamModels(context.Background()) require.Error(t, err) @@ -2629,7 +2628,6 @@ func TestFetchUpstreamModels_NilGateway(t *testing.T) { } func TestFetchUpstreamModels_NoAccounts(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") accountRepo := &stubAccountRepoForHandler{accounts: nil} gatewayService := newMinimalGatewayService(accountRepo) h := &SoraClientHandler{gatewayService: gatewayService} @@ -2639,7 +2637,6 @@ func TestFetchUpstreamModels_NoAccounts(t *testing.T) { } func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, @@ -2653,7 +2650,6 @@ func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) { } func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") accountRepo := &stubAccountRepoForHandler{ accounts: []service.Account{ {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, @@ -2668,7 +2664,6 @@ func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) { } func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com" // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败 accountRepo := &stubAccountRepoForHandler{ @@ -2684,7 +2679,6 @@ func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) { } func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) @@ -2704,7 +2698,6 @@ func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) { } func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("not json")) @@ -2725,7 +2718,6 @@ func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) { } func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"data":[]}`)) @@ -2746,7 +2738,6 @@ func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) { } func TestFetchUpstreamModels_Success(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // 验证请求头 require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization")) @@ -2770,7 +2761,6 @@ func TestFetchUpstreamModels_Success(t *testing.T) { } func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) { - t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`)) @@ -2805,7 +2795,6 @@ func TestGetModelFamilies_CachesLocalConfig(t *testing.T) { } func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) { - t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复") ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`)) diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 48c1e451b..a0045aa53 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -41,6 +41,7 @@ type SoraGatewayHandler struct { soraTLSEnabled bool soraMediaSigningKey string soraMediaRoot string + cfg *config.Config } // NewSoraGatewayHandler creates a new SoraGatewayHandler @@ -83,6 +84,7 @@ func NewSoraGatewayHandler( soraTLSEnabled: soraTLSEnabled, soraMediaSigningKey: signKey, soraMediaRoot: mediaRoot, + cfg: cfg, } } @@ -451,25 +453,12 @@ func generateOpenAISessionHash(c *gin.Context, body []byte) string { } func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { - if task == nil { - return - } - if h.usageRecordWorkerPool != nil { - h.usageRecordWorkerPool.Submit(task) - return - } - // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - defer func() { - if recovered := recover(); recovered != nil { - logger.L().With( - zap.String("component", "handler.sora_gateway.chat_completions"), - zap.Any("panic", recovered), - ).Error("sora.usage_record_task_panic_recovered") - } - }() - task(ctx) + submitUsageRecordTaskWithFallback( + "handler.sora_gateway.chat_completions", + h.usageRecordWorkerPool, + h.cfg, + task, + ) } func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index b76ab67da..66877e6ca 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -326,9 +326,6 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { return nil, nil } -func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { - return nil, nil -} func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { return nil, nil } @@ -435,8 +432,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) { deferredService, nil, testutil.StubSessionLimitCache{}, - nil, // rpmCache - nil, // digestStore + nil, ) soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}} diff --git a/backend/internal/handler/usage_record_submit_helper.go b/backend/internal/handler/usage_record_submit_helper.go new file mode 100644 index 000000000..5c1987b1e --- /dev/null +++ b/backend/internal/handler/usage_record_submit_helper.go @@ -0,0 +1,60 @@ +package handler + +import ( + "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/service" + "go.uber.org/zap" +) + +func submitUsageRecordTaskWithFallback( + component string, + pool *service.UsageRecordWorkerPool, + cfg *config.Config, + task service.UsageRecordTask, +) { + if task == nil { + return + } + if pool != nil { + mode := pool.Submit(task) + if mode != service.UsageRecordSubmitModeDropped { + return + } + // 队列溢出导致 submit 丢弃时,同步兜底执行,避免 usage 漏记费。 + logger.L().With( + zap.String("component", component), + zap.String("submit_mode", mode.String()), + ).Warn("usage_record.task_submit_dropped_sync_fallback") + } + + ctx, cancel := context.WithTimeout(context.Background(), usageRecordSyncFallbackTimeout(cfg)) + defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", component), + zap.Any("panic", recovered), + ).Error("usage_record.task_panic_recovered") + } + }() + task(ctx) +} + +func usageRecordSyncFallbackTimeout(cfg *config.Config) time.Duration { + timeout := 10 * time.Second + if cfg != nil && cfg.Gateway.UsageRecord.TaskTimeoutSeconds > 0 { + timeout = time.Duration(cfg.Gateway.UsageRecord.TaskTimeoutSeconds) * time.Second + } + // keep a sane bound on synchronous fallback to limit request-path blocking. + if timeout < time.Second { + return time.Second + } + if timeout > 10*time.Second { + return 10 * time.Second + } + return timeout +} diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go index c7c48e14b..20e8e87c3 100644 --- a/backend/internal/handler/usage_record_submit_task_test.go +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" ) @@ -54,6 +55,22 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing. require.True(t, called.Load()) } +func TestGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) { + pool := newUsageRecordTestPool(t) + pool.Stop() + h := &GatewayHandler{usageRecordWorkerPool: pool} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in dropped sync fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load(), "dropped task should run via sync fallback") +} + func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { h := &GatewayHandler{} require.NotPanics(t, func() { @@ -93,6 +110,40 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { } } +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) { + pool := newUsageRecordTestPool(t) + pool.Stop() + h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in dropped sync fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load(), "dropped task should run via sync fallback") +} + +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithConfigFallbackTimeout(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.UsageRecord.TaskTimeoutSeconds = 2 + h := &OpenAIGatewayHandler{cfg: cfg} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + deadline, ok := ctx.Deadline() + require.True(t, ok, "expected deadline in fallback context") + remaining := time.Until(deadline) + require.LessOrEqual(t, remaining, 2200*time.Millisecond) + require.GreaterOrEqual(t, remaining, 1200*time.Millisecond) + called.Store(true) + }) + + require.True(t, called.Load()) +} + func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) { h := &OpenAIGatewayHandler{} var called atomic.Bool @@ -160,6 +211,22 @@ func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *test require.True(t, called.Load()) } +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) { + pool := newUsageRecordTestPool(t) + pool.Stop() + h := &SoraGatewayHandler{usageRecordWorkerPool: pool} + var called atomic.Bool + + h.submitUsageRecordTask(func(ctx context.Context) { + if _, ok := ctx.Deadline(); !ok { + t.Fatal("expected deadline in dropped sync fallback context") + } + called.Store(true) + }) + + require.True(t, called.Load(), "dropped task should run via sync fallback") +} + func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { h := &SoraGatewayHandler{} require.NotPanics(t, func() { diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index d46bbc454..1998221a3 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -14,9 +14,6 @@ import ( "net/url" "strings" "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" - "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" ) // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) @@ -152,26 +149,22 @@ type Client struct { httpClient *http.Client } -func NewClient(proxyURL string) (*Client, error) { +func NewClient(proxyURL string) *Client { client := &http.Client{ Timeout: 30 * time.Second, } - _, parsed, err := proxyurl.Parse(proxyURL) - if err != nil { - return nil, err - } - if parsed != nil { - transport := &http.Transport{} - if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { - return nil, fmt.Errorf("configure proxy: %w", err) + if strings.TrimSpace(proxyURL) != "" { + if proxyURLParsed, err := url.Parse(proxyURL); err == nil { + client.Transport = &http.Transport{ + Proxy: http.ProxyURL(proxyURLParsed), + } } - client.Transport = transport } return &Client{ httpClient: client, - }, nil + } } // isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go index 20b57833a..394b6128c 100644 --- a/backend/internal/pkg/antigravity/client_test.go +++ b/backend/internal/pkg/antigravity/client_test.go @@ -228,20 +228,8 @@ func TestGetTier_两者都为nil(t *testing.T) { // NewClient // --------------------------------------------------------------------------- -func mustNewClient(t *testing.T, proxyURL string) *Client { - t.Helper() - client, err := NewClient(proxyURL) - if err != nil { - t.Fatalf("NewClient(%q) failed: %v", proxyURL, err) - } - return client -} - func TestNewClient_无代理(t *testing.T) { - client, err := NewClient("") - if err != nil { - t.Fatalf("NewClient 返回错误: %v", err) - } + client := NewClient("") if client == nil { t.Fatal("NewClient 返回 nil") } @@ -258,10 +246,7 @@ func TestNewClient_无代理(t *testing.T) { } func TestNewClient_有代理(t *testing.T) { - client, err := NewClient("http://proxy.example.com:8080") - if err != nil { - t.Fatalf("NewClient 返回错误: %v", err) - } + client := NewClient("http://proxy.example.com:8080") if client == nil { t.Fatal("NewClient 返回 nil") } @@ -271,10 +256,7 @@ func TestNewClient_有代理(t *testing.T) { } func TestNewClient_空格代理(t *testing.T) { - client, err := NewClient(" ") - if err != nil { - t.Fatalf("NewClient 返回错误: %v", err) - } + client := NewClient(" ") if client == nil { t.Fatal("NewClient 返回 nil") } @@ -285,13 +267,15 @@ func TestNewClient_空格代理(t *testing.T) { } func TestNewClient_无效代理URL(t *testing.T) { - // 无效 URL 应返回 error - _, err := NewClient("://invalid") - if err == nil { - t.Fatal("无效代理 URL 应返回错误") + // 无效 URL 时 url.Parse 不一定返回错误(Go 的 url.Parse 很宽容), + // 但 ://invalid 会导致解析错误 + client := NewClient("://invalid") + if client == nil { + t.Fatal("NewClient 返回 nil") } - if !strings.Contains(err.Error(), "invalid proxy URL") { - t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error()) + // 无效 URL 解析失败时,Transport 应保持 nil + if client.httpClient.Transport != nil { + t.Error("无效代理 URL 时 Transport 应为 nil") } } @@ -515,7 +499,7 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) { defaultClientSecret = "" t.Cleanup(func() { defaultClientSecret = old }) - client := mustNewClient(t, "") + client := NewClient("") _, err := client.ExchangeCode(context.Background(), "code", "verifier") if err == nil { t.Fatal("缺少 client_secret 时应返回错误") @@ -618,7 +602,7 @@ func TestClient_RefreshToken_无ClientSecret(t *testing.T) { defaultClientSecret = "" t.Cleanup(func() { defaultClientSecret = old }) - client := mustNewClient(t, "") + client := NewClient("") _, err := client.RefreshToken(context.Background(), "refresh-tok") if err == nil { t.Fatal("缺少 client_secret 时应返回错误") @@ -1258,7 +1242,7 @@ func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := mustNewClient(t, "") + client := NewClient("") resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token") if err != nil { t.Fatalf("LoadCodeAssist 失败: %v", err) @@ -1293,7 +1277,7 @@ func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := mustNewClient(t, "") + client := NewClient("") _, _, err := client.LoadCodeAssist(context.Background(), "bad-token") if err == nil { t.Fatal("服务器返回 403 时应返回错误") @@ -1316,7 +1300,7 @@ func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := mustNewClient(t, "") + client := NewClient("") _, _, err := client.LoadCodeAssist(context.Background(), "token") if err == nil { t.Fatal("无效 JSON 响应应返回错误") @@ -1349,7 +1333,7 @@ func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := mustNewClient(t, "") + client := NewClient("") resp, _, err := client.LoadCodeAssist(context.Background(), "token") if err != nil { t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err) @@ -1377,7 +1361,7 @@ func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := mustNewClient(t, "") + client := NewClient("") _, _, err := client.LoadCodeAssist(context.Background(), "token") if err == nil { t.Fatal("所有 URL 都失败时应返回错误") @@ -1393,7 +1377,7 @@ func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := mustNewClient(t, "") + client := NewClient("") ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -1457,7 +1441,7 @@ func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := mustNewClient(t, "") + client := NewClient("") resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc") if err != nil { t.Fatalf("FetchAvailableModels 失败: %v", err) @@ -1512,7 +1496,7 @@ func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := mustNewClient(t, "") + client := NewClient("") _, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj") if err == nil { t.Fatal("服务器返回 403 时应返回错误") @@ -1532,7 +1516,7 @@ func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := mustNewClient(t, "") + client := NewClient("") _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err == nil { t.Fatal("无效 JSON 响应应返回错误") @@ -1562,7 +1546,7 @@ func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := mustNewClient(t, "") + client := NewClient("") resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err) @@ -1590,7 +1574,7 @@ func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := mustNewClient(t, "") + client := NewClient("") _, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err == nil { t.Fatal("所有 URL 都失败时应返回错误") @@ -1606,7 +1590,7 @@ func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := mustNewClient(t, "") + client := NewClient("") ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -1626,7 +1610,7 @@ func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server.URL}) - client := mustNewClient(t, "") + client := NewClient("") resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 失败: %v", err) @@ -1662,7 +1646,7 @@ func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := mustNewClient(t, "") + client := NewClient("") resp, _, err := client.LoadCodeAssist(context.Background(), "token") if err != nil { t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err) @@ -1688,7 +1672,7 @@ func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) { withMockBaseURLs(t, []string{server1.URL, server2.URL}) - client := mustNewClient(t, "") + client := NewClient("") resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj") if err != nil { t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err) diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index 25782c551..b13d66cb4 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -52,7 +52,4 @@ const ( // PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。 // Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。 PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id" - - // ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22") - ClaudeCodeVersion Key = "ctx_claude_code_version" ) diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 32e4bc5b2..6ef3d7141 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -18,11 +18,11 @@ package httpclient import ( "fmt" "net/http" + "net/url" "strings" "sync" "time" - "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) @@ -41,6 +41,7 @@ type Options struct { Timeout time.Duration // 请求总超时时间 ResponseHeaderTimeout time.Duration // 等待响应头超时时间 InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true) + ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退 ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding) AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用) @@ -119,13 +120,15 @@ func buildTransport(opts Options) (*http.Transport, error) { return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead") } - _, parsed, err := proxyurl.Parse(opts.ProxyURL) + proxyURL := strings.TrimSpace(opts.ProxyURL) + if proxyURL == "" { + return transport, nil + } + + parsed, err := url.Parse(proxyURL) if err != nil { return nil, err } - if parsed == nil { - return transport, nil - } if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil { return nil, err @@ -135,11 +138,12 @@ func buildTransport(opts Options) (*http.Transport, error) { } func buildClientKey(opts Options) string { - return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%d|%d|%d", + return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d", strings.TrimSpace(opts.ProxyURL), opts.Timeout.String(), opts.ResponseHeaderTimeout.String(), opts.InsecureSkipVerify, + opts.ProxyStrict, opts.ValidateResolvedIP, opts.AllowPrivateHosts, opts.MaxIdleConns, diff --git a/backend/internal/pkg/proxyutil/dialer.go b/backend/internal/pkg/proxyutil/dialer.go index e437cae34..91b224a28 100644 --- a/backend/internal/pkg/proxyutil/dialer.go +++ b/backend/internal/pkg/proxyutil/dialer.go @@ -2,11 +2,7 @@ // // 支持的代理协议: // - HTTP/HTTPS: 通过 Transport.Proxy 设置 -// - SOCKS5: 通过 Transport.DialContext 设置(客户端本地解析 DNS) -// - SOCKS5H: 通过 Transport.DialContext 设置(代理端远程解析 DNS,推荐) -// -// 注意:proxyurl.Parse() 会自动将 socks5:// 升级为 socks5h://, -// 确保 DNS 也由代理端解析,防止 DNS 泄漏。 +// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS) package proxyutil import ( @@ -24,8 +20,7 @@ import ( // // 支持的协议: // - http/https: 设置 transport.Proxy -// - socks5: 设置 transport.DialContext(客户端本地解析 DNS) -// - socks5h: 设置 transport.DialContext(代理端远程解析 DNS,推荐) +// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS) // // 参数: // - transport: 需要配置的 http.Transport diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index 0519c2cc1..f09bee8dd 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -2,6 +2,8 @@ package response import ( + "context" + "errors" "log" "math" "net/http" @@ -75,17 +77,45 @@ func ErrorFrom(c *gin.Context, err error) bool { return false } - statusCode, status := infraerrors.ToHTTP(err) + normalizedErr := normalizeHTTPError(c, err) + statusCode, status := infraerrors.ToHTTP(normalizedErr) // Log internal errors with full details for debugging if statusCode >= 500 && c.Request != nil { - log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error())) + log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(normalizedErr.Error())) } ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata) return true } +func normalizeHTTPError(c *gin.Context, err error) error { + if err == nil { + return nil + } + if c == nil || c.Request == nil { + return err + } + if isClientCanceledError(c.Request.Context(), err) { + return infraerrors.ClientClosed("CLIENT_CLOSED", "client closed request").WithCause(err) + } + return err +} + +func isClientCanceledError(reqCtx context.Context, err error) bool { + if reqCtx == nil { + return false + } + // 只有请求上下文本身被取消时,才认为是客户端断开; + // 避免将服务端主动 cancel 导致的 context.Canceled 误归为 499。 + if errors.Is(err, context.Canceled) && errors.Is(reqCtx.Err(), context.Canceled) { + return true + } + + // Some drivers can surface deadline errors after the request context was already canceled. + return errors.Is(err, context.DeadlineExceeded) && errors.Is(reqCtx.Err(), context.Canceled) +} + // BadRequest 返回400错误 func BadRequest(c *gin.Context, message string) { Error(c, http.StatusBadRequest, message) diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go index 0debce5fd..64918ca8b 100644 --- a/backend/internal/pkg/response/response_test.go +++ b/backend/internal/pkg/response/response_test.go @@ -3,8 +3,10 @@ package response import ( + "context" "encoding/json" "errors" + "fmt" "net/http" "net/http/httptest" "testing" @@ -107,11 +109,12 @@ func TestErrorFrom(t *testing.T) { gin.SetMode(gin.TestMode) tests := []struct { - name string - err error - wantWritten bool - wantHTTPCode int - wantBody Response + name string + err error + cancelRequestContext bool + wantWritten bool + wantHTTPCode int + wantBody Response }{ { name: "nil_error", @@ -184,12 +187,75 @@ func TestErrorFrom(t *testing.T) { Message: errors2.UnknownMessage, }, }, + { + name: "context_canceled_without_request_cancel_remains_500", + err: context.Canceled, + wantWritten: true, + wantHTTPCode: http.StatusInternalServerError, + wantBody: Response{ + Code: http.StatusInternalServerError, + Message: errors2.UnknownMessage, + }, + }, + { + name: "context_canceled_maps_to_499", + err: context.Canceled, + cancelRequestContext: true, + wantWritten: true, + wantHTTPCode: 499, + wantBody: Response{ + Code: 499, + Message: "client closed request", + Reason: "CLIENT_CLOSED", + }, + }, + { + name: "wrapped_context_canceled_maps_to_499", + err: fmt.Errorf("query aborted: %w", context.Canceled), + cancelRequestContext: true, + wantWritten: true, + wantHTTPCode: 499, + wantBody: Response{ + Code: 499, + Message: "client closed request", + Reason: "CLIENT_CLOSED", + }, + }, + { + name: "deadline_exceeded_without_request_cancel_remains_500", + err: context.DeadlineExceeded, + wantWritten: true, + wantHTTPCode: http.StatusInternalServerError, + wantBody: Response{ + Code: http.StatusInternalServerError, + Message: errors2.UnknownMessage, + }, + }, + { + name: "deadline_exceeded_with_request_canceled_maps_to_499", + err: context.DeadlineExceeded, + cancelRequestContext: true, + wantWritten: true, + wantHTTPCode: 499, + wantBody: Response{ + Code: 499, + Message: "client closed request", + Reason: "CLIENT_CLOSED", + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { w := httptest.NewRecorder() c, _ := gin.CreateTestContext(w) + req := httptest.NewRequest(http.MethodGet, "/test", nil) + if tt.cancelRequestContext { + ctx, cancel := context.WithCancel(req.Context()) + cancel() + req = req.WithContext(ctx) + } + c.Request = req written := ErrorFrom(c, tt.err) require.Equal(t, tt.wantWritten, written) diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 314a6d3c9..5f4e13f54 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -78,16 +78,6 @@ type ModelStat struct { ActualCost float64 `json:"actual_cost"` // 实际扣除 } -// GroupStat represents usage statistics for a single group -type GroupStat struct { - GroupID int64 `json:"group_id"` - GroupName string `json:"group_name"` - Requests int64 `json:"requests"` - TotalTokens int64 `json:"total_tokens"` - Cost float64 `json:"cost"` // 标准计费 - ActualCost float64 `json:"actual_cost"` // 实际扣除 -} - // UserUsageTrendPoint represents user usage trend data point type UserUsageTrendPoint struct { Date string `json:"date"` diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 0669cbbdd..3e922e7c3 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -1215,6 +1215,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates args = append(args, *updates.Schedulable) idx++ } + if updates.AutoPauseOnExpired != nil { + setClauses = append(setClauses, "auto_pause_on_expired = $"+itoa(idx)) + args = append(args, *updates.AutoPauseOnExpired) + idx++ + } // JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。 if len(updates.Credentials) > 0 { payload, err := json.Marshal(updates.Credentials) diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go index 4fbdae14f..447415ace 100644 --- a/backend/internal/repository/billing_cache.go +++ b/backend/internal/repository/billing_cache.go @@ -69,9 +69,20 @@ var ( deductBalanceScript = redis.NewScript(` local current = redis.call('GET', KEYS[1]) if current == false then + return 2 + end + local cur = tonumber(current) + local delta = tonumber(ARGV[1]) + if cur == nil or delta == nil then + return -1 + end + if delta < 0 then + return -2 + end + if cur < delta then return 0 end - local newVal = tonumber(current) - tonumber(ARGV[1]) + local newVal = cur - delta redis.call('SET', KEYS[1], newVal) redis.call('EXPIRE', KEYS[1], ARGV[2]) return 1 @@ -130,12 +141,26 @@ func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { key := billingBalanceKey(userID) - _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result() + result, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Int64() if err != nil && !errors.Is(err, redis.Nil) { log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err) return err } - return nil + switch result { + case 1: + return nil + case 2: + // 缓存 key 不存在(已过期),返回特定错误让调用方区分处理 + return service.ErrBalanceCacheNotFound + case 0: + return service.ErrInsufficientBalance + case -1: + return fmt.Errorf("invalid cached balance for user %d", userID) + case -2: + return fmt.Errorf("invalid deduct amount for user %d", userID) + default: + return fmt.Errorf("unexpected deduct balance cache result for user %d: %d", userID, result) + } } func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error { diff --git a/backend/internal/repository/billing_cache_integration_test.go b/backend/internal/repository/billing_cache_integration_test.go index 4b7377b12..dffe0b6e4 100644 --- a/backend/internal/repository/billing_cache_integration_test.go +++ b/backend/internal/repository/billing_cache_integration_test.go @@ -31,14 +31,15 @@ func (s *BillingCacheSuite) TestUserBalance() { }, }, { - name: "deduct_on_nonexistent_is_noop", + name: "deduct_on_nonexistent_returns_not_found", fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) { userID := int64(1) balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID) - require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error") + err := cache.DeductUserBalance(ctx, userID, 1) + require.ErrorIs(s.T(), err, service.ErrBalanceCacheNotFound, "DeductUserBalance on non-existent key should return ErrBalanceCacheNotFound") - _, err := rdb.Get(ctx, balanceKey).Result() + _, err = rdb.Get(ctx, balanceKey).Result() require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent") }, }, @@ -278,8 +279,8 @@ func (s *BillingCacheSuite) TestSubscriptionCache() { } } -// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复: -// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。 +// TestDeductUserBalance_ErrorPropagation 验证修复: +// Redis 真实错误应传播,key 不存在应返回 ErrBalanceCacheNotFound。 func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() { tests := []struct { name string @@ -287,11 +288,11 @@ func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() { expectErr bool }{ { - name: "key_not_exists_returns_nil", + name: "key_not_exists_returns_ErrBalanceCacheNotFound", fn: func(ctx context.Context, cache service.BillingCache) { - // key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误 + // key 不存在时,Lua 脚本返回 2,应返回 ErrBalanceCacheNotFound err := cache.DeductUserBalance(ctx, 99999, 1.0) - require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil") + require.ErrorIs(s.T(), err, service.ErrBalanceCacheNotFound, "DeductUserBalance on non-existent key should return ErrBalanceCacheNotFound") }, }, { diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index b754bd55e..77764881e 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -11,7 +11,6 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/oauth" - "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/util/logredact" @@ -29,14 +28,11 @@ func NewClaudeOAuthClient() service.ClaudeOAuthClient { type claudeOAuthService struct { baseURL string tokenURL string - clientFactory func(proxyURL string) (*req.Client, error) + clientFactory func(proxyURL string) *req.Client } func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { - client, err := s.clientFactory(proxyURL) - if err != nil { - return "", fmt.Errorf("create HTTP client: %w", err) - } + client := s.clientFactory(proxyURL) var orgs []struct { UUID string `json:"uuid"` @@ -92,10 +88,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey } func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) { - client, err := s.clientFactory(proxyURL) - if err != nil { - return "", fmt.Errorf("create HTTP client: %w", err) - } + client := s.clientFactory(proxyURL) authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID) @@ -172,10 +165,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe } func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) { - client, err := s.clientFactory(proxyURL) - if err != nil { - return nil, fmt.Errorf("create HTTP client: %w", err) - } + client := s.clientFactory(proxyURL) // Parse code which may contain state in format "authCode#state" authCode := code @@ -233,10 +223,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod } func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { - client, err := s.clientFactory(proxyURL) - if err != nil { - return nil, fmt.Errorf("create HTTP client: %w", err) - } + client := s.clientFactory(proxyURL) reqBody := map[string]any{ "grant_type": "refresh_token", @@ -266,20 +253,16 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro return &tokenResp, nil } -func createReqClient(proxyURL string) (*req.Client, error) { +func createReqClient(proxyURL string) *req.Client { // 禁用 CookieJar,确保每次授权都是干净的会话 client := req.C(). SetTimeout(60 * time.Second). ImpersonateChrome(). SetCookieJar(nil) // 禁用 CookieJar - trimmed, _, err := proxyurl.Parse(proxyURL) - if err != nil { - return nil, err - } - if trimmed != "" { - client.SetProxyURL(trimmed) + if strings.TrimSpace(proxyURL) != "" { + client.SetProxyURL(strings.TrimSpace(proxyURL)) } - return client, nil + return client } diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go index c63830338..7395c6d82 100644 --- a/backend/internal/repository/claude_oauth_service_test.go +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -91,7 +91,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } + s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "") @@ -169,7 +169,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.baseURL = "http://in-process" - s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } + s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "") @@ -276,7 +276,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } + s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken) @@ -372,7 +372,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() { require.True(s.T(), ok, "type assertion failed") s.client = client s.client.tokenURL = "http://in-process/token" - s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil } + s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } resp, err := s.client.RefreshToken(context.Background(), "rt", "") diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index f6054828c..1198f4725 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -83,7 +83,7 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se AllowPrivateHosts: s.allowPrivateHosts, }) if err != nil { - return nil, fmt.Errorf("create http client failed: %w", err) + client = &http.Client{Timeout: 30 * time.Second} } resp, err = client.Do(req) diff --git a/backend/internal/repository/claude_usage_service_test.go b/backend/internal/repository/claude_usage_service_test.go index cbd0b6d3e..2e10f3e5b 100644 --- a/backend/internal/repository/claude_usage_service_test.go +++ b/backend/internal/repository/claude_usage_service_test.go @@ -50,7 +50,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() { allowPrivateHosts: true, } - resp, err := s.fetcher.FetchUsage(context.Background(), "at", "") + resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") require.NoError(s.T(), err, "FetchUsage") require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch") require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch") @@ -112,17 +112,6 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() { require.Error(s.T(), err, "expected error for cancelled context") } -func (s *ClaudeUsageServiceSuite) TestFetchUsage_InvalidProxyReturnsError() { - s.fetcher = &claudeUsageService{ - usageURL: "http://example.com", - allowPrivateHosts: true, - } - - _, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url") - require.Error(s.T(), err) - require.ErrorContains(s.T(), err, "create http client failed") -} - func TestClaudeUsageServiceSuite(t *testing.T) { suite.Run(t, new(ClaudeUsageServiceSuite)) } diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 58291b665..0f193c7dc 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -2,14 +2,20 @@ package repository import ( "context" + "encoding/json" "fmt" + "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/cespare/xxhash/v2" "github.com/redis/go-redis/v9" ) const stickySessionPrefix = "sticky_session:" +const openAIWSSessionLastResponsePrefix = "openai_ws_session_last_response:" +const openAIWSResponsePendingToolCallsPrefix = "openai_ws_response_pending_tool_calls:" type gatewayCache struct { rdb *redis.Client @@ -25,6 +31,20 @@ func buildSessionKey(groupID int64, sessionHash string) string { return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash) } +func buildOpenAIWSSessionLastResponseKey(groupID int64, sessionHash string) string { + return fmt.Sprintf("%s%d:%s", openAIWSSessionLastResponsePrefix, groupID, sessionHash) +} + +func buildOpenAIWSResponsePendingToolCallsKey(groupID int64, responseID string) string { + id := strings.TrimSpace(responseID) + if id == "" { + return "" + } + return openAIWSResponsePendingToolCallsPrefix + + strconv.FormatInt(groupID, 10) + ":" + + strconv.FormatUint(xxhash.Sum64String(id), 16) +} + func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) { key := buildSessionKey(groupID, sessionHash) return c.rdb.Get(ctx, key).Int64() @@ -51,3 +71,78 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64 key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() } + +func (c *gatewayCache) SetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash, responseID string, ttl time.Duration) error { + key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash) + return c.rdb.Set(ctx, key, responseID, ttl).Err() +} + +func (c *gatewayCache) GetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) (string, error) { + key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash) + return c.rdb.Get(ctx, key).Result() +} + +func (c *gatewayCache) DeleteOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) error { + key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash) + return c.rdb.Del(ctx, key).Err() +} + +func (c *gatewayCache) SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error { + key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID) + if key == "" { + return nil + } + normalizedCallIDs := normalizeOpenAIWSResponsePendingToolCallIDs(callIDs) + if len(normalizedCallIDs) == 0 { + return c.rdb.Del(ctx, key).Err() + } + raw, err := json.Marshal(normalizedCallIDs) + if err != nil { + return err + } + return c.rdb.Set(ctx, key, raw, ttl).Err() +} + +func (c *gatewayCache) GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error) { + key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID) + if key == "" { + return nil, nil + } + raw, err := c.rdb.Get(ctx, key).Bytes() + if err != nil { + return nil, err + } + var callIDs []string + if err := json.Unmarshal(raw, &callIDs); err != nil { + return nil, err + } + return normalizeOpenAIWSResponsePendingToolCallIDs(callIDs), nil +} + +func (c *gatewayCache) DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error { + key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID) + if key == "" { + return nil + } + return c.rdb.Del(ctx, key).Err() +} + +func normalizeOpenAIWSResponsePendingToolCallIDs(callIDs []string) []string { + if len(callIDs) == 0 { + return nil + } + seen := make(map[string]struct{}, len(callIDs)) + normalized := make([]string, 0, len(callIDs)) + for _, callID := range callIDs { + id := strings.TrimSpace(callID) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + normalized = append(normalized, id) + } + return normalized +} diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go index 0eebc33f6..093b45ca4 100644 --- a/backend/internal/repository/gateway_cache_integration_test.go +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -3,6 +3,7 @@ package repository import ( + "context" "errors" "testing" "time" @@ -104,6 +105,49 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") } +func (s *GatewayCacheSuite) TestSetAndGetOpenAIWSResponsePendingToolCalls() { + type responsePendingToolCallsCache interface { + SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error + GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error) + } + cache, ok := s.cache.(responsePendingToolCallsCache) + require.True(s.T(), ok, "gateway cache should implement pending tool calls cache") + + responseID := "resp_pending_integration_1" + groupID := int64(1) + ttl := 2 * time.Minute + require.NoError(s.T(), cache.SetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID, []string{"call_1", "call_2", "call_1", " "}, ttl)) + + callIDs, err := cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID) + require.NoError(s.T(), err) + require.ElementsMatch(s.T(), []string{"call_1", "call_2"}, callIDs) + _, err = cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID+1, responseID) + require.True(s.T(), errors.Is(err, redis.Nil), "pending tool calls should be isolated by group") + + key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID) + remainingTTL, ttlErr := s.rdb.TTL(s.ctx, key).Result() + require.NoError(s.T(), ttlErr) + s.AssertTTLWithin(remainingTTL, 1*time.Second, ttl) +} + +func (s *GatewayCacheSuite) TestDeleteOpenAIWSResponsePendingToolCalls() { + type responsePendingToolCallsCache interface { + SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error + GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error) + DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error + } + cache, ok := s.cache.(responsePendingToolCallsCache) + require.True(s.T(), ok, "gateway cache should implement pending tool calls cache") + + responseID := "resp_pending_integration_2" + groupID := int64(1) + require.NoError(s.T(), cache.SetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID, []string{"call_3"}, time.Minute)) + require.NoError(s.T(), cache.DeleteOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID)) + + _, err := cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID) + require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete") +} + func TestGatewayCacheSuite(t *testing.T) { suite.Run(t, new(GatewayCacheSuite)) } diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go index eb14f3134..8b7fe625c 100644 --- a/backend/internal/repository/gemini_oauth_client.go +++ b/backend/internal/repository/gemini_oauth_client.go @@ -26,10 +26,7 @@ func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient { } func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) { - client, err := createGeminiReqClient(proxyURL) - if err != nil { - return nil, fmt.Errorf("create HTTP client: %w", err) - } + client := createGeminiReqClient(proxyURL) // Use different OAuth clients based on oauthType: // - code_assist: always use built-in Gemini CLI OAuth client (public) @@ -75,10 +72,7 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c } func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) { - client, err := createGeminiReqClient(proxyURL) - if err != nil { - return nil, fmt.Errorf("create HTTP client: %w", err) - } + client := createGeminiReqClient(proxyURL) oauthCfgInput := geminicli.OAuthConfig{ ClientID: c.cfg.Gemini.OAuth.ClientID, @@ -117,7 +111,7 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh return &tokenResp, nil } -func createGeminiReqClient(proxyURL string) (*req.Client, error) { +func createGeminiReqClient(proxyURL string) *req.Client { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 60 * time.Second, diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index b5bc64972..4f63280d5 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -26,11 +26,7 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo } var out geminicli.LoadCodeAssistResponse - client, err := createGeminiCliReqClient(proxyURL) - if err != nil { - return nil, fmt.Errorf("create HTTP client: %w", err) - } - resp, err := client.R(). + resp, err := createGeminiCliReqClient(proxyURL).R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -70,11 +66,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody) var out geminicli.OnboardUserResponse - client, err := createGeminiCliReqClient(proxyURL) - if err != nil { - return nil, fmt.Errorf("create HTTP client: %w", err) - } - resp, err := client.R(). + resp, err := createGeminiCliReqClient(proxyURL).R(). SetContext(ctx). SetHeader("Authorization", "Bearer "+accessToken). SetHeader("Content-Type", "application/json"). @@ -106,7 +98,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken return &out, nil } -func createGeminiCliReqClient(proxyURL string) (*req.Client, error) { +func createGeminiCliReqClient(proxyURL string) *req.Client { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 30 * time.Second, diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go index ad1f22e39..28efe914a 100644 --- a/backend/internal/repository/github_release_service.go +++ b/backend/internal/repository/github_release_service.go @@ -5,10 +5,8 @@ import ( "encoding/json" "fmt" "io" - "log/slog" "net/http" "os" - "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" @@ -26,19 +24,13 @@ type githubReleaseClientError struct { // NewGitHubReleaseClient 创建 GitHub Release 客户端 // proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议 -// 代理配置失败时行为由 allowDirectOnProxyError 控制: -// - false(默认):返回错误占位客户端,禁止回退到直连 -// - true:回退到直连(仅限管理员显式开启) func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient { - // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, - // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { - if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { - slog.Warn("proxy client init failed, all requests will fail", "service", "github_release", "error", err) + if proxyURL != "" && !allowDirectOnProxyError { return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} } sharedClient = &http.Client{Timeout: 30 * time.Second} @@ -50,8 +42,7 @@ func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) servi ProxyURL: proxyURL, }) if err != nil { - if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { - slog.Warn("proxy download client init failed, all requests will fail", "service", "github_release", "error", err) + if proxyURL != "" && !allowDirectOnProxyError { return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} } downloadClient = &http.Client{Timeout: 10 * time.Minute} diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index 4edc85340..e9b4902ac 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -127,38 +127,6 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetMcpXMLInject(groupIn.MCPXMLInject). SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) - // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 - if groupIn.DailyLimitUSD != nil { - builder = builder.SetDailyLimitUsd(*groupIn.DailyLimitUSD) - } else { - builder = builder.ClearDailyLimitUsd() - } - if groupIn.WeeklyLimitUSD != nil { - builder = builder.SetWeeklyLimitUsd(*groupIn.WeeklyLimitUSD) - } else { - builder = builder.ClearWeeklyLimitUsd() - } - if groupIn.MonthlyLimitUSD != nil { - builder = builder.SetMonthlyLimitUsd(*groupIn.MonthlyLimitUSD) - } else { - builder = builder.ClearMonthlyLimitUsd() - } - if groupIn.ImagePrice1K != nil { - builder = builder.SetImagePrice1k(*groupIn.ImagePrice1K) - } else { - builder = builder.ClearImagePrice1k() - } - if groupIn.ImagePrice2K != nil { - builder = builder.SetImagePrice2k(*groupIn.ImagePrice2K) - } else { - builder = builder.ClearImagePrice2k() - } - if groupIn.ImagePrice4K != nil { - builder = builder.SetImagePrice4k(*groupIn.ImagePrice4K) - } else { - builder = builder.ClearImagePrice4k() - } - // 处理 FallbackGroupID:nil 时清除,否则设置 if groupIn.FallbackGroupID != nil { builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID) diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index a4674c1a3..a9df13229 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -1,6 +1,7 @@ package repository import ( + "crypto/tls" "errors" "fmt" "io" @@ -14,7 +15,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/Wei-Shaw/sub2api/internal/service" @@ -45,6 +45,16 @@ const ( defaultMaxUpstreamClients = 5000 // defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟) defaultClientIdleTTLSeconds = 900 + // OpenAI HTTP/2 代理回退策略默认值 + defaultOpenAIHTTP2FallbackErrorThreshold = 2 + defaultOpenAIHTTP2FallbackWindow = 60 * time.Second + defaultOpenAIHTTP2FallbackTTL = 10 * time.Minute +) + +const ( + upstreamProtocolModeDefault = "default" + upstreamProtocolModeOpenAIH2 = "openai_h2" + upstreamProtocolModeOpenAIH1Fallback = "openai_h1_fallback" ) var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached") @@ -59,14 +69,30 @@ type poolSettings struct { responseHeaderTimeout time.Duration // 等待响应头超时时间 } +type openAIHTTP2Settings struct { + enabled bool + allowProxyFallbackToHTTP1 bool + fallbackErrorThreshold int + fallbackWindow time.Duration + fallbackTTL time.Duration +} + // upstreamClientEntry 上游客户端缓存条目 // 记录客户端实例及其元数据,用于连接池管理和淘汰策略 type upstreamClientEntry struct { - client *http.Client // HTTP 客户端实例 - proxyKey string // 代理标识(用于检测代理变更) - poolKey string // 连接池配置标识(用于检测配置变更) - lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰 - inFlight int64 // 当前进行中的请求数,>0 时不可淘汰 + client *http.Client // HTTP 客户端实例 + proxyKey string // 代理标识(用于检测代理变更) + poolKey string // 连接池配置标识(用于检测配置变更) + protocolMode string // 协议模式(default/openai_h2/openai_h1_fallback) + lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰 + inFlight int64 // 当前进行中的请求数,>0 时不可淘汰 +} + +type openAIHTTP2FallbackState struct { + mu sync.Mutex + windowStart time.Time + errorCount int + fallbackUntil time.Time } // httpUpstreamService 通用 HTTP 上游服务 @@ -90,6 +116,8 @@ type httpUpstreamService struct { cfg *config.Config // 全局配置 mu sync.RWMutex // 保护 clients map 的读写锁 clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定 + // OpenAI 走 HTTP 代理时的 H2->H1 回退状态(key=标准化 proxyKey) + openAIHTTP2Fallbacks sync.Map } // NewHTTPUpstream 创建通用 HTTP 上游服务 @@ -127,9 +155,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i if err := s.validateRequestHost(req); err != nil { return nil, err } + profile := service.HTTPUpstreamProfileDefault + if req != nil { + profile = service.HTTPUpstreamProfileFromContext(req.Context()) + } // 获取或创建对应的客户端,并标记请求占用 - entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency) + entry, err := s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, profile) if err != nil { return nil, err } @@ -137,11 +169,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i // 执行请求 resp, err := entry.client.Do(req) if err != nil { + s.recordOpenAIHTTP2Failure(profile, entry.protocolMode, entry.proxyKey, err) // 请求失败,立即减少计数 atomic.AddInt64(&entry.inFlight, -1) atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) return nil, err } + s.recordOpenAIHTTP2Success(profile, entry.protocolMode, entry.proxyKey) // 包装响应体,在关闭时自动减少计数并更新时间戳 // 这确保了流式响应(如 SSE)在完全读取前不会被淘汰 @@ -236,13 +270,10 @@ func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID in // TLS 指纹客户端使用独立的缓存键,与普通客户端隔离 func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { isolation := s.getIsolationMode() - proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) - if err != nil { - return nil, err - } + proxyKey, parsedProxy := normalizeProxyURL(proxyURL) // TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀 - cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID) - poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls" + cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID, upstreamProtocolModeDefault) + poolKey := s.buildPoolKey(isolation, accountConcurrency, upstreamProtocolModeDefault) + ":tls" now := time.Now() nowUnix := now.UnixNano() @@ -359,7 +390,12 @@ func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Req // acquireClient 获取或创建客户端,并标记为进行中请求 // 用于请求路径,避免在获取后被淘汰 func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { - return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true) + return s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault) +} + +// acquireClientWithProfile 获取或创建客户端,并按请求 profile 选择协议策略。 +func (s *httpUpstreamService) acquireClientWithProfile(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile) (*upstreamClientEntry, error) { + return s.getClientEntry(proxyURL, accountID, accountConcurrency, profile, true, true) } // getOrCreateClient 获取或创建客户端 @@ -377,25 +413,25 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac // - proxy: 按代理地址隔离,同一代理共享客户端 // - account: 按账户隔离,同一账户共享客户端(代理变更时重建) // - account_proxy: 按账户+代理组合隔离,最细粒度 -func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) { - return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false) +func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry { + entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault, false, false) + return entry } // getClientEntry 获取或创建客户端条目 // markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰 // enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误 -func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { +func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { // 获取隔离模式 isolation := s.getIsolationMode() // 标准化代理 URL 并解析 - proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL) - if err != nil { - return nil, err - } + proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + // 根据请求 profile(例如 OpenAI)选择协议模式 + protocolMode := s.resolveProtocolMode(profile, proxyKey, parsedProxy) // 构建缓存键(根据隔离策略不同) - cacheKey := buildCacheKey(isolation, proxyKey, accountID) + cacheKey := buildCacheKey(isolation, proxyKey, accountID, protocolMode) // 构建连接池配置键(用于检测配置变更) - poolKey := s.buildPoolKey(isolation, accountConcurrency) + poolKey := s.buildPoolKey(isolation, accountConcurrency, protocolMode) now := time.Now() nowUnix := now.UnixNano() @@ -439,7 +475,7 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a // 缓存未命中或需要重建,创建新客户端 settings := s.resolvePoolSettings(isolation, accountConcurrency) - transport, err := buildUpstreamTransport(settings, parsedProxy) + transport, err := buildUpstreamTransport(settings, parsedProxy, protocolMode) if err != nil { s.mu.Unlock() return nil, fmt.Errorf("build transport: %w", err) @@ -449,9 +485,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a client.CheckRedirect = s.redirectChecker } entry := &upstreamClientEntry{ - client: client, - proxyKey: proxyKey, - poolKey: poolKey, + client: client, + proxyKey: proxyKey, + poolKey: poolKey, + protocolMode: protocolMode, } atomic.StoreInt64(&entry.lastUsed, nowUnix) if markInFlight { @@ -644,13 +681,17 @@ func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcu // // 返回: // - string: 配置键 -func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string { +func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int, protocolMode string) string { + base := "default" if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy { if accountConcurrency > 0 { - return fmt.Sprintf("account:%d", accountConcurrency) + base = fmt.Sprintf("account:%d", accountConcurrency) } } - return "default" + if protocolMode == "" || protocolMode == upstreamProtocolModeDefault { + return base + } + return base + "|proto:" + protocolMode } // buildCacheKey 构建客户端缓存键 @@ -668,15 +709,20 @@ func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency // - proxy 模式: "proxy:{proxyKey}" // - account 模式: "account:{accountID}" // - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}" -func buildCacheKey(isolation, proxyKey string, accountID int64) string { +func buildCacheKey(isolation, proxyKey string, accountID int64, protocolMode string) string { + var base string switch isolation { case config.ConnectionPoolIsolationAccount: - return fmt.Sprintf("account:%d", accountID) + base = fmt.Sprintf("account:%d", accountID) case config.ConnectionPoolIsolationAccountProxy: - return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey) + base = fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey) default: - return fmt.Sprintf("proxy:%s", proxyKey) + base = fmt.Sprintf("proxy:%s", proxyKey) + } + if protocolMode != "" && protocolMode != upstreamProtocolModeDefault { + base += "|proto:" + protocolMode } + return base } // normalizeProxyURL 标准化代理 URL @@ -686,18 +732,17 @@ func buildCacheKey(isolation, proxyKey string, accountID int64) string { // - raw: 原始代理 URL 字符串 // // 返回: -// - string: 标准化的代理键(空返回 "direct") -// - *url.URL: 解析后的 URL(空返回 nil) -// - error: 非空代理 URL 解析失败时返回错误(禁止回退到直连) -func normalizeProxyURL(raw string) (string, *url.URL, error) { - _, parsed, err := proxyurl.Parse(raw) +// - string: 标准化的代理键(空或解析失败返回 "direct") +// - *url.URL: 解析后的 URL(空或解析失败返回 nil) +func normalizeProxyURL(raw string) (string, *url.URL) { + proxyURL := strings.TrimSpace(raw) + if proxyURL == "" { + return directProxyKey, nil + } + parsed, err := url.Parse(proxyURL) if err != nil { - return "", nil, err - } - if parsed == nil { - return directProxyKey, nil, nil + return directProxyKey, nil } - // 规范化:小写 scheme/host,去除路径和查询参数 parsed.Scheme = strings.ToLower(parsed.Scheme) parsed.Host = strings.ToLower(parsed.Host) parsed.Path = "" @@ -717,7 +762,200 @@ func normalizeProxyURL(raw string) (string, *url.URL, error) { parsed.Host = hostname } } - return parsed.String(), parsed, nil + return parsed.String(), parsed +} + +func (s *httpUpstreamService) resolveOpenAIHTTP2Settings() openAIHTTP2Settings { + settings := openAIHTTP2Settings{ + enabled: true, + allowProxyFallbackToHTTP1: true, + fallbackErrorThreshold: defaultOpenAIHTTP2FallbackErrorThreshold, + fallbackWindow: defaultOpenAIHTTP2FallbackWindow, + fallbackTTL: defaultOpenAIHTTP2FallbackTTL, + } + if s == nil || s.cfg == nil { + return settings + } + cfg := s.cfg.Gateway.OpenAIHTTP2 + settings.enabled = cfg.Enabled + settings.allowProxyFallbackToHTTP1 = cfg.AllowProxyFallbackToHTTP1 + if cfg.FallbackErrorThreshold > 0 { + settings.fallbackErrorThreshold = cfg.FallbackErrorThreshold + } + if cfg.FallbackWindowSeconds > 0 { + settings.fallbackWindow = time.Duration(cfg.FallbackWindowSeconds) * time.Second + } + if cfg.FallbackTTLSeconds > 0 { + settings.fallbackTTL = time.Duration(cfg.FallbackTTLSeconds) * time.Second + } + return settings +} + +func (s *httpUpstreamService) resolveProtocolMode(profile service.HTTPUpstreamProfile, proxyKey string, parsedProxy *url.URL) string { + if profile != service.HTTPUpstreamProfileOpenAI { + return upstreamProtocolModeDefault + } + settings := s.resolveOpenAIHTTP2Settings() + if !settings.enabled { + return upstreamProtocolModeDefault + } + if parsedProxy == nil { + return upstreamProtocolModeOpenAIH2 + } + scheme := strings.ToLower(parsedProxy.Scheme) + if scheme != "http" && scheme != "https" { + return upstreamProtocolModeOpenAIH2 + } + if settings.allowProxyFallbackToHTTP1 && s.isOpenAIHTTP2FallbackActive(proxyKey) { + return upstreamProtocolModeOpenAIH1Fallback + } + return upstreamProtocolModeOpenAIH2 +} + +func (s *httpUpstreamService) isOpenAIHTTP2FallbackActive(proxyKey string) bool { + raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey) + if !ok { + return false + } + state, ok := raw.(*openAIHTTP2FallbackState) + if !ok || state == nil { + return false + } + return state.isFallbackActive(time.Now()) +} + +func (s *httpUpstreamService) getOrCreateOpenAIHTTP2FallbackState(proxyKey string) *openAIHTTP2FallbackState { + state := &openAIHTTP2FallbackState{} + actual, _ := s.openAIHTTP2Fallbacks.LoadOrStore(proxyKey, state) + cached, ok := actual.(*openAIHTTP2FallbackState) + if !ok || cached == nil { + return state + } + return cached +} + +func isHTTPProxyKey(proxyKey string) bool { + return strings.HasPrefix(proxyKey, "http://") || strings.HasPrefix(proxyKey, "https://") +} + +func isOpenAIHTTP2CompatibilityError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + if msg == "" { + return false + } + markers := []string{ + "http2", + "alpn", + "no application protocol", + "protocol error", + "stream error", + "goaway", + "refused_stream", + "frame too large", + } + for _, marker := range markers { + if strings.Contains(msg, marker) { + return true + } + } + return false +} + +func (s *httpUpstreamService) recordOpenAIHTTP2Failure(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string, err error) { + if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 { + return + } + settings := s.resolveOpenAIHTTP2Settings() + if !settings.enabled || !settings.allowProxyFallbackToHTTP1 { + return + } + if !isHTTPProxyKey(proxyKey) || !isOpenAIHTTP2CompatibilityError(err) { + return + } + state := s.getOrCreateOpenAIHTTP2FallbackState(proxyKey) + activated, until := state.recordFailure(time.Now(), settings.fallbackErrorThreshold, settings.fallbackWindow, settings.fallbackTTL) + if activated { + slog.Warn("openai_http2_proxy_fallback_activated", + "proxy", proxyKey, + "fallback_until", until.Format(time.RFC3339)) + } +} + +func (s *httpUpstreamService) recordOpenAIHTTP2Success(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string) { + if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 { + return + } + if !isHTTPProxyKey(proxyKey) { + return + } + raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey) + if !ok { + return + } + state, ok := raw.(*openAIHTTP2FallbackState) + if !ok || state == nil { + return + } + state.resetErrorWindow() +} + +func (s *openAIHTTP2FallbackState) isFallbackActive(now time.Time) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.fallbackUntil.IsZero() { + return false + } + if now.Before(s.fallbackUntil) { + return true + } + s.fallbackUntil = time.Time{} + return false +} + +func (s *openAIHTTP2FallbackState) resetErrorWindow() { + s.mu.Lock() + defer s.mu.Unlock() + s.windowStart = time.Time{} + s.errorCount = 0 +} + +func (s *openAIHTTP2FallbackState) recordFailure(now time.Time, threshold int, window, ttl time.Duration) (bool, time.Time) { + if threshold <= 0 { + threshold = defaultOpenAIHTTP2FallbackErrorThreshold + } + if window <= 0 { + window = defaultOpenAIHTTP2FallbackWindow + } + if ttl <= 0 { + ttl = defaultOpenAIHTTP2FallbackTTL + } + + s.mu.Lock() + defer s.mu.Unlock() + + if !s.fallbackUntil.IsZero() && now.Before(s.fallbackUntil) { + return false, s.fallbackUntil + } + if !s.fallbackUntil.IsZero() && !now.Before(s.fallbackUntil) { + s.fallbackUntil = time.Time{} + } + + if s.windowStart.IsZero() || now.Sub(s.windowStart) > window { + s.windowStart = now + s.errorCount = 0 + } + s.errorCount++ + if s.errorCount < threshold { + return false, time.Time{} + } + + s.fallbackUntil = now.Add(ttl) + s.windowStart = time.Time{} + s.errorCount = 0 + return true, s.fallbackUntil } // defaultPoolSettings 获取默认连接池配置 @@ -779,7 +1017,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings { // - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待) // - IdleConnTimeout: 空闲连接超时(超时后关闭) // - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输) -func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) { +func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL, protocolMode string) (*http.Transport, error) { transport := &http.Transport{ MaxIdleConns: settings.maxIdleConns, MaxIdleConnsPerHost: settings.maxIdleConnsPerHost, @@ -787,6 +1025,14 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra IdleConnTimeout: settings.idleConnTimeout, ResponseHeaderTimeout: settings.responseHeaderTimeout, } + switch protocolMode { + case upstreamProtocolModeOpenAIH2: + transport.ForceAttemptHTTP2 = true + case upstreamProtocolModeOpenAIH1Fallback: + // 显式禁用 HTTP/2,确保代理不兼容场景回退到 HTTP/1.1。 + transport.ForceAttemptHTTP2 = false + transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper) + } if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil { return nil, err } diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go index 89892b3b6..ebeee640e 100644 --- a/backend/internal/repository/http_upstream_benchmark_test.go +++ b/backend/internal/repository/http_upstream_benchmark_test.go @@ -45,7 +45,7 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { settings := defaultPoolSettings(cfg) for i := 0; i < b.N; i++ { // 每次迭代都创建新客户端,包含 Transport 分配 - transport, err := buildUpstreamTransport(settings, parsedProxy) + transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeDefault) if err != nil { b.Fatalf("创建 Transport 失败: %v", err) } @@ -59,10 +59,7 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) { // 模拟优化后的行为,从缓存获取客户端 b.Run("复用", func(b *testing.B) { // 预热:确保客户端已缓存 - entry, err := svc.getOrCreateClient(proxyURL, 1, 1) - if err != nil { - b.Fatalf("getOrCreateClient: %v", err) - } + entry := svc.getOrCreateClient(proxyURL, 1, 1) client := entry.client b.ResetTimer() // 重置计时器,排除预热时间 for i := 0; i < b.N; i++ { diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go index b3268463a..35ee9adde 100644 --- a/backend/internal/repository/http_upstream_test.go +++ b/backend/internal/repository/http_upstream_test.go @@ -1,13 +1,19 @@ package repository import ( + "context" + "errors" "io" "net/http" + "net/url" + "strings" "sync/atomic" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" + "github.com/Wei-Shaw/sub2api/internal/service" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" ) @@ -44,7 +50,7 @@ func (s *HTTPUpstreamSuite) newService() *httpUpstreamService { // 验证未配置时使用 300 秒默认值 func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { svc := s.newService() - entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) + entry := svc.getOrCreateClient("", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") @@ -55,27 +61,25 @@ func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() { func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() { s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7} svc := s.newService() - entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0) + entry := svc.getOrCreateClient("", 0, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch") } -// TestGetOrCreateClient_InvalidURLReturnsError 测试无效代理 URL 返回错误 -// 验证解析失败时拒绝回退到直连模式 -func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLReturnsError() { +// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退 +// 验证解析失败时回退到直连模式 +func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() { svc := s.newService() - _, err := svc.getClientEntry("://bad-proxy-url", 1, 1, false, false) - require.Error(s.T(), err, "expected error for invalid proxy URL") + entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1) + require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback") } // TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化 // 验证等价地址能够映射到同一缓存键 func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() { - key1, _, err1 := normalizeProxyURL("http://proxy.local:8080") - require.NoError(s.T(), err1) - key2, _, err2 := normalizeProxyURL("http://proxy.local:8080/") - require.NoError(s.T(), err2) + key1, _ := normalizeProxyURL("http://proxy.local:8080") + key2, _ := normalizeProxyURL("http://proxy.local:8080/") require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match") } @@ -116,6 +120,16 @@ func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() { require.Equal(s.T(), "direct", string(b), "unexpected body") } +func (s *HTTPUpstreamSuite) TestDo_RequestErrorPath() { + svc := s.newService() + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:1/unreachable", nil) + require.NoError(s.T(), err) + + resp, doErr := svc.Do(req, "", 1, 1) + require.Nil(s.T(), resp) + require.Error(s.T(), doErr) +} + // TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能 // 验证请求通过代理服务器转发,使用绝对 URI 格式 func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() { @@ -173,8 +187,8 @@ func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一代理,不同账户 - entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 1, 3) - entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 2, 3) + entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3) + entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3) require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池") require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端") } @@ -185,8 +199,8 @@ func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy} svc := s.newService() // 同一账户,不同代理 - entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) - entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) + entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) + entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理") require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端") } @@ -197,8 +211,8 @@ func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 同一账户,先后使用不同代理 - entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3) - entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3) + entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3) + entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3) require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池") require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池") require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理") @@ -210,7 +224,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() { s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount} svc := s.newService() // 账户并发数为 12 - entry := mustGetOrCreateClient(s.T(), svc, "", 1, 12) + entry := svc.getOrCreateClient("", 1, 12) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") // 连接池参数应与并发数一致 @@ -230,7 +244,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() { } svc := s.newService() // 账户并发数为 0,应使用全局配置 - entry := mustGetOrCreateClient(s.T(), svc, "", 1, 0) + entry := svc.getOrCreateClient("", 1, 0) transport, ok := entry.client.Transport.(*http.Transport) require.True(s.T(), ok, "expected *http.Transport") require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch") @@ -247,12 +261,12 @@ func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() { } svc := s.newService() // 创建两个客户端,设置不同的最后使用时间 - entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 1) - entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 2, 1) + entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1) + entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1) atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久 atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano()) // 创建第三个客户端,触发淘汰 - _ = mustGetOrCreateClient(s.T(), svc, "http://proxy-c:8080", 3, 1) + _ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1) require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内") require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理") @@ -266,29 +280,446 @@ func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() { ClientIdleTTLSeconds: 1, // 1 秒空闲超时 } svc := s.newService() - entry1 := mustGetOrCreateClient(s.T(), svc, "", 1, 1) + entry1 := svc.getOrCreateClient("", 1, 1) // 设置为很久之前使用,但有活跃请求 atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano()) atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求 // 创建新客户端,触发淘汰检查 - _, _ = svc.getOrCreateClient("", 2, 1) + _ = svc.getOrCreateClient("", 2, 1) require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收") } +func (s *HTTPUpstreamSuite) TestOpenAIProfile_UsesHTTP2TransportForHTTPProxy() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 2, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + }, + } + svc := s.newService() + + entry, err := svc.getClientEntry("http://proxy.local:8080", 1, 1, service.HTTPUpstreamProfileOpenAI, false, false) + require.NoError(s.T(), err) + require.Equal(s.T(), upstreamProtocolModeOpenAIH2, entry.protocolMode) + + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.True(s.T(), transport.ForceAttemptHTTP2, "OpenAI profile should prefer HTTP/2") + require.Nil(s.T(), transport.TLSNextProto, "HTTP/2 mode should not force-disable TLSNextProto") +} + +func (s *HTTPUpstreamSuite) TestOpenAIProfile_FallbackToHTTP11WhenProxyMarkedIncompatible() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 2, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + }, + } + svc := s.newService() + proxyURL := "http://proxy.local:8080" + + state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyURL) + state.mu.Lock() + state.fallbackUntil = time.Now().Add(3 * time.Minute) + state.mu.Unlock() + + entry, err := svc.getClientEntry(proxyURL, 1, 1, service.HTTPUpstreamProfileOpenAI, false, false) + require.NoError(s.T(), err) + require.Equal(s.T(), upstreamProtocolModeOpenAIH1Fallback, entry.protocolMode) + + transport, ok := entry.client.Transport.(*http.Transport) + require.True(s.T(), ok, "expected *http.Transport") + require.False(s.T(), transport.ForceAttemptHTTP2, "fallback mode must disable HTTP/2 force-attempt") + require.NotNil(s.T(), transport.TLSNextProto, "fallback mode must disable HTTP/2 negotiation") +} + +func (s *HTTPUpstreamSuite) TestOpenAIProfile_RecordHTTP2ErrorActivatesFallback() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 2, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + }, + } + svc := s.newService() + proxyURL := "http://proxy.local:8080" + h2Err := errors.New("http2: stream error") + + svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, h2Err) + require.False(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL), "first error should not activate fallback") + + svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, h2Err) + require.True(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL), "second error in window should activate fallback") +} + +func (s *HTTPUpstreamSuite) TestOpenAIProfile_RecordNonHTTP2ErrorDoesNotActivateFallback() { + s.cfg.Gateway = config.GatewayConfig{ + OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 1, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + }, + } + svc := s.newService() + proxyURL := "http://proxy.local:8080" + + svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, errors.New("dial tcp: i/o timeout")) + require.False(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL)) +} + +func (s *HTTPUpstreamSuite) TestDoWithTLS_DisabledDelegatesToDo() { + upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "ok") + })) + s.T().Cleanup(upstream.Close) + + svc := s.newService() + req, err := http.NewRequest(http.MethodGet, upstream.URL+"/tls-disabled", nil) + require.NoError(s.T(), err) + + resp, err := svc.DoWithTLS(req, "", 1, 1, false) + require.NoError(s.T(), err) + defer func() { _ = resp.Body.Close() }() + body, readErr := io.ReadAll(resp.Body) + require.NoError(s.T(), readErr) + require.Equal(s.T(), "ok", string(body)) +} + +func (s *HTTPUpstreamSuite) TestDoWithTLS_EnabledHTTPRequestSuccess() { + upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, "tls-enabled") + })) + s.T().Cleanup(upstream.Close) + + svc := s.newService() + req, err := http.NewRequest(http.MethodGet, upstream.URL+"/tls-enabled", nil) + require.NoError(s.T(), err) + + resp, err := svc.DoWithTLS(req, "", 9, 1, true) + require.NoError(s.T(), err) + defer func() { _ = resp.Body.Close() }() + body, readErr := io.ReadAll(resp.Body) + require.NoError(s.T(), readErr) + require.Equal(s.T(), "tls-enabled", string(body)) +} + +func (s *HTTPUpstreamSuite) TestDoWithTLS_EnabledRequestError() { + svc := s.newService() + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:1/tls-error", nil) + require.NoError(s.T(), err) + + resp, doErr := svc.DoWithTLS(req, "", 9, 1, true) + require.Nil(s.T(), resp) + require.Error(s.T(), doErr) +} + +func (s *HTTPUpstreamSuite) TestDoWithTLS_ValidateRequestHostFailure() { + s.cfg.Security.URLAllowlist.Enabled = true + s.cfg.Security.URLAllowlist.AllowPrivateHosts = false + svc := s.newService() + + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1/test", nil) + require.NoError(s.T(), err) + + resp, doErr := svc.DoWithTLS(req, "", 1, 1, true) + require.Nil(s.T(), resp) + require.Error(s.T(), doErr) +} + +func (s *HTTPUpstreamSuite) TestShouldValidateResolvedIPAndValidateRequestHost() { + svc := s.newService() + require.False(s.T(), svc.shouldValidateResolvedIP()) + require.NoError(s.T(), svc.validateRequestHost(nil)) + + s.cfg.Security.URLAllowlist.Enabled = true + s.cfg.Security.URLAllowlist.AllowPrivateHosts = false + require.True(s.T(), svc.shouldValidateResolvedIP()) + require.Error(s.T(), svc.validateRequestHost(nil)) + + req, err := http.NewRequest(http.MethodGet, "http:///nohost", nil) + require.NoError(s.T(), err) + require.Error(s.T(), svc.validateRequestHost(req)) +} + +func (s *HTTPUpstreamSuite) TestRedirectCheckerStopsAfterLimit() { + svc := s.newService() + req, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + require.NoError(s.T(), err) + + via := make([]*http.Request, 10) + require.Error(s.T(), svc.redirectChecker(req, via)) +} + +func (s *HTTPUpstreamSuite) TestRedirectCheckerValidatesRequestHost() { + s.cfg.Security.URLAllowlist.Enabled = true + s.cfg.Security.URLAllowlist.AllowPrivateHosts = false + svc := s.newService() + + req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil) + require.NoError(s.T(), err) + require.Error(s.T(), svc.redirectChecker(req, nil)) +} + +func (s *HTTPUpstreamSuite) TestShouldReuseEntryAndEvictBranches() { + svc := s.newService() + entry := &upstreamClientEntry{ + proxyKey: "proxy-a", + poolKey: "pool-a", + } + require.False(s.T(), svc.shouldReuseEntry(nil, config.ConnectionPoolIsolationAccount, "proxy-a", "pool-a")) + require.False(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationAccount, "proxy-b", "pool-a")) + require.False(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationProxy, "proxy-a", "pool-b")) + require.True(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationProxy, "proxy-x", "pool-a")) + + s.cfg.Gateway.MaxUpstreamClients = 2 + svc.clients["k1"] = &upstreamClientEntry{inFlight: 1} + svc.clients["k2"] = &upstreamClientEntry{inFlight: 1} + require.False(s.T(), svc.evictOldestIdleLocked()) + require.False(s.T(), svc.evictOverLimitLocked()) +} + +func (s *HTTPUpstreamSuite) TestBuildCacheKeyAndIsolationMode() { + svc := s.newService() + require.Equal(s.T(), "account:1", buildCacheKey(config.ConnectionPoolIsolationAccount, "direct", 1, "")) + require.Equal(s.T(), "account:2|proxy:px", buildCacheKey(config.ConnectionPoolIsolationAccountProxy, "px", 2, "")) + require.Equal(s.T(), "proxy:direct", buildCacheKey(config.ConnectionPoolIsolationProxy, "direct", 3, "")) + require.Equal(s.T(), "account:1|proto:openai_h2", buildCacheKey(config.ConnectionPoolIsolationAccount, "direct", 1, "openai_h2")) + + s.cfg.Gateway.ConnectionPoolIsolation = "invalid" + require.Equal(s.T(), config.ConnectionPoolIsolationAccountProxy, svc.getIsolationMode()) + s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationProxy + require.Equal(s.T(), config.ConnectionPoolIsolationProxy, svc.getIsolationMode()) +} + +func (s *HTTPUpstreamSuite) TestResolveProtocolModeAndSettingsBranches() { + svc := s.newService() + s.cfg.Gateway.OpenAIHTTP2 = config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 2, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + } + parsedHTTPProxy, err := url.Parse("http://proxy.local:8080") + require.NoError(s.T(), err) + parsedSOCKSProxy, err := url.Parse("socks5://proxy.local:1080") + require.NoError(s.T(), err) + + require.Equal(s.T(), upstreamProtocolModeDefault, svc.resolveProtocolMode(service.HTTPUpstreamProfileDefault, "direct", nil)) + require.Equal(s.T(), upstreamProtocolModeOpenAIH2, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "direct", nil)) + require.Equal(s.T(), upstreamProtocolModeOpenAIH2, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "socks5://proxy.local:1080", parsedSOCKSProxy)) + + state := svc.getOrCreateOpenAIHTTP2FallbackState("http://proxy.local:8080") + state.mu.Lock() + state.fallbackUntil = time.Now().Add(10 * time.Second) + state.mu.Unlock() + require.Equal(s.T(), upstreamProtocolModeOpenAIH1Fallback, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "http://proxy.local:8080", parsedHTTPProxy)) + + s.cfg.Gateway.OpenAIHTTP2.Enabled = false + require.Equal(s.T(), upstreamProtocolModeDefault, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "http://proxy.local:8080", parsedHTTPProxy)) +} + +func (s *HTTPUpstreamSuite) TestGetClientEntryWithTLS_ReusesAndRebuildsOnProxyChange() { + s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationAccount + svc := s.newService() + profile := &tlsfingerprint.Profile{Name: "tls-profile"} + + entry1, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, false, false) + require.NoError(s.T(), err) + entry2, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, false, false) + require.NoError(s.T(), err) + require.Same(s.T(), entry1, entry2) + + entry3, err := svc.getClientEntryWithTLS("http://proxy-b.local:8080", 1, 1, profile, false, false) + require.NoError(s.T(), err) + require.NotSame(s.T(), entry1, entry3) +} + +func (s *HTTPUpstreamSuite) TestGetClientEntryWithTLS_OverLimitReturnsError() { + s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationAccountProxy + s.cfg.Gateway.MaxUpstreamClients = 1 + svc := s.newService() + profile := &tlsfingerprint.Profile{Name: "tls-profile"} + + entry1, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, true, true) + require.NoError(s.T(), err) + require.NotNil(s.T(), entry1) + + entry2, err := svc.getClientEntryWithTLS("http://proxy-b.local:8080", 2, 1, profile, true, true) + require.ErrorIs(s.T(), err, errUpstreamClientLimitReached) + require.Nil(s.T(), entry2) +} + +func (s *HTTPUpstreamSuite) TestOpenAIFallbackStateHelpers() { + var state openAIHTTP2FallbackState + now := time.Now() + + active, until := state.recordFailure(now, 1, time.Minute, time.Minute) + require.True(s.T(), active) + require.False(s.T(), until.IsZero()) + require.True(s.T(), state.isFallbackActive(now)) + require.False(s.T(), state.isFallbackActive(now.Add(2*time.Minute))) + + state.recordFailure(now, 3, time.Minute, time.Minute) + state.recordFailure(now.Add(10*time.Second), 3, time.Minute, time.Minute) + state.resetErrorWindow() + require.Equal(s.T(), 0, state.errorCount) + require.True(s.T(), state.windowStart.IsZero()) + + // 在 fallback 活跃期间再次失败,不应重复激活。 + state.fallbackUntil = now.Add(time.Minute) + activated, _ := state.recordFailure(now.Add(5*time.Second), 1, time.Minute, time.Minute) + require.False(s.T(), activated) +} + +func (s *HTTPUpstreamSuite) TestRecordOpenAIHTTP2SuccessResetsWindow() { + svc := s.newService() + proxyURL := "http://proxy.local:8080" + state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyURL) + state.mu.Lock() + state.errorCount = 5 + state.windowStart = time.Now() + state.mu.Unlock() + + svc.recordOpenAIHTTP2Success(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL) + + state.mu.Lock() + defer state.mu.Unlock() + require.Equal(s.T(), 0, state.errorCount) + require.True(s.T(), state.windowStart.IsZero()) +} + +func (s *HTTPUpstreamSuite) TestDo_OpenAIProxySuccessResetsHTTP2ErrorWindow() { + seen := make(chan struct{}, 1) + proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case seen <- struct{}{}: + default: + } + _, _ = io.WriteString(w, "proxied") + })) + s.T().Cleanup(proxySrv.Close) + + s.cfg.Gateway.OpenAIHTTP2 = config.GatewayOpenAIHTTP2Config{ + Enabled: true, + AllowProxyFallbackToHTTP1: true, + FallbackErrorThreshold: 2, + FallbackWindowSeconds: 60, + FallbackTTLSeconds: 600, + } + svc := s.newService() + proxyKey, _ := normalizeProxyURL(proxySrv.URL) + state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyKey) + state.mu.Lock() + state.windowStart = time.Now() + state.errorCount = 3 + state.fallbackUntil = time.Time{} + state.mu.Unlock() + + req, err := http.NewRequest(http.MethodGet, "http://example.com/reset-window", nil) + require.NoError(s.T(), err) + req = req.WithContext(service.WithHTTPUpstreamProfile(context.Background(), service.HTTPUpstreamProfileOpenAI)) + resp, doErr := svc.Do(req, proxySrv.URL, 1, 1) + require.NoError(s.T(), doErr) + defer func() { _ = resp.Body.Close() }() + _, _ = io.ReadAll(resp.Body) + + select { + case <-seen: + default: + require.Fail(s.T(), "expected proxy to receive request") + } + + state.mu.Lock() + defer state.mu.Unlock() + require.Equal(s.T(), 0, state.errorCount) + require.True(s.T(), state.windowStart.IsZero()) +} + +func (s *HTTPUpstreamSuite) TestOpenAIFallbackStateMapTypeSafety() { + svc := s.newService() + svc.openAIHTTP2Fallbacks.Store("x", "bad-type") + require.False(s.T(), svc.isOpenAIHTTP2FallbackActive("x")) + state := svc.getOrCreateOpenAIHTTP2FallbackState("x") + require.NotNil(s.T(), state) +} + +func (s *HTTPUpstreamSuite) TestBuildUpstreamTransport_ModeSwitchingAndProxyErrors() { + settings := defaultPoolSettings(s.cfg) + parsedProxy, err := url.Parse("http://proxy.local:8080") + require.NoError(s.T(), err) + + h2Transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeOpenAIH2) + require.NoError(s.T(), err) + require.True(s.T(), h2Transport.ForceAttemptHTTP2) + + h1Transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeOpenAIH1Fallback) + require.NoError(s.T(), err) + require.False(s.T(), h1Transport.ForceAttemptHTTP2) + require.NotNil(s.T(), h1Transport.TLSNextProto) + + badProxy, err := url.Parse("ftp://proxy.local:21") + require.NoError(s.T(), err) + _, badErr := buildUpstreamTransport(settings, badProxy, upstreamProtocolModeDefault) + require.Error(s.T(), badErr) +} + +func (s *HTTPUpstreamSuite) TestBuildUpstreamTransportWithTLSFingerprintBranches() { + settings := defaultPoolSettings(s.cfg) + profile := &tlsfingerprint.Profile{Name: "test-profile"} + + transportDirect, err := buildUpstreamTransportWithTLSFingerprint(settings, nil, profile) + require.NoError(s.T(), err) + require.NotNil(s.T(), transportDirect.DialTLSContext) + + httpProxy, err := url.Parse("http://proxy.local:8080") + require.NoError(s.T(), err) + transportHTTPProxy, err := buildUpstreamTransportWithTLSFingerprint(settings, httpProxy, profile) + require.NoError(s.T(), err) + require.NotNil(s.T(), transportHTTPProxy.DialTLSContext) + + socksProxy, err := url.Parse("socks5://proxy.local:1080") + require.NoError(s.T(), err) + transportSOCKSProxy, err := buildUpstreamTransportWithTLSFingerprint(settings, socksProxy, profile) + require.NoError(s.T(), err) + require.NotNil(s.T(), transportSOCKSProxy.DialTLSContext) + + unsupportedProxy, err := url.Parse("ftp://proxy.local:21") + require.NoError(s.T(), err) + _, unsupportedErr := buildUpstreamTransportWithTLSFingerprint(settings, unsupportedProxy, profile) + require.Error(s.T(), unsupportedErr) +} + +func (s *HTTPUpstreamSuite) TestWrapTrackedBody_NilAndCloseOnce() { + require.Nil(s.T(), wrapTrackedBody(nil, nil)) + + closed := int32(0) + readCloser := io.NopCloser(strings.NewReader("x")) + wrapped := wrapTrackedBody(readCloser, func() { + atomic.AddInt32(&closed, 1) + }) + require.NotNil(s.T(), wrapped) + _ = wrapped.Close() + _ = wrapped.Close() + require.Equal(s.T(), int32(1), atomic.LoadInt32(&closed)) +} + // TestHTTPUpstreamSuite 运行测试套件 func TestHTTPUpstreamSuite(t *testing.T) { suite.Run(t, new(HTTPUpstreamSuite)) } -// mustGetOrCreateClient 测试辅助函数,调用 getOrCreateClient 并断言无错误 -func mustGetOrCreateClient(t *testing.T, svc *httpUpstreamService, proxyURL string, accountID int64, concurrency int) *upstreamClientEntry { - t.Helper() - entry, err := svc.getOrCreateClient(proxyURL, accountID, concurrency) - require.NoError(t, err, "getOrCreateClient(%q, %d, %d)", proxyURL, accountID, concurrency) - return entry -} - // hasEntry 检查客户端是否存在于缓存中 // 辅助函数,用于验证淘汰逻辑 func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool { diff --git a/backend/internal/repository/identity_cache.go b/backend/internal/repository/identity_cache.go index 6152dd7a5..c49865479 100644 --- a/backend/internal/repository/identity_cache.go +++ b/backend/internal/repository/identity_cache.go @@ -12,7 +12,7 @@ import ( const ( fingerprintKeyPrefix = "fingerprint:" - fingerprintTTL = 7 * 24 * time.Hour // 7天,配合每24小时懒续期可保持活跃账号永不过期 + fingerprintTTL = 24 * time.Hour maskedSessionKeyPrefix = "masked_session:" maskedSessionTTL = 15 * time.Minute ) diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index dca0b612f..3e155971b 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -23,10 +23,7 @@ type openaiOAuthService struct { } func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { - client, err := createOpenAIReqClient(proxyURL) - if err != nil { - return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) - } + client := createOpenAIReqClient(proxyURL) if redirectURI == "" { redirectURI = openai.DefaultRedirectURI @@ -77,10 +74,7 @@ func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refre } func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { - client, err := createOpenAIReqClient(proxyURL) - if err != nil { - return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err) - } + client := createOpenAIReqClient(proxyURL) formData := url.Values{} formData.Set("grant_type", "refresh_token") @@ -108,7 +102,7 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre return &tokenResp, nil } -func createOpenAIReqClient(proxyURL string) (*req.Client, error) { +func createOpenAIReqClient(proxyURL string) *req.Client { return getSharedReqClient(reqClientOptions{ ProxyURL: proxyURL, Timeout: 120 * time.Second, diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go index ee8e1749f..07d796b8c 100644 --- a/backend/internal/repository/pricing_service.go +++ b/backend/internal/repository/pricing_service.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "io" - "log/slog" "net/http" "strings" "time" @@ -17,37 +16,14 @@ type pricingRemoteClient struct { httpClient *http.Client } -// pricingRemoteClientError 代理初始化失败时的错误占位客户端 -// 所有请求直接返回初始化错误,禁止回退到直连 -type pricingRemoteClientError struct { - err error -} - -func (c *pricingRemoteClientError) FetchPricingJSON(_ context.Context, _ string) ([]byte, error) { - return nil, c.err -} - -func (c *pricingRemoteClientError) FetchHashText(_ context.Context, _ string) (string, error) { - return "", c.err -} - // NewPricingRemoteClient 创建定价数据远程客户端 // proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议 -// 代理配置失败时行为由 allowDirectOnProxyError 控制: -// - false(默认):返回错误占位客户端,禁止回退到直连 -// - true:回退到直连(仅限管理员显式开启) -func NewPricingRemoteClient(proxyURL string, allowDirectOnProxyError bool) service.PricingRemoteClient { - // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据, - // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。 +func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient { sharedClient, err := httpclient.GetClient(httpclient.Options{ Timeout: 30 * time.Second, ProxyURL: proxyURL, }) if err != nil { - if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError { - slog.Warn("proxy client init failed, all requests will fail", "service", "pricing", "error", err) - return &pricingRemoteClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)} - } sharedClient = &http.Client{Timeout: 30 * time.Second} } return &pricingRemoteClient{ diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go index ef2f214b0..6ea112117 100644 --- a/backend/internal/repository/pricing_service_test.go +++ b/backend/internal/repository/pricing_service_test.go @@ -19,7 +19,7 @@ type PricingServiceSuite struct { func (s *PricingServiceSuite) SetupTest() { s.ctx = context.Background() - client, ok := NewPricingRemoteClient("", false).(*pricingRemoteClient) + client, ok := NewPricingRemoteClient("").(*pricingRemoteClient) require.True(s.T(), ok, "type assertion failed") s.client = client } @@ -140,22 +140,6 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() { require.Error(s.T(), err) } -func TestNewPricingRemoteClient_InvalidProxy_NoFallback(t *testing.T) { - client := NewPricingRemoteClient("://bad", false) - _, ok := client.(*pricingRemoteClientError) - require.True(t, ok, "should return error client when proxy is invalid and fallback disabled") - - _, err := client.FetchPricingJSON(context.Background(), "http://example.com") - require.Error(t, err) - require.Contains(t, err.Error(), "proxy client init failed") -} - -func TestNewPricingRemoteClient_InvalidProxy_WithFallback(t *testing.T) { - client := NewPricingRemoteClient("://bad", true) - _, ok := client.(*pricingRemoteClient) - require.True(t, ok, "should fallback to direct client when allowed") -} - func TestPricingServiceSuite(t *testing.T) { suite.Run(t, new(PricingServiceSuite)) } diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go index b4aeab718..54de28972 100644 --- a/backend/internal/repository/proxy_probe_service.go +++ b/backend/internal/repository/proxy_probe_service.go @@ -66,6 +66,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s ProxyURL: proxyURL, Timeout: defaultProxyProbeTimeout, InsecureSkipVerify: s.insecureSkipVerify, + ProxyStrict: true, ValidateResolvedIP: s.validateResolvedIP, AllowPrivateHosts: s.allowPrivateHosts, }) diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go index 79b24396d..af71a7ee3 100644 --- a/backend/internal/repository/req_client_pool.go +++ b/backend/internal/repository/req_client_pool.go @@ -6,8 +6,6 @@ import ( "sync" "time" - "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" - "github.com/imroc/req/v3" ) @@ -35,11 +33,11 @@ var sharedReqClients sync.Map // getSharedReqClient 获取共享的 req 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 -func getSharedReqClient(opts reqClientOptions) (*req.Client, error) { +func getSharedReqClient(opts reqClientOptions) *req.Client { key := buildReqClientKey(opts) if cached, ok := sharedReqClients.Load(key); ok { if c, ok := cached.(*req.Client); ok { - return c, nil + return c } } @@ -50,19 +48,15 @@ func getSharedReqClient(opts reqClientOptions) (*req.Client, error) { if opts.Impersonate { client = client.ImpersonateChrome() } - trimmed, _, err := proxyurl.Parse(opts.ProxyURL) - if err != nil { - return nil, err - } - if trimmed != "" { - client.SetProxyURL(trimmed) + if strings.TrimSpace(opts.ProxyURL) != "" { + client.SetProxyURL(strings.TrimSpace(opts.ProxyURL)) } actual, _ := sharedReqClients.LoadOrStore(key, client) if c, ok := actual.(*req.Client); ok { - return c, nil + return c } - return client, nil + return client } func buildReqClientKey(opts reqClientOptions) string { diff --git a/backend/internal/repository/req_client_pool_test.go b/backend/internal/repository/req_client_pool_test.go index 9067d0129..904ed4d6e 100644 --- a/backend/internal/repository/req_client_pool_test.go +++ b/backend/internal/repository/req_client_pool_test.go @@ -26,13 +26,11 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: time.Second, } - clientDefault, err := getSharedReqClient(base) - require.NoError(t, err) + clientDefault := getSharedReqClient(base) force := base force.ForceHTTP2 = true - clientForce, err := getSharedReqClient(force) - require.NoError(t, err) + clientForce := getSharedReqClient(force) require.NotSame(t, clientDefault, clientForce) require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force)) @@ -44,10 +42,8 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) { ProxyURL: "http://proxy.local:8080", Timeout: 2 * time.Second, } - first, err := getSharedReqClient(opts) - require.NoError(t, err) - second, err := getSharedReqClient(opts) - require.NoError(t, err) + first := getSharedReqClient(opts) + second := getSharedReqClient(opts) require.Same(t, first, second) } @@ -60,8 +56,7 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) { key := buildReqClientKey(opts) sharedReqClients.Store(key, "invalid") - client, err := getSharedReqClient(opts) - require.NoError(t, err) + client := getSharedReqClient(opts) require.NotNil(t, client) loaded, ok := sharedReqClients.Load(key) @@ -76,45 +71,20 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) { Timeout: 4 * time.Second, Impersonate: true, } - client, err := getSharedReqClient(opts) - require.NoError(t, err) + client := getSharedReqClient(opts) require.NotNil(t, client) require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts)) } -func TestGetSharedReqClient_InvalidProxyURL(t *testing.T) { - sharedReqClients = sync.Map{} - opts := reqClientOptions{ - ProxyURL: "://missing-scheme", - Timeout: time.Second, - } - _, err := getSharedReqClient(opts) - require.Error(t, err) - require.Contains(t, err.Error(), "invalid proxy URL") -} - -func TestGetSharedReqClient_ProxyURLMissingHost(t *testing.T) { - sharedReqClients = sync.Map{} - opts := reqClientOptions{ - ProxyURL: "http://", - Timeout: time.Second, - } - _, err := getSharedReqClient(opts) - require.Error(t, err) - require.Contains(t, err.Error(), "proxy URL missing host") -} - func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) { sharedReqClients = sync.Map{} - client, err := createOpenAIReqClient("http://proxy.local:8080") - require.NoError(t, err) + client := createOpenAIReqClient("http://proxy.local:8080") require.Equal(t, 120*time.Second, client.GetClient().Timeout) } func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) { sharedReqClients = sync.Map{} - client, err := createGeminiReqClient("http://proxy.local:8080") - require.NoError(t, err) + client := createGeminiReqClient("http://proxy.local:8080") require.Equal(t, "", forceHTTPVersion(t, client)) } diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index ff40e97d6..7306b0672 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -218,6 +218,275 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) return true, nil } +func (r *usageLogRepository) usageSQLExecutor(ctx context.Context) sqlExecutor { + if tx := dbent.TxFromContext(ctx); tx != nil { + return tx.Client() + } + return r.sql +} + +func (r *usageLogRepository) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error { + if fn == nil { + return nil + } + if tx := dbent.TxFromContext(ctx); tx != nil { + return fn(ctx) + } + tx, err := r.client.Tx(ctx) + if err != nil { + return err + } + txCtx := dbent.NewTxContext(ctx, tx) + if err := fn(txCtx); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() +} + +func (r *usageLogRepository) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*service.UsageBillingEntry, error) { + query := ` + SELECT + id, + usage_log_id, + user_id, + api_key_id, + subscription_id, + billing_type, + applied, + delta_usd, + status, + attempt_count, + next_retry_at, + updated_at, + created_at, + last_error + FROM billing_usage_entries + WHERE usage_log_id = $1 + ` + rows, err := r.usageSQLExecutor(ctx).QueryContext(ctx, query, usageLogID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + if !rows.Next() { + if err = rows.Err(); err != nil { + return nil, err + } + return nil, service.ErrUsageBillingEntryNotFound + } + entry, err := scanUsageBillingEntry(rows) + if err != nil { + return nil, err + } + if err = rows.Err(); err != nil { + return nil, err + } + return entry, nil +} + +func (r *usageLogRepository) UpsertUsageBillingEntry(ctx context.Context, entry *service.UsageBillingEntry) (*service.UsageBillingEntry, bool, error) { + if entry == nil { + return nil, false, nil + } + + insertQuery := ` + INSERT INTO billing_usage_entries ( + usage_log_id, + user_id, + api_key_id, + subscription_id, + billing_type, + applied, + delta_usd, + status, + attempt_count, + next_retry_at, + updated_at + ) VALUES ( + $1, $2, $3, $4, $5, FALSE, $6, $7, 0, NOW(), NOW() + ) + ON CONFLICT (usage_log_id) DO NOTHING + RETURNING + id, + usage_log_id, + user_id, + api_key_id, + subscription_id, + billing_type, + applied, + delta_usd, + status, + attempt_count, + next_retry_at, + updated_at, + created_at, + last_error + ` + + exec := r.usageSQLExecutor(ctx) + rows, err := exec.QueryContext( + ctx, + insertQuery, + entry.UsageLogID, + entry.UserID, + entry.APIKeyID, + nullInt64(entry.SubscriptionID), + entry.BillingType, + entry.DeltaUSD, + service.UsageBillingEntryStatusPending, + ) + if err != nil { + return nil, false, err + } + defer func() { _ = rows.Close() }() + + if rows.Next() { + created, scanErr := scanUsageBillingEntry(rows) + if scanErr != nil { + return nil, false, scanErr + } + return created, true, nil + } + if err = rows.Err(); err != nil { + return nil, false, err + } + + existing, err := r.GetUsageBillingEntryByUsageLogID(ctx, entry.UsageLogID) + if err != nil { + return nil, false, err + } + return existing, false, nil +} + +func (r *usageLogRepository) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error { + query := ` + UPDATE billing_usage_entries + SET + applied = TRUE, + status = $2, + last_error = NULL, + next_retry_at = NOW(), + updated_at = NOW() + WHERE id = $1 + ` + res, err := r.usageSQLExecutor(ctx).ExecContext(ctx, query, entryID, service.UsageBillingEntryStatusApplied) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrUsageBillingEntryNotFound + } + return nil +} + +func (r *usageLogRepository) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error { + query := ` + UPDATE billing_usage_entries + SET + applied = FALSE, + status = $2, + next_retry_at = $3, + last_error = $4, + updated_at = NOW() + WHERE id = $1 + ` + res, err := r.usageSQLExecutor(ctx).ExecContext( + ctx, + query, + entryID, + service.UsageBillingEntryStatusPending, + nextRetryAt, + nullString(&lastError), + ) + if err != nil { + return err + } + affected, err := res.RowsAffected() + if err != nil { + return err + } + if affected == 0 { + return service.ErrUsageBillingEntryNotFound + } + return nil +} + +func (r *usageLogRepository) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]service.UsageBillingEntry, error) { + if limit <= 0 { + return nil, nil + } + staleAt := time.Now().Add(-processingStaleAfter) + query := ` + WITH candidates AS ( + SELECT id + FROM billing_usage_entries + WHERE applied = FALSE + AND ( + (status = $1 AND next_retry_at <= NOW()) + OR (status = $2 AND updated_at <= $3) + ) + ORDER BY id + LIMIT $4 + FOR UPDATE SKIP LOCKED + ) + UPDATE billing_usage_entries b + SET + status = $2, + attempt_count = b.attempt_count + 1, + updated_at = NOW(), + last_error = NULL + FROM candidates c + WHERE b.id = c.id + RETURNING + b.id, + b.usage_log_id, + b.user_id, + b.api_key_id, + b.subscription_id, + b.billing_type, + b.applied, + b.delta_usd, + b.status, + b.attempt_count, + b.next_retry_at, + b.updated_at, + b.created_at, + b.last_error + ` + + rows, err := r.usageSQLExecutor(ctx).QueryContext( + ctx, + query, + service.UsageBillingEntryStatusPending, + service.UsageBillingEntryStatusProcessing, + staleAt, + limit, + ) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + entries := make([]service.UsageBillingEntry, 0, limit) + for rows.Next() { + item, scanErr := scanUsageBillingEntry(rows) + if scanErr != nil { + return nil, scanErr + } + entries = append(entries, *item) + } + if err := rows.Err(); err != nil { + return nil, err + } + return entries, nil +} + func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) { query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1" rows, err := r.sql.QueryContext(ctx, query, id) @@ -1863,77 +2132,6 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start return results, nil } -// GetGroupStatsWithFilters returns group usage statistics with optional filters -func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []usagestats.GroupStat, err error) { - query := ` - SELECT - COALESCE(ul.group_id, 0) as group_id, - COALESCE(g.name, '') as group_name, - COUNT(*) as requests, - COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, - COALESCE(SUM(ul.total_cost), 0) as cost, - COALESCE(SUM(ul.actual_cost), 0) as actual_cost - FROM usage_logs ul - LEFT JOIN groups g ON g.id = ul.group_id - WHERE ul.created_at >= $1 AND ul.created_at < $2 - ` - - args := []any{startTime, endTime} - if userID > 0 { - query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1) - args = append(args, userID) - } - if apiKeyID > 0 { - query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1) - args = append(args, apiKeyID) - } - if accountID > 0 { - query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1) - args = append(args, accountID) - } - if groupID > 0 { - query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1) - args = append(args, groupID) - } - query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) - if billingType != nil { - query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1) - args = append(args, int16(*billingType)) - } - query += " GROUP BY ul.group_id, g.name ORDER BY total_tokens DESC" - - rows, err := r.sql.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer func() { - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = closeErr - results = nil - } - }() - - results = make([]usagestats.GroupStat, 0) - for rows.Next() { - var row usagestats.GroupStat - if err := rows.Scan( - &row.GroupID, - &row.GroupName, - &row.Requests, - &row.TotalTokens, - &row.Cost, - &row.ActualCost, - ); err != nil { - return nil, err - } - results = append(results, row) - } - if err := rows.Err(); err != nil { - return nil, err - } - return results, nil -} - // GetGlobalStats gets usage statistics for all users within a time range func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) { query := ` @@ -2590,6 +2788,52 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e return log, nil } +func scanUsageBillingEntry(scanner interface{ Scan(...any) error }) (*service.UsageBillingEntry, error) { + var ( + subscriptionID sql.NullInt64 + nextRetryAt time.Time + updatedAt time.Time + createdAt time.Time + lastError sql.NullString + status int16 + entry service.UsageBillingEntry + ) + + if err := scanner.Scan( + &entry.ID, + &entry.UsageLogID, + &entry.UserID, + &entry.APIKeyID, + &subscriptionID, + &entry.BillingType, + &entry.Applied, + &entry.DeltaUSD, + &status, + &entry.AttemptCount, + &nextRetryAt, + &updatedAt, + &createdAt, + &lastError, + ); err != nil { + return nil, err + } + + if subscriptionID.Valid { + v := subscriptionID.Int64 + entry.SubscriptionID = &v + } + entry.Status = service.UsageBillingEntryStatus(status) + entry.NextRetryAt = nextRetryAt + entry.UpdatedAt = updatedAt + entry.CreatedAt = createdAt + if lastError.Valid { + msg := lastError.String + entry.LastError = &msg + } + + return &entry, nil +} + func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) { results := make([]TrendDataPoint, 0) for rows.Next() { diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 2e35e0a00..eb8ce3fba 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -34,7 +34,7 @@ func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient // ProvidePricingRemoteClient 创建定价数据远程客户端 // 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据 func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient { - return NewPricingRemoteClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError) + return NewPricingRemoteClient(cfg.Update.ProxyURL) } // ProvideSessionLimitCache 创建会话限制缓存 @@ -79,8 +79,6 @@ var ProviderSet = wire.NewSet( NewTimeoutCounterCache, ProvideConcurrencyCache, ProvideSessionLimitCache, - NewRPMCache, - NewUserMsgQueueCache, NewDashboardCache, NewEmailCache, NewIdentityCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 63b6cf282..f5efd97fa 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -86,15 +86,6 @@ func TestAPIContracts(t *testing.T) { "last_used_at": null, "quota": 0, "quota_used": 0, - "rate_limit_5h": 0, - "rate_limit_1d": 0, - "rate_limit_7d": 0, - "usage_5h": 0, - "usage_1d": 0, - "usage_7d": 0, - "window_5h_start": null, - "window_1d_start": null, - "window_7d_start": null, "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" @@ -135,15 +126,6 @@ func TestAPIContracts(t *testing.T) { "last_used_at": null, "quota": 0, "quota_used": 0, - "rate_limit_5h": 0, - "rate_limit_1d": 0, - "rate_limit_7d": 0, - "usage_5h": 0, - "usage_1d": 0, - "usage_7d": 0, - "window_5h_start": null, - "window_1d_start": null, - "window_7d_start": null, "expires_at": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" @@ -517,7 +499,6 @@ func TestAPIContracts(t *testing.T) { "doc_url": "https://docs.example.com", "default_concurrency": 5, "default_balance": 1.25, - "default_subscriptions": [], "enable_model_fallback": false, "fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_antigravity": "gemini-2.5-pro", @@ -531,8 +512,7 @@ func TestAPIContracts(t *testing.T) { "hide_ccs_import_button": false, "purchase_subscription_enabled": false, "purchase_subscription_url": "", - "min_claude_code_version": "", - "custom_menu_items": [] + "custom_menu_items": null } }`, }, @@ -640,12 +620,12 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil) - adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ @@ -1510,6 +1490,18 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun return 0, errors.New("not implemented") } +func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { + return errors.New("not implemented") +} + +func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + +func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) { + return nil, errors.New("not implemented") +} + func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error { key, ok := r.byID[id] if !ok { @@ -1524,16 +1516,6 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti return nil } -func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error { - return nil -} -func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error { - return nil -} -func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) { - return nil, nil -} - type stubUsageLogRepo struct { userLogs map[int64][]service.UsageLog } @@ -1610,10 +1592,6 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { - return nil, errors.New("not implemented") -} - func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go index 033a5b778..7640ab2ae 100644 --- a/backend/internal/server/middleware/admin_auth_test.go +++ b/backend/internal/server/middleware/admin_auth_test.go @@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) { gin.SetMode(gin.TestMode) cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}} - authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil) admin := &service.User{ ID: 1, diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go index f8839cfe8..bc3209584 100644 --- a/backend/internal/server/middleware/jwt_auth_test.go +++ b/backend/internal/server/middleware/jwt_auth_test.go @@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer cfg.JWT.AccessTokenExpireMinutes = 60 userRepo := &stubJWTUserRepo{users: users} - authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil) + authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) userSvc := service.NewUserService(userRepo, nil, nil) mw := NewJWTAuthMiddleware(authSvc, userSvc) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index c36c36a0a..4d991ea47 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -26,6 +26,9 @@ func RegisterAdminRoutes( // 分组管理 registerGroupRoutes(admin, h) + // API Key 管理 + registerAdminAPIKeyRoutes(admin, h) + // 账号管理 registerAccountRoutes(admin, h) @@ -75,16 +78,6 @@ func RegisterAdminRoutes( // 错误透传规则管理 registerErrorPassthroughRoutes(admin, h) - - // API Key 管理 - registerAdminAPIKeyRoutes(admin, h) - } -} - -func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { - apiKeys := admin.Group("/api-keys") - { - apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup) } } @@ -96,6 +89,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { ops.GET("/user-concurrency", h.Admin.Ops.GetUserConcurrencyStats) ops.GET("/account-availability", h.Admin.Ops.GetAccountAvailability) ops.GET("/realtime-traffic", h.Admin.Ops.GetRealtimeTrafficSummary) + ops.GET("/openai-ws-v2/passthrough-metrics", h.Admin.Ops.GetOpenAIWSV2PassthroughMetrics) // Alerts (rules + events) ops.GET("/alert-rules", h.Admin.Ops.ListAlertRules) @@ -227,6 +221,13 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { } } +func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + apiKeys := admin.Group("/api-keys") + { + apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup) + } +} + func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts := admin.Group("/accounts") { @@ -386,6 +387,18 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // 流超时处理配置 adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) + // 批量编辑模板库(服务端共享) + adminSettings.GET("/bulk-edit-templates", h.Admin.Setting.ListBulkEditTemplates) + adminSettings.POST("/bulk-edit-templates", h.Admin.Setting.UpsertBulkEditTemplate) + adminSettings.DELETE("/bulk-edit-templates/:template_id", h.Admin.Setting.DeleteBulkEditTemplate) + adminSettings.GET( + "/bulk-edit-templates/:template_id/versions", + h.Admin.Setting.ListBulkEditTemplateVersions, + ) + adminSettings.POST( + "/bulk-edit-templates/:template_id/rollback", + h.Admin.Setting.RollbackBulkEditTemplate, + ) // Sora S3 存储配置 adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings) adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings) diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 81e91aeb0..8d5b281b0 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -853,19 +853,24 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool { } const ( - OpenAIWSIngressModeOff = "off" - OpenAIWSIngressModeShared = "shared" - OpenAIWSIngressModeDedicated = "dedicated" + OpenAIWSIngressModeOff = "off" + OpenAIWSIngressModeShared = "shared" + OpenAIWSIngressModeDedicated = "dedicated" + OpenAIWSIngressModeCtxPool = "ctx_pool" + OpenAIWSIngressModePassthrough = "passthrough" ) func normalizeOpenAIWSIngressMode(mode string) string { switch strings.ToLower(strings.TrimSpace(mode)) { case OpenAIWSIngressModeOff: return OpenAIWSIngressModeOff - case OpenAIWSIngressModeShared: - return OpenAIWSIngressModeShared - case OpenAIWSIngressModeDedicated: - return OpenAIWSIngressModeDedicated + case OpenAIWSIngressModeCtxPool: + return OpenAIWSIngressModeCtxPool + case OpenAIWSIngressModePassthrough: + return OpenAIWSIngressModePassthrough + case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // Deprecated: shared/dedicated 已废弃,平滑迁移到 ctx_pool + return OpenAIWSIngressModeCtxPool default: return "" } @@ -875,16 +880,16 @@ func normalizeOpenAIWSIngressDefaultMode(mode string) string { if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" { return normalized } - return OpenAIWSIngressModeShared + return OpenAIWSIngressModeOff } -// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。 +// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。 // // 优先级: // 1. 分类型 mode 新字段(string) // 2. 分类型 enabled 旧字段(bool) // 3. 兼容 enabled 旧字段(bool) -// 4. defaultMode(非法时回退 shared) +// 4. defaultMode(非法时回退 off) func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string { resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode) if a == nil || !a.IsOpenAI() { @@ -919,7 +924,8 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri return "", false } if enabled { - return OpenAIWSIngressModeShared, true + // 兼容旧 enabled 字段:开启时至少落到 ctx_pool。 + return OpenAIWSIngressModeCtxPool, true } return OpenAIWSIngressModeOff, true } @@ -1295,12 +1301,6 @@ func parseExtraFloat64(value any) float64 { } // parseExtraInt 从 extra 字段解析 int 值 -// ParseExtraInt 从 extra 字段的 any 值解析为 int。 -// 支持 int, int64, float64, json.Number, string 类型,无法解析时返回 0。 -func ParseExtraInt(value any) int { - return parseExtraInt(value) -} - func parseExtraInt(value any) int { switch v := value.(type) { case int: diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go index a85c68ec5..2fe62483c 100644 --- a/backend/internal/service/account_openai_passthrough_test.go +++ b/backend/internal/service/account_openai_passthrough_test.go @@ -206,30 +206,63 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) { } func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { - t.Run("default fallback to shared", func(t *testing.T) { + t.Run("default fallback to off", func(t *testing.T) { account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}, } - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("")) - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode("")) + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) }) - t.Run("oauth mode field has highest priority", func(t *testing.T) { + t.Run("unsupported mode field falls back to enabled flag", func(t *testing.T) { account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, - "openai_oauth_responses_websockets_v2_enabled": false, + "openai_oauth_responses_websockets_v2_enabled": true, "responses_websockets_v2_enabled": false, }, } - require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("ctx_pool mode field is recognized", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, + }, + } + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("passthrough mode field is recognized", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }, + } + require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("legacy enabled maps to ctx_pool when default is off", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) }) - t.Run("legacy enabled maps to shared", func(t *testing.T) { + t.Run("legacy enabled ignores unsupported default and maps to ctx_pool", func(t *testing.T) { account := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, @@ -237,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { "responses_websockets_v2_enabled": true, }, } - require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated)) }) t.Run("legacy disabled maps to off", func(t *testing.T) { @@ -249,7 +282,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { "responses_websockets_v2_enabled": true, }, } - require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool)) }) t.Run("non openai always off", func(t *testing.T) { @@ -260,7 +293,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, }, } - require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated)) + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool)) }) } diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index 18a70c5cc..964261ddc 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -73,15 +73,16 @@ type AccountRepository interface { // AccountBulkUpdate describes the fields that can be updated in a bulk operation. // Nil pointers mean "do not change". type AccountBulkUpdate struct { - Name *string - ProxyID *int64 - Concurrency *int - Priority *int - RateMultiplier *float64 - Status *string - Schedulable *bool - Credentials map[string]any - Extra map[string]any + Name *string + ProxyID *int64 + Concurrency *int + Priority *int + RateMultiplier *float64 + Status *string + Schedulable *bool + AutoPauseOnExpired *bool + Credentials map[string]any + Extra map[string]any } // CreateAccountRequest 创建账号请求 diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 6dee6c133..13a138567 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -37,7 +37,6 @@ type UsageLogRepository interface { GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) - GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 7e6982d36..6f7c3c05f 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -43,8 +43,6 @@ type AdminService interface { DeleteGroup(ctx context.Context, id int64) error GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error - - // API Key management (admin) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) // Account management @@ -224,17 +222,18 @@ type UpdateAccountInput struct { // BulkUpdateAccountsInput describes the payload for bulk updating accounts. type BulkUpdateAccountsInput struct { - AccountIDs []int64 - Name string - ProxyID *int64 - Concurrency *int - Priority *int - RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) - Status string - Schedulable *bool - GroupIDs *[]int64 - Credentials map[string]any - Extra map[string]any + AccountIDs []int64 + Name string + ProxyID *int64 + Concurrency *int + Priority *int + RateMultiplier *float64 // 账号计费倍率(>=0,允许 0) + Status string + Schedulable *bool + AutoPauseOnExpired *bool + GroupIDs *[]int64 + Credentials map[string]any + Extra map[string]any // SkipMixedChannelCheck skips the mixed channel risk check when binding groups. // This should only be set when the caller has explicitly confirmed the risk. SkipMixedChannelCheck bool @@ -250,9 +249,9 @@ type BulkUpdateAccountResult struct { // AdminUpdateAPIKeyGroupIDResult is the result of AdminUpdateAPIKeyGroupID. type AdminUpdateAPIKeyGroupIDResult struct { APIKey *APIKey - AutoGrantedGroupAccess bool // true if a new exclusive group permission was auto-added - GrantedGroupID *int64 // the group ID that was auto-granted - GrantedGroupName string // the group name that was auto-granted + AutoGrantedGroupAccess bool + GrantedGroupID *int64 + GrantedGroupName string } // BulkUpdateAccountsResult is the aggregated response for bulk updates. @@ -420,8 +419,6 @@ type adminServiceImpl struct { proxyLatencyCache ProxyLatencyCache authCacheInvalidator APIKeyAuthCacheInvalidator entClient *dbent.Client // 用于开启数据库事务 - settingService *SettingService - defaultSubAssigner DefaultSubscriptionAssigner } type userGroupRateBatchReader interface { @@ -447,8 +444,6 @@ func NewAdminService( proxyLatencyCache ProxyLatencyCache, authCacheInvalidator APIKeyAuthCacheInvalidator, entClient *dbent.Client, - settingService *SettingService, - defaultSubAssigner DefaultSubscriptionAssigner, ) AdminService { return &adminServiceImpl{ userRepo: userRepo, @@ -464,8 +459,6 @@ func NewAdminService( proxyLatencyCache: proxyLatencyCache, authCacheInvalidator: authCacheInvalidator, entClient: entClient, - settingService: settingService, - defaultSubAssigner: defaultSubAssigner, } } @@ -550,27 +543,9 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu if err := s.userRepo.Create(ctx, user); err != nil { return nil, err } - s.assignDefaultSubscriptions(ctx, user.ID) return user, nil } -func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) { - if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { - return - } - items := s.settingService.GetDefaultSubscriptions(ctx) - for _, item := range items { - if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ - UserID: userID, - GroupID: item.GroupID, - ValidityDays: item.ValidityDays, - Notes: "auto assigned by default user subscriptions setting", - }); err != nil { - logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) - } - } -} - func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { @@ -1225,8 +1200,8 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates [] return s.groupRepo.UpdateSortOrders(ctx, updates) } -// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定 -// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组 +// AdminUpdateAPIKeyGroupID allows admins to update API key-group binding. +// groupID: nil=unchanged, 0=unbind, >0=bind to target group. func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) { apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID) if err != nil { @@ -1234,22 +1209,17 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i } if groupID == nil { - // nil 表示不修改,直接返回 return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil } - if *groupID < 0 { return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative") } result := &AdminUpdateAPIKeyGroupIDResult{} - if *groupID == 0 { - // 0 表示解绑分组(不修改 user_allowed_groups,避免影响用户其他 Key) apiKey.GroupID = nil apiKey.Group = nil } else { - // 验证目标分组存在且状态为 active group, err := s.groupRepo.GetByID(ctx, *groupID) if err != nil { return nil, err @@ -1257,7 +1227,6 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i if group.Status != StatusActive { return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active") } - // 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程 if group.IsSubscriptionType() { return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow") } @@ -1266,7 +1235,6 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i apiKey.GroupID = &gid apiKey.Group = group - // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性 if group.IsExclusive { opCtx := ctx var tx *dbent.Tx @@ -1297,8 +1265,6 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i result.AutoGrantedGroupAccess = true result.GrantedGroupID = &gid result.GrantedGroupName = group.Name - - // 失效认证缓存(在事务提交后执行) if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) } @@ -1308,12 +1274,9 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i } } - // 非专属分组 / 解绑:无需事务,单步更新即可 if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil { return nil, fmt.Errorf("update api key: %w", err) } - - // 失效认证缓存 if s.authCacheInvalidator != nil { s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key) } @@ -1384,6 +1347,11 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") } } + if len(input.Extra) > 0 { + if err := validateOpenAIWSModeExtraValues(input.Extra); err != nil { + return nil, err + } + } account := &Account{ Name: input.Name, @@ -1458,6 +1426,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.Credentials = input.Credentials } if len(input.Extra) > 0 { + if err := validateOpenAIWSModeExtraValues(input.Extra); err != nil { + return nil, err + } account.Extra = input.Extra } if input.ProxyID != nil { @@ -1543,6 +1514,104 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U return updated, nil } +var openAIBulkScopedExtraKeys = map[string]struct{}{ + "openai_passthrough": {}, + "openai_oauth_passthrough": {}, + "openai_oauth_responses_websockets_v2_mode": {}, + "openai_oauth_responses_websockets_v2_enabled": {}, + "openai_apikey_responses_websockets_v2_mode": {}, + "openai_apikey_responses_websockets_v2_enabled": {}, + "codex_cli_only": {}, +} + +func hasOpenAIBulkScopedExtraField(extra map[string]any) bool { + if len(extra) == 0 { + return false + } + for key := range extra { + if _, ok := openAIBulkScopedExtraKeys[key]; ok { + return true + } + } + return false +} + +// openaiWSModeExtraKeys 列出所有控制 OpenAI WS v2 mode 的 extra 键名。 +var openaiWSModeExtraKeys = []string{ + "openai_oauth_responses_websockets_v2_mode", + "openai_apikey_responses_websockets_v2_mode", +} + +// validateOpenAIWSModeExtraValues 校验 extra 中 WS v2 mode 字段的值域, +// 只允许 off / ctx_pool / passthrough。 +func validateOpenAIWSModeExtraValues(extra map[string]any) error { + for _, key := range openaiWSModeExtraKeys { + raw, exists := extra[key] + if !exists { + continue + } + val, ok := raw.(string) + if !ok { + return infraerrors.BadRequest( + "INVALID_OPENAI_WS_MODE", + fmt.Sprintf("%s must be a string, got %T", key, raw), + ) + } + switch strings.TrimSpace(strings.ToLower(val)) { + case OpenAIWSIngressModeOff, OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough: + // valid + default: + return infraerrors.BadRequest( + "INVALID_OPENAI_WS_MODE", + fmt.Sprintf("%s must be one of off, ctx_pool, passthrough; got %q", key, val), + ) + } + } + return nil +} + +func validateOpenAIBulkScopedAccounts(accountsByID map[int64]*Account, accountIDs []int64) error { + var expectedType string + + for _, accountID := range accountIDs { + account := accountsByID[accountID] + if account == nil { + return infraerrors.BadRequest( + "BULK_OPENAI_SCOPE_ACCOUNT_MISSING", + fmt.Sprintf("account %d not found for OpenAI scoped bulk update", accountID), + ) + } + + if account.Platform != PlatformOpenAI { + return infraerrors.BadRequest( + "BULK_OPENAI_SCOPE_PLATFORM_MISMATCH", + "OpenAI scoped bulk fields require all selected accounts to be OpenAI", + ) + } + + if account.Type != AccountTypeOAuth && account.Type != AccountTypeAPIKey { + return infraerrors.BadRequest( + "BULK_OPENAI_SCOPE_TYPE_UNSUPPORTED", + "OpenAI scoped bulk fields only support oauth or apikey account types", + ) + } + + if expectedType == "" { + expectedType = account.Type + continue + } + + if account.Type != expectedType { + return infraerrors.BadRequest( + "BULK_OPENAI_SCOPE_TYPE_MISMATCH", + "OpenAI scoped bulk fields require all selected accounts to have the same type", + ) + } + } + + return nil +} + // BulkUpdateAccounts updates multiple accounts in one request. // It merges credentials/extra keys instead of overwriting the whole object. func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) { @@ -1562,32 +1631,48 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp } needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck + needOpenAIScopeCheck := hasOpenAIBulkScopedExtraField(input.Extra) + needAccountSnapshot := needMixedChannelCheck || needOpenAIScopeCheck - // 预加载账号平台信息(混合渠道检查需要)。 + // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。 platformByID := map[int64]string{} - if needMixedChannelCheck { + accountByID := map[int64]*Account{} + groupAccountsByID := map[int64][]Account{} + groupNameByID := map[int64]string{} + if needAccountSnapshot { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) if err != nil { return nil, err - } - for _, account := range accounts { - if account != nil { - platformByID[account.ID] = account.Platform + } else { + for _, account := range accounts { + if account != nil { + accountByID[account.ID] = account + platformByID[account.ID] = account.Platform + } } } } - // 预检查混合渠道风险:在任何写操作之前,若发现风险立即返回错误。 + if needOpenAIScopeCheck { + if err := validateOpenAIBulkScopedAccounts(accountByID, input.AccountIDs); err != nil { + return nil, err + } + } + + // 校验 WS v2 mode 字段值域 + if len(input.Extra) > 0 { + if err := validateOpenAIWSModeExtraValues(input.Extra); err != nil { + return nil, err + } + } + if needMixedChannelCheck { - for _, accountID := range input.AccountIDs { - platform := platformByID[accountID] - if platform == "" { - continue - } - if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { - return nil, err - } + loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs) + if err != nil { + return nil, err } + groupAccountsByID = loadedAccounts + groupNameByID = loadedNames } if input.RateMultiplier != nil { @@ -1622,6 +1707,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if input.Schedulable != nil { repoUpdates.Schedulable = input.Schedulable } + if input.AutoPauseOnExpired != nil { + repoUpdates.AutoPauseOnExpired = input.AutoPauseOnExpired + } // Run bulk update for column/jsonb fields first. if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil { @@ -1631,8 +1719,34 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp // Handle group bindings per account (requires individual operations). for _, accountID := range input.AccountIDs { entry := BulkUpdateAccountResult{AccountID: accountID} + platform := "" if input.GroupIDs != nil { + // 检查混合渠道风险(除非用户已确认) + if !input.SkipMixedChannelCheck { + platform = platformByID[accountID] + if platform == "" { + account, err := s.accountRepo.GetByID(ctx, accountID) + if err != nil { + entry.Success = false + entry.Error = err.Error() + result.Failed++ + result.FailedIDs = append(result.FailedIDs, accountID) + result.Results = append(result.Results, entry) + continue + } + platform = account.Platform + } + if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil { + entry.Success = false + entry.Error = err.Error() + result.Failed++ + result.FailedIDs = append(result.FailedIDs, accountID) + result.Results = append(result.Results, entry) + continue + } + } + if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil { entry.Success = false entry.Error = err.Error() @@ -1641,6 +1755,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp result.Results = append(result.Results, entry) continue } + if !input.SkipMixedChannelCheck && platform != "" { + updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform) + } } entry.Success = true @@ -2028,6 +2145,7 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr ProxyURL: proxyURL, Timeout: proxyQualityRequestTimeout, ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout, + ProxyStrict: true, }) if err != nil { result.Items = append(result.Items, ProxyQualityCheckItem{ @@ -2311,6 +2429,41 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc return nil } +func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) { + accountsByGroup := make(map[int64][]Account) + groupNameByID := make(map[int64]string) + if len(groupIDs) == 0 { + return accountsByGroup, groupNameByID, nil + } + + seen := make(map[int64]struct{}, len(groupIDs)) + for _, groupID := range groupIDs { + if groupID <= 0 { + continue + } + if _, ok := seen[groupID]; ok { + continue + } + seen[groupID] = struct{}{} + + accounts, err := s.accountRepo.ListByGroup(ctx, groupID) + if err != nil { + return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err) + } + accountsByGroup[groupID] = accounts + + group, err := s.groupRepo.GetByID(ctx, groupID) + if err != nil { + continue + } + if group != nil { + groupNameByID[groupID] = group.Name + } + } + + return accountsByGroup, groupNameByID, nil +} + func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { if len(groupIDs) == 0 { return nil @@ -2340,6 +2493,71 @@ func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs [ return nil } +func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error { + currentPlatform := getAccountPlatform(currentAccountPlatform) + if currentPlatform == "" { + return nil + } + + for _, groupID := range groupIDs { + accounts := accountsByGroup[groupID] + for _, account := range accounts { + if currentAccountID > 0 && account.ID == currentAccountID { + continue + } + + otherPlatform := getAccountPlatform(account.Platform) + if otherPlatform == "" { + continue + } + + if currentPlatform != otherPlatform { + groupName := fmt.Sprintf("Group %d", groupID) + if name := strings.TrimSpace(groupNameByID[groupID]); name != "" { + groupName = name + } + + return &MixedChannelError{ + GroupID: groupID, + GroupName: groupName, + CurrentPlatform: currentPlatform, + OtherPlatform: otherPlatform, + } + } + } + } + + return nil +} + +func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) { + if len(groupIDs) == 0 || accountID <= 0 || platform == "" { + return + } + for _, groupID := range groupIDs { + if groupID <= 0 { + continue + } + accounts := accountsByGroup[groupID] + found := false + for i := range accounts { + if accounts[i].ID != accountID { + continue + } + accounts[i].Platform = platform + found = true + break + } + if !found { + accounts = append(accounts, Account{ + ID: accountID, + Platform: platform, + }) + } + accountsByGroup[groupID] = accounts + } +} + // CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform. func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs) diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index 4845d87c1..3fb14cae7 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -12,23 +12,25 @@ import ( type accountRepoStubForBulkUpdate struct { accountRepoStub - bulkUpdateErr error - bulkUpdateIDs []int64 - bindGroupErrByID map[int64]error - bindGroupsCalls []int64 - getByIDsAccounts []*Account - getByIDsErr error - getByIDsCalled bool - getByIDsIDs []int64 - getByIDAccounts map[int64]*Account - getByIDErrByID map[int64]error - getByIDCalled []int64 - listByGroupData map[int64][]Account - listByGroupErr map[int64]error + bulkUpdateErr error + bulkUpdateIDs []int64 + bulkUpdatePayload AccountBulkUpdate + bindGroupErrByID map[int64]error + bindGroupsCalls []int64 + getByIDsAccounts []*Account + getByIDsErr error + getByIDsCalled bool + getByIDsIDs []int64 + getByIDAccounts map[int64]*Account + getByIDErrByID map[int64]error + getByIDCalled []int64 + listByGroupData map[int64][]Account + listByGroupErr map[int64]error } -func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { +func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) { s.bulkUpdateIDs = append([]int64{}, ids...) + s.bulkUpdatePayload = updates if s.bulkUpdateErr != nil { return 0, s.bulkUpdateErr } @@ -139,34 +141,134 @@ func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) require.Contains(t, err.Error(), "group repository not configured") } -// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies -// that the global pre-check detects a conflict with existing group members and returns an -// error before any DB write is performed. -func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict(t *testing.T) { +func TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot(t *testing.T) { repo := &accountRepoStubForBulkUpdate{ getByIDsAccounts: []*Account{ - {ID: 1, Platform: PlatformAntigravity}, + {ID: 1, Platform: PlatformAnthropic}, + {ID: 2, Platform: PlatformAntigravity}, }, - // Group 10 already contains an Anthropic account. listByGroupData: map[int64][]Account{ - 10: {{ID: 99, Platform: PlatformAnthropic}}, + 10: {}, }, } svc := &adminServiceImpl{ accountRepo: repo, - groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "target-group"}}, + groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "目标分组"}}, } groupIDs := []int64{10} input := &BulkUpdateAccountsInput{ - AccountIDs: []int64{1}, + AccountIDs: []int64{1, 2}, GroupIDs: &groupIDs, } + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 1, result.Success) + require.Equal(t, 1, result.Failed) + require.ElementsMatch(t, []int64{1}, result.SuccessIDs) + require.ElementsMatch(t, []int64{2}, result.FailedIDs) + require.Len(t, result.Results, 2) + require.Contains(t, result.Results[1].Error, "mixed channel") + require.Equal(t, []int64{1}, repo.bindGroupsCalls) +} + +func TestAdminService_BulkUpdateAccounts_ForwardsAutoPauseOnExpired(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{} + svc := &adminServiceImpl{accountRepo: repo} + + autoPause := true + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{101}, + AutoPauseOnExpired: &autoPause, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 1, result.Success) + require.NotNil(t, repo.bulkUpdatePayload.AutoPauseOnExpired) + require.True(t, *repo.bulkUpdatePayload.AutoPauseOnExpired) +} + +func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsMixedTypes(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}, + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2}, + Extra: map[string]any{"openai_passthrough": true}, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "same type") + require.Empty(t, repo.bulkUpdateIDs) + require.True(t, repo.getByIDsCalled) +} + +func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsNonOpenAIPlatform(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}, + {ID: 2, Platform: PlatformAnthropic, Type: AccountTypeOAuth}, + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2}, + Extra: map[string]any{"openai_oauth_responses_websockets_v2_mode": "shared"}, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "OpenAI") + require.Empty(t, repo.bulkUpdateIDs) +} + +func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraAllowsSameTypeOpenAI(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeOAuth}, + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2}, + Extra: map[string]any{"codex_cli_only": true}, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 2, result.Success) + require.ElementsMatch(t, []int64{1, 2}, repo.bulkUpdateIDs) +} + +func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsMissingAccount(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth}, + }, + } + svc := &adminServiceImpl{accountRepo: repo} + + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2}, + Extra: map[string]any{"openai_passthrough": true}, + } + result, err := svc.BulkUpdateAccounts(context.Background(), input) require.Nil(t, result) require.Error(t, err) - require.Contains(t, err.Error(), "mixed channel") - // No BindGroups should have been called since the check runs before any write. - require.Empty(t, repo.bindGroupsCalls) + require.Contains(t, err.Error(), "not found") + require.Empty(t, repo.bulkUpdateIDs) } diff --git a/backend/internal/service/admin_service_create_openai_ws_mode_test.go b/backend/internal/service/admin_service_create_openai_ws_mode_test.go new file mode 100644 index 000000000..d8d170998 --- /dev/null +++ b/backend/internal/service/admin_service_create_openai_ws_mode_test.go @@ -0,0 +1,61 @@ +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type accountRepoStubForCreateModeValidation struct { + AccountRepository + createCalled bool + createErr error +} + +func (s *accountRepoStubForCreateModeValidation) Create(_ context.Context, account *Account) error { + s.createCalled = true + if s.createErr != nil { + return s.createErr + } + if account != nil && account.ID == 0 { + account.ID = 1 + } + return nil +} + +func TestAdminService_CreateAccount_RejectsInvalidOpenAIWSMode(t *testing.T) { + repo := &accountRepoStubForCreateModeValidation{} + svc := &adminServiceImpl{accountRepo: repo} + + account, err := svc.CreateAccount(context.Background(), &CreateAccountInput{ + Name: "ws-mode-invalid", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{"api_key": "sk-test"}, + Extra: map[string]any{"openai_apikey_responses_websockets_v2_mode": "shared"}, + SkipDefaultGroupBind: true, + }) + require.Nil(t, account) + require.Error(t, err) + require.Contains(t, err.Error(), "INVALID_OPENAI_WS_MODE") + require.False(t, repo.createCalled) +} + +func TestAdminService_CreateAccount_AcceptsValidOpenAIWSMode(t *testing.T) { + repo := &accountRepoStubForCreateModeValidation{} + svc := &adminServiceImpl{accountRepo: repo} + + account, err := svc.CreateAccount(context.Background(), &CreateAccountInput{ + Name: "ws-mode-valid", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{"api_key": "sk-test"}, + Extra: map[string]any{"openai_apikey_responses_websockets_v2_mode": "passthrough"}, + SkipDefaultGroupBind: true, + }) + require.NoError(t, err) + require.NotNil(t, account) + require.True(t, repo.createCalled) + require.Equal(t, int64(1), account.ID) +} diff --git a/backend/internal/service/admin_service_create_user_test.go b/backend/internal/service/admin_service_create_user_test.go index c5b1e38d3..a0fe4d87b 100644 --- a/backend/internal/service/admin_service_create_user_test.go +++ b/backend/internal/service/admin_service_create_user_test.go @@ -7,7 +7,6 @@ import ( "errors" "testing" - "github.com/Wei-Shaw/sub2api/internal/config" "github.com/stretchr/testify/require" ) @@ -66,32 +65,3 @@ func TestAdminService_CreateUser_CreateError(t *testing.T) { require.ErrorIs(t, err, createErr) require.Empty(t, repo.created) } - -func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) { - repo := &userRepoStub{nextID: 21} - assigner := &defaultSubscriptionAssignerStub{} - cfg := &config.Config{ - Default: config.DefaultConfig{ - UserBalance: 0, - UserConcurrency: 1, - }, - } - settingService := NewSettingService(&settingRepoStub{values: map[string]string{ - SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`, - }}, cfg) - svc := &adminServiceImpl{ - userRepo: repo, - settingService: settingService, - defaultSubAssigner: assigner, - } - - _, err := svc.CreateUser(context.Background(), &CreateUserInput{ - Email: "new-user@test.com", - Password: "password", - }) - require.NoError(t, err) - require.Len(t, assigner.calls, 1) - require.Equal(t, int64(21), assigner.calls[0].UserID) - require.Equal(t, int64(5), assigner.calls[0].GroupID) - require.Equal(t, 30, assigner.calls[0].ValidityDays) -} diff --git a/backend/internal/service/admin_service_openai_ws_mode_validation_test.go b/backend/internal/service/admin_service_openai_ws_mode_validation_test.go new file mode 100644 index 000000000..820342f3b --- /dev/null +++ b/backend/internal/service/admin_service_openai_ws_mode_validation_test.go @@ -0,0 +1,52 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestValidateOpenAIWSModeExtraValues(t *testing.T) { + t.Parallel() + + t.Run("accepts supported mode values", func(t *testing.T) { + t.Parallel() + + err := validateOpenAIWSModeExtraValues(map[string]any{ + "openai_oauth_responses_websockets_v2_mode": " passthrough ", + "openai_apikey_responses_websockets_v2_mode": "CTX_POOL", + }) + require.NoError(t, err) + }) + + t.Run("accepts missing mode keys", func(t *testing.T) { + t.Parallel() + + err := validateOpenAIWSModeExtraValues(map[string]any{ + "codex_cli_only": true, + }) + require.NoError(t, err) + }) + + t.Run("rejects invalid mode value", func(t *testing.T) { + t.Parallel() + + err := validateOpenAIWSModeExtraValues(map[string]any{ + "openai_oauth_responses_websockets_v2_mode": "shared", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "INVALID_OPENAI_WS_MODE") + require.Contains(t, err.Error(), "off, ctx_pool, passthrough") + }) + + t.Run("rejects non-string mode value", func(t *testing.T) { + t.Parallel() + + err := validateOpenAIWSModeExtraValues(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": true, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "INVALID_OPENAI_WS_MODE") + require.Contains(t, err.Error(), "must be a string") + }) +} diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index 5f6691be2..b67c7fafa 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -112,10 +112,7 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig } } - client, err := antigravity.NewClient(proxyURL) - if err != nil { - return nil, fmt.Errorf("create antigravity client failed: %w", err) - } + client := antigravity.NewClient(proxyURL) // 交换 token tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier) @@ -170,10 +167,7 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken time.Sleep(backoff) } - client, err := antigravity.NewClient(proxyURL) - if err != nil { - return nil, fmt.Errorf("create antigravity client failed: %w", err) - } + client := antigravity.NewClient(proxyURL) tokenResp, err := client.RefreshToken(ctx, refreshToken) if err == nil { now := time.Now() @@ -215,10 +209,7 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr } // 获取用户信息(email) - client, err := antigravity.NewClient(proxyURL) - if err != nil { - return nil, fmt.Errorf("create antigravity client failed: %w", err) - } + client := antigravity.NewClient(proxyURL) userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken) if err != nil { fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err) @@ -318,10 +309,7 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac time.Sleep(backoff) } - client, err := antigravity.NewClient(proxyURL) - if err != nil { - return "", fmt.Errorf("create antigravity client failed: %w", err) - } + client := antigravity.NewClient(proxyURL) loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken) if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" { diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go index e950ec1d9..07eb563d0 100644 --- a/backend/internal/service/antigravity_quota_fetcher.go +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -2,7 +2,6 @@ package service import ( "context" - "fmt" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" @@ -32,10 +31,7 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou accessToken := account.GetCredential("access_token") projectID := account.GetCredential("project_id") - client, err := antigravity.NewClient(proxyURL) - if err != nil { - return nil, fmt.Errorf("create antigravity client failed: %w", err) - } + client := antigravity.NewClient(proxyURL) // 调用 API 获取配额 modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index fe3a0f258..eae7bd539 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -56,20 +56,15 @@ type JWTClaims struct { // AuthService 认证服务 type AuthService struct { - userRepo UserRepository - redeemRepo RedeemCodeRepository - refreshTokenCache RefreshTokenCache - cfg *config.Config - settingService *SettingService - emailService *EmailService - turnstileService *TurnstileService - emailQueueService *EmailQueueService - promoService *PromoService - defaultSubAssigner DefaultSubscriptionAssigner -} - -type DefaultSubscriptionAssigner interface { - AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) + userRepo UserRepository + redeemRepo RedeemCodeRepository + refreshTokenCache RefreshTokenCache + cfg *config.Config + settingService *SettingService + emailService *EmailService + turnstileService *TurnstileService + emailQueueService *EmailQueueService + promoService *PromoService } // NewAuthService 创建认证服务实例 @@ -83,19 +78,17 @@ func NewAuthService( turnstileService *TurnstileService, emailQueueService *EmailQueueService, promoService *PromoService, - defaultSubAssigner DefaultSubscriptionAssigner, ) *AuthService { return &AuthService{ - userRepo: userRepo, - redeemRepo: redeemRepo, - refreshTokenCache: refreshTokenCache, - cfg: cfg, - settingService: settingService, - emailService: emailService, - turnstileService: turnstileService, - emailQueueService: emailQueueService, - promoService: promoService, - defaultSubAssigner: defaultSubAssigner, + userRepo: userRepo, + redeemRepo: redeemRepo, + refreshTokenCache: refreshTokenCache, + cfg: cfg, + settingService: settingService, + emailService: emailService, + turnstileService: turnstileService, + emailQueueService: emailQueueService, + promoService: promoService, } } @@ -195,7 +188,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err) return "", nil, ErrServiceUnavailable } - s.assignDefaultSubscriptions(ctx, user.ID) // 标记邀请码为已使用(如果使用了邀请码) if invitationRedeemCode != nil { @@ -485,7 +477,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username } } else { user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -581,7 +572,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema } } else { user = newUser - s.assignDefaultSubscriptions(ctx, user.ID) } } else { logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err) @@ -607,23 +597,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema return tokenPair, user, nil } -func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) { - if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 { - return - } - items := s.settingService.GetDefaultSubscriptions(ctx) - for _, item := range items { - if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{ - UserID: userID, - GroupID: item.GroupID, - ValidityDays: item.ValidityDays, - Notes: "auto assigned by default user subscriptions setting", - }); err != nil { - logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err) - } - } -} - // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 1999e759e..93659743f 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -56,21 +56,6 @@ type emailCacheStub struct { err error } -type defaultSubscriptionAssignerStub struct { - calls []AssignSubscriptionInput - err error -} - -func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) { - if input != nil { - s.calls = append(s.calls, *input) - } - if s.err != nil { - return nil, false, s.err - } - return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil -} - func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) { if s.err != nil { return nil, s.err @@ -138,7 +123,6 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E nil, nil, nil, // promoService - nil, // defaultSubAssigner ) } @@ -397,23 +381,3 @@ func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) { require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second) } - -func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) { - repo := &userRepoStub{nextID: 42} - assigner := &defaultSubscriptionAssignerStub{} - service := newAuthService(repo, map[string]string{ - SettingKeyRegistrationEnabled: "true", - SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`, - }, nil) - service.defaultSubAssigner = assigner - - _, user, err := service.Register(context.Background(), "default-sub@test.com", "password") - require.NoError(t, err) - require.NotNil(t, user) - require.Len(t, assigner.calls, 2) - require.Equal(t, int64(42), assigner.calls[0].UserID) - require.Equal(t, int64(11), assigner.calls[0].GroupID) - require.Equal(t, 30, assigner.calls[0].ValidityDays) - require.Equal(t, int64(12), assigner.calls[1].GroupID) - require.Equal(t, 7, assigner.calls[1].ValidityDays) -} diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go index 36cb1e065..7dd9edca8 100644 --- a/backend/internal/service/auth_service_turnstile_register_test.go +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -52,7 +52,6 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier turnstileService, nil, // emailQueueService nil, // promoService - nil, // defaultSubAssigner ) } diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index e055c0f78..34038f5b2 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -2,6 +2,7 @@ package service import ( "context" + "errors" "fmt" "strconv" "sync" @@ -193,7 +194,7 @@ func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) { } case cacheWriteDeductBalance: if s.cache != nil { - if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil { + if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil && !errors.Is(err, ErrBalanceCacheNotFound) { logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err) } } @@ -335,7 +336,13 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int if s.cache == nil { return nil } - return s.cache.DeductUserBalance(ctx, userID, amount) + err := s.cache.DeductUserBalance(ctx, userID, amount) + if errors.Is(err, ErrBalanceCacheNotFound) { + // 缓存 key 不存在(已过期),无法原子扣减,不阻塞主流程。 + // 下次 GetUserBalance 将从数据库回源重建缓存。 + return nil + } + return err } // QueueDeductBalance 异步扣减余额缓存 diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go index f71098b16..d3a4d119b 100644 --- a/backend/internal/service/claude_code_validator.go +++ b/backend/internal/service/claude_code_validator.go @@ -4,7 +4,6 @@ import ( "context" "net/http" "regexp" - "strconv" "strings" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" @@ -18,9 +17,6 @@ var ( // User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感) claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`) - // 带捕获组的版本提取正则 - claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`) - // metadata.user_id 格式: user_{64位hex}_account__session_{uuid} userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`) @@ -274,55 +270,3 @@ func IsClaudeCodeClient(ctx context.Context) bool { func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context { return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode) } - -// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号 -// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串 -func (v *ClaudeCodeValidator) ExtractVersion(ua string) string { - matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua) - if len(matches) >= 2 { - return matches[1] - } - return "" -} - -// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中 -func SetClaudeCodeVersion(ctx context.Context, version string) context.Context { - return context.WithValue(ctx, ctxkey.ClaudeCodeVersion, version) -} - -// GetClaudeCodeVersion 从 context 中获取 Claude Code 版本号 -func GetClaudeCodeVersion(ctx context.Context) string { - if v, ok := ctx.Value(ctxkey.ClaudeCodeVersion).(string); ok { - return v - } - return "" -} - -// CompareVersions 比较两个 semver 版本号 -// 返回: -1 (a < b), 0 (a == b), 1 (a > b) -func CompareVersions(a, b string) int { - aParts := parseSemver(a) - bParts := parseSemver(b) - for i := 0; i < 3; i++ { - if aParts[i] < bParts[i] { - return -1 - } - if aParts[i] > bParts[i] { - return 1 - } - } - return 0 -} - -// parseSemver 解析 semver 版本号为 [major, minor, patch] -func parseSemver(v string) [3]int { - v = strings.TrimPrefix(v, "v") - parts := strings.Split(v, ".") - result := [3]int{0, 0, 0} - for i := 0; i < len(parts) && i < 3; i++ { - if parsed, err := strconv.Atoi(parts[i]); err == nil { - result[i] = parsed - } - } - return result -} diff --git a/backend/internal/service/claude_code_validator_test.go b/backend/internal/service/claude_code_validator_test.go index f87c56e83..a4cd18866 100644 --- a/backend/internal/service/claude_code_validator_test.go +++ b/backend/internal/service/claude_code_validator_test.go @@ -56,51 +56,3 @@ func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) { ok := validator.Validate(req, nil) require.True(t, ok) } - -func TestExtractVersion(t *testing.T) { - v := NewClaudeCodeValidator() - tests := []struct { - ua string - want string - }{ - {"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"}, - {"claude-cli/1.0.0", "1.0.0"}, - {"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感 - {"curl/8.0.0", ""}, // 非 Claude CLI - {"", ""}, // 空字符串 - {"claude-cli/", ""}, // 无版本号 - {"claude-cli/2.1.22-beta", "2.1.22"}, // 带后缀仍提取主版本号 - } - for _, tt := range tests { - got := v.ExtractVersion(tt.ua) - require.Equal(t, tt.want, got, "ExtractVersion(%q)", tt.ua) - } -} - -func TestCompareVersions(t *testing.T) { - tests := []struct { - a, b string - want int - }{ - {"2.1.0", "2.1.0", 0}, // 相等 - {"2.1.1", "2.1.0", 1}, // patch 更大 - {"2.0.0", "2.1.0", -1}, // minor 更小 - {"3.0.0", "2.99.99", 1}, // major 更大 - {"1.0.0", "2.0.0", -1}, // major 更小 - {"0.0.1", "0.0.0", 1}, // patch 差异 - {"", "1.0.0", -1}, // 空字符串 vs 正常版本 - {"v2.1.0", "2.1.0", 0}, // v 前缀处理 - } - for _, tt := range tests { - got := CompareVersions(tt.a, tt.b) - require.Equal(t, tt.want, got, "CompareVersions(%q, %q)", tt.a, tt.b) - } -} - -func TestSetGetClaudeCodeVersion(t *testing.T) { - ctx := context.Background() - require.Equal(t, "", GetClaudeCodeVersion(ctx), "empty context should return empty string") - - ctx = SetClaudeCodeVersion(ctx, "2.1.63") - require.Equal(t, "2.1.63", GetClaudeCodeVersion(ctx)) -} diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go index f6cab204d..996c829cd 100644 --- a/backend/internal/service/claude_token_provider.go +++ b/backend/internal/service/claude_token_provider.go @@ -15,6 +15,20 @@ const ( claudeLockWaitTime = 200 * time.Millisecond ) +func waitClaudeLockRetry(ctx context.Context, wait time.Duration) error { + if wait <= 0 { + return nil + } + timer := time.NewTimer(wait) + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + // ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) type ClaudeTokenCache = GeminiTokenCache @@ -168,7 +182,9 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou } } else { // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存 - time.Sleep(claudeLockWaitTime) + if waitErr := waitClaudeLockRetry(ctx, claudeLockWaitTime); waitErr != nil { + return "", waitErr + } if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID) return token, nil diff --git a/backend/internal/service/claude_token_provider_test.go b/backend/internal/service/claude_token_provider_test.go index 3e21f6f4a..09f3a31e8 100644 --- a/backend/internal/service/claude_token_provider_test.go +++ b/backend/internal/service/claude_token_provider_test.go @@ -800,6 +800,34 @@ func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) { require.NotEmpty(t, token) } +func TestClaudeTokenProvider_Real_LockFailedWait_ContextCanceled(t *testing.T) { + cache := newClaudeTokenCacheStub() + cache.lockAcquired = false + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 3001, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + provider := NewClaudeTokenProvider(nil, cache, nil) + start := time.Now() + token, err := provider.GetAccessToken(ctx, account) + elapsed := time.Since(start) + + require.ErrorIs(t, err, context.Canceled) + require.Empty(t, token) + require.Less(t, elapsed, claudeLockWaitTime/2, "context canceled should short-circuit lock wait") +} + func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) { cache := newClaudeTokenCacheStub() cache.lockAcquired = false // Lock acquisition fails diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go index 6a9167400..040b2357b 100644 --- a/backend/internal/service/crs_sync_service.go +++ b/backend/internal/service/crs_sync_service.go @@ -221,7 +221,7 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username, AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts, }) if err != nil { - return nil, fmt.Errorf("create http client failed: %w", err) + client = &http.Client{Timeout: 20 * time.Second} } adminToken, err := crsLogin(ctx, client, normalizedURL, username, password) diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index 2af43386b..4528def3d 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -140,14 +140,6 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi return stats, nil } -func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) { - stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) - if err != nil { - return nil, fmt.Errorf("get group stats with filters: %w", err) - } - return stats, nil -} - func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) { data, err := s.cache.GetDashboardStats(ctx) if err != nil { diff --git a/backend/internal/service/data_management_service.go b/backend/internal/service/data_management_service.go index b525c0fae..b0d4d6da8 100644 --- a/backend/internal/service/data_management_service.go +++ b/backend/internal/service/data_management_service.go @@ -63,11 +63,11 @@ func NewDataManagementService() *DataManagementService { } func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService { - _ = dialTimeout path := strings.TrimSpace(socketPath) if path == "" { path = DefaultDataManagementAgentSocketPath } + _ = dialTimeout return &DataManagementService{ socketPath: path, } diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index df2130027..ce5710945 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -118,9 +118,8 @@ const ( SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组) // 默认配置 - SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 - SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 - SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) + SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 + SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 // 管理员 API Key SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) @@ -128,6 +127,9 @@ const ( // Gemini 配额策略(JSON) SettingKeyGeminiQuotaPolicy = "gemini_quota_policy" + // Bulk edit template library(JSON) + SettingKeyBulkEditTemplateLibrary = "bulk_edit_template_library_v1" + // Model fallback settings SettingKeyEnableModelFallback = "enable_model_fallback" SettingKeyFallbackModelAnthropic = "fallback_model_anthropic" @@ -194,13 +196,6 @@ const ( // ========================= SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节) - - // ========================= - // Claude Code Version Check - // ========================= - - // SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查) - SettingKeyMinClaudeCodeVersion = "min_claude_code_version" ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 02f9a6a3a..0385ab562 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -520,7 +520,6 @@ type GatewayService struct { concurrencyService *ConcurrencyService claudeTokenProvider *ClaudeTokenProvider sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) - rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken) userGroupRateCache *gocache.Cache userGroupRateSF singleflight.Group modelsListCache *gocache.Cache @@ -550,7 +549,6 @@ func NewGatewayService( deferredService *DeferredService, claudeTokenProvider *ClaudeTokenProvider, sessionLimitCache SessionLimitCache, - rpmCache RPMCache, digestStore *DigestSessionStore, ) *GatewayService { userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) @@ -576,7 +574,6 @@ func NewGatewayService( deferredService: deferredService, claudeTokenProvider: claudeTokenProvider, sessionLimitCache: sessionLimitCache, - rpmCache: rpmCache, userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), modelsListCache: gocache.New(modelsListTTL, time.Minute), modelsListCacheTTL: modelsListTTL, @@ -1157,7 +1154,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return nil, errors.New("no available accounts") } ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) isExcluded := func(accountID int64) bool { if excludedIDs == nil { @@ -1233,10 +1229,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro filteredWindowCost++ continue } - // RPM 检查(非粘性会话路径) - if !s.isAccountSchedulableForRPM(ctx, account, false) { - continue - } routingCandidates = append(routingCandidates, account) } @@ -1260,9 +1252,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && - s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) && - - s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查 + s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 @@ -1416,9 +1406,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro s.isAccountAllowedForPlatform(account, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && - s.isAccountSchedulableForWindowCost(ctx, account, true) && - - s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查 + s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 @@ -1484,10 +1472,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { continue } - // RPM 检查(非粘性会话路径) - if !s.isAccountSchedulableForRPM(ctx, acc, false) { - continue - } candidates = append(candidates, acc) } @@ -2174,88 +2158,6 @@ checkSchedulability: return true } -// rpmPrefetchContextKey is the context key for prefetched RPM counts. -type rpmPrefetchContextKeyType struct{} - -var rpmPrefetchContextKey = rpmPrefetchContextKeyType{} - -func rpmFromPrefetchContext(ctx context.Context, accountID int64) (int, bool) { - if v, ok := ctx.Value(rpmPrefetchContextKey).(map[int64]int); ok { - count, found := v[accountID] - return count, found - } - return 0, false -} - -// withRPMPrefetch 批量预取所有候选账号的 RPM 计数 -func (s *GatewayService) withRPMPrefetch(ctx context.Context, accounts []Account) context.Context { - if s.rpmCache == nil { - return ctx - } - - var ids []int64 - for i := range accounts { - if accounts[i].IsAnthropicOAuthOrSetupToken() && accounts[i].GetBaseRPM() > 0 { - ids = append(ids, accounts[i].ID) - } - } - if len(ids) == 0 { - return ctx - } - - counts, err := s.rpmCache.GetRPMBatch(ctx, ids) - if err != nil { - return ctx // 失败开放 - } - return context.WithValue(ctx, rpmPrefetchContextKey, counts) -} - -// isAccountSchedulableForRPM 检查账号是否可根据 RPM 进行调度 -// 仅适用于 Anthropic OAuth/SetupToken 账号 -func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account *Account, isSticky bool) bool { - if !account.IsAnthropicOAuthOrSetupToken() { - return true - } - baseRPM := account.GetBaseRPM() - if baseRPM <= 0 { - return true - } - - // 尝试从预取缓存获取 - var currentRPM int - if count, ok := rpmFromPrefetchContext(ctx, account.ID); ok { - currentRPM = count - } else if s.rpmCache != nil { - if count, err := s.rpmCache.GetRPM(ctx, account.ID); err == nil { - currentRPM = count - } - // 失败开放:GetRPM 错误时允许调度 - } - - schedulability := account.CheckRPMSchedulability(currentRPM) - switch schedulability { - case WindowCostSchedulable: - return true - case WindowCostStickyOnly: - return isSticky - case WindowCostNotSchedulable: - return false - } - return true -} - -// IncrementAccountRPM increments the RPM counter for the given account. -// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口, -// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit -// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。 -func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error { - if s.rpmCache == nil { - return nil - } - _, err := s.rpmCache.IncrementRPM(ctx, accountID) - return err -} - // checkAndRegisterSession 检查并注册会话,用于会话数量限制 // 仅适用于 Anthropic OAuth/SetupToken 账号 // sessionID: 会话标识符(使用粘性会话的 hash) @@ -2450,7 +2352,7 @@ func sameAccountWithLoadGroup(a, b accountWithLoad) bool { // shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。 // // 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。 -// 因此这里采用"组内分区 + 分区内 shuffle"的方式: +// 因此这里采用“组内分区 + 分区内 shuffle”的方式: // - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前; // - 再分别在各段内随机打散,避免热点。 func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { @@ -2590,7 +2492,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2613,10 +2515,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } accountsLoaded = true - // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 - ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) - routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { if id > 0 { @@ -2644,12 +2542,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } - if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { - continue - } - if !s.isAccountSchedulableForRPM(ctx, acc, false) { - continue - } if selected == nil { selected = acc continue @@ -2700,7 +2592,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { return account, nil } } @@ -2721,10 +2613,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } } - // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) - ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) - // 3. 按优先级+最久未用选择(考虑模型支持) var selected *Account for i := range accounts { @@ -2743,12 +2631,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } - if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { - continue - } - if !s.isAccountSchedulableForRPM(ctx, acc, false) { - continue - } if selected == nil { selected = acc continue @@ -2818,7 +2700,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) @@ -2839,10 +2721,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } accountsLoaded = true - // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存 - ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) - routingSet := make(map[int64]struct{}, len(routingAccountIDs)) for _, id := range routingAccountIDs { if id > 0 { @@ -2874,12 +2752,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } - if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { - continue - } - if !s.isAccountSchedulableForRPM(ctx, acc, false) { - continue - } if selected == nil { selected = acc continue @@ -2930,7 +2802,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { return account, nil } @@ -2949,10 +2821,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } } - // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1) - ctx = s.withWindowCostPrefetch(ctx, accounts) - ctx = s.withRPMPrefetch(ctx, accounts) - // 3. 按优先级+最久未用选择(考虑模型支持和混合调度) var selected *Account for i := range accounts { @@ -2975,12 +2843,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } - if !s.isAccountSchedulableForWindowCost(ctx, acc, false) { - continue - } - if !s.isAccountSchedulableForRPM(ctx, acc, false) { - continue - } if selected == nil { selected = acc continue @@ -5332,7 +5194,7 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool { } func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { - // 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。 + // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。 // 默认保守:无法识别则不切换。 msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody))) if msg == "" { @@ -6522,7 +6384,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) + return fmt.Errorf("create usage log: %w", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { @@ -6531,7 +6393,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu return nil } - shouldBill := inserted || err != nil + shouldBill := inserted // 根据计费类型执行扣费 if isSubscriptionBilling { @@ -6720,7 +6582,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * inserted, err := s.usageLogRepo.Create(ctx, usageLog) if err != nil { - logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err) + return fmt.Errorf("create usage log: %w", err) } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { @@ -6729,7 +6591,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * return nil } - shouldBill := inserted || err != nil + shouldBill := inserted // 根据计费类型执行扣费 if isSubscriptionBilling { diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index 08a74a372..e866bdc36 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -1045,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR ValidateResolvedIP: true, }) if err != nil { - return "", fmt.Errorf("create http client failed: %w", err) + client = &http.Client{Timeout: 30 * time.Second} } resp, err := client.Do(req) diff --git a/backend/internal/service/http_upstream_profile.go b/backend/internal/service/http_upstream_profile.go new file mode 100644 index 000000000..fef1dc191 --- /dev/null +++ b/backend/internal/service/http_upstream_profile.go @@ -0,0 +1,41 @@ +package service + +import "context" + +// HTTPUpstreamProfile 标识上游 HTTP 请求的协议策略分类。 +type HTTPUpstreamProfile string + +const ( + HTTPUpstreamProfileDefault HTTPUpstreamProfile = "" + HTTPUpstreamProfileOpenAI HTTPUpstreamProfile = "openai" +) + +type httpUpstreamProfileContextKey struct{} + +// WithHTTPUpstreamProfile 在请求上下文中注入上游协议策略分类。 +func WithHTTPUpstreamProfile(ctx context.Context, profile HTTPUpstreamProfile) context.Context { + if ctx == nil { + ctx = context.Background() + } + if profile == HTTPUpstreamProfileDefault { + return ctx + } + return context.WithValue(ctx, httpUpstreamProfileContextKey{}, profile) +} + +// HTTPUpstreamProfileFromContext 从请求上下文中解析上游协议策略分类。 +func HTTPUpstreamProfileFromContext(ctx context.Context) HTTPUpstreamProfile { + if ctx == nil { + return HTTPUpstreamProfileDefault + } + profile, ok := ctx.Value(httpUpstreamProfileContextKey{}).(HTTPUpstreamProfile) + if !ok { + return HTTPUpstreamProfileDefault + } + switch profile { + case HTTPUpstreamProfileOpenAI: + return profile + default: + return HTTPUpstreamProfileDefault + } +} diff --git a/backend/internal/service/http_upstream_profile_test.go b/backend/internal/service/http_upstream_profile_test.go new file mode 100644 index 000000000..0ed14d93d --- /dev/null +++ b/backend/internal/service/http_upstream_profile_test.go @@ -0,0 +1,33 @@ +package service + +import ( + "context" + "testing" +) + +func TestWithHTTPUpstreamProfile_DefaultKeepsContext(t *testing.T) { + ctx := context.Background() + got := WithHTTPUpstreamProfile(ctx, HTTPUpstreamProfileDefault) + if got != ctx { + t.Fatalf("expected default profile to keep original context") + } +} + +func TestWithHTTPUpstreamProfile_TODOContextSetsProfile(t *testing.T) { + ctx := WithHTTPUpstreamProfile(context.TODO(), HTTPUpstreamProfileOpenAI) + if ctx == nil { + t.Fatalf("expected non-nil context") + } + if profile := HTTPUpstreamProfileFromContext(ctx); profile != HTTPUpstreamProfileOpenAI { + t.Fatalf("expected profile %q, got %q", HTTPUpstreamProfileOpenAI, profile) + } +} + +func TestHTTPUpstreamProfileFromContext_UnknownValueFallsBackDefault(t *testing.T) { + type badKey struct{} + ctx := context.WithValue(context.Background(), httpUpstreamProfileContextKey{}, HTTPUpstreamProfile("unknown")) + ctx = context.WithValue(ctx, badKey{}, "x") + if profile := HTTPUpstreamProfileFromContext(ctx); profile != HTTPUpstreamProfileDefault { + t.Fatalf("expected default profile, got %q", profile) + } +} diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index f3130c91c..dc59010d7 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -46,7 +46,6 @@ type Fingerprint struct { StainlessArch string StainlessRuntime string StainlessRuntimeVersion string - UpdatedAt int64 `json:",omitempty"` // Unix timestamp,用于判断是否需要续期TTL } // IdentityCache defines cache operations for identity service @@ -79,26 +78,14 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID // 尝试从缓存获取指纹 cached, err := s.cache.GetFingerprint(ctx, accountID) if err == nil && cached != nil { - needWrite := false - // 检查客户端的user-agent是否是更新版本 clientUA := headers.Get("User-Agent") if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) { - // 版本升级:merge 语义 — 仅更新请求中实际携带的字段,保留缓存值 - // 避免缺失的头被硬编码默认值覆盖(如新 CLI 版本 + 旧 SDK 默认值的不一致) - mergeHeadersIntoFingerprint(cached, headers) - needWrite = true - logger.LegacyPrintf("service.identity", "Updated fingerprint for account %d: %s (merge update)", accountID, clientUA) - } else if time.Since(time.Unix(cached.UpdatedAt, 0)) > 24*time.Hour { - // 距上次写入超过24小时,续期TTL - needWrite = true - } - - if needWrite { - cached.UpdatedAt = time.Now().Unix() - if err := s.cache.SetFingerprint(ctx, accountID, cached); err != nil { - logger.LegacyPrintf("service.identity", "Warning: failed to refresh fingerprint for account %d: %v", accountID, err) - } + // 更新user-agent + cached.UserAgent = clientUA + // 保存更新后的指纹 + _ = s.cache.SetFingerprint(ctx, accountID, cached) + logger.LegacyPrintf("service.identity", "Updated fingerprint user-agent for account %d: %s", accountID, clientUA) } return cached, nil } @@ -108,9 +95,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID // 生成随机ClientID fp.ClientID = generateClientID() - fp.UpdatedAt = time.Now().Unix() - // 保存到缓存(7天TTL,每24小时自动续期) + // 保存到缓存(永不过期) if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil { logger.LegacyPrintf("service.identity", "Warning: failed to cache fingerprint for account %d: %v", accountID, err) } @@ -141,31 +127,6 @@ func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fin return fp } -// mergeHeadersIntoFingerprint 将请求头中实际存在的字段合并到现有指纹中(用于版本升级场景) -// 关键语义:请求中有的字段 → 用新值覆盖;缺失的头 → 保留缓存中的已有值 -// 与 createFingerprintFromHeaders 的区别:后者用于首次创建,缺失头回退到 defaultFingerprint; -// 本函数用于升级更新,缺失头保留缓存值,避免将已知的真实值退化为硬编码默认值 -func mergeHeadersIntoFingerprint(fp *Fingerprint, headers http.Header) { - // User-Agent:版本升级的触发条件,一定存在 - if ua := headers.Get("User-Agent"); ua != "" { - fp.UserAgent = ua - } - // X-Stainless-* 头:仅在请求中实际携带时才更新,否则保留缓存值 - mergeHeader(headers, "X-Stainless-Lang", &fp.StainlessLang) - mergeHeader(headers, "X-Stainless-Package-Version", &fp.StainlessPackageVersion) - mergeHeader(headers, "X-Stainless-OS", &fp.StainlessOS) - mergeHeader(headers, "X-Stainless-Arch", &fp.StainlessArch) - mergeHeader(headers, "X-Stainless-Runtime", &fp.StainlessRuntime) - mergeHeader(headers, "X-Stainless-Runtime-Version", &fp.StainlessRuntimeVersion) -} - -// mergeHeader 如果请求头中存在该字段则更新目标值,否则保留原值 -func mergeHeader(headers http.Header, key string, target *string) { - if v := headers.Get(key); v != "" { - *target = v - } -} - // getHeaderOrDefault 获取header值,如果不存在则返回默认值 func getHeaderOrDefault(headers http.Header, key, defaultValue string) string { if v := headers.Get(key); v != "" { @@ -410,25 +371,8 @@ func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) { return major, minor, patch, true } -// extractProduct 提取 User-Agent 中 "/" 前的产品名 -// 例如:claude-cli/2.1.22 (external, cli) -> "claude-cli" -func extractProduct(ua string) string { - if idx := strings.Index(ua, "/"); idx > 0 { - return strings.ToLower(ua[:idx]) - } - return "" -} - // isNewerVersion 比较版本号,判断newUA是否比cachedUA更新 -// 要求产品名一致(防止浏览器 UA 如 Mozilla/5.0 误判为更新版本) func isNewerVersion(newUA, cachedUA string) bool { - // 校验产品名一致性 - newProduct := extractProduct(newUA) - cachedProduct := extractProduct(cachedUA) - if newProduct == "" || cachedProduct == "" || newProduct != cachedProduct { - return false - } - newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA) cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA) diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index 99013ce55..785d66b88 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -43,34 +43,42 @@ type OpenAIAccountScheduleDecision struct { } type OpenAIAccountSchedulerMetricsSnapshot struct { - SelectTotal int64 - StickyPreviousHitTotal int64 - StickySessionHitTotal int64 - LoadBalanceSelectTotal int64 - AccountSwitchTotal int64 - SchedulerLatencyMsTotal int64 - SchedulerLatencyMsAvg float64 - StickyHitRatio float64 - AccountSwitchRate float64 - LoadSkewAvg float64 - RuntimeStatsAccountCount int + SelectTotal int64 + StickyPreviousHitTotal int64 + StickySessionHitTotal int64 + LoadBalanceSelectTotal int64 + AccountSwitchTotal int64 + SchedulerLatencyMsTotal int64 + SchedulerLatencyMsAvg float64 + StickyHitRatio float64 + AccountSwitchRate float64 + LoadSkewAvg float64 + RuntimeStatsAccountCount int + CircuitBreakerOpenTotal int64 + CircuitBreakerRecoverTotal int64 + StickyReleaseErrorTotal int64 + StickyReleaseCircuitOpenTotal int64 } type OpenAIAccountScheduler interface { Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) - ReportResult(accountID int64, success bool, firstTokenMs *int) + ReportResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) ReportSwitch() SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot } type openAIAccountSchedulerMetrics struct { - selectTotal atomic.Int64 - stickyPreviousHitTotal atomic.Int64 - stickySessionHitTotal atomic.Int64 - loadBalanceSelectTotal atomic.Int64 - accountSwitchTotal atomic.Int64 - latencyMsTotal atomic.Int64 - loadSkewMilliTotal atomic.Int64 + selectTotal atomic.Int64 + stickyPreviousHitTotal atomic.Int64 + stickySessionHitTotal atomic.Int64 + loadBalanceSelectTotal atomic.Int64 + accountSwitchTotal atomic.Int64 + latencyMsTotal atomic.Int64 + loadSkewMilliTotal atomic.Int64 + circuitBreakerOpenTotal atomic.Int64 + circuitBreakerRecoverTotal atomic.Int64 + stickyReleaseErrorTotal atomic.Int64 + stickyReleaseCircuitOpenTotal atomic.Int64 } func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) { @@ -99,19 +107,439 @@ func (m *openAIAccountSchedulerMetrics) recordSwitch() { } type openAIAccountRuntimeStats struct { - accounts sync.Map - accountCount atomic.Int64 + accounts sync.Map + circuitBreakers sync.Map // accountID → *accountCircuitBreaker + accountCount atomic.Int64 + cleanupCounter atomic.Int64 // report call counter for periodic cleanup +} + +// --------------------------------------------------------------------------- +// Account-level Circuit Breaker (three-state: CLOSED → OPEN → HALF_OPEN) +// --------------------------------------------------------------------------- + +const ( + circuitBreakerStateClosed int32 = 0 + circuitBreakerStateOpen int32 = 1 + circuitBreakerStateHalfOpen int32 = 2 + + // Defaults (used when config values are zero/unset) + defaultCircuitBreakerFailThreshold = 5 + defaultCircuitBreakerCooldownSec = 30 + defaultCircuitBreakerHalfOpenMax = 2 +) + +type accountCircuitBreaker struct { + state atomic.Int32 // circuitBreakerState* + consecutiveFails atomic.Int32 + lastFailureNano atomic.Int64 // time.Now().UnixNano() + halfOpenInFlight atomic.Int32 // current in-flight probes (decremented by release) + halfOpenAdmitted atomic.Int32 // total probes admitted this half-open cycle (never decremented by release) + halfOpenSuccess atomic.Int32 +} + +// allow returns true if the circuit breaker allows a request to pass through. +func (cb *accountCircuitBreaker) allow(cooldown time.Duration, halfOpenMax int) bool { + switch cb.state.Load() { + case circuitBreakerStateClosed: + return true + case circuitBreakerStateOpen: + lastFail := time.Unix(0, cb.lastFailureNano.Load()) + if time.Since(lastFail) <= cooldown { + return false + } + // Cooldown elapsed — attempt transition to HALF_OPEN. + // Reset counters before CAS to avoid a window where another goroutine + // sees HALF_OPEN but stale counter values. + cb.halfOpenInFlight.Store(0) + cb.halfOpenAdmitted.Store(0) + cb.halfOpenSuccess.Store(0) + cb.state.CompareAndSwap(circuitBreakerStateOpen, circuitBreakerStateHalfOpen) + // Either we transitioned or another goroutine did; fall through to + // HALF_OPEN gate below. + return cb.allowHalfOpen(halfOpenMax) + case circuitBreakerStateHalfOpen: + return cb.allowHalfOpen(halfOpenMax) + default: + return true + } +} + +func (cb *accountCircuitBreaker) isHalfOpen() bool { + if cb == nil { + return false + } + return cb.state.Load() == circuitBreakerStateHalfOpen +} + +// releaseHalfOpenPermit releases one HALF_OPEN probe permit when a candidate +// passed filtering but was not actually selected to execute a request. +func (cb *accountCircuitBreaker) releaseHalfOpenPermit() { + if cb == nil || cb.state.Load() != circuitBreakerStateHalfOpen { + return + } + for { + cur := cb.halfOpenInFlight.Load() + if cur <= 0 { + return + } + if cb.halfOpenInFlight.CompareAndSwap(cur, cur-1) { + return + } + } +} + +func (cb *accountCircuitBreaker) allowHalfOpen(halfOpenMax int) bool { + for { + cur := cb.halfOpenInFlight.Load() + if int(cur) >= halfOpenMax { + return false + } + if cb.halfOpenInFlight.CompareAndSwap(cur, cur+1) { + cb.halfOpenAdmitted.Add(1) + return true + } + } +} + +// recordSuccess is called when a request succeeds. +func (cb *accountCircuitBreaker) recordSuccess() { + cb.consecutiveFails.Store(0) + if cb.state.Load() == circuitBreakerStateHalfOpen { + newSucc := cb.halfOpenSuccess.Add(1) + // Compare against halfOpenAdmitted (total probes ever admitted in + // this half-open cycle). Unlike halfOpenInFlight, this is never + // decremented by releaseHalfOpenPermit, so the recovery threshold + // remains stable regardless of candidate filtering outcomes. + admitted := cb.halfOpenAdmitted.Load() + if newSucc >= admitted && admitted > 0 { + if cb.state.CompareAndSwap(circuitBreakerStateHalfOpen, circuitBreakerStateClosed) { + cb.halfOpenInFlight.Store(0) + cb.halfOpenAdmitted.Store(0) + cb.halfOpenSuccess.Store(0) + } + } + } +} + +// recordFailure is called when a request fails. +func (cb *accountCircuitBreaker) recordFailure(threshold int) { + cb.lastFailureNano.Store(time.Now().UnixNano()) + newFails := cb.consecutiveFails.Add(1) + + switch cb.state.Load() { + case circuitBreakerStateClosed: + if int(newFails) >= threshold { + cb.state.CompareAndSwap(circuitBreakerStateClosed, circuitBreakerStateOpen) + } + case circuitBreakerStateHalfOpen: + if cb.state.CompareAndSwap(circuitBreakerStateHalfOpen, circuitBreakerStateOpen) { + cb.halfOpenInFlight.Store(0) + cb.halfOpenAdmitted.Store(0) + cb.halfOpenSuccess.Store(0) + } + } +} + +// isOpen returns true if the circuit breaker is currently in OPEN state. +func (cb *accountCircuitBreaker) isOpen() bool { + return cb.state.Load() == circuitBreakerStateOpen +} + +// stateString returns a human-readable state name. +func (cb *accountCircuitBreaker) stateString() string { + switch cb.state.Load() { + case circuitBreakerStateClosed: + return "CLOSED" + case circuitBreakerStateOpen: + return "OPEN" + case circuitBreakerStateHalfOpen: + return "HALF_OPEN" + default: + return "UNKNOWN" + } +} + +// loadCircuitBreaker returns the CB for accountID if it exists, or nil. +// Use this on hot paths (e.g. candidate filtering) to avoid allocating CB +// objects for accounts that have never received a report. +func (s *openAIAccountRuntimeStats) loadCircuitBreaker(accountID int64) *accountCircuitBreaker { + if val, ok := s.circuitBreakers.Load(accountID); ok { + if cb, _ := val.(*accountCircuitBreaker); cb != nil { + return cb + } + } + return nil +} + +func (s *openAIAccountRuntimeStats) getCircuitBreaker(accountID int64) *accountCircuitBreaker { + if val, ok := s.circuitBreakers.Load(accountID); ok { + if cb, _ := val.(*accountCircuitBreaker); cb != nil { + return cb + } + } + cb := &accountCircuitBreaker{} + actual, _ := s.circuitBreakers.LoadOrStore(accountID, cb) + if existing, _ := actual.(*accountCircuitBreaker); existing != nil { + return existing + } + return cb +} + +func (s *openAIAccountRuntimeStats) isCircuitOpen(accountID int64) bool { + val, ok := s.circuitBreakers.Load(accountID) + if !ok { + return false + } + cb, _ := val.(*accountCircuitBreaker) + if cb == nil { + return false + } + return cb.isOpen() +} + +// --------------------------------------------------------------------------- +// Dual-EWMA: fast (α=0.5) reacts quickly to degradation; slow (α=0.1) +// stabilises over many samples. The pessimistic envelope max(fast,slow) means +// we *sense* errors fast but only *confirm* recovery slowly. +// --------------------------------------------------------------------------- + +const ( + dualEWMAAlphaFast = 0.5 + dualEWMAAlphaSlow = 0.1 + + // Per-model TTFT defaults + defaultPerModelTTFTMaxModels = 100 + defaultPerModelTTFTTTL = 30 * time.Minute +) + +// dualEWMA tracks a [0,1] signal (e.g. error rate) with two speeds. +type dualEWMA struct { + fastBits atomic.Uint64 // α = dualEWMAAlphaFast, reacts in ~3 requests + slowBits atomic.Uint64 // α = dualEWMAAlphaSlow, stabilises over ~50 requests + sampleCount atomic.Int64 // total samples received; used for cold-start guard +} + +// dualEWMAMinSamples is the minimum number of samples required before the +// EWMA error rate is considered reliable for decision-making (e.g. sticky +// release). This prevents a single failure on a fresh account from yielding +// an artificially high error rate. +const dualEWMAMinSamples = 5 + +func (d *dualEWMA) update(sample float64) { + updateEWMAAtomic(&d.fastBits, sample, dualEWMAAlphaFast) + updateEWMAAtomic(&d.slowBits, sample, dualEWMAAlphaSlow) + d.sampleCount.Add(1) +} + +// isWarmedUp returns true when enough samples have been collected for the +// EWMA value to be meaningful. +func (d *dualEWMA) isWarmedUp() bool { + return d.sampleCount.Load() >= dualEWMAMinSamples +} + +// value returns the pessimistic envelope: max(fast, slow). +func (d *dualEWMA) value() float64 { + fast := math.Float64frombits(d.fastBits.Load()) + slow := math.Float64frombits(d.slowBits.Load()) + if fast >= slow { + return fast + } + return slow +} + +func (d *dualEWMA) fastValue() float64 { + return math.Float64frombits(d.fastBits.Load()) +} + +func (d *dualEWMA) slowValue() float64 { + return math.Float64frombits(d.slowBits.Load()) +} + +// dualEWMATTFT is like dualEWMA but handles the NaN-initialised first-sample +// case required by TTFT tracking. +type dualEWMATTFT struct { + fastBits atomic.Uint64 // α = dualEWMAAlphaFast + slowBits atomic.Uint64 // α = dualEWMAAlphaSlow +} + +// initNaN stores NaN into both channels. Called once at allocation time. +func (d *dualEWMATTFT) initNaN() { + nanBits := math.Float64bits(math.NaN()) + d.fastBits.Store(nanBits) + d.slowBits.Store(nanBits) +} + +func (d *dualEWMATTFT) update(sample float64) { + sampleBits := math.Float64bits(sample) + // fast channel + for { + oldBits := d.fastBits.Load() + oldValue := math.Float64frombits(oldBits) + if math.IsNaN(oldValue) { + if d.fastBits.CompareAndSwap(oldBits, sampleBits) { + break + } + continue + } + newValue := dualEWMAAlphaFast*sample + (1-dualEWMAAlphaFast)*oldValue + if d.fastBits.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + break + } + } + // slow channel + for { + oldBits := d.slowBits.Load() + oldValue := math.Float64frombits(oldBits) + if math.IsNaN(oldValue) { + if d.slowBits.CompareAndSwap(oldBits, sampleBits) { + break + } + continue + } + newValue := dualEWMAAlphaSlow*sample + (1-dualEWMAAlphaSlow)*oldValue + if d.slowBits.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + break + } + } +} + +// value returns (pessimistic TTFT, hasTTFT). If both channels are still NaN +// the caller gets (0, false). +func (d *dualEWMATTFT) value() (float64, bool) { + fast := math.Float64frombits(d.fastBits.Load()) + slow := math.Float64frombits(d.slowBits.Load()) + fastOK := !math.IsNaN(fast) + slowOK := !math.IsNaN(slow) + switch { + case fastOK && slowOK: + if fast >= slow { + return fast, true + } + return slow, true + case fastOK: + return fast, true + case slowOK: + return slow, true + default: + return 0, false + } +} + +func (d *dualEWMATTFT) fastValue() float64 { + return math.Float64frombits(d.fastBits.Load()) +} + +func (d *dualEWMATTFT) slowValue() float64 { + return math.Float64frombits(d.slowBits.Load()) +} + +// --------------------------------------------------------------------------- +// Load Trend Tracker (ring-buffer linear regression) +// --------------------------------------------------------------------------- + +const loadTrendRingSize = 10 + +// loadTrendTracker maintains a fixed-size ring buffer of (timestamp, loadRate) +// samples and computes the ordinary-least-squares slope to predict whether +// an account's load is rising, falling, or stable. +type loadTrendTracker struct { + mu sync.Mutex + samples [loadTrendRingSize]float64 // ring buffer of loadRate values + times [loadTrendRingSize]int64 // timestamps in UnixNano + head int // next write position + count int // number of valid samples (0..loadTrendRingSize) +} + +// record pushes a loadRate sample with the current wall-clock timestamp. +func (t *loadTrendTracker) record(loadRate float64) { + t.recordAt(loadRate, time.Now().UnixNano()) +} + +// recordAt pushes a loadRate sample with an explicit timestamp (for testing). +func (t *loadTrendTracker) recordAt(loadRate float64, tsNano int64) { + t.mu.Lock() + t.samples[t.head] = loadRate + t.times[t.head] = tsNano + t.head = (t.head + 1) % loadTrendRingSize + if t.count < loadTrendRingSize { + t.count++ + } + t.mu.Unlock() +} + +// slope computes the simple linear regression slope of loadRate over time. +// +// slope = (N*Sigma(xi*yi) - Sigma(xi)*Sigma(yi)) / (N*Sigma(xi^2) - (Sigma(xi))^2) +// +// where xi = seconds elapsed since the oldest sample, yi = loadRate. +// Returns 0 if fewer than 2 samples are available or if all timestamps are +// identical (degenerate case). +func (t *loadTrendTracker) slope() float64 { + t.mu.Lock() + n := t.count + if n < 2 { + t.mu.Unlock() + return 0 + } + + // Copy data under lock; computation happens outside. + var localSamples [loadTrendRingSize]float64 + var localTimes [loadTrendRingSize]int64 + copy(localSamples[:], t.samples[:]) + copy(localTimes[:], t.times[:]) + head := t.head + t.mu.Unlock() + + // Identify oldest entry index. + oldest := head // head points to the next write pos; for a full ring it's the oldest entry. + if n < loadTrendRingSize { + oldest = 0 + } + t0 := localTimes[oldest] + + var sumX, sumY, sumXY, sumX2 float64 + for i := 0; i < n; i++ { + idx := (oldest + i) % loadTrendRingSize + xi := float64(localTimes[idx]-t0) / 1e9 // relative seconds + yi := localSamples[idx] + sumX += xi + sumY += yi + sumXY += xi * yi + sumX2 += xi * xi + } + + nf := float64(n) + denom := nf*sumX2 - sumX*sumX + if denom == 0 { + // All timestamps identical (or single sample) — no meaningful slope. + return 0 + } + return (nf*sumXY - sumX*sumY) / denom } type openAIAccountRuntimeStat struct { - errorRateEWMABits atomic.Uint64 - ttftEWMABits atomic.Uint64 + errorRate dualEWMA + ttft dualEWMATTFT + modelTTFT sync.Map // key = model name (string), value = *dualEWMATTFT + modelTTFTLastUpdate sync.Map // key = model name (string), value = int64 (unix nano) + loadTrend loadTrendTracker + lastReportNano atomic.Int64 // last report timestamp for GC } func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats { return &openAIAccountRuntimeStats{} } +// loadExisting returns the stat for accountID if it exists, or nil. +// Unlike loadOrCreate, this never allocates a new stat. +func (s *openAIAccountRuntimeStats) loadExisting(accountID int64) *openAIAccountRuntimeStat { + if value, ok := s.accounts.Load(accountID); ok { + stat, _ := value.(*openAIAccountRuntimeStat) + return stat + } + return nil +} + func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat { if value, ok := s.accounts.Load(accountID); ok { stat, _ := value.(*openAIAccountRuntimeStat) @@ -121,7 +549,7 @@ func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccount } stat := &openAIAccountRuntimeStat{} - stat.ttftEWMABits.Store(math.Float64bits(math.NaN())) + stat.ttft.initNaN() actual, loaded := s.accounts.LoadOrStore(accountID, stat) if !loaded { s.accountCount.Add(1) @@ -134,6 +562,103 @@ func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccount return stat } +// getOrCreateModelTTFT returns the per-model TTFT tracker, creating it if +// it does not exist yet. Uses the LoadOrStore pattern for thread safety. +func (stat *openAIAccountRuntimeStat) getOrCreateModelTTFT(model string) *dualEWMATTFT { + if val, ok := stat.modelTTFT.Load(model); ok { + if d, _ := val.(*dualEWMATTFT); d != nil { + return d + } + } + d := &dualEWMATTFT{} + d.initNaN() + actual, _ := stat.modelTTFT.LoadOrStore(model, d) + if existing, _ := actual.(*dualEWMATTFT); existing != nil { + return existing + } + return d +} + +// reportModelTTFT updates both the per-model and global TTFT trackers. +func (stat *openAIAccountRuntimeStat) reportModelTTFT(model string, sampleMs float64) { + if model == "" || sampleMs <= 0 { + return + } + d := stat.getOrCreateModelTTFT(model) + d.update(sampleMs) + stat.modelTTFTLastUpdate.Store(model, time.Now().UnixNano()) + // Also update the global TTFT so that callers without a model still + // see a reasonable aggregate. + stat.ttft.update(sampleMs) +} + +// modelTTFTValue returns the per-model TTFT value if a tracker exists and has +// received at least one sample. Otherwise returns (0, false). +func (stat *openAIAccountRuntimeStat) modelTTFTValue(model string) (float64, bool) { + if model == "" { + return 0, false + } + val, ok := stat.modelTTFT.Load(model) + if !ok { + return 0, false + } + d, _ := val.(*dualEWMATTFT) + if d == nil { + return 0, false + } + return d.value() +} + +// cleanupStaleTTFT removes per-model TTFT entries that have not been updated +// within ttl, and enforces a hard cap of maxModels entries. Oldest entries +// are evicted first when the cap is exceeded. +func (stat *openAIAccountRuntimeStat) cleanupStaleTTFT(ttl time.Duration, maxModels int) { + now := time.Now().UnixNano() + cutoff := now - int64(ttl) + + // First pass: delete stale entries. + stat.modelTTFTLastUpdate.Range(func(key, value any) bool { + model, _ := key.(string) + ts, _ := value.(int64) + if ts < cutoff { + stat.modelTTFT.Delete(model) + stat.modelTTFTLastUpdate.Delete(model) + } + return true + }) + + if maxModels <= 0 { + return + } + + // Second pass: count remaining entries and evict oldest if over limit. + type modelEntry struct { + model string + ts int64 + } + var entries []modelEntry + stat.modelTTFTLastUpdate.Range(func(key, value any) bool { + model, _ := key.(string) + ts, _ := value.(int64) + entries = append(entries, modelEntry{model: model, ts: ts}) + return true + }) + + if len(entries) <= maxModels { + return + } + + // Sort by timestamp ascending (oldest first) and evict surplus. + sort.Slice(entries, func(i, j int) bool { + return entries[i].ts < entries[j].ts + }) + evictCount := len(entries) - maxModels + for i := 0; i < evictCount; i++ { + stat.modelTTFT.Delete(entries[i].model) + stat.modelTTFTLastUpdate.Delete(entries[i].model) + } +} + func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) { for { oldBits := target.Load() @@ -145,40 +670,77 @@ func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) { } } -func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) { +func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) { + s.reportWithOptions( + accountID, + success, + firstTokenMs, + defaultCircuitBreakerFailThreshold, + true, + model, + ttftMs, + true, + defaultPerModelTTFTMaxModels, + ) +} + +func (s *openAIAccountRuntimeStats) reportWithOptions( + accountID int64, + success bool, + firstTokenMs *int, + cbThreshold int, + updateCircuitBreaker bool, + model string, + ttftMs float64, + perModelTTFTEnabled bool, + perModelTTFTMaxModels int, +) { if s == nil || accountID <= 0 { return } - const alpha = 0.2 stat := s.loadOrCreate(accountID) + stat.lastReportNano.Store(time.Now().UnixNano()) errorSample := 1.0 if success { errorSample = 0.0 } - updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha) + stat.errorRate.update(errorSample) - if firstTokenMs != nil && *firstTokenMs > 0 { - ttft := float64(*firstTokenMs) - ttftBits := math.Float64bits(ttft) - for { - oldBits := stat.ttftEWMABits.Load() - oldValue := math.Float64frombits(oldBits) - if math.IsNaN(oldValue) { - if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) { - break - } - continue - } - newValue := alpha*ttft + (1-alpha)*oldValue - if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) { - break - } + // Per-model TTFT tracking: reportModelTTFT updates both per-model and + // global TTFT, so skip the separate global update to avoid double-counting. + if perModelTTFTEnabled && model != "" && ttftMs > 0 { + stat.reportModelTTFT(model, ttftMs) + } else if firstTokenMs != nil && *firstTokenMs > 0 { + stat.ttft.update(float64(*firstTokenMs)) + } + + // Update circuit breaker state only when feature is enabled. + if updateCircuitBreaker { + cb := s.getCircuitBreaker(accountID) + if success { + cb.recordSuccess() + } else { + cb.recordFailure(cbThreshold) + } + } + + // Periodic cleanup: every 100 reports. + cnt := s.cleanupCounter.Add(1) + if cnt%100 == 0 { + maxModels := defaultPerModelTTFTMaxModels + if perModelTTFTMaxModels > 0 { + maxModels = perModelTTFTMaxModels } + stat.cleanupStaleTTFT(defaultPerModelTTFTTTL, maxModels) + } + // GC inactive accounts and orphaned circuit breakers: every 1000 reports. + if cnt%1000 == 0 { + s.gcInactiveAccounts(time.Hour) } } -func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) { +func (s *openAIAccountRuntimeStats) snapshot(accountID int64, model ...string) (errorRate float64, ttft float64, hasTTFT bool) { if s == nil || accountID <= 0 { return 0, 0, false } @@ -190,12 +752,17 @@ func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64 if stat == nil { return 0, 0, false } - errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load())) - ttftValue := math.Float64frombits(stat.ttftEWMABits.Load()) - if math.IsNaN(ttftValue) { - return errorRate, 0, false + errorRate = clamp01(stat.errorRate.value()) + + // Try per-model TTFT first; fallback to global. + if len(model) > 0 && model[0] != "" { + if mTTFT, mOK := stat.modelTTFTValue(model[0]); mOK { + return errorRate, mTTFT, true + } } - return errorRate, ttftValue, true + + ttft, hasTTFT = stat.ttft.value() + return errorRate, ttft, hasTTFT } func (s *openAIAccountRuntimeStats) size() int { @@ -205,6 +772,25 @@ func (s *openAIAccountRuntimeStats) size() int { return int(s.accountCount.Load()) } +// gcInactiveAccounts removes account stats and circuit breakers that have not +// received any report for longer than maxIdle. This prevents unbounded growth +// of the sync.Maps when accounts are created and then deleted/deactivated. +func (s *openAIAccountRuntimeStats) gcInactiveAccounts(maxIdle time.Duration) { + if s == nil { + return + } + cutoff := time.Now().UnixNano() - int64(maxIdle) + s.accounts.Range(func(key, value any) bool { + stat, _ := value.(*openAIAccountRuntimeStat) + if stat == nil || stat.lastReportNano.Load() < cutoff { + s.accounts.Delete(key) + s.circuitBreakers.Delete(key) + s.accountCount.Add(-1) + } + return true + }) +} + type defaultOpenAIAccountScheduler struct { service *OpenAIGatewayService metrics openAIAccountSchedulerMetrics @@ -331,6 +917,12 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash( return nil, nil } + // Conditional sticky: release binding if account is unhealthy or overloaded. + if s.shouldReleaseStickySession(accountID) { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil // Fall through to load balance + } + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if acquireErr == nil && result.Acquired { _ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL()) @@ -557,6 +1149,175 @@ func buildOpenAIWeightedSelectionOrder( return order } +// selectP2COpenAICandidates selects candidates using Power-of-Two-Choices: +// randomly pick 2 candidates, return the one with the higher score. +// Repeat to build a full selection order for fallback. +func selectP2COpenAICandidates( + candidates []openAIAccountCandidateScore, + req OpenAIAccountScheduleRequest, +) []openAIAccountCandidateScore { + if len(candidates) <= 1 { + return append([]openAIAccountCandidateScore(nil), candidates...) + } + + rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req)) + pool := append([]openAIAccountCandidateScore(nil), candidates...) + order := make([]openAIAccountCandidateScore, 0, len(pool)) + + for len(pool) > 1 { + n := uint64(len(pool)) + // Pick first random index. + idx1 := int(rng.nextUint64() % n) + // Pick second random index, distinct from the first. + idx2 := int(rng.nextUint64() % (n - 1)) + if idx2 >= idx1 { + idx2++ + } + + // Compare: take the candidate with the higher score. + winner := idx1 + if isOpenAIAccountCandidateBetter(pool[idx2], pool[idx1]) { + winner = idx2 + } + + order = append(order, pool[winner]) + // Remove winner from pool (swap with last element for O(1) removal). + pool[winner] = pool[len(pool)-1] + pool = pool[:len(pool)-1] + } + // Append the last remaining candidate. + order = append(order, pool[0]) + return order +} + +// --------------------------------------------------------------------------- +// Softmax Temperature Sampling +// --------------------------------------------------------------------------- + +const defaultSoftmaxTemperature = 0.3 + +type softmaxConfig struct { + enabled bool + temperature float64 +} + +// softmaxConfigRead reads softmax scheduler config with fallback defaults. +func (s *defaultOpenAIAccountScheduler) softmaxConfigRead() softmaxConfig { + if s == nil || s.service == nil || s.service.cfg == nil { + return softmaxConfig{} + } + wsCfg := s.service.cfg.Gateway.OpenAIWS + temp := wsCfg.SchedulerSoftmaxTemperature + if temp <= 0 { + temp = defaultSoftmaxTemperature + } + return softmaxConfig{ + enabled: wsCfg.SchedulerSoftmaxEnabled, + temperature: temp, + } +} + +// selectSoftmaxOpenAICandidates applies softmax temperature sampling to select +// one candidate probabilistically, then returns the full list with the selected +// candidate first and the rest sorted by descending probability. +// +// Algorithm (numerically stable): +// +// maxScore = max(score[i]) +// weights[i] = exp((score[i] - maxScore) / temperature) +// probability[i] = weights[i] / sum(weights) +// +// A higher temperature yields more uniform selection (exploration); a lower +// temperature concentrates probability on the highest-scored candidates +// (exploitation). +func selectSoftmaxOpenAICandidates( + candidates []openAIAccountCandidateScore, + temperature float64, + rng *openAISelectionRNG, +) []openAIAccountCandidateScore { + if len(candidates) == 0 { + return nil + } + if len(candidates) == 1 { + return append([]openAIAccountCandidateScore(nil), candidates...) + } + if temperature <= 0 { + temperature = defaultSoftmaxTemperature + } + + // Step 1: find max score for numerical stability. + maxScore := candidates[0].score + for i := 1; i < len(candidates); i++ { + if candidates[i].score > maxScore { + maxScore = candidates[i].score + } + } + + // Step 2: compute softmax weights. + type indexedProb struct { + index int + prob float64 + } + probs := make([]indexedProb, len(candidates)) + sumWeights := 0.0 + for i := range candidates { + w := math.Exp((candidates[i].score - maxScore) / temperature) + // Guard against NaN/Inf from degenerate inputs. + if math.IsNaN(w) || math.IsInf(w, 0) { + w = 0 + } + probs[i] = indexedProb{index: i, prob: w} + sumWeights += w + } + + // Normalise to probabilities. If sumWeights is zero (all weights collapsed + // to zero, which can happen with extreme negative scores), fall back to + // uniform distribution. + if sumWeights > 0 { + for i := range probs { + probs[i].prob /= sumWeights + } + } else { + uniform := 1.0 / float64(len(probs)) + for i := range probs { + probs[i].prob = uniform + } + } + + // Step 3: sample ONE candidate via CDF. + r := rng.nextFloat64() + selectedIdx := probs[len(probs)-1].index // default to last if rounding issues + cumulative := 0.0 + for _, ip := range probs { + cumulative += ip.prob + if cumulative >= r { + selectedIdx = ip.index + break + } + } + + // Step 4: build result — selected candidate first, rest sorted by + // descending probability. + result := make([]openAIAccountCandidateScore, 0, len(candidates)) + result = append(result, candidates[selectedIdx]) + + // Sort remaining by probability descending for fallback order. + remaining := make([]indexedProb, 0, len(probs)-1) + for _, ip := range probs { + if ip.index != selectedIdx { + remaining = append(remaining, ip) + } + } + sort.Slice(remaining, func(i, j int) bool { + return remaining[i].prob > remaining[j].prob + }) + for _, ip := range remaining { + result = append(result, candidates[ip.index]) + } + + return result +} + func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( ctx context.Context, req OpenAIAccountScheduleRequest, @@ -597,6 +1358,44 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( return nil, 0, 0, 0, errors.New("no available OpenAI accounts") } + // Circuit breaker filtering: remove accounts with open CBs, but if that + // would empty the candidate pool, keep all accounts (graceful degradation). + cbEnabled, _, cbCooldown, cbHalfOpenMax := s.schedulerCircuitBreakerConfig() + heldHalfOpenPermits := make(map[int64]*accountCircuitBreaker) + releaseHalfOpenPermit := func(accountID int64) { + cb, ok := heldHalfOpenPermits[accountID] + if !ok || cb == nil { + return + } + cb.releaseHalfOpenPermit() + delete(heldHalfOpenPermits, accountID) + } + defer func() { + for accountID := range heldHalfOpenPermits { + releaseHalfOpenPermit(accountID) + } + }() + if cbEnabled { + healthy := make([]*Account, 0, len(filtered)) + healthyLoadReq := make([]AccountWithConcurrency, 0, len(loadReq)) + for i, account := range filtered { + cb := s.stats.loadCircuitBreaker(account.ID) + if cb == nil || cb.allow(cbCooldown, cbHalfOpenMax) { + healthy = append(healthy, account) + healthyLoadReq = append(healthyLoadReq, loadReq[i]) + if cb.isHalfOpen() { + heldHalfOpenPermits[account.ID] = cb + } + } + } + if len(healthy) > 0 { + filtered = healthy + loadReq = healthyLoadReq + } + // else: all accounts are circuit-open; fall through with the + // original set to avoid returning "no accounts". + } + loadMap := map[int64]*AccountLoadInfo{} if s.service.concurrencyService != nil { if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil { @@ -604,8 +1403,16 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( } } + trendEnabled, trendMaxSlope := s.service.openAIWSSchedulerTrendConfig() + perModelTTFTEnabled, _ := s.schedulerPerModelTTFTConfig() + requestedModelForStats := "" + if perModelTTFTEnabled { + requestedModelForStats = req.RequestedModel + } + minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority maxWaiting := 1 + maxConcurrency := 0 loadRateSum := 0.0 loadRateSumSquares := 0.0 minTTFT, maxTTFT := 0.0, 0.0 @@ -625,7 +1432,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( if loadInfo.WaitingCount > maxWaiting { maxWaiting = loadInfo.WaitingCount } - errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID) + if account.Concurrency > maxConcurrency { + maxConcurrency = account.Concurrency + } + errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID, requestedModelForStats) if hasTTFT && ttft > 0 { if !hasTTFTSample { minTTFT, maxTTFT = ttft, ttft @@ -642,6 +1452,13 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( loadRate := float64(loadInfo.LoadRate) loadRateSum += loadRate loadRateSumSquares += loadRate * loadRate + + // Record current load rate sample for trend tracking. + if trendEnabled { + stat := s.stats.loadOrCreate(account.ID) + stat.loadTrend.record(loadRate) + } + candidates = append(candidates, openAIAccountCandidateScore{ account: account, loadInfo: loadInfo, @@ -659,8 +1476,33 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( if maxPriority > minPriority { priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) } + // Base load factor from percentage utilization. loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) + // Capacity-aware adjustment: accounts with more absolute headroom get a bonus. + if maxConcurrency > 0 && item.account.Concurrency > 0 { + remainingSlots := float64(item.account.Concurrency) * (1 - float64(item.loadInfo.LoadRate)/100.0) + capacityBonus := clamp01(remainingSlots / float64(maxConcurrency)) + // Blend: 70% relative load + 30% capacity bonus + loadFactor = 0.7*loadFactor + 0.3*capacityBonus + } + + // Trend adjustment: penalise accounts whose load is rising, reward those declining. + // trendAdj ranges [0, 1] where 0 = max rising slope, 1 = max falling/flat slope. + // loadFactor is blended: 70% base load + 30% trend influence. + if trendEnabled { + stat := s.stats.loadOrCreate(item.account.ID) + slope := stat.loadTrend.slope() + trendAdj := 1.0 - clamp01(slope/trendMaxSlope) + loadFactor *= (0.7 + 0.3*trendAdj) + } + queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) + // Queue depth relative to account's own capacity for capacity-aware blending. + if item.account.Concurrency > 0 { + relativeQueue := clamp01(float64(item.loadInfo.WaitingCount) / float64(item.account.Concurrency)) + // Blend: 60% cross-account normalized + 40% self-relative + queueFactor = 0.6*queueFactor + 0.4*(1-relativeQueue) + } errorFactor := 1 - clamp01(item.errorRate) ttftFactor := 0.5 if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { @@ -674,23 +1516,40 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( weights.TTFT*ttftFactor } - topK := s.service.openAIWSLBTopK() - if topK > len(candidates) { - topK = len(candidates) - } - if topK <= 0 { - topK = 1 + var selectionOrder []openAIAccountCandidateScore + topK := 0 + rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req)) + smCfg := s.softmaxConfigRead() + p2cEnabled := s.service.openAIWSSchedulerP2CEnabled() + if smCfg.enabled && len(candidates) > 3 { + selectionOrder = selectSoftmaxOpenAICandidates(candidates, smCfg.temperature, &rng) + // topK = 0 signals softmax mode in metrics / decision struct. + } else if p2cEnabled { + selectionOrder = selectP2COpenAICandidates(candidates, req) + // topK = 0 signals P2C mode in metrics / decision struct. + } else { + topK = s.service.openAIWSLBTopK() + if topK > len(candidates) { + topK = len(candidates) + } + if topK <= 0 { + topK = 1 + } + rankedCandidates := selectTopKOpenAICandidates(candidates, topK) + selectionOrder = buildOpenAIWeightedSelectionOrder(rankedCandidates, req) } - rankedCandidates := selectTopKOpenAICandidates(candidates, topK) - selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req) for i := 0; i < len(selectionOrder); i++ { candidate := selectionOrder[i] result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency) if acquireErr != nil { + releaseHalfOpenPermit(candidate.account.ID) return nil, len(candidates), topK, loadSkew, acquireErr } if result != nil && result.Acquired { + // Keep HALF_OPEN permit for the selected account; the outcome will be + // settled by ReportResult(success/failure) after the request finishes. + delete(heldHalfOpenPermits, candidate.account.ID) if req.SessionHash != "" { _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID) } @@ -700,10 +1559,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( ReleaseFunc: result.ReleaseFunc, }, len(candidates), topK, loadSkew, nil } + releaseHalfOpenPermit(candidate.account.ID) } cfg := s.service.schedulingConfig() candidate := selectionOrder[0] + releaseHalfOpenPermit(candidate.account.ID) return &AccountSelectionResult{ Account: candidate.account, WaitPlan: &AccountWaitPlan{ @@ -726,11 +1587,54 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport } -func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { +func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) { if s == nil || s.stats == nil { return } - s.stats.report(accountID, success, firstTokenMs) + perModelTTFTEnabled, perModelTTFTMaxModels := s.schedulerPerModelTTFTConfig() + enabled, threshold, _, _ := s.schedulerCircuitBreakerConfig() + if !enabled { + // Circuit breaker disabled: only update runtime signals (error-rate/TTFT), + // do not mutate circuit breaker state. + s.stats.reportWithOptions( + accountID, + success, + firstTokenMs, + 0, + false, + model, + ttftMs, + perModelTTFTEnabled, + perModelTTFTMaxModels, + ) + return + } + + // Snapshot state before the update for metrics tracking. + cb := s.stats.getCircuitBreaker(accountID) + stateBefore := cb.state.Load() + + s.stats.reportWithOptions( + accountID, + success, + firstTokenMs, + threshold, + true, + model, + ttftMs, + perModelTTFTEnabled, + perModelTTFTMaxModels, + ) + + stateAfter := cb.state.Load() + // CLOSED/HALF_OPEN → OPEN: circuit tripped. + if stateBefore != circuitBreakerStateOpen && stateAfter == circuitBreakerStateOpen { + s.metrics.circuitBreakerOpenTotal.Add(1) + } + // OPEN/HALF_OPEN → CLOSED: circuit recovered. + if stateBefore != circuitBreakerStateClosed && stateAfter == circuitBreakerStateClosed { + s.metrics.circuitBreakerRecoverTotal.Add(1) + } } func (s *defaultOpenAIAccountScheduler) ReportSwitch() { @@ -740,6 +1644,107 @@ func (s *defaultOpenAIAccountScheduler) ReportSwitch() { s.metrics.recordSwitch() } +// schedulerCircuitBreakerConfig reads CB config with fallback defaults. +func (s *defaultOpenAIAccountScheduler) schedulerCircuitBreakerConfig() (enabled bool, threshold int, cooldown time.Duration, halfOpenMax int) { + threshold = defaultCircuitBreakerFailThreshold + cooldown = time.Duration(defaultCircuitBreakerCooldownSec) * time.Second + halfOpenMax = defaultCircuitBreakerHalfOpenMax + + if s == nil || s.service == nil || s.service.cfg == nil { + return false, threshold, cooldown, halfOpenMax + } + wsCfg := s.service.cfg.Gateway.OpenAIWS + enabled = wsCfg.SchedulerCircuitBreakerEnabled + if wsCfg.SchedulerCircuitBreakerFailThreshold > 0 { + threshold = wsCfg.SchedulerCircuitBreakerFailThreshold + } + if wsCfg.SchedulerCircuitBreakerCooldownSec > 0 { + cooldown = time.Duration(wsCfg.SchedulerCircuitBreakerCooldownSec) * time.Second + } + if wsCfg.SchedulerCircuitBreakerHalfOpenMax > 0 { + halfOpenMax = wsCfg.SchedulerCircuitBreakerHalfOpenMax + } + return enabled, threshold, cooldown, halfOpenMax +} + +func (s *defaultOpenAIAccountScheduler) schedulerPerModelTTFTConfig() (enabled bool, maxModels int) { + maxModels = defaultPerModelTTFTMaxModels + if s == nil || s.service == nil || s.service.cfg == nil { + return false, maxModels + } + wsCfg := s.service.cfg.Gateway.OpenAIWS + enabled = wsCfg.SchedulerPerModelTTFTEnabled + if wsCfg.SchedulerPerModelTTFTMaxModels > 0 { + maxModels = wsCfg.SchedulerPerModelTTFTMaxModels + } + return enabled, maxModels +} + +// --------------------------------------------------------------------------- +// Conditional Sticky Session Release +// --------------------------------------------------------------------------- + +const defaultStickyReleaseErrorThreshold = 0.3 + +type stickyReleaseConfig struct { + enabled bool + errorThreshold float64 +} + +// stickyReleaseConfigRead reads conditional sticky release config with defaults. +func (s *defaultOpenAIAccountScheduler) stickyReleaseConfigRead() stickyReleaseConfig { + if s == nil || s.service == nil || s.service.cfg == nil { + return stickyReleaseConfig{} + } + wsCfg := s.service.cfg.Gateway.OpenAIWS + threshold := wsCfg.StickyReleaseErrorThreshold + if threshold <= 0 { + threshold = defaultStickyReleaseErrorThreshold + } + return stickyReleaseConfig{ + enabled: wsCfg.StickyReleaseEnabled, + errorThreshold: threshold, + } +} + +// shouldReleaseStickySession checks whether a sticky binding should be +// released because the account is unhealthy (circuit breaker open) or has a +// high error rate. This runs BEFORE slot acquisition to avoid wasting +// concurrency capacity on degraded accounts. +func (s *defaultOpenAIAccountScheduler) shouldReleaseStickySession(accountID int64) bool { + if s == nil || s.stats == nil || s.service == nil { + return false + } + + cfg := s.stickyReleaseConfigRead() + if !cfg.enabled { + return false + } + + // Check 1: Circuit breaker is open -> immediate release. + // Only check if CB feature is actually enabled, because the default CB + // threshold (5) is very aggressive and may trip unexpectedly. + cbEnabled, _, _, _ := s.schedulerCircuitBreakerConfig() + if cbEnabled && s.stats.isCircuitOpen(accountID) { + s.metrics.stickyReleaseCircuitOpenTotal.Add(1) + return true + } + + // Check 2: Error rate exceeds threshold -> immediate release. + // Guard against cold-start: the EWMA error rate is unreliable when + // fewer than dualEWMAMinSamples have been collected. + stat := s.stats.loadExisting(accountID) + if stat != nil && stat.errorRate.isWarmedUp() { + errorRate, _, _ := s.stats.snapshot(accountID) + if errorRate > cfg.errorThreshold { + s.metrics.stickyReleaseErrorTotal.Add(1) + return true + } + } + + return false +} + func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot { if s == nil { return OpenAIAccountSchedulerMetricsSnapshot{} @@ -753,13 +1758,17 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler loadSkewTotal := s.metrics.loadSkewMilliTotal.Load() snapshot := OpenAIAccountSchedulerMetricsSnapshot{ - SelectTotal: selectTotal, - StickyPreviousHitTotal: prevHit, - StickySessionHitTotal: sessionHit, - LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(), - AccountSwitchTotal: switchTotal, - SchedulerLatencyMsTotal: latencyTotal, - RuntimeStatsAccountCount: s.stats.size(), + SelectTotal: selectTotal, + StickyPreviousHitTotal: prevHit, + StickySessionHitTotal: sessionHit, + LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(), + AccountSwitchTotal: switchTotal, + SchedulerLatencyMsTotal: latencyTotal, + RuntimeStatsAccountCount: s.stats.size(), + CircuitBreakerOpenTotal: s.metrics.circuitBreakerOpenTotal.Load(), + CircuitBreakerRecoverTotal: s.metrics.circuitBreakerRecoverTotal.Load(), + StickyReleaseErrorTotal: s.metrics.stickyReleaseErrorTotal.Load(), + StickyReleaseCircuitOpenTotal: s.metrics.stickyReleaseCircuitOpenTotal.Load(), } if selectTotal > 0 { snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal) @@ -820,12 +1829,12 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler( }) } -func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { +func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) { scheduler := s.getOpenAIAccountScheduler() if scheduler == nil { return } - scheduler.ReportResult(accountID, success, firstTokenMs) + scheduler.ReportResult(accountID, success, firstTokenMs, model, ttftMs) } func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { @@ -855,7 +1864,14 @@ func (s *OpenAIGatewayService) openAIWSLBTopK() int { if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 { return s.cfg.Gateway.OpenAIWS.LBTopK } - return 7 + return 999 +} + +func (s *OpenAIGatewayService) openAIWSSchedulerP2CEnabled() bool { + if s != nil && s.cfg != nil { + return s.cfg.Gateway.OpenAIWS.SchedulerP2CEnabled + } + return false } func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView { @@ -885,6 +1901,24 @@ type GatewayOpenAIWSSchedulerScoreWeightsView struct { TTFT float64 } +// defaultSchedulerTrendMaxSlope is the normalization ceiling for the trend +// slope. A slope of 5.0 means the account's load rate is increasing at 5 +// percentage points per second — a very steep rise. +const defaultSchedulerTrendMaxSlope = 5.0 + +// openAIWSSchedulerTrendConfig reads trend-prediction config with defaults. +func (s *OpenAIGatewayService) openAIWSSchedulerTrendConfig() (enabled bool, maxSlope float64) { + maxSlope = defaultSchedulerTrendMaxSlope + if s == nil || s.cfg == nil { + return false, maxSlope + } + enabled = s.cfg.Gateway.OpenAIWS.SchedulerTrendEnabled + if s.cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope > 0 { + maxSlope = s.cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope + } + return enabled, maxSlope +} + func clamp01(value float64) float64 { switch { case value < 0: diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go index 7f6f1b66a..6cc3ba8f0 100644 --- a/backend/internal/service/openai_account_scheduler_test.go +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -1,6 +1,7 @@ package service import ( + "container/heap" "context" "fmt" "math" @@ -12,6 +13,14 @@ import ( "github.com/stretchr/testify/require" ) +func mustDefaultOpenAIAccountScheduler(t *testing.T, svc *OpenAIGatewayService, stats *openAIAccountRuntimeStats) *defaultOpenAIAccountScheduler { + t.Helper() + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + return scheduler +} + func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { ctx := context.Background() groupID := int64(9) @@ -447,7 +456,7 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) require.NoError(t, err) require.NotNil(t, selection) - svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120)) + svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120), "", 0) svc.RecordOpenAIAccountSwitch() snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() @@ -465,19 +474,120 @@ func intPtrForTest(v int) *int { func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) { stats := newOpenAIAccountRuntimeStats() - stats.report(1001, true, nil) + stats.report(1001, true, nil, "", 0) // error: fast 0→0, slow 0→0 firstTTFT := 100 - stats.report(1001, false, &firstTTFT) + stats.report(1001, false, &firstTTFT, "", 0) // error: fast 0→0.5, slow 0→0.1; ttft: NaN→100 (both) secondTTFT := 200 - stats.report(1001, false, &secondTTFT) + stats.report(1001, false, &secondTTFT, "", 0) // error: fast 0.5→0.75, slow 0.1→0.19; ttft: fast 100→150, slow 100→110 errorRate, ttft, hasTTFT := stats.snapshot(1001) require.True(t, hasTTFT) - require.InDelta(t, 0.36, errorRate, 1e-9) - require.InDelta(t, 120.0, ttft, 1e-9) + // errorRate = max(fast=0.75, slow=0.19) = 0.75 + require.InDelta(t, 0.75, errorRate, 1e-9) + // ttft = max(fast=150, slow=110) = 150 + require.InDelta(t, 150.0, ttft, 1e-9) require.Equal(t, 1, stats.size()) } +func TestDualEWMA_UpdateAndValue(t *testing.T) { + var d dualEWMA + + // Initial state: both channels are 0. + require.Equal(t, 0.0, d.fastValue()) + require.Equal(t, 0.0, d.slowValue()) + require.Equal(t, 0.0, d.value()) + + // First sample = 1.0 + d.update(1.0) + // fast: 0.5*1 + 0.5*0 = 0.5 + require.InDelta(t, 0.5, d.fastValue(), 1e-12) + // slow: 0.1*1 + 0.9*0 = 0.1 + require.InDelta(t, 0.1, d.slowValue(), 1e-12) + // value = max(0.5, 0.1) = 0.5 + require.InDelta(t, 0.5, d.value(), 1e-12) + + // Second sample = 0.0 (recovery) + d.update(0.0) + // fast: 0.5*0 + 0.5*0.5 = 0.25 + require.InDelta(t, 0.25, d.fastValue(), 1e-12) + // slow: 0.1*0 + 0.9*0.1 = 0.09 + require.InDelta(t, 0.09, d.slowValue(), 1e-12) + // value = max(0.25, 0.09) = 0.25 + require.InDelta(t, 0.25, d.value(), 1e-12) +} + +func TestDualEWMA_SlowDominatesAfterRecovery(t *testing.T) { + var d dualEWMA + + // Spike: several failures. + for i := 0; i < 10; i++ { + d.update(1.0) + } + // Now fast is close to 1, slow is also rising. + + // Recovery: many successes. + for i := 0; i < 20; i++ { + d.update(0.0) + } + // Fast should have dropped close to 0, slow should still be > fast. + require.Greater(t, d.slowValue(), d.fastValue(), + "after recovery, slow channel should dominate the pessimistic envelope") + require.Equal(t, d.slowValue(), d.value()) +} + +func TestDualEWMATTFT_NaNInitAndFirstSample(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + + // Before any sample, value should report no data. + v, ok := d.value() + require.False(t, ok) + require.Equal(t, 0.0, v) + + // First sample seeds both channels. + d.update(100.0) + require.InDelta(t, 100.0, d.fastValue(), 1e-12) + require.InDelta(t, 100.0, d.slowValue(), 1e-12) + v, ok = d.value() + require.True(t, ok) + require.InDelta(t, 100.0, v, 1e-12) + + // Second sample. + d.update(200.0) + // fast: 0.5*200 + 0.5*100 = 150 + require.InDelta(t, 150.0, d.fastValue(), 1e-12) + // slow: 0.1*200 + 0.9*100 = 110 + require.InDelta(t, 110.0, d.slowValue(), 1e-12) + v, ok = d.value() + require.True(t, ok) + require.InDelta(t, 150.0, v, 1e-12) +} + +func TestDualEWMATTFT_SlowDominatesWhenLatencyDrops(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + + // Warm up with high latency. + for i := 0; i < 20; i++ { + d.update(500.0) + } + // Now push many low-latency samples. + for i := 0; i < 20; i++ { + d.update(100.0) + } + // Fast should have adapted down quickly; slow should still be higher. + require.Greater(t, d.slowValue(), d.fastValue(), + "after latency improvement, slow channel should dominate the pessimistic TTFT") + v, ok := d.value() + require.True(t, ok) + require.InDelta(t, d.slowValue(), v, 1e-12) +} + +func TestDualEWMAConstants(t *testing.T) { + require.Equal(t, 0.5, dualEWMAAlphaFast) + require.Equal(t, 0.1, dualEWMAAlphaSlow) +} + func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) { stats := newOpenAIAccountRuntimeStats() @@ -496,7 +606,7 @@ func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) { accountID := int64(i%accountCount + 1) success := (i+worker)%3 != 0 ttft := 80 + (i+worker)%40 - stats.report(accountID, success, &ttft) + stats.report(accountID, success, &ttft, "", 0) } }() } @@ -751,7 +861,7 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { require.True(t, ok) ttft := 100 - scheduler.ReportResult(1001, true, &ttft) + scheduler.ReportResult(1001, true, &ttft, "", 0) scheduler.ReportSwitch() scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{ Layer: openAIAccountScheduleLayerLoadBalance, @@ -780,11 +890,11 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) { svc := &OpenAIGatewayService{} ttft := 120 - svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) + svc.ReportOpenAIAccountScheduleResult(10, true, &ttft, "", 0) svc.RecordOpenAIAccountSwitch() snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) - require.Equal(t, 7, svc.openAIWSLBTopK()) + require.Equal(t, 999, svc.openAIWSLBTopK()) require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL()) defaultWeights := svc.openAIWSSchedulerWeights() @@ -836,6 +946,3268 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t * require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2)) } +func TestLoadFactorCapacityAwareness(t *testing.T) { + // Test that accounts with higher absolute capacity get better scores + // when percentage load is equal. + // + // Setup: + // Account A: Concurrency=100, LoadRate=50 (50 free slots) + // Account B: Concurrency=10, LoadRate=50 (5 free slots) + // Both at 50% load, but A should score higher due to more headroom. + + ctx := context.Background() + groupID := int64(20) + accounts := []Account{ + { + ID: 6001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 100, + Priority: 0, + }, + { + ID: 6002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 10, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + // Use only Load weight to isolate the capacity-aware loadFactor effect. + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 6001: {AccountID: 6001, LoadRate: 50, WaitingCount: 0}, + 6002: {AccountID: 6002, LoadRate: 50, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + // Verify account A (high capacity) is always selected first by score. + // Because weighted selection has randomness, we run multiple iterations + // and verify A is selected more often than B. + countA := 0 + countB := 0 + iterations := 100 + for i := 0; i < iterations; i++ { + sessionHash := fmt.Sprintf("cap_aware_test_%d", i) + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + if selection.Account.ID == 6001 { + countA++ + } else { + countB++ + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // Account A (100 concurrency) should be selected significantly more often + // than Account B (10 concurrency) because A has 50 free slots vs 5 free slots. + require.Greater(t, countA, countB, + "high-capacity account (50 free slots) should be selected more often than low-capacity (5 free slots) at equal load percentage; got A=%d B=%d", countA, countB) + + // ----------------------------------------------------------------------- + // Verify score math directly via the capacity-aware loadFactor formula. + // ----------------------------------------------------------------------- + // maxConcurrency = 100 (from account A) + // + // Account A (Concurrency=100, LoadRate=50): + // base loadFactor = 1 - 50/100 = 0.5 + // remainingSlots = 100 * 0.5 = 50 + // capacityBonus = 50 / 100 = 0.5 + // loadFactor = 0.7*0.5 + 0.3*0.5 = 0.5 + // + // Account B (Concurrency=10, LoadRate=50): + // base loadFactor = 1 - 50/100 = 0.5 + // remainingSlots = 10 * 0.5 = 5 + // capacityBonus = 5 / 100 = 0.05 + // loadFactor = 0.7*0.5 + 0.3*0.05 = 0.365 + // + // With Load weight = 1.0 and all others 0.0, score = loadFactor. + expectedScoreA := 0.7*0.5 + 0.3*0.5 // 0.5 + expectedScoreB := 0.7*0.5 + 0.3*(5.0/100.0) // 0.365 + require.Greater(t, expectedScoreA, expectedScoreB, "score sanity check") + require.InDelta(t, 0.5, expectedScoreA, 1e-9) + require.InDelta(t, 0.365, expectedScoreB, 1e-9) +} + +func TestQueueFactorCapacityAwareness(t *testing.T) { + // Test that the capacity-aware queue factor penalises accounts + // whose queue depth is high relative to their own concurrency. + // + // Account A: Concurrency=100, WaitingCount=10 (10% of capacity) + // Account B: Concurrency=10, WaitingCount=10 (100% of capacity) + // Both have same absolute waiting count, but B should score lower. + + ctx := context.Background() + groupID := int64(21) + accounts := []Account{ + { + ID: 7001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 100, + Priority: 0, + }, + { + ID: 7002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 10, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + // Use only Queue weight to isolate the capacity-aware queueFactor effect. + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 7001: {AccountID: 7001, LoadRate: 30, WaitingCount: 10}, + 7002: {AccountID: 7002, LoadRate: 30, WaitingCount: 10}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + countA := 0 + countB := 0 + iterations := 100 + for i := 0; i < iterations; i++ { + sessionHash := fmt.Sprintf("queue_aware_test_%d", i) + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + if selection.Account.ID == 7001 { + countA++ + } else { + countB++ + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + require.Greater(t, countA, countB, + "account with lower relative queue depth should be selected more often; got A=%d B=%d", countA, countB) + + // ----------------------------------------------------------------------- + // Verify score math for the capacity-aware queueFactor. + // ----------------------------------------------------------------------- + // maxWaiting = 10 (both accounts have WaitingCount=10) + // + // Account A (Concurrency=100, WaitingCount=10): + // base queueFactor = 1 - 10/10 = 0.0 + // relativeQueue = 10/100 = 0.1 + // queueFactor = 0.6*0.0 + 0.4*(1-0.1) = 0.36 + // + // Account B (Concurrency=10, WaitingCount=10): + // base queueFactor = 1 - 10/10 = 0.0 + // relativeQueue = clamp01(10/10) = 1.0 + // queueFactor = 0.6*0.0 + 0.4*(1-1.0) = 0.0 + expectedQueueA := 0.6*0.0 + 0.4*(1-0.1) + expectedQueueB := 0.6*0.0 + 0.4*(1-1.0) + require.Greater(t, expectedQueueA, expectedQueueB) + require.InDelta(t, 0.36, expectedQueueA, 1e-9) + require.InDelta(t, 0.0, expectedQueueB, 1e-9) +} + +func TestLoadFactorCapacityAwareness_ZeroConcurrencyFallback(t *testing.T) { + // When Concurrency is 0, the capacity-aware blending should be skipped + // and loadFactor should fall back to the simple loadRate/100 formula. + + ctx := context.Background() + groupID := int64(22) + accounts := []Account{ + { + ID: 8001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 0, // unset / zero + Priority: 0, + }, + { + ID: 8002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 0, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 8001: {AccountID: 8001, LoadRate: 30, WaitingCount: 0}, + 8002: {AccountID: 8002, LoadRate: 70, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + // With Concurrency=0, maxConcurrency=0, so the capacity-aware path is skipped. + // Account 8001 (LoadRate=30) should have higher loadFactor than 8002 (LoadRate=70). + countLow := 0 + iterations := 60 + for i := 0; i < iterations; i++ { + sessionHash := fmt.Sprintf("zero_conc_%d", i) + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + if selection.Account.ID == 8001 { + countLow++ + } + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // 8001 (lower load) should be picked more often. + require.Greater(t, countLow, iterations/2, + "account with lower load should be selected more often when concurrency is 0; got %d/%d", countLow, iterations) +} + func int64PtrForTest(v int64) *int64 { return &v } + +// --------------------------------------------------------------------------- +// Circuit Breaker Tests +// --------------------------------------------------------------------------- + +func TestAccountCircuitBreaker_ClosedToOpen(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 30 * time.Second + halfOpenMax := 2 + + // Initially CLOSED — should allow. + require.True(t, cb.allow(cooldown, halfOpenMax)) + require.Equal(t, "CLOSED", cb.stateString()) + require.False(t, cb.isOpen()) + + // Record 4 failures — should still be CLOSED (threshold is 5). + for i := 0; i < 4; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "CLOSED", cb.stateString()) + require.True(t, cb.allow(cooldown, halfOpenMax)) + + // 5th failure trips the breaker to OPEN. + cb.recordFailure(defaultCircuitBreakerFailThreshold) + require.Equal(t, "OPEN", cb.stateString()) + require.True(t, cb.isOpen()) + require.False(t, cb.allow(cooldown, halfOpenMax)) +} + +func TestAccountCircuitBreaker_OpenToHalfOpen(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 50 * time.Millisecond + halfOpenMax := 2 + + // Trip the breaker. + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "OPEN", cb.stateString()) + require.False(t, cb.allow(cooldown, halfOpenMax)) + + // Wait for cooldown to elapse. + time.Sleep(cooldown + 10*time.Millisecond) + + // Next allow() should transition to HALF_OPEN and admit the request. + require.True(t, cb.allow(cooldown, halfOpenMax)) + require.Equal(t, "HALF_OPEN", cb.stateString()) +} + +func TestAccountCircuitBreaker_HalfOpenToClose(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 50 * time.Millisecond + halfOpenMax := 2 + + // Trip the breaker. + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "OPEN", cb.stateString()) + + // Wait for cooldown. + time.Sleep(cooldown + 10*time.Millisecond) + + // Allow first probe — transitions to HALF_OPEN. + require.True(t, cb.allow(cooldown, halfOpenMax)) + require.Equal(t, "HALF_OPEN", cb.stateString()) + + // Allow second probe. + require.True(t, cb.allow(cooldown, halfOpenMax)) + + // Third probe should be rejected (halfOpenMax=2). + require.False(t, cb.allow(cooldown, halfOpenMax)) + + // Both probes succeed — should close the circuit. + cb.recordSuccess() + // After first success, still HALF_OPEN (need both to succeed). + require.Equal(t, "HALF_OPEN", cb.stateString()) + cb.recordSuccess() + // Both probes succeeded — circuit should be CLOSED now. + require.Equal(t, "CLOSED", cb.stateString()) + require.False(t, cb.isOpen()) + require.True(t, cb.allow(cooldown, halfOpenMax)) +} + +func TestAccountCircuitBreaker_ReleaseHalfOpenPermit(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 10 * time.Millisecond + halfOpenMax := 2 + + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "OPEN", cb.stateString()) + + time.Sleep(cooldown + 5*time.Millisecond) + require.True(t, cb.allow(cooldown, halfOpenMax)) + require.Equal(t, "HALF_OPEN", cb.stateString()) + require.Equal(t, int32(1), cb.halfOpenInFlight.Load()) + + cb.releaseHalfOpenPermit() + require.Equal(t, int32(0), cb.halfOpenInFlight.Load()) + + // Idempotent release should not underflow. + cb.releaseHalfOpenPermit() + require.Equal(t, int32(0), cb.halfOpenInFlight.Load()) +} + +func TestAccountCircuitBreaker_HalfOpenToOpen(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 50 * time.Millisecond + halfOpenMax := 2 + + // Trip the breaker. + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "OPEN", cb.stateString()) + + // Wait for cooldown. + time.Sleep(cooldown + 10*time.Millisecond) + + // Allow a probe — transitions to HALF_OPEN. + require.True(t, cb.allow(cooldown, halfOpenMax)) + require.Equal(t, "HALF_OPEN", cb.stateString()) + + // Failure in HALF_OPEN should trip back to OPEN. + cb.recordFailure(defaultCircuitBreakerFailThreshold) + require.Equal(t, "OPEN", cb.stateString()) + require.True(t, cb.isOpen()) + require.False(t, cb.allow(cooldown, halfOpenMax)) +} + +func TestAccountCircuitBreaker_ResetOnSuccess(t *testing.T) { + cb := &accountCircuitBreaker{} + + // 4 failures followed by a success should reset the counter. + for i := 0; i < 4; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, int32(4), cb.consecutiveFails.Load()) + + cb.recordSuccess() + require.Equal(t, int32(0), cb.consecutiveFails.Load()) + require.Equal(t, "CLOSED", cb.stateString()) + + // 4 more failures — still not tripped because counter was reset. + for i := 0; i < 4; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.Equal(t, "CLOSED", cb.stateString()) + + // 5th consecutive failure trips it. + cb.recordFailure(defaultCircuitBreakerFailThreshold) + require.Equal(t, "OPEN", cb.stateString()) +} + +func TestAccountCircuitBreaker_IntegrationWithScheduler(t *testing.T) { + ctx := context.Background() + groupID := int64(30) + accounts := []Account{ + { + ID: 9001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + { + ID: 9002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + // Enable circuit breaker with low threshold for testing. + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 9001: {AccountID: 9001, LoadRate: 10, WaitingCount: 0}, + 9002: {AccountID: 9002, LoadRate: 10, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + scheduler := svc.getOpenAIAccountScheduler() + + // Report 3 consecutive failures for account 9001 — trips the circuit breaker. + for i := 0; i < 3; i++ { + scheduler.ReportResult(9001, false, nil, "", 0) + } + + // Now all selections should avoid account 9001 and pick 9002. + for i := 0; i < 20; i++ { + sessionHash := fmt.Sprintf("cb_integration_%d", i) + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(9002), selection.Account.ID, + "circuit-open account 9001 should be skipped, got %d on iteration %d", selection.Account.ID, i) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // Verify metrics tracked the trip. + snapshot := scheduler.SnapshotMetrics() + require.GreaterOrEqual(t, snapshot.CircuitBreakerOpenTotal, int64(1)) +} + +func TestAccountCircuitBreaker_AllOpenFallback(t *testing.T) { + ctx := context.Background() + groupID := int64(31) + accounts := []Account{ + { + ID: 9101, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 9101: {AccountID: 9101, LoadRate: 10, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + scheduler := svc.getOpenAIAccountScheduler() + + // Trip the only account. + for i := 0; i < 3; i++ { + scheduler.ReportResult(9101, false, nil, "", 0) + } + + // Even though the only account is circuit-open, the scheduler should + // still return it (graceful degradation — never return "no accounts"). + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "cb_fallback_test", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(9101), selection.Account.ID) +} + +func TestAccountCircuitBreaker_SelectReleasesUnselectedHalfOpenPermit(t *testing.T) { + ctx := context.Background() + groupID := int64(311) + accounts := []Account{ + { + ID: 9111, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + { + ID: 9112, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 1 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 1 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 1 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 9111: {AccountID: 9111, LoadRate: 10, WaitingCount: 0}, + 9112: {AccountID: 9112, LoadRate: 10, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + scheduler, ok := svc.getOpenAIAccountScheduler().(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + // Trip both accounts to OPEN so next select will transition both to HALF_OPEN. + scheduler.ReportResult(9111, false, nil, "", 0) + scheduler.ReportResult(9112, false, nil, "", 0) + scheduler.stats.getCircuitBreaker(9111).lastFailureNano.Store(time.Now().Add(-2 * time.Second).UnixNano()) + scheduler.stats.getCircuitBreaker(9112).lastFailureNano.Store(time.Now().Add(-2 * time.Second).UnixNano()) + + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "cb_release_unselected", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + + selectedID := selection.Account.ID + otherID := int64(9111) + if selectedID == otherID { + otherID = 9112 + } + otherCB := scheduler.stats.getCircuitBreaker(otherID) + require.Equal(t, int32(0), otherCB.halfOpenInFlight.Load(), + "unselected HALF_OPEN candidate should release probe permit") + + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestAccountCircuitBreaker_DisabledByConfig(t *testing.T) { + ctx := context.Background() + groupID := int64(32) + accounts := []Account{ + { + ID: 9201, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + { + ID: 9202, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + // Circuit breaker explicitly DISABLED. + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = false + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 9201: {AccountID: 9201, LoadRate: 10, WaitingCount: 0}, + 9202: {AccountID: 9202, LoadRate: 10, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + scheduler := svc.getOpenAIAccountScheduler() + internalScheduler, ok := scheduler.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + // Report many failures — should NOT affect scheduling when disabled. + for i := 0; i < 10; i++ { + scheduler.ReportResult(9201, false, nil, "", 0) + } + + // Both accounts should still be eligible. + selected := map[int64]int{} + for i := 0; i < 40; i++ { + sessionHash := fmt.Sprintf("cb_disabled_%d", i) + selection, _, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + selected[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + // When disabled, 9201 should still appear as a candidate. + require.Greater(t, selected[int64(9201)]+selected[int64(9202)], 0) + require.Len(t, selected, 2, "both accounts should be selectable when CB is disabled") + cb := internalScheduler.stats.getCircuitBreaker(9201) + require.False(t, cb.isOpen(), "circuit breaker should not transition to OPEN when feature is disabled") + require.Equal(t, int64(0), internalScheduler.metrics.circuitBreakerOpenTotal.Load()) +} + +func TestAccountCircuitBreaker_RecoveryMetrics(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + svc := &OpenAIGatewayService{} + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + // Manually enable CB by setting config on the service. + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 0 // immediate cooldown for test + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 1 + scheduler.service.cfg = cfg + + // Trip the breaker: 3 consecutive failures. + for i := 0; i < 3; i++ { + scheduler.ReportResult(5001, false, nil, "", 0) + } + require.Equal(t, int64(1), scheduler.metrics.circuitBreakerOpenTotal.Load()) + + // Let the cooldown expire (0 seconds) and call allow to trigger HALF_OPEN. + cb := stats.getCircuitBreaker(5001) + require.Equal(t, "OPEN", cb.stateString()) + allowed := cb.allow(0, 1) + require.True(t, allowed) + require.Equal(t, "HALF_OPEN", cb.stateString()) + + // Report success — should transition HALF_OPEN → CLOSED. + scheduler.ReportResult(5001, true, nil, "", 0) + require.Equal(t, "CLOSED", cb.stateString()) + require.Equal(t, int64(1), scheduler.metrics.circuitBreakerRecoverTotal.Load()) +} + +func TestAccountCircuitBreaker_StateString(t *testing.T) { + cb := &accountCircuitBreaker{} + require.Equal(t, "CLOSED", cb.stateString()) + + cb.state.Store(circuitBreakerStateOpen) + require.Equal(t, "OPEN", cb.stateString()) + + cb.state.Store(circuitBreakerStateHalfOpen) + require.Equal(t, "HALF_OPEN", cb.stateString()) + + cb.state.Store(99) + require.Equal(t, "UNKNOWN", cb.stateString()) +} + +func TestAccountCircuitBreaker_GetAndIsCircuitOpen(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + + // isCircuitOpen on a non-existent account should return false. + require.False(t, stats.isCircuitOpen(1234)) + + // getCircuitBreaker should create on first access. + cb := stats.getCircuitBreaker(1234) + require.NotNil(t, cb) + require.Equal(t, "CLOSED", cb.stateString()) + require.False(t, stats.isCircuitOpen(1234)) + + // Trip it and verify isCircuitOpen returns true. + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.True(t, stats.isCircuitOpen(1234)) + + // Second call to getCircuitBreaker should return same instance. + cb2 := stats.getCircuitBreaker(1234) + require.True(t, cb == cb2, "should return same pointer") +} + +func TestAccountCircuitBreaker_ConcurrentAllowAndRecord(t *testing.T) { + cb := &accountCircuitBreaker{} + cooldown := 50 * time.Millisecond + halfOpenMax := 4 + + var wg sync.WaitGroup + const workers = 16 + const iterations = 200 + + wg.Add(workers) + for w := 0; w < workers; w++ { + w := w + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + _ = cb.allow(cooldown, halfOpenMax) + if (i+w)%3 == 0 { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } else { + cb.recordSuccess() + } + } + }() + } + wg.Wait() + + // Just verify it doesn't panic or deadlock, and state is valid. + state := cb.state.Load() + require.True(t, state == circuitBreakerStateClosed || + state == circuitBreakerStateOpen || + state == circuitBreakerStateHalfOpen, + "unexpected state: %d", state) +} + +// --------------------------------------------------------------------------- +// P2C (Power-of-Two-Choices) Tests +// --------------------------------------------------------------------------- + +func TestSelectP2COpenAICandidates_BasicSelection(t *testing.T) { + // P2C should return all candidates in some order, and higher-scored + // candidates should tend to appear earlier in the selection order. + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 1}, score: 0.9}, + {account: &Account{ID: 2, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 2}, score: 0.5}, + {account: &Account{ID: 3, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 3}, score: 0.1}, + {account: &Account{ID: 4, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 4}, score: 0.7}, + {account: &Account{ID: 5, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 5}, score: 0.3}, + } + + req := OpenAIAccountScheduleRequest{ + SessionHash: "p2c_basic_test", + } + + result := selectP2COpenAICandidates(candidates, req) + + // All candidates must be present exactly once. + require.Len(t, result, len(candidates)) + seen := map[int64]bool{} + for _, c := range result { + require.False(t, seen[c.account.ID], "duplicate account ID %d", c.account.ID) + seen[c.account.ID] = true + } + for _, c := range candidates { + require.True(t, seen[c.account.ID], "missing account ID %d", c.account.ID) + } + + // Statistical check: over many runs the highest-scored candidate (ID=1, + // score=0.9) should appear in position 0 more often than the lowest-scored + // candidate (ID=3, score=0.1). + topCount := map[int64]int{} + iterations := 500 + for i := 0; i < iterations; i++ { + iterReq := OpenAIAccountScheduleRequest{ + SessionHash: fmt.Sprintf("p2c_stat_%d", i), + } + order := selectP2COpenAICandidates(candidates, iterReq) + topCount[order[0].account.ID]++ + } + require.Greater(t, topCount[int64(1)], topCount[int64(3)], + "highest-scored candidate should appear first more often than lowest-scored; got best=%d worst=%d", + topCount[int64(1)], topCount[int64(3)]) +} + +func TestSelectP2COpenAICandidates_SingleCandidate(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 42, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 42}, score: 1.0}, + } + req := OpenAIAccountScheduleRequest{SessionHash: "single"} + + result := selectP2COpenAICandidates(candidates, req) + require.Len(t, result, 1) + require.Equal(t, int64(42), result[0].account.ID) +} + +func TestSelectP2COpenAICandidates_DeterministicWithSameSeed(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 10, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 10}, score: 0.8}, + {account: &Account{ID: 20, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 20}, score: 0.6}, + {account: &Account{ID: 30, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 30}, score: 0.4}, + {account: &Account{ID: 40, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 40}, score: 0.2}, + } + // Use a session hash to ensure the seed is deterministic (no time entropy). + req := OpenAIAccountScheduleRequest{ + SessionHash: "deterministic_p2c_seed", + } + + first := selectP2COpenAICandidates(candidates, req) + for i := 0; i < 10; i++ { + again := selectP2COpenAICandidates(candidates, req) + require.Len(t, again, len(first)) + for j := range first { + require.Equal(t, first[j].account.ID, again[j].account.ID, + "iteration %d position %d mismatch", i, j) + } + } +} + +func TestP2CLoadBalanceIntegration(t *testing.T) { + // End-to-end test: enable P2C via config, verify it distributes across + // accounts and that decision.TopK == 0 (P2C mode indicator). + ctx := context.Background() + groupID := int64(50) + accounts := []Account{ + { + ID: 5001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, + Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0, + }, + { + ID: 5002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, + Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0, + }, + { + ID: 5003, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, + Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = true + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.7 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.8 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.5 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5001: {AccountID: 5001, LoadRate: 20, WaitingCount: 0}, + 5002: {AccountID: 5002, LoadRate: 30, WaitingCount: 0}, + 5003: {AccountID: 5003, LoadRate: 40, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selected := map[int64]int{} + iterations := 100 + for i := 0; i < iterations; i++ { + sessionHash := fmt.Sprintf("p2c_integration_%d", i) + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + // P2C mode: TopK should be 0. + require.Equal(t, 0, decision.TopK, "P2C mode should set TopK=0") + selected[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // P2C with 3 candidates: the two better-scored accounts (5001, 5002) + // should be selected, while the weakest (5003) may rarely or never win + // a P2C tournament. Verify at least 2 distinct accounts are picked and + // the lowest-loaded account dominates. + require.GreaterOrEqual(t, len(selected), 2, + "P2C should distribute across at least 2 accounts; got %v", selected) + require.Greater(t, selected[int64(5001)], 0, + "lowest-loaded account 5001 should be selected at least once") + require.Greater(t, selected[int64(5001)], selected[int64(5003)], + "lowest-loaded account should be favored over highest-loaded; got 5001=%d 5003=%d", + selected[int64(5001)], selected[int64(5003)]) +} + +func TestP2CFallbackToTopK(t *testing.T) { + // When P2C is disabled (default), the Top-K path should be used. + // Verify topK > 0 in decision. + ctx := context.Background() + groupID := int64(51) + accounts := []Account{ + { + ID: 5101, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, + Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0, + }, + { + ID: 5102, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, + Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0, + }, + } + + cfg := &config.Config{} + // Explicitly disable P2C (or leave at default false). + cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.7 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.8 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.5 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5101: {AccountID: 5101, LoadRate: 10, WaitingCount: 0}, + 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "topk_fallback_test", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + // Top-K mode: TopK should be > 0. + require.Greater(t, decision.TopK, 0, "Top-K mode should set TopK > 0") + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + + // Also verify P2C helper returns false when disabled. + require.False(t, svc.openAIWSSchedulerP2CEnabled()) +} + +// --------------------------------------------------------------------------- +// Conditional Sticky Session Release Tests +// --------------------------------------------------------------------------- + +// buildConditionalStickyTestService creates a minimal OpenAIGatewayService and +// scheduler with injectable runtime stats for conditional sticky tests. +func buildConditionalStickyTestService( + accounts []Account, + stickyKey string, + stickyAccountID int64, + stickyReleaseEnabled bool, + cbEnabled bool, +) (*OpenAIGatewayService, *defaultOpenAIAccountScheduler) { + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + stickyKey: stickyAccountID, + }, + } + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyReleaseEnabled = stickyReleaseEnabled + // Leave StickyReleaseErrorThreshold at 0 to use the default (0.3). + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = cbEnabled + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3 + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 5 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + stats := newOpenAIAccountRuntimeStats() + scheduler := &defaultOpenAIAccountScheduler{ + service: svc, + stats: stats, + } + // Wire the scheduler into the service so that SelectAccountWithScheduler + // uses it via getOpenAIAccountScheduler. + svc.openaiScheduler = scheduler + svc.openaiAccountStats = stats + return svc, scheduler +} + +func TestConditionalSticky_ReleaseOnHighErrorRate(t *testing.T) { + ctx := context.Background() + groupID := int64(30001) + stickyAccount := Account{ + ID: 5001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + fallbackAccount := Account{ + ID: 5002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + + svc, scheduler := buildConditionalStickyTestService( + []Account{stickyAccount, fallbackAccount}, + fmt.Sprintf("openai:sticky_err_%d", groupID), + stickyAccount.ID, + true, // stickyReleaseEnabled + false, // cbEnabled (not needed for error rate test) + ) + + // Pump the error rate above the 0.3 default threshold. + // With alpha=0.5 (fast EWMA), after ~5 consecutive failures the rate + // converges well above 0.3. + for i := 0; i < 10; i++ { + scheduler.stats.report(stickyAccount.ID, false, nil, "", 0) + } + errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID) + require.Greater(t, errRate, 0.3, "error rate should exceed threshold before test") + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", + fmt.Sprintf("sticky_err_%d", groupID), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + // The sticky account should have been released; the scheduler should + // have fallen through to load balance and selected one of the accounts. + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer, + "should fall through to load balance after sticky release") + require.False(t, decision.StickySessionHit, "sticky hit should be false") +} + +func TestConditionalSticky_ReleaseOnCircuitOpen(t *testing.T) { + ctx := context.Background() + groupID := int64(30002) + stickyAccount := Account{ + ID: 5011, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + fallbackAccount := Account{ + ID: 5012, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + + svc, scheduler := buildConditionalStickyTestService( + []Account{stickyAccount, fallbackAccount}, + fmt.Sprintf("openai:sticky_cb_%d", groupID), + stickyAccount.ID, + true, // stickyReleaseEnabled + true, // cbEnabled + ) + + // Trip the circuit breaker by reporting consecutive failures beyond + // the configured threshold (3). + for i := 0; i < 5; i++ { + scheduler.ReportResult(stickyAccount.ID, false, nil, "", 0) + } + require.True(t, scheduler.stats.isCircuitOpen(stickyAccount.ID), + "circuit breaker should be OPEN before test") + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", + fmt.Sprintf("sticky_cb_%d", groupID), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer, + "should fall through to load balance after sticky release due to CB open") + require.False(t, decision.StickySessionHit) +} + +func TestConditionalSticky_KeepsHealthySticky(t *testing.T) { + ctx := context.Background() + groupID := int64(30003) + stickyAccount := Account{ + ID: 5021, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + + svc, scheduler := buildConditionalStickyTestService( + []Account{stickyAccount}, + fmt.Sprintf("openai:sticky_ok_%d", groupID), + stickyAccount.ID, + true, // stickyReleaseEnabled + false, // cbEnabled + ) + + // Report some successes so error rate stays at 0. + for i := 0; i < 5; i++ { + scheduler.stats.report(stickyAccount.ID, true, nil, "", 0) + } + errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID) + require.Less(t, errRate, 0.3, "error rate should be below threshold") + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", + fmt.Sprintf("sticky_ok_%d", groupID), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, stickyAccount.ID, selection.Account.ID, + "healthy sticky account should be kept") + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestConditionalSticky_DisabledByConfig(t *testing.T) { + ctx := context.Background() + groupID := int64(30004) + stickyAccount := Account{ + ID: 5031, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + + svc, scheduler := buildConditionalStickyTestService( + []Account{stickyAccount}, + fmt.Sprintf("openai:sticky_off_%d", groupID), + stickyAccount.ID, + false, // stickyReleaseEnabled = OFF + false, // cbEnabled + ) + + // Pump error rate very high, but since sticky release is disabled, + // the sticky binding should still hold. + for i := 0; i < 10; i++ { + scheduler.stats.report(stickyAccount.ID, false, nil, "", 0) + } + errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID) + require.Greater(t, errRate, 0.3, "error rate should exceed threshold") + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", + fmt.Sprintf("sticky_off_%d", groupID), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, stickyAccount.ID, selection.Account.ID, + "sticky should be kept when feature is disabled") + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestConditionalSticky_Metrics(t *testing.T) { + groupID := int64(30005) + ctx := context.Background() + + // --- Error rate release metric --- + stickyAccount := Account{ + ID: 5041, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + fallbackAccount := Account{ + ID: 5042, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + + svc, scheduler := buildConditionalStickyTestService( + []Account{stickyAccount, fallbackAccount}, + fmt.Sprintf("openai:sticky_m1_%d", groupID), + stickyAccount.ID, + true, // stickyReleaseEnabled + true, // cbEnabled + ) + + // Trigger error-rate release. With CB also enabled and threshold=3, + // the CB will be OPEN after 3 failures via stats.report (which uses + // the default CB threshold of 5). Send enough to ensure both are + // triggered. + for i := 0; i < 10; i++ { + scheduler.stats.report(stickyAccount.ID, false, nil, "", 0) + } + + _, _, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", + fmt.Sprintf("sticky_m1_%d", groupID), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + + snap := scheduler.SnapshotMetrics() + // At least one of the two release metrics should have been incremented. + totalReleases := snap.StickyReleaseErrorTotal + snap.StickyReleaseCircuitOpenTotal + require.Greater(t, totalReleases, int64(0), + "at least one sticky release metric should be incremented") + + // --- Circuit breaker release metric (clean setup) --- + groupID2 := int64(30006) + svc2, scheduler2 := buildConditionalStickyTestService( + []Account{stickyAccount, fallbackAccount}, + fmt.Sprintf("openai:sticky_m2_%d", groupID2), + stickyAccount.ID, + true, // stickyReleaseEnabled + true, // cbEnabled + ) + + // Trip CB via ReportResult (which checks the configured threshold=3). + for i := 0; i < 5; i++ { + scheduler2.ReportResult(stickyAccount.ID, false, nil, "", 0) + } + require.True(t, scheduler2.stats.isCircuitOpen(stickyAccount.ID)) + + _, _, err = svc2.SelectAccountWithScheduler( + ctx, &groupID2, "", + fmt.Sprintf("sticky_m2_%d", groupID2), + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + + snap2 := scheduler2.SnapshotMetrics() + require.Greater(t, snap2.StickyReleaseCircuitOpenTotal, int64(0), + "circuit-open sticky release metric should be incremented") +} + +// --------------------------------------------------------------------------- +// Softmax Temperature Sampling Tests +// --------------------------------------------------------------------------- + +// makeSoftmaxCandidates builds N candidates with the given scores. +func makeSoftmaxCandidates(scores ...float64) []openAIAccountCandidateScore { + out := make([]openAIAccountCandidateScore, len(scores)) + for i, s := range scores { + out[i] = openAIAccountCandidateScore{ + account: &Account{ID: int64(i + 1), Priority: 0}, + loadInfo: &AccountLoadInfo{AccountID: int64(i + 1)}, + score: s, + } + } + return out +} + +func TestSoftmax_LowTemperatureApproximatesArgmax(t *testing.T) { + // With a very low temperature the highest-scored candidate should win + // almost every time. + candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5) + + winCount := 0 + trials := 100 + for i := 0; i < trials; i++ { + rng := newOpenAISelectionRNG(uint64(i + 1)) + result := selectSoftmaxOpenAICandidates(candidates, 0.01, &rng) + require.Len(t, result, len(candidates)) + if result[0].account.ID == 1 { // ID 1 has score 5.0 (the highest) + winCount++ + } + } + + require.Greater(t, winCount, 90, + "with temperature=0.01 the highest-scored candidate should win >90%% of trials; got %d/%d", winCount, trials) +} + +func TestSoftmax_HighTemperatureApproximatesUniform(t *testing.T) { + // With a very high temperature, all candidates should get roughly equal + // selection frequency. + candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5) + + counts := map[int64]int{} + trials := 1000 + for i := 0; i < trials; i++ { + rng := newOpenAISelectionRNG(uint64(i + 1)) + result := selectSoftmaxOpenAICandidates(candidates, 100.0, &rng) + require.Len(t, result, len(candidates)) + counts[result[0].account.ID]++ + } + + expected := float64(trials) / float64(len(candidates)) // 250 + for id, count := range counts { + require.InDelta(t, expected, float64(count), float64(trials)*0.10, + "candidate ID=%d expected ~%.0f selections, got %d", id, expected, count) + } +} + +func TestSoftmax_DefaultTemperature(t *testing.T) { + // With the default temperature (0.3), higher-scored candidates should be + // picked more often than lower-scored ones. + candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5) + + counts := map[int64]int{} + trials := 1000 + for i := 0; i < trials; i++ { + rng := newOpenAISelectionRNG(uint64(i + 1)) + result := selectSoftmaxOpenAICandidates(candidates, defaultSoftmaxTemperature, &rng) + counts[result[0].account.ID]++ + } + + // The candidate with the highest score (ID=1, score=5.0) should be + // selected more often than the candidate with the lowest (ID=4, score=0.5). + require.Greater(t, counts[int64(1)], counts[int64(4)], + "highest-scored candidate should be picked more often; best=%d worst=%d", + counts[int64(1)], counts[int64(4)]) + + // Also check that the top-scored candidate beats the second-highest. + require.Greater(t, counts[int64(1)], counts[int64(2)], + "score=5.0 should beat score=3.0; got %d vs %d", + counts[int64(1)], counts[int64(2)]) +} + +func TestSoftmax_SingleCandidate(t *testing.T) { + candidates := makeSoftmaxCandidates(7.5) + + rng := newOpenAISelectionRNG(42) + result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng) + + require.Len(t, result, 1) + require.Equal(t, int64(1), result[0].account.ID) + require.Equal(t, 7.5, result[0].score) +} + +func TestSoftmax_TwoCandidates(t *testing.T) { + // Use a moderate score gap (1.0 vs 0.5) with temperature=1.0 so both + // candidates have meaningful selection probability. + candidates := makeSoftmaxCandidates(1.0, 0.5) + + counts := map[int64]int{} + trials := 1000 + for i := 0; i < trials; i++ { + rng := newOpenAISelectionRNG(uint64(i + 1)) + result := selectSoftmaxOpenAICandidates(candidates, 1.0, &rng) + require.Len(t, result, 2) + counts[result[0].account.ID]++ + } + + // Both candidates should be selected at least once (proving no + // single-candidate monopoly), and the higher-scored one should dominate. + require.Greater(t, counts[int64(1)], 0, "high-scored candidate must be selected at least once") + require.Greater(t, counts[int64(2)], 0, "low-scored candidate must be selected at least once") + require.Greater(t, counts[int64(1)], counts[int64(2)], + "higher-scored candidate should be picked more often; got %d vs %d", + counts[int64(1)], counts[int64(2)]) +} + +func TestSoftmax_EqualScores(t *testing.T) { + // When all scores are equal, selection should be approximately uniform. + candidates := makeSoftmaxCandidates(3.0, 3.0, 3.0, 3.0) + + counts := map[int64]int{} + trials := 1000 + for i := 0; i < trials; i++ { + rng := newOpenAISelectionRNG(uint64(i + 1)) + result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng) + counts[result[0].account.ID]++ + } + + expected := float64(trials) / float64(len(candidates)) // 250 + for id, count := range counts { + require.InDelta(t, expected, float64(count), float64(trials)*0.10, + "equal scores should yield ~uniform distribution; ID=%d expected ~%.0f got %d", + id, expected, count) + } +} + +func TestSoftmax_NumericalStability(t *testing.T) { + // Large score differences should not cause overflow or NaN. + candidates := makeSoftmaxCandidates(100.0, -100.0, 50.0, -50.0) + + rng := newOpenAISelectionRNG(12345) + result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng) + + require.Len(t, result, len(candidates)) + // Verify all scores are finite in the output (no NaN or Inf propagation). + for _, c := range result { + require.False(t, math.IsNaN(c.score), "score should not be NaN") + require.False(t, math.IsInf(c.score, 0), "score should not be Inf") + } + // All candidates must appear exactly once. + seen := map[int64]bool{} + for _, c := range result { + require.False(t, seen[c.account.ID], "duplicate account ID %d", c.account.ID) + seen[c.account.ID] = true + } + require.Len(t, seen, len(candidates)) + + // With such extreme differences at temperature=0.3, the highest scorer (100.0) + // should always win because exp((100 - 100)/0.3) = 1 while + // exp((-100 - 100)/0.3) ~= 0 (numerically stable via maxScore subtraction). + winCount := 0 + for i := 0; i < 100; i++ { + rng2 := newOpenAISelectionRNG(uint64(i + 1)) + r := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng2) + if r[0].account.ID == 1 { // score 100.0 + winCount++ + } + } + require.Greater(t, winCount, 95, + "with extreme score gap, the highest scorer should win nearly always; got %d/100", winCount) +} + +func TestSoftmax_DisabledFallsThrough(t *testing.T) { + // When softmax is disabled, the scheduler should fall through to P2C or Top-K. + ctx := context.Background() + groupID := int64(40001) + accounts := make([]Account, 5) + for i := range accounts { + accounts[i] = Account{ + ID: int64(6001 + i), + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + } + + cache := &stubGatewayCache{} + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + // Softmax explicitly disabled (default). + cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = false + // P2C also disabled. + cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", "softmax_disabled_test", + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + // TopK should be > 0, confirming Top-K path was taken. + require.Greater(t, decision.TopK, 0, "should fall through to Top-K when softmax is disabled") + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestSoftmax_FewCandidatesFallsThrough(t *testing.T) { + // When softmax is enabled but there are <= 3 candidates, it should fall + // through to the next strategy (P2C or Top-K). + ctx := context.Background() + groupID := int64(40002) + // Only 3 accounts — softmax guard requires >3. + accounts := make([]Account, 3) + for i := range accounts { + accounts[i] = Account{ + ID: int64(7001 + i), + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 5, + } + } + + cache := &stubGatewayCache{} + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + // Softmax enabled but should not activate with only 3 candidates. + cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = true + cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0.5 + // P2C disabled, so it should fall through to Top-K. + cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, &groupID, "", "softmax_few_candidates_test", + "gpt-5.1", nil, OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + // TopK should be > 0, confirming Top-K path was taken instead of softmax. + require.Greater(t, decision.TopK, 0, "should fall through to Top-K when candidate count <= 3") + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestSoftmax_ConfigDefaults(t *testing.T) { + // When config values are zero/unset, defaults should be applied. + + // Case 1: nil service — returns empty config. + nilScheduler := &defaultOpenAIAccountScheduler{} + cfg0 := nilScheduler.softmaxConfigRead() + require.False(t, cfg0.enabled) + require.Equal(t, 0.0, cfg0.temperature) // no default when service is nil + + // Case 2: zero temperature (unset) — should default to 0.3. + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + } + svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = true + svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0 // unset + scheduler := &defaultOpenAIAccountScheduler{ + service: svc, + stats: newOpenAIAccountRuntimeStats(), + } + cfg1 := scheduler.softmaxConfigRead() + require.True(t, cfg1.enabled) + require.Equal(t, 0.3, cfg1.temperature, "default temperature should be 0.3") + + // Case 3: explicit temperature — should use the configured value. + svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0.7 + cfg2 := scheduler.softmaxConfigRead() + require.True(t, cfg2.enabled) + require.Equal(t, 0.7, cfg2.temperature, "should use explicitly configured temperature") + + // Case 4: negative temperature — should fall back to default 0.3. + svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = -1.0 + cfg3 := scheduler.softmaxConfigRead() + require.Equal(t, 0.3, cfg3.temperature, "negative temperature should fall back to default 0.3") +} + +// --------------------------------------------------------------------------- +// Per-Model TTFT Tests +// --------------------------------------------------------------------------- + +func TestPerModelTTFT_IndependentTracking(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8001) + + // Report different TTFT values for two different models on the same account. + stats.report(accountID, true, nil, "gpt-4o", 100) + stats.report(accountID, true, nil, "gpt-4o", 120) + stats.report(accountID, true, nil, "o3-pro", 500) + stats.report(accountID, true, nil, "o3-pro", 600) + + // Snapshot for model gpt-4o. + _, ttftGPT4o, hasTTFT := stats.snapshot(accountID, "gpt-4o") + require.True(t, hasTTFT, "gpt-4o should have TTFT data") + + // Snapshot for model o3-pro. + _, ttftO3Pro, hasO3Pro := stats.snapshot(accountID, "o3-pro") + require.True(t, hasO3Pro, "o3-pro should have TTFT data") + + // The two models should have different TTFT values because their + // sample inputs are very different (100-120 vs 500-600). + require.Greater(t, math.Abs(ttftGPT4o-ttftO3Pro), 50.0, + "different models should track independent TTFT values") + + // gpt-4o TTFT should be much lower than o3-pro. + require.Less(t, ttftGPT4o, ttftO3Pro, + "gpt-4o should have lower TTFT than o3-pro") +} + +func TestPerModelTTFT_FallbackToGlobal(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8002) + + // Report TTFT for a specific model. + stats.report(accountID, true, nil, "gpt-4o", 200) + + // Snapshot with an unknown model should fall back to global TTFT. + _, ttftUnknown, hasUnknown := stats.snapshot(accountID, "unknown-model") + require.True(t, hasUnknown, "should fall back to global TTFT") + + // Global TTFT should have been updated by the gpt-4o report. + _, ttftGlobal, hasGlobal := stats.snapshot(accountID) + require.True(t, hasGlobal, "global TTFT should exist") + + // The unknown-model fallback should equal the global. + require.InDelta(t, ttftGlobal, ttftUnknown, 1e-9, + "unknown model should return global TTFT as fallback") +} + +func TestPerModelTTFT_GlobalAlsoUpdated(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8003) + + // No data initially. + _, _, hasGlobal := stats.snapshot(accountID) + require.False(t, hasGlobal, "no global TTFT initially") + + // Report with model — should also update global. + stats.report(accountID, true, nil, "gpt-4o", 300) + + _, ttftGlobal, hasGlobal := stats.snapshot(accountID) + require.True(t, hasGlobal, "global TTFT should exist after model report") + require.InDelta(t, 300.0, ttftGlobal, 1e-9, + "global TTFT should be updated by model report") +} + +func TestPerModelTTFT_TTLCleanup(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + stat.ttft.initNaN() + + // Manually insert a model entry with an old timestamp. + d := &dualEWMATTFT{} + d.initNaN() + d.update(100) + stat.modelTTFT.Store("old-model", d) + stat.modelTTFTLastUpdate.Store("old-model", time.Now().Add(-time.Hour).UnixNano()) + + // Insert a recent model entry. + d2 := &dualEWMATTFT{} + d2.initNaN() + d2.update(200) + stat.modelTTFT.Store("new-model", d2) + stat.modelTTFTLastUpdate.Store("new-model", time.Now().UnixNano()) + + // Cleanup with 30-minute TTL — old-model should be removed. + stat.cleanupStaleTTFT(30*time.Minute, 100) + + _, hasOld := stat.modelTTFTValue("old-model") + require.False(t, hasOld, "old-model should be cleaned up") + + _, hasNew := stat.modelTTFTValue("new-model") + require.True(t, hasNew, "new-model should survive cleanup") +} + +func TestPerModelTTFT_MaxModelLimit(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + stat.ttft.initNaN() + + now := time.Now() + // Insert 10 models with sequential timestamps. + for i := 0; i < 10; i++ { + model := fmt.Sprintf("model-%d", i) + d := &dualEWMATTFT{} + d.initNaN() + d.update(float64(100 + i*10)) + stat.modelTTFT.Store(model, d) + stat.modelTTFTLastUpdate.Store(model, now.Add(time.Duration(i)*time.Second).UnixNano()) + } + + // Enforce limit of 5 models — the 5 oldest should be evicted. + stat.cleanupStaleTTFT(time.Hour, 5) + + // Count remaining models. + remaining := 0 + stat.modelTTFT.Range(func(_, _ any) bool { + remaining++ + return true + }) + require.Equal(t, 5, remaining, "should have exactly 5 models after cleanup") + + // The newest 5 (model-5 through model-9) should survive. + for i := 5; i < 10; i++ { + model := fmt.Sprintf("model-%d", i) + _, has := stat.modelTTFTValue(model) + require.True(t, has, "%s should survive", model) + } + // The oldest 5 (model-0 through model-4) should be evicted. + for i := 0; i < 5; i++ { + model := fmt.Sprintf("model-%d", i) + _, has := stat.modelTTFTValue(model) + require.False(t, has, "%s should be evicted", model) + } +} + +func TestPerModelTTFT_SnapshotUsesModelData(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8006) + + // Report two models with very different TTFT. + for i := 0; i < 5; i++ { + stats.report(accountID, true, nil, "fast-model", 50) + stats.report(accountID, true, nil, "slow-model", 500) + } + + // Snapshot with specific model should return that model's TTFT. + _, ttftFast, hasFast := stats.snapshot(accountID, "fast-model") + require.True(t, hasFast) + + _, ttftSlow, hasSlow := stats.snapshot(accountID, "slow-model") + require.True(t, hasSlow) + + // Fast model should have much lower TTFT. + require.Less(t, ttftFast, 100.0, "fast-model TTFT should be close to 50") + require.Greater(t, ttftSlow, 400.0, "slow-model TTFT should be close to 500") + + // Global TTFT should be a blend of both. + _, ttftGlobal, hasGlobal := stats.snapshot(accountID) + require.True(t, hasGlobal) + require.Greater(t, ttftGlobal, ttftFast, "global TTFT should be higher than fast-model") + require.Less(t, ttftGlobal, ttftSlow, "global TTFT should be lower than slow-model") +} + +func TestPerModelTTFT_ConcurrentAccess(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8007) + + const workers = 8 + const iterations = 200 + models := []string{"model-a", "model-b", "model-c", "model-d"} + + var wg sync.WaitGroup + wg.Add(workers) + for w := 0; w < workers; w++ { + w := w + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + model := models[(w+i)%len(models)] + ttft := float64(100 + (w*10+i)%200) + stats.report(accountID, true, nil, model, ttft) + + // Also read concurrently. + stats.snapshot(accountID, model) + stats.snapshot(accountID) + } + }() + } + wg.Wait() + + // All models should have TTFT data. + for _, model := range models { + _, ttft, has := stats.snapshot(accountID, model) + require.True(t, has, "%s should have TTFT", model) + require.Greater(t, ttft, 0.0, "%s TTFT should be positive", model) + } + + // Global should also have data. + _, ttftGlobal, hasGlobal := stats.snapshot(accountID) + require.True(t, hasGlobal) + require.Greater(t, ttftGlobal, 0.0) +} + +func TestPerModelTTFT_EmptyModel(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + accountID := int64(8008) + + // Report with empty model — should only update global TTFT. + ttft := 150 + stats.report(accountID, true, &ttft, "", 0) + + // Global should have data. + _, ttftGlobal, hasGlobal := stats.snapshot(accountID) + require.True(t, hasGlobal, "global TTFT should exist from firstTokenMs") + require.InDelta(t, 150.0, ttftGlobal, 1e-9) + + // Snapshot with empty model returns global. + _, ttftEmpty, hasEmpty := stats.snapshot(accountID, "") + require.True(t, hasEmpty) + require.InDelta(t, ttftGlobal, ttftEmpty, 1e-9, + "empty model snapshot should return global TTFT") + + // No per-model entries should exist. + stat := stats.loadOrCreate(accountID) + count := 0 + stat.modelTTFT.Range(func(_, _ any) bool { + count++ + return true + }) + require.Equal(t, 0, count, "no per-model entries should exist for empty model") +} + +// --------------------------------------------------------------------------- +// Load Trend Prediction Tests +// --------------------------------------------------------------------------- + +func TestLoadTrend_RisingLoad(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 10; i++ { + tracker.recordAt(float64((i+1)*10), base+int64(i)*int64(time.Second)) + } + slope := tracker.slope() + require.Greater(t, slope, 0.0, "rising load should produce positive slope") + require.InDelta(t, 10.0, slope, 0.01, "slope should be ~10 per second") +} + +func TestLoadTrend_FallingLoad(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 10; i++ { + tracker.recordAt(float64(100-i*10), base+int64(i)*int64(time.Second)) + } + slope := tracker.slope() + require.Less(t, slope, 0.0, "falling load should produce negative slope") + require.InDelta(t, -10.0, slope, 0.01, "slope should be ~-10 per second") +} + +func TestLoadTrend_StableLoad(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 10; i++ { + tracker.recordAt(50.0, base+int64(i)*int64(time.Second)) + } + slope := tracker.slope() + require.InDelta(t, 0.0, slope, 1e-9, "constant load should produce zero slope") +} + +func TestLoadTrend_RingBufferFull(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 5; i++ { + tracker.recordAt(100.0, base+int64(i)*int64(time.Second)) + } + for i := 0; i < 10; i++ { + tracker.recordAt(float64((i+1)*10), base+int64(5+i)*int64(time.Second)) + } + slope := tracker.slope() + require.Greater(t, slope, 0.0, "should reflect rising trend from last 10 samples") + require.InDelta(t, 10.0, slope, 0.01, "slope should be ~10 per second after ring wraps") +} + +func TestLoadTrend_InsufficientSamples(t *testing.T) { + var tracker loadTrendTracker + slope := tracker.slope() + require.Equal(t, 0.0, slope, "zero samples should return slope 0") +} + +func TestLoadTrend_SingleSample(t *testing.T) { + var tracker loadTrendTracker + tracker.recordAt(42.0, time.Now().UnixNano()) + slope := tracker.slope() + require.Equal(t, 0.0, slope, "single sample should return slope 0") +} + +func TestLoadTrend_TwoSamples(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + tracker.recordAt(10.0, base) + tracker.recordAt(30.0, base+int64(2*time.Second)) + slope := tracker.slope() + require.InDelta(t, 10.0, slope, 0.01, "two-sample slope should be exact delta/time") +} + +func TestLoadTrend_AllSameTimestamp(t *testing.T) { + var tracker loadTrendTracker + ts := time.Now().UnixNano() + for i := 0; i < 5; i++ { + tracker.recordAt(float64(i*10), ts) + } + slope := tracker.slope() + require.Equal(t, 0.0, slope, "all-same-timestamp should return slope 0 (degenerate)") +} + +func TestLoadTrend_NegativeSlope(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 5; i++ { + tracker.recordAt(50.0-float64(i)*5.0, base+int64(i)*int64(time.Second)) + } + slope := tracker.slope() + require.InDelta(t, -5.0, slope, 0.01) +} + +func TestLoadTrend_ScoringIntegration(t *testing.T) { + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + } + + loadMap := map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 50}, + 2: {AccountID: 2, LoadRate: 50}, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true + cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 5.0 + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{ + loadMap: loadMap, + skipDefaultLoad: true, + }), + } + + stats := newOpenAIAccountRuntimeStats() + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + base := time.Now().UnixNano() - int64(10*time.Second) + stat1 := stats.loadOrCreate(1) + for i := 0; i < 9; i++ { + stat1.loadTrend.recordAt(float64((i+1)*10), base+int64(i)*int64(time.Second)) + } + stat2 := stats.loadOrCreate(2) + for i := 0; i < 9; i++ { + stat2.loadTrend.recordAt(50.0, base+int64(i)*int64(time.Second)) + } + + ctx := context.Background() + selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{ + RequiredTransport: OpenAIUpstreamTransportAny, + }) + require.NoError(t, err) + require.NotNil(t, selection) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + + slope1 := stat1.loadTrend.slope() + slope2 := stat2.loadTrend.slope() + require.Greater(t, slope1, slope2, + "rising-trend account should have higher slope than stable account; slope1=%f slope2=%f", slope1, slope2) + require.Greater(t, slope1, 0.0, "rising-trend account slope should be positive") + require.InDelta(t, 0.0, slope2, 1.0, + "stable-trend account slope should be near zero; got %f", slope2) + + trendAdj1 := 1.0 - clamp01(slope1/5.0) + trendAdj2 := 1.0 - clamp01(slope2/5.0) + require.Less(t, trendAdj1, trendAdj2, + "rising-trend trendAdj should be less than stable trendAdj; adj1=%f adj2=%f", trendAdj1, trendAdj2) +} + +func TestLoadTrend_DisabledByDefault(t *testing.T) { + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + stats := newOpenAIAccountRuntimeStats() + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + base := time.Now().UnixNano() + stat1 := stats.loadOrCreate(1) + for i := 0; i < 10; i++ { + stat1.loadTrend.recordAt(float64(i*10), base+int64(i)*int64(time.Second)) + } + stat2 := stats.loadOrCreate(2) + for i := 0; i < 10; i++ { + stat2.loadTrend.recordAt(30.0, base+int64(i)*int64(time.Second)) + } + + ctx := context.Background() + selectedCounts := map[int64]int{} + const rounds = 100 + for r := 0; r < rounds; r++ { + selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{ + RequiredTransport: OpenAIUpstreamTransportAny, + }) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + selectedCounts[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + require.Greater(t, selectedCounts[int64(1)], 10, + "with trend disabled, account 1 should still get selections; got %d", selectedCounts[int64(1)]) + require.Greater(t, selectedCounts[int64(2)], 10, + "with trend disabled, account 2 should still get selections; got %d", selectedCounts[int64(2)]) +} + +func TestLoadTrend_ConcurrentAccess(t *testing.T) { + var tracker loadTrendTracker + var wg sync.WaitGroup + const goroutines = 10 + const recordsPerGoroutine = 100 + + for g := 0; g < goroutines; g++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for i := 0; i < recordsPerGoroutine; i++ { + tracker.record(float64(id*100 + i)) + } + }(g) + } + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < goroutines*recordsPerGoroutine; i++ { + _ = tracker.slope() + } + }() + + wg.Wait() + + slope := tracker.slope() + require.False(t, math.IsNaN(slope), "slope should be finite after concurrent access") + require.False(t, math.IsInf(slope, 0), "slope should be finite after concurrent access") +} + +func TestLoadTrend_TrendConfigDefaults(t *testing.T) { + svc := &OpenAIGatewayService{} + enabled, maxSlope := svc.openAIWSSchedulerTrendConfig() + require.False(t, enabled, "trend should be disabled by default") + require.Equal(t, defaultSchedulerTrendMaxSlope, maxSlope, "maxSlope should use default") +} + +func TestLoadTrend_TrendConfigCustom(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true + cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 8.0 + svc := &OpenAIGatewayService{cfg: cfg} + enabled, maxSlope := svc.openAIWSSchedulerTrendConfig() + require.True(t, enabled) + require.Equal(t, 8.0, maxSlope) +} + +func TestLoadTrend_TrendConfigZeroMaxSlope(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true + cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 0 + svc := &OpenAIGatewayService{cfg: cfg} + enabled, maxSlope := svc.openAIWSSchedulerTrendConfig() + require.True(t, enabled) + require.Equal(t, defaultSchedulerTrendMaxSlope, maxSlope) +} + +func TestLoadTrend_RingBufferCountTracking(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + + for i := 0; i < loadTrendRingSize; i++ { + tracker.recordAt(float64(i), base+int64(i)*int64(time.Second)) + } + require.Equal(t, loadTrendRingSize, tracker.count, "count should equal ring size after filling") + + tracker.recordAt(99.0, base+int64(loadTrendRingSize)*int64(time.Second)) + require.Equal(t, loadTrendRingSize, tracker.count, "count should remain capped at ring size") +} + +func TestLoadTrend_GentleRise(t *testing.T) { + var tracker loadTrendTracker + base := time.Now().UnixNano() + for i := 0; i < 10; i++ { + tracker.recordAt(50.0+float64(i)*0.1, base+int64(i)*int64(time.Second)) + } + slope := tracker.slope() + require.Greater(t, slope, 0.0, "gentle rise should produce positive slope") + require.InDelta(t, 0.1, slope, 0.01) +} + +func TestLoadTrend_RecordUpdatesRuntimeStat(t *testing.T) { + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + stats := newOpenAIAccountRuntimeStats() + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + ctx := context.Background() + for i := 0; i < 5; i++ { + selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{ + RequiredTransport: OpenAIUpstreamTransportAny, + }) + require.NoError(t, err) + if selection != nil && selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + stat := stats.loadOrCreate(1) + require.GreaterOrEqual(t, stat.loadTrend.count, 5, + "trend tracker should have received samples from scoring loop") +} + +func TestLoadTrend_FallingTrendBoostsScore(t *testing.T) { + accounts := []Account{ + {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10}, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true + cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 5.0 + + loadMap := map[int64]*AccountLoadInfo{ + 1: {AccountID: 1, LoadRate: 50}, + 2: {AccountID: 2, LoadRate: 50}, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{ + loadMap: loadMap, + skipDefaultLoad: true, + }), + } + + stats := newOpenAIAccountRuntimeStats() + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + base := time.Now().UnixNano() + stat1 := stats.loadOrCreate(1) + for i := 0; i < 10; i++ { + stat1.loadTrend.recordAt(80.0-float64(i)*5.0, base+int64(i)*int64(time.Second)) + } + stat2 := stats.loadOrCreate(2) + for i := 0; i < 10; i++ { + stat2.loadTrend.recordAt(50.0, base+int64(i)*int64(time.Second)) + } + + ctx := context.Background() + selectedCounts := map[int64]int{} + const rounds = 50 + for r := 0; r < rounds; r++ { + selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{ + RequiredTransport: OpenAIUpstreamTransportAny, + }) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + selectedCounts[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + require.Greater(t, selectedCounts[int64(1)]+selectedCounts[int64(2)], 0, + "both accounts should receive selections") +} + +// --------------------------------------------------------------------------- +// Circuit Breaker Coverage Tests +// --------------------------------------------------------------------------- + +func TestCircuitBreaker_AllowClosed(t *testing.T) { + cb := &accountCircuitBreaker{} + require.True(t, cb.allow(30*time.Second, 2), "CLOSED state should allow requests") +} + +func TestCircuitBreaker_AllowOpenWithinCooldown(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateOpen) + cb.lastFailureNano.Store(time.Now().UnixNano()) + require.False(t, cb.allow(30*time.Second, 2), "OPEN within cooldown should deny") +} + +func TestCircuitBreaker_AllowOpenAfterCooldown(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateOpen) + cb.lastFailureNano.Store(time.Now().Add(-1 * time.Minute).UnixNano()) + require.True(t, cb.allow(30*time.Second, 2), "OPEN after cooldown should transition to HALF_OPEN and allow") + require.Equal(t, circuitBreakerStateHalfOpen, cb.state.Load()) +} + +func TestCircuitBreaker_AllowHalfOpenLimited(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateHalfOpen) + require.True(t, cb.allowHalfOpen(2)) + require.True(t, cb.allowHalfOpen(2)) + require.False(t, cb.allowHalfOpen(2), "should deny when in-flight reaches max") +} + +func TestCircuitBreaker_AllowHalfOpenViaAllow(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateHalfOpen) + require.True(t, cb.allow(30*time.Second, 1)) + require.False(t, cb.allow(30*time.Second, 1), "HALF_OPEN with max=1 should deny second") +} + +func TestCircuitBreaker_AllowDefaultState(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(99) // unknown state + require.True(t, cb.allow(30*time.Second, 2), "unknown state should default to allow") +} + +func TestCircuitBreaker_RecordSuccessClosed(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.consecutiveFails.Store(3) + cb.recordSuccess() + require.Equal(t, int32(0), cb.consecutiveFails.Load()) + require.Equal(t, circuitBreakerStateClosed, cb.state.Load()) +} + +func TestCircuitBreaker_RecordSuccessHalfOpenToClose(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateHalfOpen) + cb.halfOpenInFlight.Store(1) + cb.halfOpenAdmitted.Store(1) + cb.halfOpenSuccess.Store(0) + cb.recordSuccess() + require.Equal(t, circuitBreakerStateClosed, cb.state.Load(), "all probes succeeded should close") + require.Equal(t, int32(0), cb.halfOpenInFlight.Load()) +} + +func TestCircuitBreaker_RecordSuccessHalfOpenPartial(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateHalfOpen) + cb.halfOpenInFlight.Store(3) + cb.halfOpenAdmitted.Store(3) + cb.halfOpenSuccess.Store(0) + cb.recordSuccess() + require.Equal(t, circuitBreakerStateHalfOpen, cb.state.Load(), "not all probes succeeded yet") +} + +func TestCircuitBreaker_RecordFailureTripsOpen(t *testing.T) { + cb := &accountCircuitBreaker{} + for i := 0; i < 4; i++ { + cb.recordFailure(5) + } + require.Equal(t, circuitBreakerStateClosed, cb.state.Load()) + cb.recordFailure(5) + require.Equal(t, circuitBreakerStateOpen, cb.state.Load(), "5th failure should trip to OPEN") +} + +func TestCircuitBreaker_RecordFailureHalfOpenReverts(t *testing.T) { + cb := &accountCircuitBreaker{} + cb.state.Store(circuitBreakerStateHalfOpen) + cb.halfOpenInFlight.Store(2) + cb.recordFailure(5) + require.Equal(t, circuitBreakerStateOpen, cb.state.Load(), "failure in HALF_OPEN should revert to OPEN") + require.Equal(t, int32(0), cb.halfOpenInFlight.Load()) +} + +func TestCircuitBreaker_IsHalfOpen(t *testing.T) { + var cb *accountCircuitBreaker + require.False(t, cb.isHalfOpen(), "nil should return false") + + cb = &accountCircuitBreaker{} + require.False(t, cb.isHalfOpen()) + cb.state.Store(circuitBreakerStateHalfOpen) + require.True(t, cb.isHalfOpen()) +} + +func TestCircuitBreaker_ReleaseHalfOpenPermit(t *testing.T) { + var cb *accountCircuitBreaker + cb.releaseHalfOpenPermit() // should not panic + + cb = &accountCircuitBreaker{} + cb.releaseHalfOpenPermit() // not in HALF_OPEN, should be no-op + + cb.state.Store(circuitBreakerStateHalfOpen) + cb.halfOpenInFlight.Store(2) + cb.releaseHalfOpenPermit() + require.Equal(t, int32(1), cb.halfOpenInFlight.Load()) + + cb.halfOpenInFlight.Store(0) + cb.releaseHalfOpenPermit() // already at 0, should be no-op + require.Equal(t, int32(0), cb.halfOpenInFlight.Load()) +} + +func TestCircuitBreaker_StateString(t *testing.T) { + cb := &accountCircuitBreaker{} + require.Equal(t, "CLOSED", cb.stateString()) + cb.state.Store(circuitBreakerStateOpen) + require.Equal(t, "OPEN", cb.stateString()) + cb.state.Store(circuitBreakerStateHalfOpen) + require.Equal(t, "HALF_OPEN", cb.stateString()) + cb.state.Store(99) + require.Equal(t, "UNKNOWN", cb.stateString()) +} + +func TestCircuitBreaker_IsOpen(t *testing.T) { + cb := &accountCircuitBreaker{} + require.False(t, cb.isOpen()) + cb.state.Store(circuitBreakerStateOpen) + require.True(t, cb.isOpen()) +} + +func TestCircuitBreaker_FullLifecycle(t *testing.T) { + cb := &accountCircuitBreaker{} + threshold := 3 + cooldown := 50 * time.Millisecond + + // CLOSED: allow requests + require.True(t, cb.allow(cooldown, 2)) + require.Equal(t, "CLOSED", cb.stateString()) + + // Trip to OPEN + for i := 0; i < threshold; i++ { + cb.recordFailure(threshold) + } + require.Equal(t, "OPEN", cb.stateString()) + require.False(t, cb.allow(cooldown, 2), "should deny in OPEN within cooldown") + + // Wait for cooldown + time.Sleep(cooldown + 10*time.Millisecond) + + // Should transition to HALF_OPEN + require.True(t, cb.allow(cooldown, 2)) + require.Equal(t, "HALF_OPEN", cb.stateString()) + + // Success should close + cb.recordSuccess() + require.Equal(t, "CLOSED", cb.stateString()) +} + +// --------------------------------------------------------------------------- +// dualEWMATTFT Coverage Tests +// --------------------------------------------------------------------------- + +func TestDualEWMATTFT_InitNaN(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + require.True(t, math.IsNaN(d.fastValue())) + require.True(t, math.IsNaN(d.slowValue())) + _, hasTTFT := d.value() + require.False(t, hasTTFT, "NaN-initialized should return hasTTFT=false") +} + +func TestDualEWMATTFT_UpdateFromNaN(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + d.update(100.0) + v, ok := d.value() + require.True(t, ok) + require.InDelta(t, 100.0, v, 0.01, "first update should set sample directly") +} + +func TestDualEWMATTFT_UpdateMultiple(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + for i := 0; i < 20; i++ { + d.update(200.0) + } + v, ok := d.value() + require.True(t, ok) + require.InDelta(t, 200.0, v, 1.0, "after many updates of same value, should converge") +} + +func TestDualEWMATTFT_ValueFastOnly(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + // Set fast only, slow stays NaN + d.fastBits.Store(math.Float64bits(42.0)) + v, ok := d.value() + require.True(t, ok) + require.Equal(t, 42.0, v) +} + +func TestDualEWMATTFT_ValueSlowOnly(t *testing.T) { + var d dualEWMATTFT + d.initNaN() + // Set slow only, fast stays NaN + d.slowBits.Store(math.Float64bits(55.0)) + v, ok := d.value() + require.True(t, ok) + require.Equal(t, 55.0, v) +} + +func TestDualEWMATTFT_ValueSlowGreaterThanFast(t *testing.T) { + var d dualEWMATTFT + d.fastBits.Store(math.Float64bits(30.0)) + d.slowBits.Store(math.Float64bits(50.0)) + v, ok := d.value() + require.True(t, ok) + require.Equal(t, 50.0, v, "pessimistic value should return max(fast, slow)") +} + +func TestDualEWMATTFT_ValueFastGreaterThanSlow(t *testing.T) { + var d dualEWMATTFT + d.fastBits.Store(math.Float64bits(80.0)) + d.slowBits.Store(math.Float64bits(50.0)) + v, ok := d.value() + require.True(t, ok) + require.Equal(t, 80.0, v, "pessimistic value should return max(fast, slow)") +} + +// --------------------------------------------------------------------------- +// Softmax Additional Coverage Tests +// --------------------------------------------------------------------------- + +func TestSoftmax_EmptyCandidates(t *testing.T) { + rng := newOpenAISelectionRNG(42) + result := selectSoftmaxOpenAICandidates(nil, 0.3, &rng) + require.Nil(t, result) +} + +func TestSoftmax_ZeroTemperatureFallsToDefault(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1}, score: 0.9}, + {account: &Account{ID: 2}, score: 0.1}, + {account: &Account{ID: 3}, score: 0.5}, + {account: &Account{ID: 4}, score: 0.3}, + } + rng := newOpenAISelectionRNG(42) + result := selectSoftmaxOpenAICandidates(candidates, 0, &rng) + require.Len(t, result, 4) +} + +func TestSoftmax_NaNScoresUniform(t *testing.T) { + // Extreme negative scores that cause exp() to return 0 → uniform fallback + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1}, score: -1e308}, + {account: &Account{ID: 2}, score: -1e308}, + {account: &Account{ID: 3}, score: -1e308}, + {account: &Account{ID: 4}, score: -1e308}, + } + rng := newOpenAISelectionRNG(42) + result := selectSoftmaxOpenAICandidates(candidates, 0.001, &rng) + require.Len(t, result, 4) +} + +// --------------------------------------------------------------------------- +// Snapshot / Stats Coverage Tests +// --------------------------------------------------------------------------- + +func TestSnapshot_NilStats(t *testing.T) { + var s *openAIAccountRuntimeStats + errorRate, ttft, hasTTFT := s.snapshot(1) + require.Equal(t, 0.0, errorRate) + require.Equal(t, 0.0, ttft) + require.False(t, hasTTFT) +} + +func TestSnapshot_InvalidAccountID(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + errorRate, ttft, hasTTFT := s.snapshot(0) + require.Equal(t, 0.0, errorRate) + require.Equal(t, 0.0, ttft) + require.False(t, hasTTFT) +} + +func TestSnapshot_UnknownAccount(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + errorRate, ttft, hasTTFT := s.snapshot(999) + require.Equal(t, 0.0, errorRate) + require.Equal(t, 0.0, ttft) + require.False(t, hasTTFT) +} + +func TestSnapshot_WithModelFallback(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + stat := s.loadOrCreate(1) + // Set global TTFT + stat.ttft.update(100.0) + // Snapshot with unknown model should fallback to global + _, ttft, hasTTFT := s.snapshot(1, "unknown-model") + require.True(t, hasTTFT) + require.InDelta(t, 100.0, ttft, 0.01) +} + +func TestSnapshot_WithModelSpecific(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + stat := s.loadOrCreate(1) + stat.reportModelTTFT("gpt-4", 200.0) + stat.ttft.update(50.0) // global is different + _, ttft, hasTTFT := s.snapshot(1, "gpt-4") + require.True(t, hasTTFT) + require.InDelta(t, 200.0, ttft, 0.01, "should use per-model TTFT") +} + +func TestStatsSize_Nil(t *testing.T) { + var s *openAIAccountRuntimeStats + require.Equal(t, 0, s.size()) +} + +func TestStatsSize_Empty(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + require.Equal(t, 0, s.size()) +} + +func TestStatsSize_WithAccounts(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + s.loadOrCreate(1) + s.loadOrCreate(2) + require.Equal(t, 2, s.size()) +} + +func TestLoadOrCreate_ConcurrentSameID(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + var wg sync.WaitGroup + results := make([]*openAIAccountRuntimeStat, 10) + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + results[idx] = s.loadOrCreate(1) + }(i) + } + wg.Wait() + // All should return the same pointer + for i := 1; i < 10; i++ { + require.Same(t, results[0], results[i], "concurrent loadOrCreate should return same stat") + } +} + +// --------------------------------------------------------------------------- +// modelTTFTValue / reportModelTTFT Coverage Tests +// --------------------------------------------------------------------------- + +func TestModelTTFTValue_EmptyModel(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + v, ok := stat.modelTTFTValue("") + require.False(t, ok) + require.Equal(t, 0.0, v) +} + +func TestModelTTFTValue_UnknownModel(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + v, ok := stat.modelTTFTValue("nonexistent") + require.False(t, ok) + require.Equal(t, 0.0, v) +} + +func TestReportModelTTFT_EmptyModel(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + stat.ttft.initNaN() + stat.reportModelTTFT("", 100.0) + // Should be no-op: global TTFT not updated + _, hasTTFT := stat.ttft.value() + require.False(t, hasTTFT) +} + +func TestReportModelTTFT_ZeroSample(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + stat.ttft.initNaN() + stat.reportModelTTFT("gpt-4", 0) + _, hasTTFT := stat.ttft.value() + require.False(t, hasTTFT) +} + +func TestReportModelTTFT_NegativeSample(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + stat.ttft.initNaN() + stat.reportModelTTFT("gpt-4", -10.0) + _, hasTTFT := stat.ttft.value() + require.False(t, hasTTFT) +} + +func TestGetOrCreateModelTTFT_ConcurrentSameModel(t *testing.T) { + stat := &openAIAccountRuntimeStat{} + var wg sync.WaitGroup + results := make([]*dualEWMATTFT, 10) + for i := 0; i < 10; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + results[idx] = stat.getOrCreateModelTTFT("gpt-4") + }(i) + } + wg.Wait() + for i := 1; i < 10; i++ { + require.Same(t, results[0], results[i]) + } +} + +// --------------------------------------------------------------------------- +// schedulerCircuitBreakerConfig Coverage Tests +// --------------------------------------------------------------------------- + +func TestSchedulerCircuitBreakerConfig_Defaults(t *testing.T) { + svc := &OpenAIGatewayService{} + stats := newOpenAIAccountRuntimeStats() + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + enabled, threshold, cooldown, halfOpenMax := scheduler.schedulerCircuitBreakerConfig() + require.False(t, enabled) + require.Equal(t, defaultCircuitBreakerFailThreshold, threshold) + require.Equal(t, time.Duration(defaultCircuitBreakerCooldownSec)*time.Second, cooldown) + require.Equal(t, defaultCircuitBreakerHalfOpenMax, halfOpenMax) +} + +func TestSchedulerCircuitBreakerConfig_Custom(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 10 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60 + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 5 + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + enabled, threshold, cooldown, halfOpenMax := scheduler.schedulerCircuitBreakerConfig() + require.True(t, enabled) + require.Equal(t, 10, threshold) + require.Equal(t, 60*time.Second, cooldown) + require.Equal(t, 5, halfOpenMax) +} + +func TestSchedulerPerModelTTFTConfig_Defaults(t *testing.T) { + svc := &OpenAIGatewayService{} + stats := newOpenAIAccountRuntimeStats() + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + enabled, maxModels := scheduler.schedulerPerModelTTFTConfig() + require.False(t, enabled) + require.Equal(t, defaultPerModelTTFTMaxModels, maxModels) +} + +func TestSchedulerPerModelTTFTConfig_Custom(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = true + cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTMaxModels = 64 + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + enabled, maxModels := scheduler.schedulerPerModelTTFTConfig() + require.True(t, enabled) + require.Equal(t, 64, maxModels) +} + +func TestReportResult_PerModelTTFTDisabled_NoPerModelTrackerCreated(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = false + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + ttft := 120 + + scheduler.ReportResult(7001, true, &ttft, "gpt-5.1", 120) + + stat := scheduler.stats.loadExisting(7001) + require.NotNil(t, stat) + count := 0 + stat.modelTTFT.Range(func(_, _ any) bool { + count++ + return true + }) + require.Equal(t, 0, count, "per-model ttft should remain disabled") + _, globalTTFT, hasTTFT := scheduler.stats.snapshot(7001) + require.True(t, hasTTFT) + require.InDelta(t, 120.0, globalTTFT, 0.01) +} + +func TestReportResult_PerModelTTFTMaxModels_UsesConfigLimit(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = true + cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTMaxModels = 2 + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats) + ttft := 100 + + models := []string{"gpt-5.1", "gpt-4o", "o3", "o4-mini"} + for i := 0; i < 200; i++ { + model := models[i%len(models)] + scheduler.ReportResult(7002, true, &ttft, model, float64(ttft+i)) + } + + stat := scheduler.stats.loadExisting(7002) + require.NotNil(t, stat) + count := 0 + stat.modelTTFT.Range(func(_, _ any) bool { + count++ + return true + }) + require.LessOrEqual(t, count, 2, "model tracker count should honor scheduler_per_model_ttft_max_models") +} + +// --------------------------------------------------------------------------- +// P2C Edge Case Coverage +// --------------------------------------------------------------------------- + +func TestSelectP2C_EmptyCandidates(t *testing.T) { + result := selectP2COpenAICandidates(nil, OpenAIAccountScheduleRequest{}) + require.Nil(t, result) +} + +func TestSelectP2C_SingleCandidate(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1}, score: 0.5}, + } + result := selectP2COpenAICandidates(candidates, OpenAIAccountScheduleRequest{}) + require.Len(t, result, 1) + require.Equal(t, int64(1), result[0].account.ID) +} + +func TestSelectP2C_TwoCandidatesPicksBetter(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1}, score: 0.2}, + {account: &Account{ID: 2}, score: 0.8}, + } + result := selectP2COpenAICandidates(candidates, OpenAIAccountScheduleRequest{}) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID, "first should be higher-scored") +} + +// --------------------------------------------------------------------------- +// TopK Selection & Heap Coverage +// --------------------------------------------------------------------------- + +func TestSelectTopK_Empty(t *testing.T) { + result := selectTopKOpenAICandidates(nil, 3) + require.Nil(t, result) +} + +func TestSelectTopK_TopKZero(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}}, + } + result := selectTopKOpenAICandidates(candidates, 0) + require.Len(t, result, 1, "topK=0 should default to 1") +} + +func TestSelectTopK_TopKExceedsCandidates(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1, Priority: 1}, score: 0.3, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, Priority: 1}, score: 0.9, loadInfo: &AccountLoadInfo{}}, + } + result := selectTopKOpenAICandidates(candidates, 10) + require.Len(t, result, 2) + require.Equal(t, int64(2), result[0].account.ID, "highest score first") +} + +func TestSelectTopK_ProperFiltering(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + {account: &Account{ID: 1, Priority: 1}, score: 0.1, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 2, Priority: 1}, score: 0.9, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 3, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 4, Priority: 1}, score: 0.3, loadInfo: &AccountLoadInfo{}}, + {account: &Account{ID: 5, Priority: 1}, score: 0.7, loadInfo: &AccountLoadInfo{}}, + } + result := selectTopKOpenAICandidates(candidates, 3) + require.Len(t, result, 3) + require.Equal(t, int64(2), result[0].account.ID) + require.Equal(t, int64(5), result[1].account.ID) + require.Equal(t, int64(3), result[2].account.ID) +} + +func TestIsOpenAIAccountCandidateBetter_AllTiebreakers(t *testing.T) { + // Equal scores, different priority + a := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}} + b := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 2}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}} + require.True(t, isOpenAIAccountCandidateBetter(a, b), "lower priority number = better") + require.False(t, isOpenAIAccountCandidateBetter(b, a)) + + // Equal scores and priority, different load rate + c := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 5}} + d := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 60, WaitingCount: 5}} + require.True(t, isOpenAIAccountCandidateBetter(c, d), "lower load rate = better") + + // Equal everything except waiting count + e := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2}} + f := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 8}} + require.True(t, isOpenAIAccountCandidateBetter(e, f), "lower waiting count = better") + + // Equal everything except ID + g := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}} + h := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}} + require.True(t, isOpenAIAccountCandidateBetter(g, h), "lower ID = better") +} + +// --------------------------------------------------------------------------- +// shouldReleaseStickySession Coverage +// --------------------------------------------------------------------------- + +func TestShouldReleaseStickySession_NilScheduler(t *testing.T) { + var s *defaultOpenAIAccountScheduler + require.False(t, s.shouldReleaseStickySession(1)) +} + +func TestShouldReleaseStickySession_Disabled(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickyReleaseEnabled = false + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats) + require.False(t, scheduler.shouldReleaseStickySession(1)) +} + +func TestShouldReleaseStickySession_CircuitOpen(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true + cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + // Trip the circuit breaker + cb := stats.getCircuitBreaker(1) + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats) + require.True(t, scheduler.shouldReleaseStickySession(1), "should release when circuit is open") +} + +func TestShouldReleaseStickySession_HighErrorRate(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + // Report many failures to push error rate above threshold + for i := 0; i < 20; i++ { + stats.report(1, false, nil, "", 0) + } + scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats) + require.True(t, scheduler.shouldReleaseStickySession(1), "should release when error rate is high") +} + +func TestShouldReleaseStickySession_Healthy(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true + svc := &OpenAIGatewayService{cfg: cfg} + stats := newOpenAIAccountRuntimeStats() + // Report successes + for i := 0; i < 20; i++ { + stats.report(1, true, nil, "", 0) + } + scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats) + require.False(t, scheduler.shouldReleaseStickySession(1), "should not release when healthy") +} + +// --------------------------------------------------------------------------- +// stickyReleaseConfigRead Coverage +// --------------------------------------------------------------------------- + +func TestStickyReleaseConfigRead_NilConfig(t *testing.T) { + svc := &OpenAIGatewayService{} + stats := newOpenAIAccountRuntimeStats() + scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats) + cfg := scheduler.stickyReleaseConfigRead() + require.False(t, cfg.enabled) + require.Equal(t, 0.0, cfg.errorThreshold, "nil config returns zero-value struct") +} + +func TestStickyReleaseConfigRead_Defaults(t *testing.T) { + c := &config.Config{} + // StickyReleaseErrorThreshold defaults to 0 → code uses defaultStickyReleaseErrorThreshold + svc := &OpenAIGatewayService{cfg: c} + stats := newOpenAIAccountRuntimeStats() + scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats) + cfg := scheduler.stickyReleaseConfigRead() + require.False(t, cfg.enabled) + require.Equal(t, defaultStickyReleaseErrorThreshold, cfg.errorThreshold) +} + +func TestStickyReleaseConfigRead_Custom(t *testing.T) { + c := &config.Config{} + c.Gateway.OpenAIWS.StickyReleaseEnabled = true + c.Gateway.OpenAIWS.StickyReleaseErrorThreshold = 0.5 + svc := &OpenAIGatewayService{cfg: c} + stats := newOpenAIAccountRuntimeStats() + scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats) + cfg := scheduler.stickyReleaseConfigRead() + require.True(t, cfg.enabled) + require.Equal(t, 0.5, cfg.errorThreshold) +} + +// --------------------------------------------------------------------------- +// RNG Coverage +// --------------------------------------------------------------------------- + +func TestNewOpenAISelectionRNG_ZeroSeed(t *testing.T) { + rng := newOpenAISelectionRNG(0) + require.NotEqual(t, uint64(0), rng.state, "zero seed should be replaced with default") + v := rng.nextFloat64() + require.True(t, v >= 0 && v < 1.0) +} + +// --------------------------------------------------------------------------- +// isCircuitOpen Coverage +// --------------------------------------------------------------------------- + +func TestIsCircuitOpen_UnknownAccount(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + require.False(t, stats.isCircuitOpen(999), "unknown account should not be circuit-open") +} + +func TestIsCircuitOpen_OpenAccount(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + cb := stats.getCircuitBreaker(1) + for i := 0; i < defaultCircuitBreakerFailThreshold; i++ { + cb.recordFailure(defaultCircuitBreakerFailThreshold) + } + require.True(t, stats.isCircuitOpen(1)) +} + +// --------------------------------------------------------------------------- +// openAIWSSchedulerP2CEnabled / openAIWSSchedulerWeights Coverage +// --------------------------------------------------------------------------- + +func TestP2CEnabled_NilConfig(t *testing.T) { + svc := &OpenAIGatewayService{} + require.False(t, svc.openAIWSSchedulerP2CEnabled()) +} + +func TestP2CEnabled_Enabled(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = true + svc := &OpenAIGatewayService{cfg: cfg} + require.True(t, svc.openAIWSSchedulerP2CEnabled()) +} + +func TestSchedulerWeights_NilConfig(t *testing.T) { + svc := &OpenAIGatewayService{} + w := svc.openAIWSSchedulerWeights() + require.Equal(t, 1.0, w.Priority) + require.Equal(t, 1.0, w.Load) + require.Equal(t, 0.7, w.Queue) + require.Equal(t, 0.8, w.ErrorRate) + require.Equal(t, 0.5, w.TTFT) +} + +func TestSchedulerWeights_Custom(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 2.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 3.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1.5 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.8 + svc := &OpenAIGatewayService{cfg: cfg} + w := svc.openAIWSSchedulerWeights() + require.Equal(t, 2.0, w.Priority) + require.Equal(t, 3.0, w.Load) +} + +// --------------------------------------------------------------------------- +// Snapshot edge cases +// --------------------------------------------------------------------------- + +func TestSnapshot_WithEmptyModel(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + stat := s.loadOrCreate(1) + stat.ttft.update(100.0) + _, ttft, hasTTFT := s.snapshot(1, "") + require.True(t, hasTTFT) + require.InDelta(t, 100.0, ttft, 0.01, "empty model string should fall through to global") +} + +func TestSnapshot_NoModelArg(t *testing.T) { + s := newOpenAIAccountRuntimeStats() + stat := s.loadOrCreate(1) + stat.ttft.update(100.0) + _, ttft, hasTTFT := s.snapshot(1) + require.True(t, hasTTFT) + require.InDelta(t, 100.0, ttft, 0.01) +} + +// --------------------------------------------------------------------------- +// deriveOpenAISelectionSeed coverage +// --------------------------------------------------------------------------- + +func TestDeriveOpenAISelectionSeed_WithSessionHash(t *testing.T) { + seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{SessionHash: "abc123"}) + require.NotEqual(t, uint64(0), seed) +} + +func TestDeriveOpenAISelectionSeed_WithPreviousResponseID(t *testing.T) { + seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{PreviousResponseID: "resp_123"}) + require.NotEqual(t, uint64(0), seed) +} + +func TestDeriveOpenAISelectionSeed_WithGroupID(t *testing.T) { + gid := int64(42) + seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{GroupID: &gid}) + require.NotEqual(t, uint64(0), seed) +} + +func TestDeriveOpenAISelectionSeed_Empty(t *testing.T) { + seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{}) + require.NotEqual(t, uint64(0), seed, "empty request should use time entropy") +} + +func TestDeriveOpenAISelectionSeed_WithModel(t *testing.T) { + seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{RequestedModel: "gpt-4"}) + require.NotEqual(t, uint64(0), seed) +} + +// --------------------------------------------------------------------------- +// SnapshotMetrics coverage +// --------------------------------------------------------------------------- + +func TestSnapshotMetrics_NilScheduler(t *testing.T) { + var s *defaultOpenAIAccountScheduler + metrics := s.SnapshotMetrics() + require.Equal(t, int64(0), metrics.SelectTotal) +} + +func TestSnapshotMetrics_Normal(t *testing.T) { + svc := &OpenAIGatewayService{} + stats := newOpenAIAccountRuntimeStats() + scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats) + metrics := scheduler.SnapshotMetrics() + require.Equal(t, int64(0), metrics.SelectTotal) +} + +// --------------------------------------------------------------------------- +// Heap Pop coverage +// --------------------------------------------------------------------------- + +func TestCandidateHeap_Pop(t *testing.T) { + h := &openAIAccountCandidateHeap{} + heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}}) + heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 2}, score: 0.9, loadInfo: &AccountLoadInfo{}}) + heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 3}, score: 0.3, loadInfo: &AccountLoadInfo{}}) + require.Equal(t, 3, h.Len()) + + // Pop returns the minimum (worst candidate in min-heap) + popped, ok := heap.Pop(h).(openAIAccountCandidateScore) + require.True(t, ok) + require.Equal(t, int64(3), popped.account.ID, "should pop the lowest-scored") + require.Equal(t, 2, h.Len()) +} + +// --------------------------------------------------------------------------- +// openAIWSSessionStickyTTL coverage +// --------------------------------------------------------------------------- + +func TestOpenAIWSSessionStickyTTL_DefaultConfig(t *testing.T) { + svc := &OpenAIGatewayService{} + ttl := svc.openAIWSSessionStickyTTL() + require.Equal(t, openaiStickySessionTTL, ttl, "nil config should return default TTL") +} + +func TestOpenAIWSSessionStickyTTL_Custom(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + svc := &OpenAIGatewayService{cfg: cfg} + ttl := svc.openAIWSSessionStickyTTL() + require.Equal(t, 1800*time.Second, ttl) +} diff --git a/backend/internal/service/openai_client_transport.go b/backend/internal/service/openai_client_transport.go index c9cf32462..5ed5ff69a 100644 --- a/backend/internal/service/openai_client_transport.go +++ b/backend/internal/service/openai_client_transport.go @@ -64,8 +64,25 @@ func resolveOpenAIWSDecisionByClientTransport( decision OpenAIWSProtocolDecision, clientTransport OpenAIClientTransport, ) OpenAIWSProtocolDecision { - if clientTransport == OpenAIClientTransportHTTP { + // WSv2 upstream is only allowed for explicit WebSocket ingress. + // Unknown/missing transport is treated as HTTP to avoid accidental WS pool usage. + if clientTransport != OpenAIClientTransportWS { return openAIWSHTTPDecision("client_protocol_http") } return decision } + +func shouldWarnOpenAIWSUnknownTransportFallback( + decision OpenAIWSProtocolDecision, + clientTransport OpenAIClientTransport, +) bool { + if clientTransport != OpenAIClientTransportUnknown { + return false + } + switch decision.Transport { + case OpenAIUpstreamTransportResponsesWebsocketV2, OpenAIUpstreamTransportResponsesWebsocket: + return true + default: + return false + } +} diff --git a/backend/internal/service/openai_client_transport_test.go b/backend/internal/service/openai_client_transport_test.go index ef90e6145..479534c79 100644 --- a/backend/internal/service/openai_client_transport_test.go +++ b/backend/internal/service/openai_client_transport_test.go @@ -103,5 +103,40 @@ func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) { require.Equal(t, base, wsDecision) unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown) - require.Equal(t, base, unknownDecision) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, unknownDecision.Transport) + require.Equal(t, "client_protocol_http", unknownDecision.Reason) +} + +func TestShouldWarnOpenAIWSUnknownTransportFallback(t *testing.T) { + require.True(t, shouldWarnOpenAIWSUnknownTransportFallback( + OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + }, + OpenAIClientTransportUnknown, + )) + + require.True(t, shouldWarnOpenAIWSUnknownTransportFallback( + OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocket, + Reason: "ws_v1_enabled", + }, + OpenAIClientTransportUnknown, + )) + + require.False(t, shouldWarnOpenAIWSUnknownTransportFallback( + OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportHTTPSSE, + Reason: "http_only", + }, + OpenAIClientTransportUnknown, + )) + + require.False(t, shouldWarnOpenAIWSUnknownTransportFallback( + OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + }, + OpenAIClientTransportHTTP, + )) } diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go new file mode 100644 index 000000000..9cb453ad0 --- /dev/null +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -0,0 +1,827 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type openAIRecordUsageLogRepoStub struct { + UsageLogRepository + + inserted bool + err error + calls int + lastLog *UsageLog + nextID int64 + + billingEntry *UsageBillingEntry + billingEntryErr error + upsertCalls int + getCalls int + markAppliedCalls int + markRetryCalls int + lastRetryAt time.Time + lastRetryErrMessage string + txCalls int +} + +func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) { + s.calls++ + if log != nil { + if log.ID == 0 { + if s.nextID == 0 { + s.nextID = 1000 + } + log.ID = s.nextID + s.nextID++ + } + } + s.lastLog = log + return s.inserted, s.err +} + +func (s *openAIRecordUsageLogRepoStub) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error) { + s.getCalls++ + if s.billingEntryErr != nil { + return nil, s.billingEntryErr + } + if s.billingEntry == nil { + return nil, ErrUsageBillingEntryNotFound + } + if s.billingEntry.UsageLogID != usageLogID { + return nil, ErrUsageBillingEntryNotFound + } + return s.billingEntry, nil +} + +func (s *openAIRecordUsageLogRepoStub) UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error) { + s.upsertCalls++ + if s.billingEntryErr != nil { + return nil, false, s.billingEntryErr + } + if s.billingEntry != nil { + return s.billingEntry, false, nil + } + if entry == nil { + return nil, false, nil + } + copyEntry := *entry + copyEntry.ID = 9100 + int64(s.upsertCalls) + copyEntry.Status = UsageBillingEntryStatusPending + s.billingEntry = ©Entry + return s.billingEntry, true, nil +} + +func (s *openAIRecordUsageLogRepoStub) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error { + s.markAppliedCalls++ + if s.billingEntry != nil && s.billingEntry.ID == entryID { + s.billingEntry.Applied = true + s.billingEntry.Status = UsageBillingEntryStatusApplied + } + return nil +} + +func (s *openAIRecordUsageLogRepoStub) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error { + s.markRetryCalls++ + s.lastRetryAt = nextRetryAt + s.lastRetryErrMessage = lastError + if s.billingEntry != nil && s.billingEntry.ID == entryID { + s.billingEntry.Applied = false + s.billingEntry.Status = UsageBillingEntryStatusPending + msg := lastError + s.billingEntry.LastError = &msg + s.billingEntry.NextRetryAt = nextRetryAt + } + return nil +} + +func (s *openAIRecordUsageLogRepoStub) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error) { + return nil, nil +} + +func (s *openAIRecordUsageLogRepoStub) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error { + s.txCalls++ + if fn == nil { + return nil + } + return fn(ctx) +} + +type openAIRecordUsageUserRepoStub struct { + UserRepository + + deductCalls int + deductErr error +} + +func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error { + s.deductCalls++ + return s.deductErr +} + +type openAIRecordUsageSubRepoStub struct { + UserSubscriptionRepository + + incrementCalls int + incrementErr error +} + +func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + s.incrementCalls++ + return s.incrementErr +} + +type openAIRecordUsageBillingCacheStub struct { + BillingCache + + deductCalls int + deductErr error +} + +func (s *openAIRecordUsageBillingCacheStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { + s.deductCalls++ + return s.deductErr +} + +func (s *openAIRecordUsageBillingCacheStub) GetUserBalance(context.Context, int64) (float64, error) { + return 0, errors.New("not implemented") +} + +func (s *openAIRecordUsageBillingCacheStub) SetUserBalance(context.Context, int64, float64) error { + return errors.New("not implemented") +} + +func (s *openAIRecordUsageBillingCacheStub) InvalidateUserBalance(context.Context, int64) error { + return nil +} + +func (s *openAIRecordUsageBillingCacheStub) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) { + return nil, errors.New("not implemented") +} + +func (s *openAIRecordUsageBillingCacheStub) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error { + return errors.New("not implemented") +} + +func (s *openAIRecordUsageBillingCacheStub) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error { + return errors.New("not implemented") +} + +func (s *openAIRecordUsageBillingCacheStub) InvalidateSubscriptionCache(context.Context, int64, int64) error { + return nil +} + +type openAIRecordUsageAPIKeyQuotaStub struct { + calls int + err error +} + +func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error { + s.calls++ + return s.err +} + +func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error { + return s.err +} + +func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *OpenAIGatewayService { + cfg := &config.Config{ + Default: config.DefaultConfig{ + RateMultiplier: 1, + }, + } + return &OpenAIGatewayService{ + usageLogRepo: usageRepo, + userRepo: userRepo, + userSubRepo: subRepo, + cfg: cfg, + billingService: NewBillingService(cfg, nil), + billingCacheService: &BillingCacheService{}, + deferredService: &DeferredService{}, + } +} + +func TestOpenAIGatewayServiceRecordUsage_NoBillingWhenCreateUsageLogFails(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + err: errors.New("write usage log failed"), + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_test_create_fail", + Usage: OpenAIUsage{ + InputTokens: 12, + OutputTokens: 8, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1001, + }, + User: &User{ + ID: 2001, + }, + Account: &Account{ + ID: 3001, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "create usage log") + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingWhenUsageLogInserted(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_test_inserted", + Usage: OpenAIUsage{ + InputTokens: 20, + OutputTokens: 10, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1002, + }, + User: &User{ + ID: 2002, + }, + Account: &Account{ + ID: 3002, + }, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_PricingFailureReturnsError(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + svc.billingService = &BillingService{ + cfg: &config.Config{}, + fallbackPrices: map[string]*ModelPricing{}, + } + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_pricing_fail", + Usage: OpenAIUsage{ + InputTokens: 1, + OutputTokens: 1, + }, + Model: "model_pricing_not_found_for_test", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1102, + }, + User: &User{ + ID: 2102, + }, + Account: &Account{ + ID: 3102, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "calculate cost") + require.Equal(t, 0, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_DeductBalanceFailureReturnsError(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{ + deductErr: errors.New("db deduct failed"), + } + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_deduct_fail", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1003, + }, + User: &User{ + ID: 2003, + }, + Account: &Account{ + ID: 3003, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "deduct balance") + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_DeductBalanceCacheFailureReturnsError(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + cache := &openAIRecordUsageBillingCacheStub{ + deductErr: ErrInsufficientBalance, + } + svc.billingCacheService = &BillingCacheService{cache: cache} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_cache_deduct_fail", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1004, + }, + User: &User{ + ID: 2004, + }, + Account: &Account{ + ID: 3004, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "deduct balance cache") + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, cache.deductCalls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_SubscriptionIncrementFailureReturnsError(t *testing.T) { + groupID := int64(12) + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{ + incrementErr: errors.New("subscription update failed"), + } + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + svc.billingCacheService = &BillingCacheService{cache: &openAIRecordUsageBillingCacheStub{}} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_sub_increment_fail", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1005, + GroupID: &groupID, + Group: &Group{ + ID: groupID, + SubscriptionType: SubscriptionTypeSubscription, + }, + }, + User: &User{ + ID: 2005, + }, + Account: &Account{ + ID: 3005, + }, + Subscription: &UserSubscription{ + ID: 4005, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "increment subscription usage") + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 1, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: false, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_duplicate", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 4, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1006, + }, + User: &User{ + ID: 2006, + }, + Account: &Account{ + ID: 3006, + }, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBilling(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + svc.cfg.RunMode = config.RunModeSimple + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_simple_mode", + Usage: OpenAIUsage{ + InputTokens: 5, + OutputTokens: 2, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1007, + Quota: 100, + }, + User: &User{ + ID: 2007, + }, + Account: &Account{ + ID: 3007, + }, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 0, subRepo.incrementCalls) + require.Equal(t, 0, quotaSvc.calls) +} + +func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSuccess(t *testing.T) { + groupID := int64(13) + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_sub_success", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 8, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1008, + GroupID: &groupID, + Group: &Group{ + ID: groupID, + SubscriptionType: SubscriptionTypeSubscription, + }, + }, + User: &User{ + ID: 2008, + }, + Account: &Account{ + ID: 3008, + }, + Subscription: &UserSubscription{ + ID: 4008, + }, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 0, userRepo.deductCalls) + require.Equal(t, 1, subRepo.incrementCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{} + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_quota_update", + Usage: OpenAIUsage{ + InputTokens: 1, + OutputTokens: 1, + CacheReadInputTokens: 3, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1009, + Quota: 100, + }, + User: &User{ + ID: 2009, + }, + Account: &Account{ + ID: 3009, + }, + APIKeyService: quotaSvc, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, 0, usageRepo.lastLog.InputTokens, "input_tokens 小于 cache_read_tokens 时应被钳制为 0") + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 1, quotaSvc.calls) +} + +func TestOpenAIGatewayServiceRecordUsage_DuplicateWithPendingBillingEntryStillBills(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: false, + billingEntry: &UsageBillingEntry{ + ID: 9201, + UsageLogID: 1000, + Applied: false, + BillingType: BillingTypeBalance, + DeltaUSD: 1.25, + }, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_duplicate_pending_entry", + Usage: OpenAIUsage{ + InputTokens: 8, + OutputTokens: 2, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ID: 1011}, + User: &User{ID: 2011}, + Account: &Account{ + ID: 3011, + }, + }) + + require.NoError(t, err) + require.Equal(t, 1, usageRepo.calls) + require.Equal(t, 1, usageRepo.getCalls) + require.Equal(t, 1, usageRepo.markAppliedCalls) + require.Equal(t, 1, userRepo.deductCalls) +} + +func TestOpenAIGatewayServiceRecordUsage_BillingFailureMarksRetry(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{ + deductErr: errors.New("deduct failed"), + } + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_mark_retry", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 6, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1012, + }, + User: &User{ + ID: 2012, + }, + Account: &Account{ + ID: 3012, + }, + }) + + require.Error(t, err) + require.ErrorContains(t, err, "deduct balance") + require.Equal(t, 1, usageRepo.markRetryCalls) + require.NotZero(t, usageRepo.lastRetryAt) + require.NotEmpty(t, usageRepo.lastRetryErrMessage) + require.Equal(t, 0, usageRepo.markAppliedCalls) +} + +func TestResolveOpenAIUsageRequestID_FallbackDeterministic(t *testing.T) { + reasoning := "medium" + input := &OpenAIRecordUsageInput{ + FallbackRequestID: "req_fallback_seed", + APIKey: &APIKey{ID: 11001}, + Account: &Account{ID: 21001}, + Result: &OpenAIForwardResult{ + RequestID: "", + Model: "gpt-5.1", + Usage: OpenAIUsage{ + InputTokens: 12, + OutputTokens: 8, + CacheCreationInputTokens: 2, + CacheReadInputTokens: 1, + }, + Duration: 2300 * time.Millisecond, + ReasoningEffort: &reasoning, + Stream: true, + OpenAIWSMode: true, + }, + } + + got1 := resolveOpenAIUsageRequestID(input) + got2 := resolveOpenAIUsageRequestID(input) + + require.NotEmpty(t, got1) + require.Equal(t, got1, got2, "fallback request id should be deterministic") + require.Contains(t, got1, "wsf_") +} + +func TestResolveOpenAIUsageRequestID_FallbackChangesWhenUsageChanges(t *testing.T) { + base := &OpenAIRecordUsageInput{ + FallbackRequestID: "req_fallback_seed", + APIKey: &APIKey{ID: 11002}, + Account: &Account{ID: 21002}, + Result: &OpenAIForwardResult{ + Model: "gpt-5.1", + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 4, + }, + Duration: 2 * time.Second, + }, + } + changed := &OpenAIRecordUsageInput{ + FallbackRequestID: base.FallbackRequestID, + APIKey: base.APIKey, + Account: base.Account, + Result: &OpenAIForwardResult{ + Model: "gpt-5.1", + Usage: OpenAIUsage{ + InputTokens: 11, + OutputTokens: 4, + }, + Duration: 2 * time.Second, + }, + } + + baseID := resolveOpenAIUsageRequestID(base) + changedID := resolveOpenAIUsageRequestID(changed) + + require.NotEqual(t, baseID, changedID, "fallback request id should change when usage fingerprint changes") +} + +func TestResolveOpenAIUsageRequestID_FallbackChangesWhenWSIngressModeChanges(t *testing.T) { + base := &OpenAIRecordUsageInput{ + FallbackRequestID: "req_fallback_seed", + APIKey: &APIKey{ID: 11003}, + Account: &Account{ID: 21003}, + Result: &OpenAIForwardResult{ + Model: "gpt-5.3-codex", + Stream: true, + OpenAIWSMode: true, + WSIngressMode: OpenAIWSIngressModeCtxPool, + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 4, + }, + Duration: 2 * time.Second, + }, + } + changed := &OpenAIRecordUsageInput{ + FallbackRequestID: base.FallbackRequestID, + APIKey: base.APIKey, + Account: base.Account, + Result: &OpenAIForwardResult{ + Model: base.Result.Model, + Stream: base.Result.Stream, + OpenAIWSMode: true, + WSIngressMode: OpenAIWSIngressModePassthrough, + Usage: OpenAIUsage{ + InputTokens: 10, + OutputTokens: 4, + }, + Duration: 2 * time.Second, + }, + } + + baseID := resolveOpenAIUsageRequestID(base) + changedID := resolveOpenAIUsageRequestID(changed) + + require.NotEqual(t, baseID, changedID, "fallback request id should change when ws ingress mode changes") +} + +func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDWhenMissing(t *testing.T) { + usageRepo := &openAIRecordUsageLogRepoStub{ + inserted: true, + } + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo) + + input := &OpenAIRecordUsageInput{ + FallbackRequestID: "req_from_handler", + Result: &OpenAIForwardResult{ + RequestID: "", + Usage: OpenAIUsage{ + InputTokens: 9, + OutputTokens: 3, + }, + Model: "gpt-5.1", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1013, + }, + User: &User{ + ID: 2013, + }, + Account: &Account{ + ID: 3013, + }, + } + + expectedRequestID := resolveOpenAIUsageRequestID(input) + require.NotEmpty(t, expectedRequestID) + + err := svc.RecordUsage(context.Background(), input) + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.Equal(t, expectedRequestID, usageRepo.lastLog.RequestID) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 8606708f7..4a8f28a08 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "log/slog" "math/rand" "net/http" "sort" @@ -42,6 +43,8 @@ const ( // OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。 OpenAIParsedRequestBodyKey = "openai_parsed_request_body" + // OpenAIRequestMetaKey 缓存 handler 已提取的请求元数据,供 Service 层复用。 + OpenAIRequestMetaKey = "openai_request_meta" // OpenAI WS Mode 失败后的重连次数上限(不含首次尝试)。 // 与 Codex 客户端保持一致:失败后最多重连 5 次。 openAIWSReconnectRetryLimit = 5 @@ -49,6 +52,13 @@ const ( openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond openAIWSRetryBackoffMaxDefault = 2 * time.Second openAIWSRetryJitterRatioDefault = 0.2 + openAICodexUsageUpdateConcurrency = 16 +) + +// SSE 热路径包级常量,避免循环内重复分配 +var ( + sseDataDone = []byte("[DONE]") + sseResponseCompletedMark = []byte(`"response.completed"`) ) // OpenAI allowed headers whitelist (for non-passthrough). @@ -210,8 +220,16 @@ type OpenAIForwardResult struct { ReasoningEffort *string Stream bool OpenAIWSMode bool - Duration time.Duration - FirstTokenMs *int + // WSIngressMode records which WS v2 ingress path produced this result. + // Expected values: ctx_pool / passthrough. + WSIngressMode string + Duration time.Duration + FirstTokenMs *int + // TerminalEventType records the terminal event that ended the WS turn. + TerminalEventType string + // PendingFunctionCallIDs 表示该 response 中未完成的 function_call call_id 集合。 + // 仅在 WS ingress 连续对话场景用于续链自愈,不参与外部 API 返回。 + PendingFunctionCallIDs []string } type OpenAIWSRetryMetricsSnapshot struct { @@ -221,6 +239,17 @@ type OpenAIWSRetryMetricsSnapshot struct { NonRetryableFastFallbackTotal int64 `json:"non_retryable_fast_fallback_total"` } +type OpenAIWSTurnAbortMetricPoint struct { + Reason string `json:"reason"` + Expected bool `json:"expected"` + Total int64 `json:"total"` +} + +type OpenAIWSAbortMetricsSnapshot struct { + TurnAbortTotal []OpenAIWSTurnAbortMetricPoint `json:"turn_abort_total"` + TurnAbortRecoveredTotal int64 `json:"turn_abort_recovered_total"` +} + type OpenAICompatibilityFallbackMetricsSnapshot struct { SessionHashLegacyReadFallbackTotal int64 `json:"session_hash_legacy_read_fallback_total"` SessionHashLegacyReadFallbackHit int64 `json:"session_hash_legacy_read_fallback_hit"` @@ -243,6 +272,16 @@ type openAIWSRetryMetrics struct { nonRetryableFastFallback atomic.Int64 } +type openAIWSTurnAbortMetricKey struct { + reason string + expected bool +} + +type openAIWSAbortMetrics struct { + turnAbortTotal sync.Map // key: openAIWSTurnAbortMetricKey, value: *atomic.Int64 + turnAbortRecovered atomic.Int64 +} + // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { accountRepo AccountRepository @@ -263,17 +302,28 @@ type OpenAIGatewayService struct { toolCorrector *CodexToolCorrector openaiWSResolver OpenAIWSProtocolResolver - openaiWSPoolOnce sync.Once - openaiWSStateStoreOnce sync.Once - openaiSchedulerOnce sync.Once - openaiWSPool *openAIWSConnPool - openaiWSStateStore OpenAIWSStateStore - openaiScheduler OpenAIAccountScheduler - openaiAccountStats *openAIAccountRuntimeStats + openaiWSIngressCtxOnce sync.Once + openaiWSStateStoreOnce sync.Once + openaiSchedulerOnce sync.Once + openaiWSPassthroughDialerOnce sync.Once + openaiWSIngressCtxPool *openAIWSIngressContextPool + openaiWSStateStore OpenAIWSStateStore + openaiScheduler OpenAIAccountScheduler + openaiWSPassthroughDialer openAIWSClientDialer + openaiAccountStats *openAIAccountRuntimeStats openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time openaiWSRetryMetrics openAIWSRetryMetrics + openaiWSAbortMetrics openAIWSAbortMetrics responseHeaderFilter *responseheaders.CompiledHeaderFilter + + codexUsageUpdateOnce sync.Once + codexUsageUpdateSem chan struct{} + + usageBillingCompensation *UsageBillingCompensationService + + // test hook for deterministic tie-break in account selection. + accountTieBreakIntnFn func(n int) int } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -313,15 +363,25 @@ func NewOpenAIGatewayService( openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), responseHeaderFilter: compileResponseHeaderFilter(cfg), } + svc.usageBillingCompensation = NewUsageBillingCompensationService(usageLogRepo, userRepo, userSubRepo, billingCacheService, cfg) + svc.usageBillingCompensation.Start() svc.logOpenAIWSModeBootstrap() return svc } -// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。 +// CloseOpenAIWSCtxPool 关闭 OpenAI WebSocket ctx_pool 的后台 worker 与连接资源。 // 应在应用优雅关闭时调用。 -func (s *OpenAIGatewayService) CloseOpenAIWSPool() { - if s != nil && s.openaiWSPool != nil { - s.openaiWSPool.Close() +func (s *OpenAIGatewayService) CloseOpenAIWSCtxPool() { + if s != nil && s.openaiWSIngressCtxPool != nil { + s.openaiWSIngressCtxPool.Close() + } + if s != nil && s.openaiWSStateStore != nil { + if closer, ok := s.openaiWSStateStore.(interface{ Close() }); ok { + closer.Close() + } + } + if s != nil && s.usageBillingCompensation != nil { + s.usageBillingCompensation.Stop() } } @@ -537,6 +597,34 @@ func (s *OpenAIGatewayService) writeOpenAIWSFallbackErrorResponse(c *gin.Context return true } +func (s *OpenAIGatewayService) writeOpenAIWSV1UnsupportedResponse(c *gin.Context, account *Account) error { + const ( + upstreamMessage = "openai ws v1 is temporarily unsupported; use ws v2" + clientMessage = "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2." + ) + setOpsUpstreamError(c, http.StatusBadRequest, upstreamMessage, "") + if account != nil { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: http.StatusBadRequest, + Kind: "ws_error", + Message: upstreamMessage, + }) + } + if c != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": clientMessage, + }, + }) + c.Abort() + } + return errors.New(upstreamMessage) +} + func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration { if attempt <= 0 { return 0 @@ -610,6 +698,16 @@ func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration { return 0 } +func openAIWSRetryContextError(ctx context.Context) error { + if ctx == nil { + return nil + } + if err := ctx.Err(); err != nil { + return wrapOpenAIWSFallback("retry_context_canceled", err) + } + return nil +} + func (s *OpenAIGatewayService) recordOpenAIWSRetryAttempt(backoff time.Duration) { if s == nil { return @@ -646,6 +744,70 @@ func (s *OpenAIGatewayService) SnapshotOpenAIWSRetryMetrics() OpenAIWSRetryMetri } } +func (s *OpenAIGatewayService) recordOpenAIWSTurnAbort(reason openAIWSIngressTurnAbortReason, expected bool) { + if s == nil { + return + } + normalizedReason := strings.TrimSpace(string(reason)) + if normalizedReason == "" { + normalizedReason = string(openAIWSIngressTurnAbortReasonUnknown) + } + key := openAIWSTurnAbortMetricKey{ + reason: normalizedReason, + expected: expected, + } + counterAny, _ := s.openaiWSAbortMetrics.turnAbortTotal.LoadOrStore(key, &atomic.Int64{}) + counter, ok := counterAny.(*atomic.Int64) + if !ok || counter == nil { + return + } + counter.Add(1) +} + +func (s *OpenAIGatewayService) recordOpenAIWSTurnAbortRecovered() { + if s == nil { + return + } + s.openaiWSAbortMetrics.turnAbortRecovered.Add(1) +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSAbortMetrics() OpenAIWSAbortMetricsSnapshot { + if s == nil { + return OpenAIWSAbortMetricsSnapshot{} + } + points := make([]OpenAIWSTurnAbortMetricPoint, 0, 8) + s.openaiWSAbortMetrics.turnAbortTotal.Range(func(key, value any) bool { + label, ok := key.(openAIWSTurnAbortMetricKey) + if !ok { + return true + } + counter, ok := value.(*atomic.Int64) + if !ok || counter == nil { + return true + } + total := counter.Load() + if total <= 0 { + return true + } + points = append(points, OpenAIWSTurnAbortMetricPoint{ + Reason: label.reason, + Expected: label.expected, + Total: total, + }) + return true + }) + sort.Slice(points, func(i, j int) bool { + if points[i].Reason == points[j].Reason { + return !points[i].Expected && points[j].Expected + } + return points[i].Reason < points[j].Reason + }) + return OpenAIWSAbortMetricsSnapshot{ + TurnAbortTotal: points, + TurnAbortRecoveredTotal: s.openaiWSAbortMetrics.turnAbortRecovered.Load(), + } +} + func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMetricsSnapshot { legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal := openAIStickyCompatStats() isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount := RequestMetadataFallbackStats() @@ -1041,6 +1203,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // Returns nil if no available account. func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account { var selected *Account + tieCount := 0 for i := range accounts { acc := &accounts[i] @@ -1067,17 +1230,41 @@ func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedMo // Select highest priority and least recently used if selected == nil { selected = acc + tieCount = 1 continue } if s.isBetterAccount(acc, selected) { selected = acc + tieCount = 1 + continue + } + if !s.isBetterAccount(selected, acc) { + // 完全同分档时进行 reservoir tie-break,避免长期集中命中首个账号。 + tieCount++ + if s.accountTieBreakIntn(tieCount) == 0 { + selected = acc + } } } return selected } +func (s *OpenAIGatewayService) accountTieBreakIntn(n int) int { + if n <= 1 { + return 0 + } + if s != nil && s.accountTieBreakIntnFn != nil { + v := s.accountTieBreakIntnFn(n) + if v >= 0 && v < n { + return v + } + return 0 + } + return rand.Intn(n) +} + // isBetterAccount 判断 candidate 是否比 current 更优。 // 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。 // @@ -1381,6 +1568,9 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig // GetAccessToken gets the access token for an OpenAI account func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { + if account == nil { + return "", "", errors.New("account is nil") + } switch account.Type { case AccountTypeOAuth: // 使用 TokenProvider 获取缓存的 token @@ -1392,13 +1582,13 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco return accessToken, "oauth", nil } // 降级:TokenProvider 未配置时直接从账号读取 - accessToken := account.GetOpenAIAccessToken() + accessToken := strings.TrimSpace(account.GetOpenAIAccessToken()) if accessToken == "" { return "", "", errors.New("access_token not found in credentials") } return accessToken, "oauth", nil case AccountTypeAPIKey: - apiKey := account.GetOpenAIApiKey() + apiKey := strings.TrimSpace(account.GetOpenAIApiKey()) if apiKey == "" { return "", "", errors.New("api_key not found in credentials") } @@ -1417,8 +1607,19 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool } } -func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) { - body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) +func (s *OpenAIGatewayService) handleFailoverSideEffects( + ctx context.Context, + resp *http.Response, + account *Account, + respBody []byte, +) { + if s == nil || s.rateLimitService == nil || resp == nil || account == nil { + return + } + body := respBody + if len(body) == 0 && resp.Body != nil { + body, _ = io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + } s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } @@ -1440,12 +1641,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } originalBody := body - reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + reqModel, reqStream, promptCacheKey := extractOpenAIRequestMeta(c, body) originalModel := reqModel isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) clientTransport := GetOpenAIClientTransport(c) + if shouldWarnOpenAIWSUnknownTransportFallback(wsDecision, clientTransport) { + logOpenAIWSModeInfo( + "client_transport_unknown_fallback_http account_id=%d account_type=%s resolved_transport=%s resolved_reason=%s", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + normalizeOpenAIWSLogValue(wsDecision.Reason), + ) + } // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport) if c != nil { @@ -1465,15 +1675,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。 if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket { - if c != nil { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "type": "invalid_request_error", - "message": "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.", - }, - }) - } - return nil, errors.New("openai ws v1 is temporarily unsupported; use ws v2") + return nil, s.writeOpenAIWSV1UnsupportedResponse(c, account) } passthroughEnabled := account.IsOpenAIPassthroughEnabled() if passthroughEnabled { @@ -1762,6 +1964,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco retryStartedAt := time.Now() wsRetryLoop: for attempt := 1; attempt <= maxAttempts; attempt++ { + if cancelErr := openAIWSRetryContextError(ctx); cancelErr != nil { + wsErr = cancelErr + break + } wsAttempts = attempt wsResult, wsErr = s.forwardOpenAIWSV2( ctx, @@ -1942,8 +2148,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco Message: upstreamMsg, Detail: upstreamDetail, }) - - s.handleFailoverSideEffects(ctx, resp, account) + s.handleFailoverSideEffects(ctx, resp, account, respBody) return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } return s.handleErrorResponse(ctx, resp, c, account, body) @@ -2036,7 +2241,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( } logger.LegacyPrintf("service.openai_gateway", - "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", + "[DEBUG] [OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v", account.ID, account.Name, account.Type, @@ -2202,6 +2407,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough( if err != nil { return nil, err } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) // 透传客户端请求头(安全白名单)。 allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed() @@ -2582,6 +2788,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin. if err != nil { return nil, err } + req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI)) // Set authentication header req.Header.Set("authorization", "Bearer "+token) @@ -2934,13 +3141,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp line = s.replaceModelInSSELine(line, mappedModel, originalModel) } - dataBytes := []byte(data) - // Correct Codex tool calls if needed (apply_patch -> edit, etc.) - if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { - dataBytes = correctedData - data = string(correctedData) - line = "data: " + data + // 仅在 toolCorrector 存在时才转换为 []byte,避免热路径无谓分配 + if s.toolCorrector != nil { + dataBytes := []byte(data) + if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { + data = string(correctedData) + line = "data: " + data + } } // 写入客户端(客户端断开后继续 drain 上游) @@ -2969,7 +3177,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - s.parseSSEUsageBytes(dataBytes, usage) + // 使用 string 版本解析 usage,避免 string→[]byte 转换 + s.parseSSEUsageString(data, usage) return } @@ -3143,24 +3352,50 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt } func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { - s.parseSSEUsageBytes([]byte(data), usage) + s.parseSSEUsageString(data, usage) +} + +// parseSSEUsageString 使用 gjson.Get(string 版本)解析 usage,避免 string→[]byte 转换 +func (s *OpenAIGatewayService) parseSSEUsageString(data string, usage *OpenAIUsage) { + if usage == nil || len(data) == 0 || data == "[DONE]" { + return + } + if len(data) < 80 || !strings.Contains(data, `"response.completed"`) { + return + } + if gjson.Get(data, "type").String() != "response.completed" { + return + } + usageFields := gjson.GetMany(data, + "response.usage.input_tokens", + "response.usage.output_tokens", + "response.usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(usageFields[0].Int()) + usage.OutputTokens = int(usageFields[1].Int()) + usage.CacheReadInputTokens = int(usageFields[2].Int()) } func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) { - if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { + if usage == nil || len(data) == 0 || bytes.Equal(data, sseDataDone) { return } // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 - if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) { + if len(data) < 80 || !bytes.Contains(data, sseResponseCompletedMark) { return } if gjson.GetBytes(data, "type").String() != "response.completed" { return } - - usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int()) - usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int()) - usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int()) + // 使用 GetManyBytes 一次提取 3 个 usage 字段 + usageFields := gjson.GetManyBytes(data, + "response.usage.input_tokens", + "response.usage.output_tokens", + "response.usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(usageFields[0].Int()) + usage.OutputTokens = int(usageFields[1].Int()) + usage.CacheReadInputTokens = int(usageFields[2].Int()) } func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { @@ -3318,6 +3553,13 @@ func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel st } func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) { + if s == nil || s.cfg == nil { + normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{}) + if err != nil { + return "", fmt.Errorf("invalid base_url: %w", err) + } + return normalized, nil + } if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled { normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP) if err != nil { @@ -3365,14 +3607,224 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel // OpenAIRecordUsageInput input for recording usage type OpenAIRecordUsageInput struct { - Result *OpenAIForwardResult - APIKey *APIKey - User *User - Account *Account - Subscription *UserSubscription - UserAgent string // 请求的 User-Agent - IPAddress string // 请求的客户端 IP 地址 - APIKeyService APIKeyQuotaUpdater + Result *OpenAIForwardResult + APIKey *APIKey + User *User + Account *Account + Subscription *UserSubscription + UserAgent string // 请求的 User-Agent + IPAddress string // 请求的客户端 IP 地址 + FallbackRequestID string // 当上游 request_id 缺失时,用于生成稳定幂等 request_id + APIKeyService APIKeyQuotaUpdater +} + +func (s *OpenAIGatewayService) usageBillingEntryStore() UsageBillingEntryStore { + store, ok := s.usageLogRepo.(UsageBillingEntryStore) + if !ok { + return nil + } + return store +} + +func (s *OpenAIGatewayService) usageBillingTxRunner() UsageBillingTxRunner { + runner, ok := s.usageLogRepo.(UsageBillingTxRunner) + if !ok { + return nil + } + return runner +} + +func (s *OpenAIGatewayService) runUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error { + runner := s.usageBillingTxRunner() + if runner == nil { + return fn(ctx) + } + return runner.WithUsageBillingTx(ctx, fn) +} + +func (s *OpenAIGatewayService) prepareUsageBillingEntry( + ctx context.Context, + usageLog *UsageLog, + inserted bool, + billingType int8, + deltaUSD float64, +) (*UsageBillingEntry, bool, error) { + if deltaUSD <= 0 { + return nil, false, nil + } + + store := s.usageBillingEntryStore() + if store == nil { + if inserted { + return nil, true, nil + } + return nil, false, nil + } + + if !inserted { + entry, err := store.GetUsageBillingEntryByUsageLogID(ctx, usageLog.ID) + if err != nil { + if errors.Is(err, ErrUsageBillingEntryNotFound) { + logger.LegacyPrintf( + "service.openai_gateway", + "[BillingReconcile] missing billing entry for duplicate usage log, skip immediate billing: usage_log=%d request_id=%s", + usageLog.ID, + usageLog.RequestID, + ) + return nil, false, nil + } + logger.LegacyPrintf( + "service.openai_gateway", + "[BillingReconcile] load billing entry failed for duplicate usage log, skip immediate billing: usage_log=%d request_id=%s err=%v", + usageLog.ID, + usageLog.RequestID, + err, + ) + return nil, false, nil + } + return entry, !entry.Applied, nil + } + + entry, _, err := store.UpsertUsageBillingEntry(ctx, &UsageBillingEntry{ + UsageLogID: usageLog.ID, + UserID: usageLog.UserID, + APIKeyID: usageLog.APIKeyID, + SubscriptionID: usageLog.SubscriptionID, + BillingType: billingType, + DeltaUSD: deltaUSD, + Status: UsageBillingEntryStatusPending, + }) + if err != nil { + logger.LegacyPrintf( + "service.openai_gateway", + "[BillingReconcile] upsert billing entry failed, fallback to inline billing: usage_log=%d request_id=%s err=%v", + usageLog.ID, + usageLog.RequestID, + err, + ) + return nil, true, nil + } + + return entry, !entry.Applied, nil +} + +func (s *OpenAIGatewayService) markUsageBillingRetry(ctx context.Context, entry *UsageBillingEntry, cause error) { + if entry == nil || cause == nil { + return + } + store := s.usageBillingEntryStore() + if store == nil { + return + } + errMsg := strings.TrimSpace(cause.Error()) + if len(errMsg) > 500 { + errMsg = errMsg[:500] + } + nextRetryAt := time.Now().Add(usageBillingRetryBackoff(entry.AttemptCount + 1)) + if err := store.MarkUsageBillingEntryRetry(ctx, entry.ID, nextRetryAt, errMsg); err != nil { + logger.LegacyPrintf("service.openai_gateway", "[BillingReconcile] mark retry failed: entry=%d err=%v", entry.ID, err) + } +} + +func resolveOpenAIUsageRequestID(input *OpenAIRecordUsageInput) string { + if input == nil || input.Result == nil { + return "" + } + if requestID := strings.TrimSpace(input.Result.RequestID); requestID != "" { + return requestID + } + return buildOpenAIUsageFallbackRequestID(input) +} + +func buildOpenAIUsageFallbackRequestID(input *OpenAIRecordUsageInput) string { + if input == nil || input.Result == nil { + return "" + } + result := input.Result + usage := result.Usage + + seed := strings.Builder{} + seed.Grow(192) + _, _ = seed.WriteString(strings.TrimSpace(input.FallbackRequestID)) + _ = seed.WriteByte('|') + if input.APIKey != nil { + _, _ = seed.WriteString(strconv.FormatInt(input.APIKey.ID, 10)) + } + _ = seed.WriteByte('|') + if input.Account != nil { + _, _ = seed.WriteString(strconv.FormatInt(input.Account.ID, 10)) + } + _ = seed.WriteByte('|') + _, _ = seed.WriteString(strings.TrimSpace(result.Model)) + _ = seed.WriteByte('|') + _, _ = seed.WriteString(strings.TrimSpace(result.TerminalEventType)) + _ = seed.WriteByte('|') + _, _ = seed.WriteString(strconv.FormatBool(result.Stream)) + _ = seed.WriteByte('|') + _, _ = seed.WriteString(strconv.FormatBool(result.OpenAIWSMode)) + _ = seed.WriteByte('|') + _, _ = seed.WriteString(normalizeOpenAIWSIngressMode(result.WSIngressMode)) + _ = seed.WriteByte('|') + _, _ = seed.WriteString(strconv.Itoa(usage.InputTokens)) + _ = seed.WriteByte('|') + _, _ = seed.WriteString(strconv.Itoa(usage.OutputTokens)) + _ = seed.WriteByte('|') + _, _ = seed.WriteString(strconv.Itoa(usage.CacheCreationInputTokens)) + _ = seed.WriteByte('|') + _, _ = seed.WriteString(strconv.Itoa(usage.CacheReadInputTokens)) + _ = seed.WriteByte('|') + _, _ = seed.WriteString(strconv.FormatInt(result.Duration.Milliseconds(), 10)) + _ = seed.WriteByte('|') + firstTokenMs := -1 + if result.FirstTokenMs != nil { + firstTokenMs = *result.FirstTokenMs + } + _, _ = seed.WriteString(strconv.Itoa(firstTokenMs)) + _ = seed.WriteByte('|') + if result.ReasoningEffort != nil { + _, _ = seed.WriteString(strings.TrimSpace(*result.ReasoningEffort)) + } + + sum := sha256.Sum256([]byte(seed.String())) + return "wsf_" + hex.EncodeToString(sum[:16]) +} + +func logOpenAIWSUsageRecorded( + result *OpenAIForwardResult, + usageLog *UsageLog, + inserted bool, + shouldBill bool, + billedAmount float64, +) { + if result == nil || usageLog == nil || !result.OpenAIWSMode { + return + } + wsIngressMode := normalizeOpenAIWSIngressMode(result.WSIngressMode) + if wsIngressMode == "" { + wsIngressMode = "-" + } + + logOpenAIWSModeInfo( + "ingress_ws_usage_recorded account_id=%d api_key_id=%d user_id=%d request_id=%s ws_ingress_mode=%s request_type=%s requests=1 model=%s input_tokens=%d output_tokens=%d cache_creation_tokens=%d cache_read_tokens=%d total_tokens=%d total_cost=%.6f actual_cost=%.6f billed_amount=%.6f should_bill=%v inserted=%v terminal_event=%s", + usageLog.AccountID, + usageLog.APIKeyID, + usageLog.UserID, + truncateOpenAIWSLogValue(usageLog.RequestID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(wsIngressMode), + normalizeOpenAIWSLogValue(usageLog.EffectiveRequestType().String()), + truncateOpenAIWSLogValue(usageLog.Model, openAIWSLogValueMaxLen), + usageLog.InputTokens, + usageLog.OutputTokens, + usageLog.CacheCreationTokens, + usageLog.CacheReadTokens, + usageLog.TotalTokens(), + usageLog.TotalCost, + usageLog.ActualCost, + billedAmount, + shouldBill, + inserted, + truncateOpenAIWSLogValue(result.TerminalEventType, openAIWSLogValueMaxLen), + ) } // RecordUsage records usage and deducts balance @@ -3406,7 +3858,25 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier) if err != nil { - cost = &CostBreakdown{ActualCost: 0} + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + logger.LegacyPrintf( + "service.openai_gateway", + "[PricingWarn] calculate cost failed in simple mode, fallback to zero cost: model=%s request_id=%s err=%v", + result.Model, + result.RequestID, + err, + ) + cost = &CostBreakdown{} + } else { + logger.LegacyPrintf( + "service.openai_gateway", + "[PricingAlert] calculate cost failed, reject usage record: model=%s request_id=%s err=%v", + result.Model, + result.RequestID, + err, + ) + return fmt.Errorf("calculate cost: %w", err) + } } // Determine billing type @@ -3419,11 +3889,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec // Create usage log durationMs := int(result.Duration.Milliseconds()) accountRateMultiplier := account.BillingRateMultiplier() + requestID := resolveOpenAIUsageRequestID(input) usageLog := &UsageLog{ UserID: user.ID, APIKeyID: apiKey.ID, AccountID: account.ID, - RequestID: result.RequestID, + RequestID: requestID, Model: result.Model, ReasoningEffort: result.ReasoningEffort, InputTokens: actualInputTokens, @@ -3464,24 +3935,68 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec } inserted, err := s.usageLogRepo.Create(ctx, usageLog) + if err != nil { + return fmt.Errorf("create usage log: %w", err) + } if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) + logOpenAIWSUsageRecorded(result, usageLog, inserted, false, 0) s.deferredService.ScheduleLastUsedUpdate(account.ID) return nil } - shouldBill := inserted || err != nil - - // Deduct based on billing type + billAmount := cost.ActualCost if isSubscriptionBilling { - if shouldBill && cost.TotalCost > 0 { - _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost) - s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) + billAmount = cost.TotalCost + } + billingEntry, shouldBill, err := s.prepareUsageBillingEntry(ctx, usageLog, inserted, billingType, billAmount) + if err != nil { + return fmt.Errorf("prepare usage billing entry: %w", err) + } + + if shouldBill { + cacheDeducted := false + if !isSubscriptionBilling && billAmount > 0 && s.billingCacheService != nil { + // 同步扣减缓存,避免并发场景下仅靠“先查后扣”产生透支窗口。 + if err := s.billingCacheService.DeductBalanceCache(ctx, user.ID, billAmount); err != nil { + s.markUsageBillingRetry(ctx, billingEntry, err) + return fmt.Errorf("deduct balance cache: %w", err) + } + cacheDeducted = true } - } else { - if shouldBill && cost.ActualCost > 0 { - _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost) - s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost) + + applyErr := s.runUsageBillingTx(ctx, func(txCtx context.Context) error { + if isSubscriptionBilling { + if err := s.userSubRepo.IncrementUsage(txCtx, subscription.ID, cost.TotalCost); err != nil { + return fmt.Errorf("increment subscription usage: %w", err) + } + } else if billAmount > 0 { + if err := s.userRepo.DeductBalance(txCtx, user.ID, billAmount); err != nil { + return fmt.Errorf("deduct balance: %w", err) + } + } + if billingEntry == nil { + return nil + } + store := s.usageBillingEntryStore() + if store == nil { + return nil + } + if err := store.MarkUsageBillingEntryApplied(txCtx, billingEntry.ID); err != nil { + return fmt.Errorf("mark usage billing entry applied: %w", err) + } + return nil + }) + if applyErr != nil { + if !isSubscriptionBilling && cacheDeducted && s.billingCacheService != nil { + _ = s.billingCacheService.InvalidateUserBalance(context.Background(), user.ID) + } + s.markUsageBillingRetry(ctx, billingEntry, applyErr) + return applyErr + } + + if isSubscriptionBilling && s.billingCacheService != nil && apiKey.GroupID != nil && cost.TotalCost > 0 { + s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost) } } @@ -3491,6 +4006,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err) } } + billedAmount := 0.0 + if shouldBill { + billedAmount = billAmount + } + logOpenAIWSUsageRecorded(result, usageLog, inserted, shouldBill, billedAmount) // Update API Key rate limit usage if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil { @@ -3676,15 +4196,48 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc if len(updates) == 0 { return } + if !s.tryAcquireCodexUsageUpdateSlot() { + slog.Warn("openai_gateway.codex_usage_update_dropped", + "account_id", accountID, + "reason", "concurrency_limit_reached", + ) + return + } // Update account's Extra field asynchronously go func() { + defer s.releaseCodexUsageUpdateSlot() updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() _ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates) }() } +func (s *OpenAIGatewayService) tryAcquireCodexUsageUpdateSlot() bool { + if s == nil { + return false + } + s.codexUsageUpdateOnce.Do(func() { + s.codexUsageUpdateSem = make(chan struct{}, openAICodexUsageUpdateConcurrency) + }) + select { + case s.codexUsageUpdateSem <- struct{}{}: + return true + default: + return false + } +} + +func (s *OpenAIGatewayService) releaseCodexUsageUpdateSlot() { + if s == nil || s.codexUsageUpdateSem == nil { + return + } + select { + case <-s.codexUsageUpdateSem: + default: + } +} + func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) { if reqBody == nil { return "", false @@ -3731,6 +4284,27 @@ func deriveOpenAIReasoningEffortFromModel(model string) string { return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) } +// OpenAIRequestMeta 缓存已提取的请求元数据,避免重复解析 +type OpenAIRequestMeta struct { + Model string + Stream bool + PromptCacheKey string +} + +// extractOpenAIRequestMeta 优先从 context 读取已缓存的 meta(只读),回退到 body 解析。 +// Handler 层已完成所有字段提取(含 prompt_cache_key),此处不再修改 meta,避免并发竞态。 +func extractOpenAIRequestMeta(c *gin.Context, body []byte) (model string, stream bool, promptCacheKey string) { + if c != nil { + if cached, ok := c.Get(OpenAIRequestMetaKey); ok { + if meta, ok := cached.(*OpenAIRequestMeta); ok && meta != nil { + return meta.Model, meta.Stream, meta.PromptCacheKey + } + } + } + // 回退到原始解析(WebSocket 等其他入口) + return extractOpenAIRequestMetaFromBody(body) +} + func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) { if len(body) == 0 { return "", false, "" diff --git a/backend/internal/service/openai_gateway_service_access_token_test.go b/backend/internal/service/openai_gateway_service_access_token_test.go new file mode 100644 index 000000000..32b5c8e4e --- /dev/null +++ b/backend/internal/service/openai_gateway_service_access_token_test.go @@ -0,0 +1,90 @@ +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayServiceGetAccessToken(t *testing.T) { + t.Parallel() + + svc := &OpenAIGatewayService{} + + t.Run("nil account", func(t *testing.T) { + token, tokenType, err := svc.GetAccessToken(context.Background(), nil) + require.Error(t, err) + require.Contains(t, err.Error(), "account is nil") + require.Empty(t, token) + require.Empty(t, tokenType) + }) + + t.Run("oauth account", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "oauth-token", + }, + } + token, tokenType, err := svc.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "oauth-token", token) + require.Equal(t, "oauth", tokenType) + }) + + t.Run("oauth account trims token", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": " oauth-token-trim ", + }, + } + token, tokenType, err := svc.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "oauth-token-trim", token) + require.Equal(t, "oauth", tokenType) + }) + + t.Run("api key account", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": "sk-live-token", + }, + } + token, tokenType, err := svc.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "sk-live-token", token) + require.Equal(t, "apikey", tokenType) + }) + + t.Run("api key account trims token", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "api_key": " sk-live-trim ", + }, + } + token, tokenType, err := svc.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "sk-live-trim", token) + require.Equal(t, "apikey", tokenType) + }) + + t.Run("unsupported account type", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: "unknown", + } + token, tokenType, err := svc.GetAccessToken(context.Background(), account) + require.Error(t, err) + require.Contains(t, err.Error(), "unsupported account type") + require.Empty(t, token) + require.Empty(t, tokenType) + }) +} diff --git a/backend/internal/service/openai_gateway_service_hotpath_test.go b/backend/internal/service/openai_gateway_service_hotpath_test.go index f73c06c5e..a82e6e2a3 100644 --- a/backend/internal/service/openai_gateway_service_hotpath_test.go +++ b/backend/internal/service/openai_gateway_service_hotpath_test.go @@ -139,3 +139,126 @@ func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) { require.True(t, ok) require.Equal(t, got, cachedMap) } + +// --- extractOpenAIRequestMeta context 缓存测试 --- + +func TestExtractOpenAIRequestMeta_CacheHit(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + // 预设缓存(Handler 层已提取所有字段,包括 PromptCacheKey) + c.Set(OpenAIRequestMetaKey, &OpenAIRequestMeta{ + Model: "gpt-5", + Stream: true, + PromptCacheKey: "key-1", + }) + + body := []byte(`{"model":"gpt-4","stream":false,"prompt_cache_key":"key-other"}`) + model, stream, promptKey := extractOpenAIRequestMeta(c, body) + + // 应返回缓存值而非 body 中的值 + require.Equal(t, "gpt-5", model) + require.True(t, stream) + require.Equal(t, "key-1", promptKey) +} + +func TestExtractOpenAIRequestMeta_CacheHit_PromptCacheKeyFromHandler(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + // Handler 层已提取 PromptCacheKey,meta 设置后只读不写 + meta := &OpenAIRequestMeta{Model: "gpt-5", Stream: false, PromptCacheKey: "pk-abc"} + c.Set(OpenAIRequestMetaKey, meta) + + body := []byte(`{"model":"gpt-4","prompt_cache_key":"pk-other"}`) + + // 应返回缓存中的值(Handler 层提取),而非 body 中的值 + _, _, promptKey1 := extractOpenAIRequestMeta(c, body) + require.Equal(t, "pk-abc", promptKey1) + + // 多次调用结果一致 + _, _, promptKey2 := extractOpenAIRequestMeta(c, body) + require.Equal(t, "pk-abc", promptKey2) +} + +func TestExtractOpenAIRequestMeta_CacheHit_EmptyBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + c.Set(OpenAIRequestMetaKey, &OpenAIRequestMeta{ + Model: "gpt-5", + Stream: true, + }) + + // body 为空时不应 panic,prompt_cache_key 应为空 + model, stream, promptKey := extractOpenAIRequestMeta(c, nil) + require.Equal(t, "gpt-5", model) + require.True(t, stream) + require.Equal(t, "", promptKey) +} + +func TestExtractOpenAIRequestMeta_FallbackToBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + // 不设缓存,应回退到 body 解析 + body := []byte(`{"model":"gpt-4o","stream":true,"prompt_cache_key":"ses-2"}`) + model, stream, promptKey := extractOpenAIRequestMeta(c, body) + + require.Equal(t, "gpt-4o", model) + require.True(t, stream) + require.Equal(t, "ses-2", promptKey) +} + +func TestExtractOpenAIRequestMeta_NilContext(t *testing.T) { + body := []byte(`{"model":"gpt-4","stream":false,"prompt_cache_key":"k"}`) + model, stream, promptKey := extractOpenAIRequestMeta(nil, body) + + require.Equal(t, "gpt-4", model) + require.False(t, stream) + require.Equal(t, "k", promptKey) +} + +func TestExtractOpenAIRequestMeta_InvalidCacheType(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + // 缓存类型错误,应回退到 body 解析 + c.Set(OpenAIRequestMetaKey, "invalid-type") + + body := []byte(`{"model":"gpt-4o","stream":true}`) + model, stream, _ := extractOpenAIRequestMeta(c, body) + + require.Equal(t, "gpt-4o", model) + require.True(t, stream) +} + +func TestExtractOpenAIRequestMeta_NilCacheValue(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + c.Set(OpenAIRequestMetaKey, (*OpenAIRequestMeta)(nil)) + + body := []byte(`{"model":"gpt-5","stream":false}`) + model, stream, _ := extractOpenAIRequestMeta(c, body) + + require.Equal(t, "gpt-5", model) + require.False(t, stream) +} + +func TestOpenAIRequestMeta_Fields(t *testing.T) { + meta := &OpenAIRequestMeta{ + Model: "gpt-5", + Stream: true, + PromptCacheKey: "pk", + } + require.Equal(t, "gpt-5", meta.Model) + require.True(t, meta.Stream) + require.Equal(t, "pk", meta.PromptCacheKey) +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 4f5f7f3c1..335495e9b 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -10,6 +10,8 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" + "sync/atomic" "testing" "time" @@ -61,6 +63,54 @@ func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Co return r.ListSchedulableByPlatform(ctx, platform) } +func (r stubOpenAIAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + if len(platforms) == 0 { + return nil, nil + } + platformSet := make(map[string]struct{}, len(platforms)) + for _, p := range platforms { + platformSet[p] = struct{}{} + } + result := make([]Account, 0, len(r.accounts)) + for _, acc := range r.accounts { + if _, ok := platformSet[acc.Platform]; ok { + result = append(result, acc) + } + } + return result, nil +} + +func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) { + return r.ListSchedulableByPlatforms(ctx, platforms) +} + +type codexUsageUpdateAccountRepoStub struct { + stubOpenAIAccountRepo + + calls atomic.Int32 + entered chan struct{} + release chan struct{} + enterSig sync.Once +} + +func (r *codexUsageUpdateAccountRepoStub) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error { + r.calls.Add(1) + r.enterSig.Do(func() { + if r.entered != nil { + close(r.entered) + } + }) + if r.release == nil { + return nil + } + select { + case <-r.release: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + type stubConcurrencyCache struct { ConcurrencyCache loadBatchErr error @@ -355,6 +405,37 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre } } +func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_BoundedAsyncUpdates(t *testing.T) { + repo := &codexUsageUpdateAccountRepoStub{ + entered: make(chan struct{}), + release: make(chan struct{}), + } + svc := &OpenAIGatewayService{ + accountRepo: repo, + codexUsageUpdateSem: make(chan struct{}, 1), + } + svc.codexUsageUpdateOnce.Do(func() {}) + + snapshot := &OpenAICodexUsageSnapshot{ + UpdatedAt: time.Now().Format(time.RFC3339), + } + + svc.updateCodexUsageSnapshot(context.Background(), 1, snapshot) + + select { + case <-repo.entered: + case <-time.After(time.Second): + t.Fatal("first codex usage snapshot update did not start") + } + + // first async update is still holding the single slot + svc.updateCodexUsageSnapshot(context.Background(), 1, snapshot) + time.Sleep(80 * time.Millisecond) + require.Equal(t, int32(1), repo.calls.Load(), "slot full 时应拒绝新异步写入,避免 goroutine 无界增长") + + close(repo.release) +} + func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) { sessionHash := "session-1" repo := stubOpenAIAccountRepo{ @@ -832,6 +913,33 @@ func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing. } } +func TestOpenAISelectAccountForModelWithExclusions_EqualPriorityTieBreakRandomized(t *testing.T) { + repo := stubOpenAIAccountRepo{ + accounts: []Account{ + {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1}, + {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1}, + }, + } + cache := &stubGatewayCache{} + + calls := 0 + svc := &OpenAIGatewayService{ + accountRepo: repo, + cache: cache, + accountTieBreakIntnFn: func(n int) int { + calls++ + require.Equal(t, 2, n) + return 0 + }, + } + + acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil) + require.NoError(t, err) + require.NotNil(t, acc) + require.Equal(t, int64(2), acc.ID, "tie-break should allow later equal-priority candidate to be selected") + require.Equal(t, 1, calls, "tie-break should be invoked exactly once for two equal candidates") +} + func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) { groupID := int64(1) lastUsed := time.Now().Add(-1 * time.Hour) @@ -1189,6 +1297,53 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) { } } +func TestOpenAIBuildUpstreamRequestSetsHTTPUpstreamProfile(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)) + + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + req, err := svc.buildUpstreamRequest(context.Background(), c, account, []byte(`{}`), "token", false, "", false) + require.NoError(t, err) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(req.Context())) +} + +func TestOpenAIBuildUpstreamPassthroughRequestSetsHTTPUpstreamProfile(t *testing.T) { + gin.SetMode(gin.TestMode) + cfg := &config.Config{ + Security: config.SecurityConfig{ + URLAllowlist: config.URLAllowlistConfig{Enabled: false}, + }, + } + svc := &OpenAIGatewayService{cfg: cfg} + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`)) + c.Request.Header.Set("Content-Type", "application/json") + + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + } + + req, err := svc.buildUpstreamRequestOpenAIPassthrough(context.Background(), c, account, []byte(`{}`), "token") + require.NoError(t, err) + require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(req.Context())) +} + func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) { cfg := &config.Config{ Security: config.SecurityConfig{ @@ -1248,6 +1403,21 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) { } } +func TestOpenAIValidateUpstreamBaseURLNilConfigDefaultsToStrictHTTPS(t *testing.T) { + svc := &OpenAIGatewayService{} + + if _, err := svc.validateUpstreamBaseURL("http://example.com"); err == nil { + t.Fatalf("expected http to be rejected when config is nil") + } + normalized, err := svc.validateUpstreamBaseURL("https://example.com") + if err != nil { + t.Fatalf("expected https to pass when config is nil, got %v", err) + } + if normalized != "https://example.com" { + t.Fatalf("expected normalized https url, got %q", normalized) + } +} + // ==================== P1-08 修复:model 替换性能优化测试 ==================== func TestReplaceModelInSSELine(t *testing.T) { diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index 72f4bbb09..1e085923e 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -7,6 +7,7 @@ import ( "io" "log/slog" "net/http" + "net/url" "regexp" "sort" "strconv" @@ -14,7 +15,6 @@ import ( "time" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" - "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" ) @@ -273,13 +273,7 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi req.Header.Set("Referer", "https://sora.chatgpt.com/") req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") - client, err := httpclient.GetClient(httpclient.Options{ - ProxyURL: proxyURL, - Timeout: 120 * time.Second, - }) - if err != nil { - return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err) - } + client := newOpenAIOAuthHTTPClient(proxyURL) resp, err := client.Do(req) if err != nil { return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err) @@ -471,7 +465,7 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A } var proxyURL string - if account.ProxyID != nil { + if account.ProxyID != nil && s.proxyRepo != nil { proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID) if err == nil && proxy != nil { proxyURL = proxy.URL() @@ -536,6 +530,19 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64 return proxy.URL(), nil } +func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client { + transport := &http.Transport{} + if strings.TrimSpace(proxyURL) != "" { + if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" { + transport.Proxy = http.ProxyURL(parsed) + } + } + return &http.Client{ + Timeout: 120 * time.Second, + Transport: transport, + } +} + func normalizeOpenAIOAuthPlatform(platform string) string { switch strings.ToLower(strings.TrimSpace(platform)) { case PlatformSora: diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go index 292523288..c03352c1f 100644 --- a/backend/internal/service/openai_oauth_service_state_test.go +++ b/backend/internal/service/openai_oauth_service_state_test.go @@ -34,6 +34,29 @@ func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Contex return s.RefreshToken(ctx, refreshToken, proxyURL) } +type openaiOAuthClientRefreshStub struct { + refreshCalled int32 + lastClientID string +} + +func (s *openaiOAuthClientRefreshStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientRefreshStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "") +} + +func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + atomic.AddInt32(&s.refreshCalled, 1) + s.lastClientID = clientID + return &openai.TokenResponse{ + AccessToken: "new-at", + RefreshToken: "new-rt", + ExpiresIn: 3600, + }, nil +} + func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) { client := &openaiOAuthClientStateStub{} svc := NewOpenAIOAuthService(nil, client) @@ -55,6 +78,30 @@ func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) { require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled)) } +func TestOpenAIOAuthService_RefreshAccountToken_NilProxyRepoWithProxyID(t *testing.T) { + client := &openaiOAuthClientRefreshStub{} + svc := NewOpenAIOAuthService(nil, client) + defer svc.Stop() + + proxyID := int64(123) + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + ProxyID: &proxyID, + Credentials: map[string]any{ + "refresh_token": "rt", + "client_id": "cid", + }, + } + + tokenInfo, err := svc.RefreshAccountToken(context.Background(), account) + require.NoError(t, err) + require.NotNil(t, tokenInfo) + require.Equal(t, "new-at", tokenInfo.AccessToken) + require.Equal(t, int32(1), atomic.LoadInt32(&client.refreshCalled)) + require.Equal(t, "cid", client.lastClientID) +} + func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) { client := &openaiOAuthClientStateStub{} svc := NewOpenAIOAuthService(nil, client) diff --git a/backend/internal/service/openai_sse_zero_alloc_test.go b/backend/internal/service/openai_sse_zero_alloc_test.go new file mode 100644 index 000000000..fe853d59a --- /dev/null +++ b/backend/internal/service/openai_sse_zero_alloc_test.go @@ -0,0 +1,276 @@ +package service + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/require" +) + +// --- 包级常量验证 --- + +func TestSSEPackageLevelConstants(t *testing.T) { + require.Equal(t, []byte("[DONE]"), sseDataDone) + require.Equal(t, []byte(`"response.completed"`), sseResponseCompletedMark) +} + +func TestSSEDataDone_UsedInBytesEqual(t *testing.T) { + require.True(t, bytes.Equal([]byte("[DONE]"), sseDataDone)) + require.False(t, bytes.Equal([]byte("[done]"), sseDataDone)) + require.False(t, bytes.Equal([]byte(""), sseDataDone)) +} + +func TestSSEResponseCompletedMark_UsedInBytesContains(t *testing.T) { + data := []byte(`{"type":"response.completed","response":{"usage":{}}}`) + require.True(t, bytes.Contains(data, sseResponseCompletedMark)) + + unrelated := []byte(`{"type":"response.in_progress"}`) + require.False(t, bytes.Contains(unrelated, sseResponseCompletedMark)) +} + +// --- parseSSEUsageString 测试 --- + +func TestParseSSEUsageString_CompletedEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{} + + data := `{"type":"response.completed","response":{"usage":{"input_tokens":100,"output_tokens":50,"input_tokens_details":{"cached_tokens":20}}}}` + svc.parseSSEUsageString(data, usage) + + require.Equal(t, 100, usage.InputTokens) + require.Equal(t, 50, usage.OutputTokens) + require.Equal(t, 20, usage.CacheReadInputTokens) +} + +func TestParseSSEUsageString_NonCompletedEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 99, OutputTokens: 88} + + data := `{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}}}` + svc.parseSSEUsageString(data, usage) + + // 非 completed 事件不应修改 usage + require.Equal(t, 99, usage.InputTokens) + require.Equal(t, 88, usage.OutputTokens) +} + +func TestParseSSEUsageString_DoneEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 10} + + svc.parseSSEUsageString("[DONE]", usage) + require.Equal(t, 10, usage.InputTokens) // 不应修改 +} + +func TestParseSSEUsageString_EmptyString(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 5} + + svc.parseSSEUsageString("", usage) + require.Equal(t, 5, usage.InputTokens) // 不应修改 +} + +func TestParseSSEUsageString_NilUsage(t *testing.T) { + svc := &OpenAIGatewayService{} + + // 不应 panic + require.NotPanics(t, func() { + svc.parseSSEUsageString(`{"type":"response.completed"}`, nil) + }) +} + +func TestParseSSEUsageString_ShortData(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 7} + + // 短于 80 字节的数据直接跳过 + svc.parseSSEUsageString(`{"type":"response.completed"}`, usage) + require.Equal(t, 7, usage.InputTokens) // 不应修改 +} + +func TestParseSSEUsageString_ContainsCompletedButWrongType(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 42} + + // 包含 "response.completed" 子串但 type 字段不匹配 + data := `{"type":"response.in_progress","description":"not response.completed at all","padding":"aaaaaaaaaaaaaaaaaaaaaaaaaaaa"}` + svc.parseSSEUsageString(data, usage) + require.Equal(t, 42, usage.InputTokens) // 不应修改 +} + +func TestParseSSEUsageString_ZeroUsageValues(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 99} + + data := `{"type":"response.completed","response":{"usage":{"input_tokens":0,"output_tokens":0,"input_tokens_details":{"cached_tokens":0}}}}` + svc.parseSSEUsageString(data, usage) + + require.Equal(t, 0, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) + require.Equal(t, 0, usage.CacheReadInputTokens) +} + +func TestParseSSEUsageString_MissingUsageFields(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{} + + // response.usage 存在但缺少某些子字段 + data := `{"type":"response.completed","response":{"usage":{"input_tokens":10},"padding":"aaaaaaaaaaaaaaaaaaa"}}` + svc.parseSSEUsageString(data, usage) + + require.Equal(t, 10, usage.InputTokens) + require.Equal(t, 0, usage.OutputTokens) + require.Equal(t, 0, usage.CacheReadInputTokens) +} + +// --- parseSSEUsageBytes 与包级常量集成测试 --- + +func TestParseSSEUsageBytes_CompletedEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{} + + data := []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":200,"output_tokens":80,"input_tokens_details":{"cached_tokens":30}}}}`) + svc.parseSSEUsageBytes(data, usage) + + require.Equal(t, 200, usage.InputTokens) + require.Equal(t, 80, usage.OutputTokens) + require.Equal(t, 30, usage.CacheReadInputTokens) +} + +func TestParseSSEUsageBytes_DoneEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 10} + + svc.parseSSEUsageBytes([]byte("[DONE]"), usage) + require.Equal(t, 10, usage.InputTokens) // 不应修改 +} + +func TestParseSSEUsageBytes_EmptyData(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 5} + + svc.parseSSEUsageBytes(nil, usage) + require.Equal(t, 5, usage.InputTokens) + + svc.parseSSEUsageBytes([]byte{}, usage) + require.Equal(t, 5, usage.InputTokens) +} + +func TestParseSSEUsageBytes_NilUsage(t *testing.T) { + svc := &OpenAIGatewayService{} + + require.NotPanics(t, func() { + svc.parseSSEUsageBytes([]byte(`{"type":"response.completed"}`), nil) + }) +} + +func TestParseSSEUsageBytes_ShortData(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 7} + + svc.parseSSEUsageBytes([]byte(`{"type":"response.completed"}`), usage) + require.Equal(t, 7, usage.InputTokens) +} + +func TestParseSSEUsageBytes_NonCompletedEvent(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 99} + + data := []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}},"pad":"xxxx"}`) + svc.parseSSEUsageBytes(data, usage) + + require.Equal(t, 99, usage.InputTokens) +} + +func TestParseSSEUsageBytes_GetManyBytesExtraction(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{} + + // 验证 GetManyBytes 一次提取 3 个字段的正确性 + data := []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":111,"output_tokens":222,"input_tokens_details":{"cached_tokens":333}}}}`) + svc.parseSSEUsageBytes(data, usage) + + require.Equal(t, 111, usage.InputTokens) + require.Equal(t, 222, usage.OutputTokens) + require.Equal(t, 333, usage.CacheReadInputTokens) +} + +// --- parseSSEUsage wrapper 测试 --- + +func TestParseSSEUsage_DelegatesToString(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{} + + // 验证 parseSSEUsage 最终正确提取 usage + data := `{"type":"response.completed","response":{"usage":{"input_tokens":55,"output_tokens":66,"input_tokens_details":{"cached_tokens":77}}}}` + svc.parseSSEUsage(data, usage) + + require.Equal(t, 55, usage.InputTokens) + require.Equal(t, 66, usage.OutputTokens) + require.Equal(t, 77, usage.CacheReadInputTokens) +} + +func TestParseSSEUsage_DoneNotParsed(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 123} + + svc.parseSSEUsage("[DONE]", usage) + require.Equal(t, 123, usage.InputTokens) +} + +func TestParseSSEUsage_EmptyNotParsed(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 456} + + svc.parseSSEUsage("", usage) + require.Equal(t, 456, usage.InputTokens) +} + +// --- string 和 bytes 一致性测试 --- + +func TestParseSSEUsage_StringAndBytesConsistency(t *testing.T) { + svc := &OpenAIGatewayService{} + + completedData := `{"type":"response.completed","response":{"usage":{"input_tokens":300,"output_tokens":150,"input_tokens_details":{"cached_tokens":50}}}}` + + usageStr := &OpenAIUsage{} + svc.parseSSEUsageString(completedData, usageStr) + + usageBytes := &OpenAIUsage{} + svc.parseSSEUsageBytes([]byte(completedData), usageBytes) + + require.Equal(t, usageStr.InputTokens, usageBytes.InputTokens) + require.Equal(t, usageStr.OutputTokens, usageBytes.OutputTokens) + require.Equal(t, usageStr.CacheReadInputTokens, usageBytes.CacheReadInputTokens) +} + +func TestParseSSEUsage_StringAndBytesConsistency_NonCompleted(t *testing.T) { + svc := &OpenAIGatewayService{} + + inProgressData := `{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}},"pad":"xxx"}` + + usageStr := &OpenAIUsage{InputTokens: 10} + svc.parseSSEUsageString(inProgressData, usageStr) + + usageBytes := &OpenAIUsage{InputTokens: 10} + svc.parseSSEUsageBytes([]byte(inProgressData), usageBytes) + + // 两者都不应修改 + require.Equal(t, 10, usageStr.InputTokens) + require.Equal(t, 10, usageBytes.InputTokens) +} + +func TestParseSSEUsage_StringAndBytesConsistency_LargeTokenCounts(t *testing.T) { + svc := &OpenAIGatewayService{} + + data := `{"type":"response.completed","response":{"usage":{"input_tokens":1000000,"output_tokens":500000,"input_tokens_details":{"cached_tokens":200000}}}}` + + usageStr := &OpenAIUsage{} + svc.parseSSEUsageString(data, usageStr) + + usageBytes := &OpenAIUsage{} + svc.parseSSEUsageBytes([]byte(data), usageBytes) + + require.Equal(t, 1000000, usageStr.InputTokens) + require.Equal(t, usageStr, usageBytes) +} diff --git a/backend/internal/service/openai_sticky_compat.go b/backend/internal/service/openai_sticky_compat.go index e897debc2..0f576b664 100644 --- a/backend/internal/service/openai_sticky_compat.go +++ b/backend/internal/service/openai_sticky_compat.go @@ -35,12 +35,31 @@ func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash return "", "" } - currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized)) - sum := sha256.Sum256([]byte(normalized)) - legacyHash = hex.EncodeToString(sum[:]) + currentHash = deriveOpenAISessionHash(normalized) + legacyHash = deriveOpenAILegacySessionHash(normalized) return currentHash, legacyHash } +// deriveOpenAISessionHash returns the fast xxhash-based session hash. +func deriveOpenAISessionHash(sessionID string) string { + normalized := strings.TrimSpace(sessionID) + if normalized == "" { + return "" + } + return fmt.Sprintf("%016x", xxhash.Sum64String(normalized)) +} + +// deriveOpenAILegacySessionHash returns the SHA-256 legacy hash. +// Only call this when legacy fallback or dual-write is enabled. +func deriveOpenAILegacySessionHash(sessionID string) string { + normalized := strings.TrimSpace(sessionID) + if normalized == "" { + return "" + } + sum := sha256.Sum256([]byte(normalized)) + return hex.EncodeToString(sum[:]) +} + func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context { if ctx == nil { return nil diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index a8a6b96c5..d12a57638 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -151,9 +151,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } // 从数据库获取最新账户信息 - fresh, err := p.accountRepo.GetByID(ctx, account.ID) - if err == nil && fresh != nil { - account = fresh + if p.accountRepo != nil { + fresh, err := p.accountRepo.GetByID(ctx, account.ID) + if err == nil && fresh != nil { + account = fresh + } } expiresAt = account.GetCredentialAsTime("expires_at") if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { @@ -181,8 +183,10 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } account.Credentials = newCredentials - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr) + if p.accountRepo != nil { + if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr) + } } expiresAt = account.GetCredentialAsTime("expires_at") } @@ -233,8 +237,10 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } account.Credentials = newCredentials - if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { - slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr) + if p.accountRepo != nil { + if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil { + slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr) + } } expiresAt = account.GetCredentialAsTime("expires_at") } diff --git a/backend/internal/service/openai_token_provider_nil_repo_test.go b/backend/internal/service/openai_token_provider_nil_repo_test.go new file mode 100644 index 000000000..f2e7e91e8 --- /dev/null +++ b/backend/internal/service/openai_token_provider_nil_repo_test.go @@ -0,0 +1,101 @@ +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openAITokenCacheNilRepoStub struct { + token string + lockCalls atomic.Int32 + setCalls atomic.Int32 + getCalls atomic.Int32 + lockEnabled bool +} + +func (s *openAITokenCacheNilRepoStub) GetAccessToken(context.Context, string) (string, error) { + s.getCalls.Add(1) + return s.token, nil +} + +func (s *openAITokenCacheNilRepoStub) SetAccessToken(context.Context, string, string, time.Duration) error { + s.setCalls.Add(1) + return nil +} + +func (s *openAITokenCacheNilRepoStub) DeleteAccessToken(context.Context, string) error { + return nil +} + +func (s *openAITokenCacheNilRepoStub) AcquireRefreshLock(context.Context, string, time.Duration) (bool, error) { + s.lockCalls.Add(1) + return s.lockEnabled, nil +} + +func (s *openAITokenCacheNilRepoStub) ReleaseRefreshLock(context.Context, string) error { + return nil +} + +type openAIOAuthClientNilRepoStub struct{} + +func (s *openAIOAuthClientNilRepoStub) ExchangeCode( + context.Context, + string, + string, + string, + string, + string, +) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openAIOAuthClientNilRepoStub) RefreshToken( + context.Context, + string, + string, +) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openAIOAuthClientNilRepoStub) RefreshTokenWithClientID( + context.Context, + string, + string, + string, +) (*openai.TokenResponse, error) { + return &openai.TokenResponse{ + AccessToken: "fresh-token", + RefreshToken: "fresh-refresh-token", + ExpiresIn: 3600, + }, nil +} + +func TestOpenAITokenProviderRefreshWithNilAccountRepo(t *testing.T) { + cache := &openAITokenCacheNilRepoStub{lockEnabled: true} + oauthSvc := NewOpenAIOAuthService(nil, &openAIOAuthClientNilRepoStub{}) + provider := NewOpenAITokenProvider(nil, cache, oauthSvc) + + expiresAt := time.Now().Add(30 * time.Second).Format(time.RFC3339) + account := &Account{ + ID: 3001, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "stale-token", + "refresh_token": "refresh-token", + "expires_at": expiresAt, + }, + } + + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "fresh-token", token) + require.Equal(t, int32(1), cache.lockCalls.Load()) + require.Equal(t, int32(1), cache.setCalls.Load()) +} diff --git a/backend/internal/service/openai_ws_client.go b/backend/internal/service/openai_ws_client.go index 9f3c47b7b..f46eda8a5 100644 --- a/backend/internal/service/openai_ws_client.go +++ b/backend/internal/service/openai_ws_client.go @@ -11,6 +11,7 @@ import ( "sync/atomic" "time" + openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" coderws "github.com/coder/websocket" "github.com/coder/websocket/wsjson" ) @@ -77,8 +78,9 @@ func (d *coderOpenAIWSClientDialer) Dial( } opts := &coderws.DialOptions{ - HTTPHeader: cloneHeader(headers), - CompressionMode: coderws.CompressionContextTakeover, + HTTPHeader: cloneHeader(headers), + // 高频长连接场景优先降低内存/CPU 抖动,避免 context takeover 带来的状态累积。 + CompressionMode: coderws.CompressionNoContextTakeover, } if proxy := strings.TrimSpace(proxyURL); proxy != "" { proxyClient, err := d.proxyHTTPClient(proxy) @@ -230,6 +232,9 @@ func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransport } } +// 编译期断言:coderOpenAIWSClientConn 必须实现 openai_ws_v2.FrameConn 接口(passthrough 路径依赖)。 +var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil) + type coderOpenAIWSClientConn struct { conn *coderws.Conn } @@ -264,6 +269,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro } } +func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.conn == nil { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + msgType, payload, err := c.conn.Read(ctx) + if err != nil { + return coderws.MessageText, nil, err + } + return msgType, payload, nil +} + +func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Write(ctx, msgType, payload) +} + func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error { if c == nil || c.conn == nil { return errOpenAIWSConnClosed diff --git a/backend/internal/service/openai_ws_client_preempt_test.go b/backend/internal/service/openai_ws_client_preempt_test.go new file mode 100644 index 000000000..cbbc149e9 --- /dev/null +++ b/backend/internal/service/openai_ws_client_preempt_test.go @@ -0,0 +1,1073 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "strconv" + "testing" + + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// errOpenAIWSClientPreempted 哨兵错误基础测试 +// --------------------------------------------------------------------------- + +func TestErrOpenAIWSClientPreempted_NotNil(t *testing.T) { + t.Parallel() + require.NotNil(t, errOpenAIWSClientPreempted) + require.Contains(t, errOpenAIWSClientPreempted.Error(), "client preempted") +} + +func TestErrOpenAIWSClientPreempted_ErrorsIs(t *testing.T) { + t.Parallel() + + // 直接匹配 + require.True(t, errors.Is(errOpenAIWSClientPreempted, errOpenAIWSClientPreempted)) + + // 包裹后仍可匹配 + wrapped := fmt.Errorf("outer: %w", errOpenAIWSClientPreempted) + require.True(t, errors.Is(wrapped, errOpenAIWSClientPreempted)) + + // 不同错误不匹配 + require.False(t, errors.Is(errors.New("other"), errOpenAIWSClientPreempted)) +} + +func TestErrOpenAIWSClientPreempted_WrapInTurnError(t *testing.T) { + t.Parallel() + + // 用 wrapOpenAIWSIngressTurnErrorWithPartial 包裹后 errors.Is 仍能识别 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + require.Error(t, turnErr) + require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted)) +} + +func TestErrOpenAIWSClientPreempted_WrapInTurnError_WithPartialResult(t *testing.T) { + t.Parallel() + + partial := &OpenAIForwardResult{ + RequestID: "resp_preempt_partial", + Usage: OpenAIUsage{ + InputTokens: 100, + OutputTokens: 50, + }, + } + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, + partial, + ) + require.Error(t, turnErr) + require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted)) + + // 验证 partial result 可提取 + got, ok := OpenAIWSIngressTurnPartialResult(turnErr) + require.True(t, ok) + require.NotNil(t, got) + require.Equal(t, partial.RequestID, got.RequestID) + require.Equal(t, partial.Usage.InputTokens, got.Usage.InputTokens) +} + +// --------------------------------------------------------------------------- +// classifyOpenAIWSIngressTurnAbortReason 对 client_preempted 的识别测试 +// --------------------------------------------------------------------------- + +func TestClassifyAbortReason_ClientPreempted_Direct(t *testing.T) { + t.Parallel() + + // 直接哨兵错误 + reason, expected := classifyOpenAIWSIngressTurnAbortReason(errOpenAIWSClientPreempted) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) +} + +func TestClassifyAbortReason_ClientPreempted_WrappedInTurnError(t *testing.T) { + t.Parallel() + + // 包裹在 turnError 中 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) +} + +func TestClassifyAbortReason_ClientPreempted_WrappedInTurnError_WroteDownstream(t *testing.T) { + t.Parallel() + + // 包裹在 turnError 中,wroteDownstream=true + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, + &OpenAIForwardResult{RequestID: "resp_partial"}, + ) + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) +} + +func TestClassifyAbortReason_ClientPreempted_DoubleWrapped(t *testing.T) { + t.Parallel() + + // 多层 fmt.Errorf 包裹 + inner := fmt.Errorf("relay failed: %w", errOpenAIWSClientPreempted) + reason, expected := classifyOpenAIWSIngressTurnAbortReason(inner) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) +} + +func TestClassifyAbortReason_ClientPreempted_NotConfusedWithOther(t *testing.T) { + t.Parallel() + + // 确保其他错误不会被误分类为 client_preempted + others := []error{ + errors.New("client preempted"), // 文本相同但不是同一哨兵 + context.Canceled, // context 取消 + io.EOF, // 客户端断连 + errors.New("random error"), // 随机错误 + } + + for _, err := range others { + reason, _ := classifyOpenAIWSIngressTurnAbortReason(err) + require.NotEqual(t, openAIWSIngressTurnAbortReasonClientPreempted, reason, + "error %q should not classify as client_preempted", err) + } +} + +// --------------------------------------------------------------------------- +// openAIWSIngressTurnAbortDispositionForReason 对 ClientPreempted 的处置测试 +// --------------------------------------------------------------------------- + +func TestDisposition_ClientPreempted_IsContinueTurn(t *testing.T) { + t.Parallel() + + disposition := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) +} + +func TestDisposition_ClientPreempted_SameAsPreviousResponse(t *testing.T) { + t.Parallel() + + // client_preempted 与 previous_response_not_found 应有相同的处置 + prevDisp := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonPreviousResponse) + preemptDisp := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted) + require.Equal(t, prevDisp, preemptDisp) +} + +func TestDisposition_AllContinueTurnReasons(t *testing.T) { + t.Parallel() + + // 验证所有应归为 ContinueTurn 的 reason 列表完整且正确 + continueTurnReasons := []openAIWSIngressTurnAbortReason{ + openAIWSIngressTurnAbortReasonPreviousResponse, + openAIWSIngressTurnAbortReasonToolOutput, + openAIWSIngressTurnAbortReasonUpstreamError, + openAIWSIngressTurnAbortReasonClientPreempted, + } + + for _, reason := range continueTurnReasons { + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition, + "reason %q should be ContinueTurn", reason) + } +} + +func TestDisposition_ClientPreempted_NotCloseGracefully(t *testing.T) { + t.Parallel() + + disposition := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted) + require.NotEqual(t, openAIWSIngressTurnAbortDispositionCloseGracefully, disposition) + require.NotEqual(t, openAIWSIngressTurnAbortDispositionFailRequest, disposition) +} + +// --------------------------------------------------------------------------- +// 端到端 classify → disposition 链路测试 +// --------------------------------------------------------------------------- + +func TestClientPreempted_ClassifyToDisposition_EndToEnd(t *testing.T) { + t.Parallel() + + // 模拟 sendAndRelay 返回 client_preempted 错误的完整链路 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + + // 1. classify + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) + + // 2. disposition + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) + + // 3. wroteDownstream + require.False(t, openAIWSIngressTurnWroteDownstream(turnErr)) +} + +func TestClientPreempted_ClassifyToDisposition_WroteDownstream(t *testing.T) { + t.Parallel() + + partial := &OpenAIForwardResult{ + RequestID: "resp_half", + Usage: OpenAIUsage{ + InputTokens: 200, + OutputTokens: 100, + }, + } + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, + partial, + ) + + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) + + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) + + require.True(t, openAIWSIngressTurnWroteDownstream(turnErr)) + + got, ok := OpenAIWSIngressTurnPartialResult(turnErr) + require.True(t, ok) + require.Equal(t, "resp_half", got.RequestID) +} + +// --------------------------------------------------------------------------- +// ContinueTurn 分支对 client_preempted 的特殊行为验证 +// --------------------------------------------------------------------------- + +func TestClientPreempted_ShouldNotSendErrorEvent(t *testing.T) { + t.Parallel() + + // 核心语义:client_preempted 时客户端已发出新请求,不需要旧 turn 的 error 事件。 + // 验证 abortReason 为 client_preempted 时不应产生 error 通知。 + abortReason := openAIWSIngressTurnAbortReasonClientPreempted + + // 模拟 ContinueTurn 分支的判断逻辑 + shouldSendError := abortReason != openAIWSIngressTurnAbortReasonClientPreempted + require.False(t, shouldSendError, "client_preempted 不应发送 error 事件") +} + +func TestClientPreempted_ShouldNotClearLastResponseID(t *testing.T) { + t.Parallel() + + // 核心语义:被抢占的 turn 未完成,上一轮 response_id 仍有效供新 turn 续链。 + // 验证 abortReason 为 client_preempted 时不应调用 clearSessionLastResponseID。 + abortReason := openAIWSIngressTurnAbortReasonClientPreempted + + shouldClearLastResponseID := abortReason != openAIWSIngressTurnAbortReasonClientPreempted + require.False(t, shouldClearLastResponseID, + "client_preempted 不应清除 lastResponseID") +} + +func TestNonPreempted_ContinueTurn_ShouldSendErrorAndClearID(t *testing.T) { + t.Parallel() + + // 对照测试:非 client_preempted 的 ContinueTurn reason 应正常发送 error 并清除 ID + otherReasons := []openAIWSIngressTurnAbortReason{ + openAIWSIngressTurnAbortReasonPreviousResponse, + openAIWSIngressTurnAbortReasonToolOutput, + openAIWSIngressTurnAbortReasonUpstreamError, + } + + for _, reason := range otherReasons { + shouldSendError := reason != openAIWSIngressTurnAbortReasonClientPreempted + shouldClearID := reason != openAIWSIngressTurnAbortReasonClientPreempted + require.True(t, shouldSendError, + "reason %q (non-preempted) should send error event", reason) + require.True(t, shouldClearID, + "reason %q (non-preempted) should clear lastResponseID", reason) + } +} + +// --------------------------------------------------------------------------- +// ContinueTurn abort 路径中 client_preempted 的 error 事件格式验证 +// --------------------------------------------------------------------------- + +func TestClientPreempted_ErrorEventNotGenerated(t *testing.T) { + t.Parallel() + + // 在实际的 ContinueTurn 分支中,client_preempted 分支根本不会构造 error 事件。 + // 此测试验证如果误走错误路径(防御性),error 事件格式仍然正确。 + abortReason := openAIWSIngressTurnAbortReasonClientPreempted + abortMessage := "turn failed: " + string(abortReason) + + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`) + + var parsed map[string]any + err := json.Unmarshal(errorEvent, &parsed) + require.NoError(t, err, "hypothetical error event should be valid JSON") + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "client_preempted", errorObj["code"]) + require.Contains(t, errorObj["message"], "client_preempted") +} + +// --------------------------------------------------------------------------- +// openAIWSIngressTurnAbortReason 常量值验证 +// --------------------------------------------------------------------------- + +func TestClientPreempted_ReasonStringValue(t *testing.T) { + t.Parallel() + + require.Equal(t, openAIWSIngressTurnAbortReason("client_preempted"), + openAIWSIngressTurnAbortReasonClientPreempted) +} + +// --------------------------------------------------------------------------- +// classifyOpenAIWSIngressTurnAbortReason 完整 table-driven 测试(含 client_preempted) +// --------------------------------------------------------------------------- + +func TestClassifyAbortReason_AllReasons_IncludeClientPreempted(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + wantReason openAIWSIngressTurnAbortReason + wantExpected bool + }{ + { + name: "client_preempted_sentinel", + err: errOpenAIWSClientPreempted, + wantReason: openAIWSIngressTurnAbortReasonClientPreempted, + wantExpected: true, + }, + { + name: "client_preempted_wrapped_in_turn_error", + err: wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ), + wantReason: openAIWSIngressTurnAbortReasonClientPreempted, + wantExpected: true, + }, + { + name: "client_preempted_wrapped_in_turn_error_wrote_downstream", + err: wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, + &OpenAIForwardResult{RequestID: "resp_x"}, + ), + wantReason: openAIWSIngressTurnAbortReasonClientPreempted, + wantExpected: true, + }, + { + name: "client_preempted_double_wrapped", + err: fmt.Errorf("relay: %w", errOpenAIWSClientPreempted), + wantReason: openAIWSIngressTurnAbortReasonClientPreempted, + wantExpected: true, + }, + { + name: "previous_response_not_confused_with_preempt", + err: wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("not found"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonPreviousResponse, + wantExpected: true, + }, + { + name: "tool_output_not_confused_with_preempt", + err: wrapOpenAIWSIngressTurnError( + openAIWSIngressStageToolOutputNotFound, + errors.New("tool output not found"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonToolOutput, + wantExpected: true, + }, + { + name: "context_canceled_not_preempted", + err: context.Canceled, + wantReason: openAIWSIngressTurnAbortReasonContextCanceled, + wantExpected: true, + }, + { + name: "eof_not_preempted", + err: io.EOF, + wantReason: openAIWSIngressTurnAbortReasonClientClosed, + wantExpected: true, + }, + { + name: "ws_normal_closure_not_preempted", + err: coderws.CloseError{Code: coderws.StatusNormalClosure}, + wantReason: openAIWSIngressTurnAbortReasonClientClosed, + wantExpected: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + reason, expected := classifyOpenAIWSIngressTurnAbortReason(tt.err) + require.Equal(t, tt.wantReason, reason) + require.Equal(t, tt.wantExpected, expected) + }) + } +} + +// --------------------------------------------------------------------------- +// classify 优先级测试:client_preempted 在 context.Canceled 之前 +// --------------------------------------------------------------------------- + +func TestClassifyAbortReason_ClientPreempted_PriorityOverContextCanceled(t *testing.T) { + t.Parallel() + + // errOpenAIWSClientPreempted 不会同时匹配 context.Canceled, + // 但若将来有包裹 context.Canceled 的情况,client_preempted 检测应在前。 + reason, _ := classifyOpenAIWSIngressTurnAbortReason(errOpenAIWSClientPreempted) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason, + "client_preempted 检测应优先于 context.Canceled") +} + +// --------------------------------------------------------------------------- +// openAIWSIngressTurnAbortDispositionForReason table-driven 测试(含 client_preempted) +// --------------------------------------------------------------------------- + +func TestDisposition_AllReasons_IncludeClientPreempted(t *testing.T) { + t.Parallel() + + tests := []struct { + reason openAIWSIngressTurnAbortReason + wantDisp openAIWSIngressTurnAbortDisposition + }{ + {openAIWSIngressTurnAbortReasonPreviousResponse, openAIWSIngressTurnAbortDispositionContinueTurn}, + {openAIWSIngressTurnAbortReasonToolOutput, openAIWSIngressTurnAbortDispositionContinueTurn}, + {openAIWSIngressTurnAbortReasonUpstreamError, openAIWSIngressTurnAbortDispositionContinueTurn}, + {openAIWSIngressTurnAbortReasonClientPreempted, openAIWSIngressTurnAbortDispositionContinueTurn}, + {openAIWSIngressTurnAbortReasonContextCanceled, openAIWSIngressTurnAbortDispositionCloseGracefully}, + {openAIWSIngressTurnAbortReasonClientClosed, openAIWSIngressTurnAbortDispositionCloseGracefully}, + {openAIWSIngressTurnAbortReasonUnknown, openAIWSIngressTurnAbortDispositionFailRequest}, + {openAIWSIngressTurnAbortReasonContextDeadline, openAIWSIngressTurnAbortDispositionFailRequest}, + {openAIWSIngressTurnAbortReasonWriteUpstream, openAIWSIngressTurnAbortDispositionFailRequest}, + {openAIWSIngressTurnAbortReasonReadUpstream, openAIWSIngressTurnAbortDispositionFailRequest}, + {openAIWSIngressTurnAbortReasonWriteClient, openAIWSIngressTurnAbortDispositionFailRequest}, + {openAIWSIngressTurnAbortReasonContinuationUnavailable, openAIWSIngressTurnAbortDispositionFailRequest}, + } + + for _, tt := range tests { + tt := tt + t.Run(string(tt.reason), func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.wantDisp, openAIWSIngressTurnAbortDispositionForReason(tt.reason)) + }) + } +} + +// --------------------------------------------------------------------------- +// isOpenAIWSIngressTurnRetryable 与 client_preempted 的交互 +// --------------------------------------------------------------------------- + +func TestIsRetryable_ClientPreempted_NotRetryable(t *testing.T) { + t.Parallel() + + // client_preempted 有专门的恢复路径(ContinueTurn),不走通用重试 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + require.False(t, isOpenAIWSIngressTurnRetryable(turnErr), + "client_preempted 不应被标记为 retryable") +} + +func TestIsRetryable_ClientPreempted_WroteDownstream(t *testing.T) { + t.Parallel() + + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, + nil, + ) + require.False(t, isOpenAIWSIngressTurnRetryable(turnErr), + "client_preempted wroteDownstream=true 不应被标记为 retryable") +} + +// --------------------------------------------------------------------------- +// sendAndRelay 中 clientMsgCh / clientReadErrCh 行为的单元级测试 +// --------------------------------------------------------------------------- + +func TestClientMsgCh_BufferedOne(t *testing.T) { + t.Parallel() + + // 验证 clientMsgCh(buffered 1) 的语义:goroutine 在 sendAndRelay 返回到 + // advanceToNextClientTurn 的间隙不阻塞 + ch := make(chan []byte, 1) + + // 非阻塞写入 + select { + case ch <- []byte(`{"type":"response.create"}`): + // ok + default: + t.Fatal("buffered(1) channel should not block on first write") + } + + // 第二次写入应阻塞 + select { + case ch <- []byte(`{"type":"response.create"}`): + t.Fatal("buffered(1) channel should block on second write") + default: + // expected + } +} + +func TestClientReadErrCh_BufferedOne(t *testing.T) { + t.Parallel() + + ch := make(chan error, 1) + + // 非阻塞写入 + select { + case ch <- io.EOF: + default: + t.Fatal("buffered(1) channel should not block on first write") + } + + // 第二次写入应阻塞 + select { + case ch <- io.EOF: + t.Fatal("buffered(1) channel should block on second write") + default: + // expected + } +} + +func TestClientMsgCh_CloseSignalsClosed(t *testing.T) { + t.Parallel() + + ch := make(chan []byte, 1) + close(ch) + + msg, ok := <-ch + require.False(t, ok, "closed channel should return ok=false") + require.Nil(t, msg) +} + +// --------------------------------------------------------------------------- +// 客户端抢占暂存(nextClientPreemptedPayload)行为测试 +// --------------------------------------------------------------------------- + +func TestPreemptedPayload_ConsumedOnce(t *testing.T) { + t.Parallel() + + // 模拟 advanceToNextClientTurn 中预存消息的消费行为 + var nextPreempted []byte + nextPreempted = []byte(`{"type":"response.create","model":"gpt-5.1"}`) + + // 第一次消费 + require.NotNil(t, nextPreempted) + msg := nextPreempted + nextPreempted = nil + + require.Equal(t, `{"type":"response.create","model":"gpt-5.1"}`, string(msg)) + require.Nil(t, nextPreempted, "消费后应置空") +} + +func TestPreemptedPayload_NilFallsBackToChannel(t *testing.T) { + t.Parallel() + + // 模拟 advanceToNextClientTurn 中无预存消息时走 channel + var nextPreempted []byte + clientMsgCh := make(chan []byte, 1) + clientMsgCh <- []byte(`{"type":"response.create","model":"gpt-5.1"}`) + + var nextClientMessage []byte + if nextPreempted != nil { + nextClientMessage = nextPreempted + } else { + msg, ok := <-clientMsgCh + require.True(t, ok) + nextClientMessage = msg + } + + require.Equal(t, `{"type":"response.create","model":"gpt-5.1"}`, string(nextClientMessage)) +} + +// --------------------------------------------------------------------------- +// sendAndRelay select 路径:pumpEventCh 关闭 → goto pumpClosed +// --------------------------------------------------------------------------- + +func TestSelectLoop_PumpClosed_GoToPumpClosed(t *testing.T) { + t.Parallel() + + // 模拟 pumpEventCh 关闭时的行为 + pumpEventCh := make(chan openAIWSUpstreamPumpEvent) + close(pumpEventCh) + + evt, ok := <-pumpEventCh + require.False(t, ok, "closed pumpEventCh should return ok=false") + require.Nil(t, evt.message) + require.Nil(t, evt.err) +} + +// --------------------------------------------------------------------------- +// sendAndRelay select 路径:clientMsgCh 收到消息 → client preempt +// --------------------------------------------------------------------------- + +func TestSelectLoop_ClientPreempt_ReturnsCorrectError(t *testing.T) { + t.Parallel() + + // 模拟 select 中收到客户端抢占消息后生成的 turnError + preemptPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + + // 模拟 sendAndRelay 返回的错误 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, // buildPartialResult 在没有 usage 时返回 nil + ) + + // 验证错误分类 + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) + + // 验证处置 + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) + + // 验证预存消息可供 advanceToNextClientTurn 使用 + require.NotEmpty(t, preemptPayload) +} + +func TestSelectLoop_ClientPreempt_WithPartialUsage(t *testing.T) { + t.Parallel() + + // 模拟上游已发送部分 token 后被客户端抢占 + partial := &OpenAIForwardResult{ + RequestID: "resp_interrupted", + Usage: OpenAIUsage{ + InputTokens: 500, + OutputTokens: 200, + }, + } + + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + true, // 已写过下游 + partial, + ) + + require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted)) + require.True(t, openAIWSIngressTurnWroteDownstream(turnErr)) + + got, ok := OpenAIWSIngressTurnPartialResult(turnErr) + require.True(t, ok) + require.Equal(t, "resp_interrupted", got.RequestID) + require.Equal(t, 500, got.Usage.InputTokens) + require.Equal(t, 200, got.Usage.OutputTokens) +} + +// --------------------------------------------------------------------------- +// sendAndRelay select 路径:clientMsgCh 关闭 → nil channel +// --------------------------------------------------------------------------- + +func TestSelectLoop_ClientMsgChClosed_NilChannelPreventsReselect(t *testing.T) { + t.Parallel() + + // clientMsgCh 关闭后被设为 nil,后续 select 不应再选中它 + var clientMsgCh chan []byte + clientMsgCh = make(chan []byte, 1) + close(clientMsgCh) + + // 第一次读取:closed + _, ok := <-clientMsgCh + require.False(t, ok) + + // 设为 nil + clientMsgCh = nil + + // nil channel 上的 select 永远不会被选中(不会 panic) + select { + case <-clientMsgCh: + t.Fatal("nil channel should never be selected") + default: + // expected: nil channel 不参与 select + } +} + +// --------------------------------------------------------------------------- +// sendAndRelay select 路径:clientReadErrCh 客户端断连 +// --------------------------------------------------------------------------- + +func TestSelectLoop_ClientReadErr_DisconnectSetsDrain(t *testing.T) { + t.Parallel() + + // 模拟客户端断连读取错误的分类行为 + disconnectErrors := []error{ + io.EOF, + coderws.CloseError{Code: coderws.StatusNormalClosure}, + coderws.CloseError{Code: coderws.StatusGoingAway}, + } + + for _, readErr := range disconnectErrors { + require.True(t, isOpenAIWSClientDisconnectError(readErr), + "error %v should be classified as client disconnect", readErr) + } +} + +func TestSelectLoop_ClientReadErr_NonDisconnect(t *testing.T) { + t.Parallel() + + // 非断连错误不应触发 drain + nonDisconnectErrors := []error{ + errors.New("tls handshake timeout"), + coderws.CloseError{Code: coderws.StatusPolicyViolation}, + } + + for _, readErr := range nonDisconnectErrors { + require.False(t, isOpenAIWSClientDisconnectError(readErr), + "error %v should not be classified as client disconnect", readErr) + } +} + +func TestSelectLoop_ClientReadErr_NilChannelsAfterError(t *testing.T) { + t.Parallel() + + // 模拟收到 clientReadErrCh 后将两个 channel 置 nil + clientMsgCh := make(chan []byte, 1) + clientReadErrCh := make(chan error, 1) + + clientReadErrCh <- io.EOF + + // 消费错误 + readErr := <-clientReadErrCh + require.Error(t, readErr) + + // 模拟置空(实际代码中 select case 后的操作) + var nilMsgCh chan []byte + var nilErrCh chan error + nilMsgCh = nil + nilErrCh = nil + + // 验证 nil channel 行为 + _ = clientMsgCh // unused in this test + + select { + case <-nilMsgCh: + t.Fatal("nil channel should never be selected") + case <-nilErrCh: + t.Fatal("nil channel should never be selected") + default: + // expected + } +} + +func TestAdvanceConsumePendingClientReadErr(t *testing.T) { + t.Parallel() + + require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(nil)) + + var pendingErr error + require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(&pendingErr)) + + sourceErr := errors.New("custom read error") + pendingErr = sourceErr + + gotErr := openAIWSAdvanceConsumePendingClientReadErr(&pendingErr) + require.Error(t, gotErr) + require.ErrorIs(t, gotErr, sourceErr) + require.Nil(t, pendingErr, "pending error should be consumed once") + require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(&pendingErr)) +} + +func TestAdvanceClientReadUnavailable(t *testing.T) { + t.Parallel() + + var nilMsgCh chan []byte + var nilErrCh chan error + require.True(t, openAIWSAdvanceClientReadUnavailable(nilMsgCh, nilErrCh)) + + msgCh := make(chan []byte, 1) + require.False(t, openAIWSAdvanceClientReadUnavailable(msgCh, nilErrCh)) + + errCh := make(chan error, 1) + require.False(t, openAIWSAdvanceClientReadUnavailable(nilMsgCh, errCh)) + require.False(t, openAIWSAdvanceClientReadUnavailable(msgCh, errCh)) +} + +// --------------------------------------------------------------------------- +// advanceToNextClientTurn channel 读取路径测试 +// --------------------------------------------------------------------------- + +func TestAdvance_ClientMsgCh_ClosedReturnsExit(t *testing.T) { + t.Parallel() + + // clientMsgCh 关闭意味着客户端读取 goroutine 已退出,应返回 exit=true + ch := make(chan []byte, 1) + close(ch) + + _, ok := <-ch + require.False(t, ok, "should signal goroutine exit") +} + +func TestAdvance_ClientReadErrCh_DisconnectReturnsExit(t *testing.T) { + t.Parallel() + + // 断连错误应返回 exit=true + ch := make(chan error, 1) + ch <- io.EOF + + readErr := <-ch + require.True(t, isOpenAIWSClientDisconnectError(readErr)) +} + +func TestAdvance_ClientReadErrCh_NonDisconnectReturnsError(t *testing.T) { + t.Parallel() + + // 非断连错误应返回 error + ch := make(chan error, 1) + errCustom := errors.New("custom read error") + ch <- errCustom + + readErr := <-ch + require.False(t, isOpenAIWSClientDisconnectError(readErr)) + require.Equal(t, errCustom, readErr) +} + +// --------------------------------------------------------------------------- +// 持久客户端读取 goroutine 行为测试 +// --------------------------------------------------------------------------- + +func TestPersistentReader_NormalMessage(t *testing.T) { + t.Parallel() + + // 模拟正常消息的推送和消费 + clientMsgCh := make(chan []byte, 1) + + // 模拟 goroutine 写入 + go func() { + clientMsgCh <- []byte(`{"type":"response.create"}`) + }() + + msg := <-clientMsgCh + require.Equal(t, `{"type":"response.create"}`, string(msg)) +} + +func TestPersistentReader_ErrorSendsToErrCh(t *testing.T) { + t.Parallel() + + clientReadErrCh := make(chan error, 1) + + // 模拟 goroutine 发送错误 + go func() { + clientReadErrCh <- io.EOF + }() + + readErr := <-clientReadErrCh + require.Equal(t, io.EOF, readErr) +} + +func TestPersistentReader_ContextCancel(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithCancel(context.Background()) + clientMsgCh := make(chan []byte, 1) + + // 填满 buffer + clientMsgCh <- []byte("first") + + // 模拟 goroutine 尝试写入已满的 channel + done := make(chan struct{}) + go func() { + defer close(done) + select { + case clientMsgCh <- []byte("second"): + // 不应到达 + case <-ctx.Done(): + // 正确退出 + return + } + }() + + // 取消 context + cancel() + <-done +} + +func TestPersistentReader_ClosesMsgChOnExit(t *testing.T) { + t.Parallel() + + clientMsgCh := make(chan []byte, 1) + + // 模拟 goroutine 退出时关闭 channel + go func() { + defer close(clientMsgCh) + // 模拟读取错误后退出 + }() + + // 等待 channel 关闭 + _, ok := <-clientMsgCh + require.False(t, ok, "channel should be closed when goroutine exits") +} + +// --------------------------------------------------------------------------- +// client_preempted 与其他 abort reason 的正交性验证 +// --------------------------------------------------------------------------- + +func TestClientPreempted_OrthogonalWithPreviousResponseNotFound(t *testing.T) { + t.Parallel() + + preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + prevErr := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("not found"), + false, + ) + + // client_preempted 不会被误判为 previous_response_not_found + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(preemptErr)) + // previous_response_not_found 不会被误判为 client_preempted + require.False(t, errors.Is(prevErr, errOpenAIWSClientPreempted)) +} + +func TestClientPreempted_OrthogonalWithToolOutputNotFound(t *testing.T) { + t.Parallel() + + preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + toolErr := wrapOpenAIWSIngressTurnError( + openAIWSIngressStageToolOutputNotFound, + errors.New("tool output not found"), + false, + ) + + require.False(t, isOpenAIWSIngressToolOutputNotFound(preemptErr)) + require.False(t, errors.Is(toolErr, errOpenAIWSClientPreempted)) +} + +func TestClientPreempted_OrthogonalWithUpstreamError(t *testing.T) { + t.Parallel() + + preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + upstreamErr := wrapOpenAIWSIngressTurnError( + "upstream_error_event", + errors.New("upstream error"), + false, + ) + + require.False(t, isOpenAIWSIngressUpstreamErrorEvent(preemptErr)) + require.False(t, errors.Is(upstreamErr, errOpenAIWSClientPreempted)) +} + +func TestClientPreempted_OrthogonalWithContinuationUnavailable(t *testing.T) { + t.Parallel() + + preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + require.False(t, isOpenAIWSContinuationUnavailableCloseError(preemptErr)) +} + +func TestClientPreempted_NotClientDisconnect(t *testing.T) { + t.Parallel() + + require.False(t, isOpenAIWSClientDisconnectError(errOpenAIWSClientPreempted), + "client_preempted should not be classified as client disconnect") +} + +// --------------------------------------------------------------------------- +// recordOpenAIWSTurnAbort 指标兼容性测试 +// --------------------------------------------------------------------------- + +func TestClientPreempted_RecordAbortArgs(t *testing.T) { + t.Parallel() + + // 验证 classify 返回的 (reason, expected) 值与 recordOpenAIWSTurnAbort 兼容 + turnErr := wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", + errOpenAIWSClientPreempted, + false, + nil, + ) + + reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr) + require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason) + require.True(t, expected) + + // expected=true 表示这是预期行为,不应触发告警 + assert.True(t, expected, "client_preempted 应标记为 expected,不触发告警") +} + +// --------------------------------------------------------------------------- +// shouldFlushOpenAIWSBufferedEventsOnError 与 client_preempted 场景 +// --------------------------------------------------------------------------- + +func TestShouldFlushBufferedEvents_ClientPreempted(t *testing.T) { + t.Parallel() + + // client_preempted 场景下 clientDisconnected=false(客户端仍在), + // 是否 flush 取决于 reqStream 和 wroteDownstream + tests := []struct { + name string + reqStream bool + wroteDownstream bool + wantFlush bool + }{ + {"stream_wrote", true, true, true}, + {"stream_not_wrote", true, false, false}, + {"not_stream_wrote", false, true, false}, + {"not_stream_not_wrote", false, false, false}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := shouldFlushOpenAIWSBufferedEventsOnError(tt.reqStream, tt.wroteDownstream, false) + require.Equal(t, tt.wantFlush, got) + }) + } +} diff --git a/backend/internal/service/openai_ws_client_test.go b/backend/internal/service/openai_ws_client_test.go index a88d62665..fd8d8c76f 100644 --- a/backend/internal/service/openai_ws_client_test.go +++ b/backend/internal/service/openai_ws_client_test.go @@ -1,11 +1,15 @@ package service import ( + "context" "fmt" "net/http" + "net/http/httptest" + "strings" "testing" "time" + coderws "github.com/coder/websocket" "github.com/stretchr/testify/require" ) @@ -110,3 +114,146 @@ func TestCoderOpenAIWSClientDialer_ProxyTransportTLSHandshakeTimeout(t *testing. require.NotNil(t, transport) require.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout) } + +func TestCoderOpenAIWSClientDialer_Dial_EmptyURL(t *testing.T) { + t.Parallel() + + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + conn, status, headers, err := impl.Dial(context.Background(), " ", nil, "") + require.Error(t, err) + require.Nil(t, conn) + require.Equal(t, 0, status) + require.Nil(t, headers) +} + +func TestCoderOpenAIWSClientDialer_Dial_InvalidProxyURL(t *testing.T) { + t.Parallel() + + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + conn, status, headers, err := impl.Dial(context.Background(), "ws://example.com", nil, "://bad-proxy") + require.Error(t, err) + require.Nil(t, conn) + require.Equal(t, 0, status) + require.Nil(t, headers) +} + +func TestCoderOpenAIWSClientConn_NilGuards(t *testing.T) { + t.Parallel() + + var nilConn *coderOpenAIWSClientConn + require.ErrorIs(t, nilConn.WriteJSON(context.Background(), map[string]any{"a": 1}), errOpenAIWSConnClosed) + _, err := nilConn.ReadMessage(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, _, err = nilConn.ReadFrame(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + require.ErrorIs(t, nilConn.WriteFrame(context.Background(), coderws.MessageText, []byte("x")), errOpenAIWSConnClosed) + require.ErrorIs(t, nilConn.Ping(context.Background()), errOpenAIWSConnClosed) + require.NoError(t, nilConn.Close()) + + empty := &coderOpenAIWSClientConn{} + var nilCtx context.Context + require.ErrorIs(t, empty.WriteJSON(nilCtx, map[string]any{"a": 1}), errOpenAIWSConnClosed) + _, err = empty.ReadMessage(nilCtx) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, _, err = empty.ReadFrame(nilCtx) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + require.ErrorIs(t, empty.WriteFrame(nilCtx, coderws.MessageText, []byte("x")), errOpenAIWSConnClosed) + require.ErrorIs(t, empty.Ping(nilCtx), errOpenAIWSConnClosed) + require.NoError(t, empty.Close()) +} + +func TestCoderOpenAIWSClientDialer_DialAndConnWrappers(t *testing.T) { + t.Parallel() + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, nil) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + _, _, err = conn.Read(readCtx) + cancelRead() + if err != nil { + serverErrCh <- err + return + } + + writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second) + if err = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.output_text.delta","delta":"ok"}`)); err != nil { + cancelWrite() + serverErrCh <- err + return + } + if err = conn.Write(writeCtx, coderws.MessageBinary, []byte{0x01, 0x02, 0x03}); err != nil { + cancelWrite() + serverErrCh <- err + return + } + cancelWrite() + + readCtx2, cancelRead2 := context.WithTimeout(r.Context(), 3*time.Second) + _, payload, err := conn.Read(readCtx2) + cancelRead2() + if err != nil { + serverErrCh <- err + return + } + if len(payload) == 0 { + serverErrCh <- fmt.Errorf("expected client payload") + return + } + serverErrCh <- nil + })) + defer wsServer.Close() + + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + conn, status, headers, err := impl.Dial( + dialCtx, + "ws"+strings.TrimPrefix(wsServer.URL, "http"), + http.Header{"User-Agent": []string{"unit-test-agent/1.0"}}, + "", + ) + cancelDial() + require.NoError(t, err) + require.NotNil(t, conn) + require.Equal(t, 0, status) + _ = headers // 成功建连时状态码为 0;headers 仅用于握手失败诊断。 + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, conn.WriteJSON(ctx, map[string]any{"type": "response.create"})) + + msg, err := conn.ReadMessage(ctx) + require.NoError(t, err) + require.JSONEq(t, `{"type":"response.output_text.delta","delta":"ok"}`, string(msg)) + + typedConn, ok := conn.(*coderOpenAIWSClientConn) + require.True(t, ok) + msgType, payload, err := typedConn.ReadFrame(ctx) + require.NoError(t, err) + require.Equal(t, coderws.MessageBinary, msgType) + require.Equal(t, []byte{0x01, 0x02, 0x03}, payload) + + require.NoError(t, typedConn.WriteFrame(ctx, coderws.MessageText, []byte(`{"client":"ack"}`))) + pingCtx, cancelPing := context.WithTimeout(context.Background(), 100*time.Millisecond) + _ = typedConn.Ping(pingCtx) + cancelPing() + require.NoError(t, typedConn.Close()) + require.NoError(t, <-serverErrCh) +} diff --git a/backend/internal/service/openai_ws_common.go b/backend/internal/service/openai_ws_common.go new file mode 100644 index 000000000..d3ce904f9 --- /dev/null +++ b/backend/internal/service/openai_ws_common.go @@ -0,0 +1,54 @@ +package service + +import ( + "errors" + "fmt" + "net/http" + "time" +) + +var ( + errOpenAIWSConnClosed = errors.New("openai ws connection closed") + errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full") + errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable") +) + +const ( + openAIWSConnHealthCheckTO = 2 * time.Second +) + +type openAIWSDialError struct { + StatusCode int + ResponseHeaders http.Header + Err error +} + +func (e *openAIWSDialError) Error() string { + if e == nil { + return "" + } + if e.StatusCode > 0 { + return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err) + } + return fmt.Sprintf("openai ws dial failed: %v", e.Err) +} + +func (e *openAIWSDialError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +func cloneHeader(h http.Header) http.Header { + if h == nil { + return nil + } + cloned := make(http.Header, len(h)) + for k, values := range h { + copied := make([]string, len(values)) + copy(copied, values) + cloned[k] = copied + } + return cloned +} diff --git a/backend/internal/service/openai_ws_common_test.go b/backend/internal/service/openai_ws_common_test.go new file mode 100644 index 000000000..211883cb1 --- /dev/null +++ b/backend/internal/service/openai_ws_common_test.go @@ -0,0 +1,44 @@ +package service + +import ( + "errors" + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSDialError_Behavior(t *testing.T) { + t.Parallel() + + var nilErr *openAIWSDialError + require.Equal(t, "", nilErr.Error()) + require.Nil(t, nilErr.Unwrap()) + + err := &openAIWSDialError{ + StatusCode: 429, + Err: errors.New("too many requests"), + } + require.Contains(t, err.Error(), "status=429") + require.ErrorIs(t, err.Unwrap(), err.Err) +} + +func TestCloneHeader_DeepCopy(t *testing.T) { + t.Parallel() + + require.Nil(t, cloneHeader(nil)) + + origin := http.Header{ + "X-Request-Id": []string{"req-1"}, + "Set-Cookie": []string{"a=1", "b=2"}, + } + cloned := cloneHeader(origin) + require.Equal(t, origin.Get("X-Request-Id"), cloned.Get("X-Request-Id")) + require.Equal(t, origin.Values("Set-Cookie"), cloned.Values("Set-Cookie")) + + // 修改拷贝不应污染原 header。 + cloned.Set("X-Request-Id", "req-2") + cloned["Set-Cookie"][0] = "a=9" + require.Equal(t, "req-1", origin.Get("X-Request-Id")) + require.Equal(t, "a=1", origin.Values("Set-Cookie")[0]) +} diff --git a/backend/internal/service/openai_ws_fallback_test.go b/backend/internal/service/openai_ws_fallback_test.go index ce06f6a21..f840034e2 100644 --- a/backend/internal/service/openai_ws_fallback_test.go +++ b/backend/internal/service/openai_ws_fallback_test.go @@ -3,12 +3,15 @@ package service import ( "context" "errors" + "io" "net/http" + "net/http/httptest" "testing" "time" "github.com/Wei-Shaw/sub2api/internal/config" coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -108,6 +111,154 @@ func TestClassifyOpenAIWSErrorEvent(t *testing.T) { reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`)) require.Equal(t, "previous_response_not_found", reason) require.True(t, recoverable) + + // tool_output_not_found: 用户按 ESC 取消 function_call 后重新发送消息 + reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool output found for function call call_zXKPiNecBmIAoKeW9o2pNMvo.","param":"input"}}`)) + require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason) + require.True(t, recoverable) + + reason, recoverable = classifyOpenAIWSErrorEventFromRaw("", "invalid_request_error", "No tool output found for function call call_abc123.") + require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason) + require.True(t, recoverable) + + reason, recoverable = classifyOpenAIWSErrorEventFromRaw( + "", + "invalid_request_error", + "No tool call found for function call output with call_id call_abc123.", + ) + require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason) + require.True(t, recoverable) + + // reasoning orphaned items should reuse tool_output_not_found recovery path. + reason, recoverable = classifyOpenAIWSErrorEventFromRaw( + "", + "invalid_request_error", + "Item 'rs_xxx' of type 'reasoning' was provided without its required following item.", + ) + require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason) + require.True(t, recoverable) + + reason, recoverable = classifyOpenAIWSErrorEventFromRaw( + "", + "invalid_request_error", + "Item 'rs_xxx' of type 'reasoning' was provided without its required preceding item.", + ) + require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason) + require.True(t, recoverable) +} + +func TestClassifyOpenAIWSErrorEventFromRaw_AllBranches(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + codeRaw string + errTypeRaw string + msgRaw string + wantReason string + wantRecover bool + }{ + { + name: "code_upgrade_required", + codeRaw: "upgrade_required", + wantReason: "upgrade_required", + wantRecover: true, + }, + { + name: "code_ws_unsupported", + codeRaw: "websocket_not_supported", + wantReason: "ws_unsupported", + wantRecover: true, + }, + { + name: "code_ws_connection_limit", + codeRaw: "websocket_connection_limit_reached", + wantReason: "ws_connection_limit_reached", + wantRecover: true, + }, + { + name: "msg_upgrade_required", + msgRaw: "status 426 upgrade required", + wantReason: "upgrade_required", + wantRecover: true, + }, + { + name: "err_type_upgrade", + errTypeRaw: "gateway_upgrade_error", + wantReason: "upgrade_required", + wantRecover: true, + }, + { + name: "msg_ws_unsupported", + msgRaw: "websocket is unsupported in this region", + wantReason: "ws_unsupported", + wantRecover: true, + }, + { + name: "msg_ws_connection_limit", + msgRaw: "websocket connection limit exceeded", + wantReason: "ws_connection_limit_reached", + wantRecover: true, + }, + { + name: "msg_previous_response_not_found_variant", + msgRaw: "previous response is not found", + wantReason: "previous_response_not_found", + wantRecover: true, + }, + { + name: "msg_no_tool_output", + msgRaw: "No tool output found for function call call_abc.", + wantReason: openAIWSIngressStageToolOutputNotFound, + wantRecover: true, + }, + { + name: "msg_no_tool_call_for_function_call_output", + msgRaw: "No tool call found for function call output with call_id call_abc.", + wantReason: openAIWSIngressStageToolOutputNotFound, + wantRecover: true, + }, + { + name: "msg_reasoning_missing_following", + msgRaw: "Item 'rs_xxx' of type 'reasoning' was provided without its required following item.", + wantReason: openAIWSIngressStageToolOutputNotFound, + wantRecover: true, + }, + { + name: "msg_reasoning_missing_preceding", + msgRaw: "Item 'rs_xxx' of type 'reasoning' was provided without its required preceding item.", + wantReason: openAIWSIngressStageToolOutputNotFound, + wantRecover: true, + }, + { + name: "server_error_by_type", + errTypeRaw: "server_error", + wantReason: "upstream_error_event", + wantRecover: true, + }, + { + name: "server_error_by_code", + codeRaw: "server_error", + wantReason: "upstream_error_event", + wantRecover: true, + }, + { + name: "unknown_event_error", + codeRaw: "other", + errTypeRaw: "other", + msgRaw: "other", + wantReason: "event_error", + wantRecover: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reason, recoverable := classifyOpenAIWSErrorEventFromRaw(tt.codeRaw, tt.errTypeRaw, tt.msgRaw) + require.Equal(t, tt.wantReason, reason) + require.Equal(t, tt.wantRecover, recoverable) + }) + } } func TestClassifyOpenAIWSReconnectReason(t *testing.T) { @@ -197,12 +348,39 @@ func TestOpenAIWSRetryTotalBudget(t *testing.T) { require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget()) } +func TestOpenAIWSRetryContextError(t *testing.T) { + require.NoError(t, openAIWSRetryContextError(context.TODO())) + require.NoError(t, openAIWSRetryContextError(context.Background())) + + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + err := openAIWSRetryContextError(canceledCtx) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + + var fallbackErr *openAIWSFallbackError + require.ErrorAs(t, err, &fallbackErr) + require.Equal(t, "retry_context_canceled", fallbackErr.Reason) +} + func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) { + require.Equal(t, "service_restart", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusServiceRestart})) + require.Equal(t, "try_again_later", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusTryAgainLater})) require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation})) require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig})) require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io"))) } +func TestClassifyOpenAIWSIngressReadErrorClass(t *testing.T) { + require.Equal(t, "unknown", classifyOpenAIWSIngressReadErrorClass(nil)) + require.Equal(t, "context_canceled", classifyOpenAIWSIngressReadErrorClass(context.Canceled)) + require.Equal(t, "deadline_exceeded", classifyOpenAIWSIngressReadErrorClass(context.DeadlineExceeded)) + require.Equal(t, "service_restart", classifyOpenAIWSIngressReadErrorClass(coderws.CloseError{Code: coderws.StatusServiceRestart})) + require.Equal(t, "try_again_later", classifyOpenAIWSIngressReadErrorClass(coderws.CloseError{Code: coderws.StatusTryAgainLater})) + require.Equal(t, "upstream_closed", classifyOpenAIWSIngressReadErrorClass(io.EOF)) + require.Equal(t, "unknown", classifyOpenAIWSIngressReadErrorClass(errors.New("tls handshake timeout"))) +} + func TestOpenAIWSStoreDisabledConnMode(t *testing.T) { svc := &OpenAIGatewayService{cfg: &config.Config{}} svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true @@ -239,6 +417,119 @@ func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) { require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal) } +func TestWriteOpenAIWSV1UnsupportedResponse_TracksOps(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + svc := &OpenAIGatewayService{} + account := &Account{ + ID: 42, + Name: "acc-ws-v1", + Platform: PlatformOpenAI, + } + + err := svc.writeOpenAIWSV1UnsupportedResponse(c, account) + require.Error(t, err) + require.Contains(t, err.Error(), "openai ws v1 is temporarily unsupported") + + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "invalid_request_error") + require.Contains(t, rec.Body.String(), "temporarily unsupported") + + rawStatus, ok := c.Get(OpsUpstreamStatusCodeKey) + require.True(t, ok) + require.Equal(t, http.StatusBadRequest, rawStatus) + + rawMsg, ok := c.Get(OpsUpstreamErrorMessageKey) + require.True(t, ok) + require.Equal(t, "openai ws v1 is temporarily unsupported; use ws v2", rawMsg) + + rawEvents, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := rawEvents.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, account.ID, events[0].AccountID) + require.Equal(t, account.Platform, events[0].Platform) + require.Equal(t, http.StatusBadRequest, events[0].UpstreamStatusCode) + require.Equal(t, "ws_error", events[0].Kind) +} + +func TestIsOpenAIWSStreamWriteDisconnectError(t *testing.T) { + require.False(t, isOpenAIWSStreamWriteDisconnectError(nil, nil)) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + require.True(t, isOpenAIWSStreamWriteDisconnectError(errors.New("writer failed"), ctx)) + + require.True(t, isOpenAIWSStreamWriteDisconnectError(errors.New("broken pipe"), context.Background())) + require.True(t, isOpenAIWSStreamWriteDisconnectError(io.EOF, context.Background())) + + require.False(t, isOpenAIWSStreamWriteDisconnectError(errors.New("template execute failed"), context.Background())) +} + +func TestShouldFlushOpenAIWSBufferedEventsOnError(t *testing.T) { + require.True(t, shouldFlushOpenAIWSBufferedEventsOnError(true, true, false)) + require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(true, false, false)) + require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(true, true, true)) + require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(false, true, false)) +} + +func TestCloneOpenAIWSJSONRawString(t *testing.T) { + require.Nil(t, cloneOpenAIWSJSONRawString("")) + require.Nil(t, cloneOpenAIWSJSONRawString(" ")) + + raw := `{"id":"resp_1","type":"response"}` + cloned := cloneOpenAIWSJSONRawString(raw) + require.Equal(t, raw, string(cloned)) + require.Equal(t, len(raw), len(cloned)) +} + +func TestOpenAIWSAbortMetricsSnapshot(t *testing.T) { + svc := &OpenAIGatewayService{} + svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonUpstreamError, true) + svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonUpstreamError, true) + svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonWriteUpstream, false) + svc.recordOpenAIWSTurnAbortRecovered() + + snapshot := svc.SnapshotOpenAIWSAbortMetrics() + require.Equal(t, int64(1), snapshot.TurnAbortRecoveredTotal) + + getTotal := func(reason string, expected bool) int64 { + for _, point := range snapshot.TurnAbortTotal { + if point.Reason == reason && point.Expected == expected { + return point.Total + } + } + return 0 + } + require.Equal(t, int64(2), getTotal(string(openAIWSIngressTurnAbortReasonUpstreamError), true)) + require.Equal(t, int64(1), getTotal(string(openAIWSIngressTurnAbortReasonWriteUpstream), false)) +} + +func TestOpenAIWSPerformanceMetricsSnapshot_ContainsAbortMetrics(t *testing.T) { + svc := &OpenAIGatewayService{} + svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonClientClosed, true) + svc.recordOpenAIWSTurnAbortRecovered() + + snapshot := svc.SnapshotOpenAIWSPerformanceMetrics() + require.Equal(t, int64(1), snapshot.Abort.TurnAbortRecoveredTotal) + require.Equal(t, int64(0), snapshot.Passthrough.SemanticMutationTotal) + require.GreaterOrEqual(t, snapshot.Passthrough.UsageParseFailureTotal, int64(0)) + + found := false + for _, point := range snapshot.Abort.TurnAbortTotal { + if point.Reason == string(openAIWSIngressTurnAbortReasonClientClosed) && point.Expected { + require.Equal(t, int64(1), point.Total) + found = true + break + } + } + require.True(t, found) +} + func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) { svc := &OpenAIGatewayService{cfg: &config.Config{}} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go index 74ba472f4..dd91055ac 100644 --- a/backend/internal/service/openai_ws_forwarder.go +++ b/backend/internal/service/openai_ws_forwarder.go @@ -6,17 +6,18 @@ import ( "encoding/json" "errors" "fmt" - "io" "math/rand" - "net" "net/http" "net/url" + "runtime/debug" "sort" + "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" @@ -26,8 +27,10 @@ import ( ) const ( - openAIWSBetaV1Value = "responses_websockets=2026-02-04" - openAIWSBetaV2Value = "responses_websockets=2026-02-06" + openAIWSBetaV1Value = "responses_websockets=2026-02-04" + openAIWSBetaV2Value = "responses_websockets=2026-02-06" + openAIWSConnIDPrefixLegacy = "oa_ws_" + openAIWSConnIDPrefixCtx = "ctxws_" openAIWSTurnStateHeader = "x-codex-turn-state" openAIWSTurnMetadataHeader = "x-codex-turn-metadata" @@ -55,879 +58,100 @@ const ( openAIWSStoreDisabledConnModeOff = "off" openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found" + openAIWSIngressStageToolOutputNotFound = "tool_output_not_found" openAIWSMaxPrevResponseIDDeletePasses = 8 + openAIWSIngressReplayInputMaxBytes = 512 * 1024 + openAIWSContinuationUnavailableReason = "upstream continuation connection is unavailable; please restart the conversation" + openAIWSAutoAbortedToolOutputValue = `{"error":"tool call aborted by gateway"}` + openAIWSClientReadIdleTimeoutDefault = 30 * time.Minute + openAIWSPassthroughIdleTimeoutDefault = time.Hour + openAIWSIngressClientDisconnectDrainTimeout = 5 * time.Second + openAIWSUpstreamPumpInfoMinAlive = 100 * time.Millisecond ) -var openAIWSLogValueReplacer = strings.NewReplacer( - "error", "err", - "fallback", "fb", - "warning", "warnx", - "failed", "fail", -) - -var openAIWSIngressPreflightPingIdle = 20 * time.Second - -// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。 -type openAIWSFallbackError struct { - Reason string - Err error -} - -func (e *openAIWSFallbackError) Error() string { - if e == nil { - return "" - } - if e.Err == nil { - return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason)) - } - return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err) -} - -func (e *openAIWSFallbackError) Unwrap() error { - if e == nil { - return nil - } - return e.Err -} - -func wrapOpenAIWSFallback(reason string, err error) error { - return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err} -} - -// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。 -type OpenAIWSClientCloseError struct { - statusCode coderws.StatusCode - reason string - err error -} - -type openAIWSIngressTurnError struct { - stage string - cause error - wroteDownstream bool -} - -func (e *openAIWSIngressTurnError) Error() string { - if e == nil { - return "" - } - if e.cause == nil { - return strings.TrimSpace(e.stage) - } - return e.cause.Error() -} - -func (e *openAIWSIngressTurnError) Unwrap() error { - if e == nil { - return nil - } - return e.cause -} - -func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error { - if cause == nil { - return nil - } - return &openAIWSIngressTurnError{ - stage: strings.TrimSpace(stage), - cause: cause, - wroteDownstream: wroteDownstream, - } -} - -func isOpenAIWSIngressTurnRetryable(err error) bool { - var turnErr *openAIWSIngressTurnError - if !errors.As(err, &turnErr) || turnErr == nil { - return false - } - if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) { - return false - } - if turnErr.wroteDownstream { - return false - } - switch turnErr.stage { - case "write_upstream", "read_upstream": - return true - default: - return false - } -} - -func openAIWSIngressTurnRetryReason(err error) string { - var turnErr *openAIWSIngressTurnError - if !errors.As(err, &turnErr) || turnErr == nil { - return "unknown" - } - if turnErr.stage == "" { - return "unknown" - } - return turnErr.stage -} - -func isOpenAIWSIngressPreviousResponseNotFound(err error) bool { - var turnErr *openAIWSIngressTurnError - if !errors.As(err, &turnErr) || turnErr == nil { - return false - } - if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound { - return false - } - return !turnErr.wroteDownstream -} - -// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。 -func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error { - return &OpenAIWSClientCloseError{ - statusCode: statusCode, - reason: strings.TrimSpace(reason), - err: err, - } -} - -func (e *OpenAIWSClientCloseError) Error() string { - if e == nil { - return "" - } - if e.err == nil { - return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason)) - } - return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err) -} - -func (e *OpenAIWSClientCloseError) Unwrap() error { - if e == nil { - return nil - } - return e.err -} - -func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode { - if e == nil { - return coderws.StatusInternalError - } - return e.statusCode -} - -func (e *OpenAIWSClientCloseError) Reason() string { - if e == nil { - return "" - } - return strings.TrimSpace(e.reason) -} - -// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 -type OpenAIWSIngressHooks struct { - BeforeTurn func(turn int) error - AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) -} - -func normalizeOpenAIWSLogValue(value string) string { - trimmed := strings.TrimSpace(value) - if trimmed == "" { - return "-" - } - return openAIWSLogValueReplacer.Replace(trimmed) -} - -func truncateOpenAIWSLogValue(value string, maxLen int) string { - normalized := normalizeOpenAIWSLogValue(value) - if normalized == "-" || maxLen <= 0 { - return normalized - } - if len(normalized) <= maxLen { - return normalized - } - return normalized[:maxLen] + "..." -} - -func openAIWSHeaderValueForLog(headers http.Header, key string) string { - if headers == nil { - return "-" - } - return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen) -} - -func hasOpenAIWSHeader(headers http.Header, key string) bool { - if headers == nil { - return false - } - return strings.TrimSpace(headers.Get(key)) != "" -} - -type openAIWSSessionHeaderResolution struct { - SessionID string - ConversationID string - SessionSource string - ConversationSource string -} - -func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution { - resolution := openAIWSSessionHeaderResolution{ - SessionSource: "none", - ConversationSource: "none", - } - if c != nil && c.Request != nil { - if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" { - resolution.SessionID = sessionID - resolution.SessionSource = "header_session_id" - } - if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" { - resolution.ConversationID = conversationID - resolution.ConversationSource = "header_conversation_id" - if resolution.SessionID == "" { - resolution.SessionID = conversationID - resolution.SessionSource = "header_conversation_id" - } - } - } - - cacheKey := strings.TrimSpace(promptCacheKey) - if cacheKey != "" { - if resolution.SessionID == "" { - resolution.SessionID = cacheKey - resolution.SessionSource = "prompt_cache_key" - } - } - return resolution -} - -func shouldLogOpenAIWSEvent(idx int, eventType string) bool { - if idx <= openAIWSEventLogHeadLimit { - return true - } - if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 { - return true - } - if eventType == "error" || isOpenAIWSTerminalEvent(eventType) { - return true - } - return false -} - -func shouldLogOpenAIWSBufferedEvent(idx int) bool { - if idx <= openAIWSBufferLogHeadLimit { - return true - } - if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 { - return true - } - return false -} - -func openAIWSEventMayContainModel(eventType string) bool { - switch eventType { - case "response.created", - "response.in_progress", - "response.completed", - "response.done", - "response.failed", - "response.incomplete", - "response.cancelled", - "response.canceled": - return true - default: - trimmed := strings.TrimSpace(eventType) - if trimmed == eventType { - return false - } - switch trimmed { - case "response.created", - "response.in_progress", - "response.completed", - "response.done", - "response.failed", - "response.incomplete", - "response.cancelled", - "response.canceled": - return true - default: - return false - } - } -} - -func openAIWSEventMayContainToolCalls(eventType string) bool { - eventType = strings.TrimSpace(eventType) - if eventType == "" { - return false - } - if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") { - return true - } - switch eventType { - case "response.output_item.added", "response.output_item.done", "response.completed", "response.done": - return true - default: - return false - } -} - -func openAIWSEventShouldParseUsage(eventType string) bool { - return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed" -} - -func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) { - if len(message) == 0 { - return "", "", gjson.Result{} - } - values := gjson.GetManyBytes(message, "type", "response.id", "id", "response") - eventType = strings.TrimSpace(values[0].String()) - if id := strings.TrimSpace(values[1].String()); id != "" { - responseID = id - } else { - responseID = strings.TrimSpace(values[2].String()) - } - return eventType, responseID, values[3] -} - -func openAIWSMessageLikelyContainsToolCalls(message []byte) bool { - if len(message) == 0 { - return false - } - return bytes.Contains(message, []byte(`"tool_calls"`)) || - bytes.Contains(message, []byte(`"tool_call"`)) || - bytes.Contains(message, []byte(`"function_call"`)) -} - -func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) { - if usage == nil || len(message) == 0 { - return - } - values := gjson.GetManyBytes( - message, - "response.usage.input_tokens", - "response.usage.output_tokens", - "response.usage.input_tokens_details.cached_tokens", - ) - usage.InputTokens = int(values[0].Int()) - usage.OutputTokens = int(values[1].Int()) - usage.CacheReadInputTokens = int(values[2].Int()) -} - -func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { - if len(message) == 0 { - return "", "", "" - } - values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message") - return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String()) -} - -func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) { - code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen) - errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen) - errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen) - return code, errType, errMessage -} - -func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { - if len(message) == 0 { - return "-", "-", "-" - } - return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message)) -} - -func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string { - if len(payload) == 0 { - return "-" - } - type keySize struct { - Key string - Size int - } - sizes := make([]keySize, 0, len(payload)) - for key, value := range payload { - size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth) - sizes = append(sizes, keySize{Key: key, Size: size}) - } - sort.Slice(sizes, func(i, j int) bool { - if sizes[i].Size == sizes[j].Size { - return sizes[i].Key < sizes[j].Key - } - return sizes[i].Size > sizes[j].Size - }) - - if topN <= 0 || topN > len(sizes) { - topN = len(sizes) - } - parts := make([]string, 0, topN) - for idx := 0; idx < topN; idx++ { - item := sizes[idx] - parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size)) - } - return strings.Join(parts, ",") -} - -func estimateOpenAIWSPayloadValueSize(value any, depth int) int { - if depth <= 0 { - return -1 - } - switch v := value.(type) { - case nil: - return 0 - case string: - return len(v) - case []byte: - return len(v) - case bool: - return 1 - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: - return 8 - case float32, float64: - return 8 - case map[string]any: - if len(v) == 0 { - return 2 - } - total := 2 - count := 0 - for key, item := range v { - count++ - if count > openAIWSPayloadSizeEstimateMaxItems { - return -1 - } - itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1) - if itemSize < 0 { - return -1 - } - total += len(key) + itemSize + 3 - if total > openAIWSPayloadSizeEstimateMaxBytes { - return -1 - } - } - return total - case []any: - if len(v) == 0 { - return 2 - } - total := 2 - limit := len(v) - if limit > openAIWSPayloadSizeEstimateMaxItems { - return -1 - } - for i := 0; i < limit; i++ { - itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1) - if itemSize < 0 { - return -1 - } - total += itemSize + 1 - if total > openAIWSPayloadSizeEstimateMaxBytes { - return -1 - } - } - return total - default: - raw, err := json.Marshal(v) - if err != nil { - return -1 - } - if len(raw) > openAIWSPayloadSizeEstimateMaxBytes { - return -1 - } - return len(raw) - } -} - -func openAIWSPayloadString(payload map[string]any, key string) string { - if len(payload) == 0 { - return "" - } - raw, ok := payload[key] - if !ok { - return "" - } - switch v := raw.(type) { - case nil: - return "" - case string: - return strings.TrimSpace(v) - case []byte: - return strings.TrimSpace(string(v)) - default: - return "" - } -} - -func openAIWSPayloadStringFromRaw(payload []byte, key string) string { - if len(payload) == 0 || strings.TrimSpace(key) == "" { - return "" - } - return strings.TrimSpace(gjson.GetBytes(payload, key).String()) -} - -func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool { - if len(payload) == 0 || strings.TrimSpace(key) == "" { - return defaultValue - } - value := gjson.GetBytes(payload, key) - if !value.Exists() { - return defaultValue - } - if value.Type != gjson.True && value.Type != gjson.False { - return defaultValue - } - return value.Bool() -} - -func openAIWSSessionHashesFromID(sessionID string) (string, string) { - return deriveOpenAISessionHashes(sessionID) -} - -func extractOpenAIWSImageURL(value any) string { - switch v := value.(type) { - case string: - return strings.TrimSpace(v) - case map[string]any: - if raw, ok := v["url"].(string); ok { - return strings.TrimSpace(raw) - } - } - return "" -} - -func summarizeOpenAIWSInput(input any) string { - items, ok := input.([]any) - if !ok || len(items) == 0 { - return "-" - } - - itemCount := len(items) - textChars := 0 - imageDataURLs := 0 - imageDataURLChars := 0 - imageRemoteURLs := 0 - - handleContentItem := func(contentItem map[string]any) { - contentType, _ := contentItem["type"].(string) - switch strings.TrimSpace(contentType) { - case "input_text", "output_text", "text": - if text, ok := contentItem["text"].(string); ok { - textChars += len(text) - } - case "input_image": - imageURL := extractOpenAIWSImageURL(contentItem["image_url"]) - if imageURL == "" { - return - } - if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { - imageDataURLs++ - imageDataURLChars += len(imageURL) - return - } - imageRemoteURLs++ - } - } - - handleInputItem := func(inputItem map[string]any) { - if content, ok := inputItem["content"].([]any); ok { - for _, rawContent := range content { - contentItem, ok := rawContent.(map[string]any) - if !ok { - continue - } - handleContentItem(contentItem) - } - return - } - - itemType, _ := inputItem["type"].(string) - switch strings.TrimSpace(itemType) { - case "input_text", "output_text", "text": - if text, ok := inputItem["text"].(string); ok { - textChars += len(text) - } - case "input_image": - imageURL := extractOpenAIWSImageURL(inputItem["image_url"]) - if imageURL == "" { - return - } - if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { - imageDataURLs++ - imageDataURLChars += len(imageURL) - return - } - imageRemoteURLs++ - } - } - - for _, rawItem := range items { - inputItem, ok := rawItem.(map[string]any) - if !ok { - continue - } - handleInputItem(inputItem) - } - - return fmt.Sprintf( - "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d", - itemCount, - textChars, - imageDataURLs, - imageDataURLChars, - imageRemoteURLs, - ) -} - -func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) { - if len(payload) == 0 || strings.TrimSpace(key) == "" { - return - } - if _, exists := payload[key]; !exists { - return - } - delete(payload, key) - *removed = append(*removed, key) -} - -// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段, -// 避免重试成功却改变原始请求语义。 -// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。 -func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) { - if len(payload) == 0 { - return "empty", nil - } - if attempt <= 1 { - return "full", nil - } - - removed := make([]string, 0, 2) - if attempt >= 2 { - dropOpenAIWSPayloadKey(payload, "include", &removed) - } - - if len(removed) == 0 { - return "full", nil - } - sort.Strings(removed) - return "trim_optional_fields", removed -} - -func logOpenAIWSModeInfo(format string, args ...any) { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...) -} - -func isOpenAIWSModeDebugEnabled() bool { - return logger.L().Core().Enabled(zap.DebugLevel) -} - -func logOpenAIWSModeDebug(format string, args ...any) { - if !isOpenAIWSModeDebugEnabled() { - return - } - logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...) -} - -func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) { - if err == nil { - return - } - logger.L().Warn( - "openai.ws_bind_response_account_failed", - zap.Int64("group_id", groupID), - zap.Int64("account_id", accountID), - zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)), - zap.Error(err), - ) -} - -func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) { - if err == nil { - return "-", "-" - } - statusCode := coderws.CloseStatus(err) - if statusCode == -1 { - return "-", "-" - } - closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) - closeReason := "-" - var closeErr coderws.CloseError - if errors.As(err, &closeErr) { - reasonText := strings.TrimSpace(closeErr.Reason) - if reasonText != "" { - closeReason = normalizeOpenAIWSLogValue(reasonText) - } - } - return normalizeOpenAIWSLogValue(closeStatus), closeReason -} +type openAIWSIngressTurnAbortReason string -func unwrapOpenAIWSDialBaseError(err error) error { - if err == nil { - return nil - } - var dialErr *openAIWSDialError - if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil { - return dialErr.Err - } - return err -} +const ( + openAIWSIngressTurnAbortReasonUnknown openAIWSIngressTurnAbortReason = "unknown" + + openAIWSIngressTurnAbortReasonClientClosed openAIWSIngressTurnAbortReason = "client_closed" + openAIWSIngressTurnAbortReasonContextCanceled openAIWSIngressTurnAbortReason = "ctx_canceled" + openAIWSIngressTurnAbortReasonContextDeadline openAIWSIngressTurnAbortReason = "ctx_deadline_exceeded" + openAIWSIngressTurnAbortReasonPreviousResponse openAIWSIngressTurnAbortReason = openAIWSIngressStagePreviousResponseNotFound + openAIWSIngressTurnAbortReasonToolOutput openAIWSIngressTurnAbortReason = openAIWSIngressStageToolOutputNotFound + openAIWSIngressTurnAbortReasonUpstreamError openAIWSIngressTurnAbortReason = "upstream_error_event" + openAIWSIngressTurnAbortReasonWriteUpstream openAIWSIngressTurnAbortReason = "write_upstream" + openAIWSIngressTurnAbortReasonReadUpstream openAIWSIngressTurnAbortReason = "read_upstream" + openAIWSIngressTurnAbortReasonWriteClient openAIWSIngressTurnAbortReason = "write_client" + openAIWSIngressTurnAbortReasonContinuationUnavailable openAIWSIngressTurnAbortReason = "continuation_unavailable" + openAIWSIngressTurnAbortReasonClientPreempted openAIWSIngressTurnAbortReason = "client_preempted" + openAIWSIngressTurnAbortReasonUpstreamRestart openAIWSIngressTurnAbortReason = "upstream_restart" +) -func openAIWSDialRespHeaderForLog(err error, key string) string { - var dialErr *openAIWSDialError - if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil { - return "-" - } - return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen) -} +type openAIWSIngressTurnAbortDisposition string -func classifyOpenAIWSDialError(err error) string { - if err == nil { - return "-" - } - baseErr := unwrapOpenAIWSDialBaseError(err) - if baseErr == nil { - return "-" - } - if errors.Is(baseErr, context.DeadlineExceeded) { - return "ctx_deadline_exceeded" - } - if errors.Is(baseErr, context.Canceled) { - return "ctx_canceled" - } - var netErr net.Error - if errors.As(baseErr, &netErr) && netErr.Timeout() { - return "net_timeout" - } - if status := coderws.CloseStatus(baseErr); status != -1 { - return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status))) - } - message := strings.ToLower(strings.TrimSpace(baseErr.Error())) - switch { - case strings.Contains(message, "handshake not finished"): - return "handshake_not_finished" - case strings.Contains(message, "bad handshake"): - return "bad_handshake" - case strings.Contains(message, "connection refused"): - return "connection_refused" - case strings.Contains(message, "no such host"): - return "dns_not_found" - case strings.Contains(message, "tls"): - return "tls_error" - case strings.Contains(message, "i/o timeout"): - return "io_timeout" - case strings.Contains(message, "context deadline exceeded"): - return "ctx_deadline_exceeded" - default: - return "dial_error" - } -} +const ( + openAIWSIngressTurnAbortDispositionFailRequest openAIWSIngressTurnAbortDisposition = "fail_request" + openAIWSIngressTurnAbortDispositionContinueTurn openAIWSIngressTurnAbortDisposition = "continue_turn" + openAIWSIngressTurnAbortDispositionCloseGracefully openAIWSIngressTurnAbortDisposition = "close_gracefully" +) -func summarizeOpenAIWSDialError(err error) ( - statusCode int, - dialClass string, - closeStatus string, - closeReason string, - respServer string, - respVia string, - respCFRay string, - respRequestID string, -) { - dialClass = "-" - closeStatus = "-" - closeReason = "-" - respServer = "-" - respVia = "-" - respCFRay = "-" - respRequestID = "-" - if err == nil { - return - } - var dialErr *openAIWSDialError - if errors.As(err, &dialErr) && dialErr != nil { - statusCode = dialErr.StatusCode - respServer = openAIWSDialRespHeaderForLog(err, "server") - respVia = openAIWSDialRespHeaderForLog(err, "via") - respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray") - respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id") - } - dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err)) - closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err)) - return +// openAIWSUpstreamPumpEvent 是上游事件读取泵传递给主 goroutine 的消息载体。 +type openAIWSUpstreamPumpEvent struct { + message []byte + err error } -func isOpenAIWSClientDisconnectError(err error) bool { - if err == nil { - return false - } - if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { - return true - } - switch coderws.CloseStatus(err) { - case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: - return true - } - message := strings.ToLower(strings.TrimSpace(err.Error())) - if message == "" { - return false - } - return strings.Contains(message, "failed to read frame header: eof") || - strings.Contains(message, "unexpected eof") || - strings.Contains(message, "use of closed network connection") || - strings.Contains(message, "connection reset by peer") || - strings.Contains(message, "broken pipe") -} +const ( + // openAIWSUpstreamPumpBufferSize 是上游事件读取泵的缓冲 channel 大小。 + // 缓冲允许上游读取和客户端写入并发执行,吸收客户端写入延迟波动。 + openAIWSUpstreamPumpBufferSize = 16 +) -func classifyOpenAIWSReadFallbackReason(err error) string { - if err == nil { - return "read_event" - } - switch coderws.CloseStatus(err) { - case coderws.StatusPolicyViolation: - return "policy_violation" - case coderws.StatusMessageTooBig: - return "message_too_big" - default: - return "read_event" - } -} +var openAIWSLogValueReplacer = strings.NewReplacer( + "error", "err", + "fallback", "fb", + "warning", "warnx", + "failed", "fail", +) -func sortedKeys(m map[string]any) []string { - if len(m) == 0 { - return nil - } - keys := make([]string, 0, len(m)) - for k := range m { - keys = append(keys, k) - } - sort.Strings(keys) - return keys -} +var openAIWSIngressPreflightPingIdle = 20 * time.Second -func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool { +func (s *OpenAIGatewayService) getOpenAIWSIngressContextPool() *openAIWSIngressContextPool { if s == nil { return nil } - s.openaiWSPoolOnce.Do(func() { - if s.openaiWSPool == nil { - s.openaiWSPool = newOpenAIWSConnPool(s.cfg) + s.openaiWSIngressCtxOnce.Do(func() { + if s.openaiWSIngressCtxPool == nil { + pool := newOpenAIWSIngressContextPool(s.cfg) + // Ensure the scheduler (and its runtime stats) are initialized + // before wiring load-aware signals into the context pool. + _ = s.getOpenAIAccountScheduler() + pool.schedulerStats = s.openaiAccountStats + s.openaiWSIngressCtxPool = pool } }) - return s.openaiWSPool -} - -func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { - pool := s.getOpenAIWSConnPool() - if pool == nil { - return OpenAIWSPoolMetricsSnapshot{} - } - return pool.SnapshotMetrics() + return s.openaiWSIngressCtxPool } type OpenAIWSPerformanceMetricsSnapshot struct { - Pool OpenAIWSPoolMetricsSnapshot `json:"pool"` - Retry OpenAIWSRetryMetricsSnapshot `json:"retry"` - Transport OpenAIWSTransportMetricsSnapshot `json:"transport"` + Retry OpenAIWSRetryMetricsSnapshot `json:"retry"` + Abort OpenAIWSAbortMetricsSnapshot `json:"abort"` + Transport OpenAIWSTransportMetricsSnapshot `json:"transport"` + Passthrough openaiwsv2.MetricsSnapshot `json:"passthrough"` } func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot { - pool := s.getOpenAIWSConnPool() + ingressPool := s.getOpenAIWSIngressContextPool() snapshot := OpenAIWSPerformanceMetricsSnapshot{ - Retry: s.SnapshotOpenAIWSRetryMetrics(), + Retry: s.SnapshotOpenAIWSRetryMetrics(), + Abort: s.SnapshotOpenAIWSAbortMetrics(), + Passthrough: openaiwsv2.SnapshotMetrics(), } - if pool == nil { + if ingressPool == nil { return snapshot } - snapshot.Pool = pool.SnapshotMetrics() - snapshot.Transport = pool.SnapshotTransportMetrics() + snapshot.Transport = ingressPool.SnapshotTransportMetrics() return snapshot } @@ -943,6 +167,18 @@ func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore { return s.openaiWSStateStore } +func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer { + if s == nil { + return nil + } + s.openaiWSPassthroughDialerOnce.Do(func() { + if s.openaiWSPassthroughDialer == nil { + s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer() + } + }) + return s.openaiWSPassthroughDialer +} + func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration { if s != nil && s.cfg != nil { seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds @@ -967,6 +203,20 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { return 15 * time.Minute } +func (s *OpenAIGatewayService) openAIWSClientReadIdleTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds) * time.Second + } + return openAIWSClientReadIdleTimeoutDefault +} + +func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds) * time.Second + } + return openAIWSPassthroughIdleTimeoutDefault +} + func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second @@ -1147,8 +397,14 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders( if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { headers.Set("user-agent", codexCLIUserAgent) } - if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) { + userAgentIsCodexCLI := openai.IsCodexCLIRequest(headers.Get("user-agent")) + if account != nil && account.Type == AccountTypeOAuth && !userAgentIsCodexCLI { headers.Set("user-agent", codexCLIUserAgent) + userAgentIsCodexCLI = true + } + if account != nil && account.Type == AccountTypeOAuth && userAgentIsCodexCLI { + // 保持 OAuth 握手头的一致性:Codex 风格 UA 必须搭配 codex_cli_rs originator。 + headers.Set("originator", "codex_cli_rs") } return headers, sessionResolution @@ -1202,448 +458,85 @@ func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) { } } -func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool { - if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() { - return true - } - if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery { - return true - } - return false -} - -func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool { - if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { - return true - } - if len(reqBody) == 0 { - return false - } - rawStore, ok := reqBody["store"] - if !ok { - return false - } - storeEnabled, ok := rawStore.(bool) - if !ok { - return false - } - return !storeEnabled -} - -func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool { - if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { - return true - } - if len(reqBody) == 0 { - return false - } - storeValue := gjson.GetBytes(reqBody, "store") - if !storeValue.Exists() { - return false - } - if storeValue.Type != gjson.True && storeValue.Type != gjson.False { - return false - } - return !storeValue.Bool() -} - -func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string { - if s == nil || s.cfg == nil { - return openAIWSStoreDisabledConnModeStrict - } - mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode)) - switch mode { - case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff: - return mode - case "": - // 兼容旧配置:仅配置了布尔开关时按旧语义推导。 - if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { - return openAIWSStoreDisabledConnModeStrict - } - return openAIWSStoreDisabledConnModeOff - default: - return openAIWSStoreDisabledConnModeStrict - } -} - -func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool { - switch mode { - case openAIWSStoreDisabledConnModeOff: - return false - case openAIWSStoreDisabledConnModeAdaptive: - reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_") - switch reason { - case "policy_violation", "message_too_big", "auth_failed", "write_request", "write": - return true - default: - return false - } - default: - return true - } -} - -func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) { - return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes) -} - -func dropPreviousResponseIDFromRawPayloadWithDeleteFn( - payload []byte, - deleteFn func([]byte, string) ([]byte, error), -) ([]byte, bool, error) { - if len(payload) == 0 { - return payload, false, nil - } - if !gjson.GetBytes(payload, "previous_response_id").Exists() { - return payload, false, nil - } - if deleteFn == nil { - deleteFn = sjson.DeleteBytes - } - - updated := payload - for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses && - gjson.GetBytes(updated, "previous_response_id").Exists(); i++ { - next, err := deleteFn(updated, "previous_response_id") - if err != nil { - return payload, false, err - } - updated = next - } - return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil -} - -func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) { - normalizedPrevID := strings.TrimSpace(previousResponseID) - if len(payload) == 0 || normalizedPrevID == "" { - return payload, nil - } - updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID) - if err == nil { - return updated, nil - } - - var reqBody map[string]any - if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil { - return nil, err - } - reqBody["previous_response_id"] = normalizedPrevID - rebuilt, marshalErr := json.Marshal(reqBody) - if marshalErr != nil { - return nil, marshalErr - } - return rebuilt, nil -} - -func shouldInferIngressFunctionCallOutputPreviousResponseID( - storeDisabled bool, - turn int, - hasFunctionCallOutput bool, - currentPreviousResponseID string, - expectedPreviousResponseID string, -) bool { - if !storeDisabled || turn <= 1 || !hasFunctionCallOutput { - return false - } - if strings.TrimSpace(currentPreviousResponseID) != "" { - return false - } - return strings.TrimSpace(expectedPreviousResponseID) != "" -} - -func alignStoreDisabledPreviousResponseID( - payload []byte, - expectedPreviousResponseID string, -) ([]byte, bool, error) { - if len(payload) == 0 { - return payload, false, nil - } - expected := strings.TrimSpace(expectedPreviousResponseID) - if expected == "" { - return payload, false, nil - } - current := openAIWSPayloadStringFromRaw(payload, "previous_response_id") - if current == "" || current == expected { - return payload, false, nil - } - - withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) - if dropErr != nil { - return payload, false, dropErr - } - if !removed { - return payload, false, nil - } - updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected) - if setErr != nil { - return payload, false, setErr - } - return updated, true, nil -} - -func cloneOpenAIWSPayloadBytes(payload []byte) []byte { - if len(payload) == 0 { - return nil - } - cloned := make([]byte, len(payload)) - copy(cloned, payload) - return cloned -} - -func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage { - if items == nil { - return nil - } - cloned := make([]json.RawMessage, 0, len(items)) - for idx := range items { - cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx]))) - } - return cloned -} - -func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) { - trimmed := bytes.TrimSpace(raw) - if len(trimmed) == 0 { - return nil, errors.New("json is empty") - } - var decoded any - if err := json.Unmarshal(trimmed, &decoded); err != nil { - return nil, err - } - return json.Marshal(decoded) -} - -func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte { - normalized, err := normalizeOpenAIWSJSONForCompare(raw) - if err != nil { - return bytes.TrimSpace(raw) - } - return normalized -} - -func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) { - if len(payload) == 0 { - return nil, errors.New("payload is empty") - } - var decoded map[string]any - if err := json.Unmarshal(payload, &decoded); err != nil { - return nil, err - } - delete(decoded, "input") - delete(decoded, "previous_response_id") - return json.Marshal(decoded) -} - -func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) { - if len(payload) == 0 { - return nil, false, nil - } - inputValue := gjson.GetBytes(payload, "input") - if !inputValue.Exists() { - return nil, false, nil - } - if inputValue.Type == gjson.JSON { - raw := strings.TrimSpace(inputValue.Raw) - if strings.HasPrefix(raw, "[") { - var items []json.RawMessage - if err := json.Unmarshal([]byte(raw), &items); err != nil { - return nil, true, err - } - return items, true, nil - } - return []json.RawMessage{json.RawMessage(raw)}, true, nil - } - if inputValue.Type == gjson.String { - encoded, _ := json.Marshal(inputValue.String()) - return []json.RawMessage{encoded}, true, nil - } - return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil -} - -func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) { - previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload) - if prevErr != nil { - return false, prevErr - } - currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) - if currentErr != nil { - return false, currentErr - } - if !previousExists && !currentExists { - return true, nil - } - if !previousExists { - return len(currentItems) == 0, nil - } - if !currentExists { - return len(previousItems) == 0, nil - } - if len(currentItems) < len(previousItems) { - return false, nil - } - - for idx := range previousItems { - previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx]) - currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx]) - if !bytes.Equal(previousNormalized, currentNormalized) { - return false, nil - } - } - return true, nil -} - -func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool { - if len(prefix) == 0 { - return true - } - if len(items) < len(prefix) { - return false - } - for idx := range prefix { - previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx]) - currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx]) - if !bytes.Equal(previousNormalized, currentNormalized) { - return false - } - } - return true -} - -func buildOpenAIWSReplayInputSequence( - previousFullInput []json.RawMessage, - previousFullInputExists bool, - currentPayload []byte, - hasPreviousResponseID bool, -) ([]json.RawMessage, bool, error) { - currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) - if currentErr != nil { - return nil, false, currentErr - } - if !hasPreviousResponseID { - return cloneOpenAIWSRawMessages(currentItems), currentExists, nil - } - if !previousFullInputExists { - return cloneOpenAIWSRawMessages(currentItems), currentExists, nil - } - if !currentExists || len(currentItems) == 0 { - return cloneOpenAIWSRawMessages(previousFullInput), true, nil - } - if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { - return cloneOpenAIWSRawMessages(currentItems), true, nil - } - merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems)) - merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...) - merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...) - return merged, true, nil -} - -func setOpenAIWSPayloadInputSequence( - payload []byte, - fullInput []json.RawMessage, - fullInputExists bool, -) ([]byte, error) { - if !fullInputExists { - return payload, nil - } - // Preserve [] vs null semantics when input exists but is empty. - inputForMarshal := fullInput - if inputForMarshal == nil { - inputForMarshal = []json.RawMessage{} +func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool { + if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() { + return true } - inputRaw, marshalErr := json.Marshal(inputForMarshal) - if marshalErr != nil { - return nil, marshalErr + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery { + return true } - return sjson.SetRawBytes(payload, "input", inputRaw) + return false } -func shouldKeepIngressPreviousResponseID( - previousPayload []byte, - currentPayload []byte, - lastTurnResponseID string, - hasFunctionCallOutput bool, -) (bool, string, error) { - if hasFunctionCallOutput { - return true, "has_function_call_output", nil - } - currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) - if currentPreviousResponseID == "" { - return false, "missing_previous_response_id", nil +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true } - expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) - if expectedPreviousResponseID == "" { - return false, "missing_last_turn_response_id", nil + if len(reqBody) == 0 { + return false } - if currentPreviousResponseID != expectedPreviousResponseID { - return false, "previous_response_id_mismatch", nil + rawStore, ok := reqBody["store"] + if !ok { + return false } - if len(previousPayload) == 0 { - return false, "missing_previous_turn_payload", nil + storeEnabled, ok := rawStore.(bool) + if !ok { + return false } + return !storeEnabled +} - previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload) - if previousComparableErr != nil { - return false, "non_input_compare_error", previousComparableErr - } - currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) - if currentComparableErr != nil { - return false, "non_input_compare_error", currentComparableErr +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true } - if !bytes.Equal(previousComparable, currentComparable) { - return false, "non_input_changed", nil + if len(reqBody) == 0 { + return false } - return true, "strict_incremental_ok", nil -} - -type openAIWSIngressPreviousTurnStrictState struct { - nonInputComparable []byte -} - -func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) { - if len(payload) == 0 { - return nil, nil + storeValue := gjson.GetBytes(reqBody, "store") + if !storeValue.Exists() { + return false } - nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload) - if nonInputErr != nil { - return nil, nonInputErr + if storeValue.Type != gjson.True && storeValue.Type != gjson.False { + return false } - return &openAIWSIngressPreviousTurnStrictState{ - nonInputComparable: nonInputComparable, - }, nil + return !storeValue.Bool() } -func shouldKeepIngressPreviousResponseIDWithStrictState( - previousState *openAIWSIngressPreviousTurnStrictState, - currentPayload []byte, - lastTurnResponseID string, - hasFunctionCallOutput bool, -) (bool, string, error) { - if hasFunctionCallOutput { - return true, "has_function_call_output", nil - } - currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) - if currentPreviousResponseID == "" { - return false, "missing_previous_response_id", nil - } - expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) - if expectedPreviousResponseID == "" { - return false, "missing_last_turn_response_id", nil - } - if currentPreviousResponseID != expectedPreviousResponseID { - return false, "previous_response_id_mismatch", nil +func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string { + if s == nil || s.cfg == nil { + return openAIWSStoreDisabledConnModeStrict } - if previousState == nil { - return false, "missing_previous_turn_payload", nil + mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode)) + switch mode { + case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff: + return mode + case "": + // 兼容旧配置:仅配置了布尔开关时按旧语义推导。 + if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { + return openAIWSStoreDisabledConnModeStrict + } + return openAIWSStoreDisabledConnModeOff + default: + return openAIWSStoreDisabledConnModeStrict } +} - currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) - if currentComparableErr != nil { - return false, "non_input_compare_error", currentComparableErr - } - if !bytes.Equal(previousState.nonInputComparable, currentComparable) { - return false, "non_input_changed", nil +func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool { + switch mode { + case openAIWSStoreDisabledConnModeOff: + return false + case openAIWSStoreDisabledConnModeAdaptive: + reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_") + switch reason { + case "policy_violation", "message_too_big", "auth_failed", "write_request", "write": + return true + default: + return false + } + default: + return true } - return true, "strict_incremental_ok", nil } func (s *OpenAIGatewayService) forwardOpenAIWSV2( @@ -1660,7 +553,20 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( startTime time.Time, attempt int, lastFailureReason string, -) (*OpenAIForwardResult, error) { +) (result *OpenAIForwardResult, err error) { + defer func() { + if recovered := recover(); recovered != nil { + logger.LegacyPrintf( + "service.openai_ws_forwarder", + "[OpenAIWS] recovered panic in forwardOpenAIWSV2: panic=%v stack=%s", + recovered, + string(debug.Stack()), + ) + err = fmt.Errorf("openai ws panic recovered: %v", recovered) + result = nil + } + }() + if s == nil || account == nil { return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil")) } @@ -1669,16 +575,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( if err != nil { return nil, wrapOpenAIWSFallback("build_ws_url", err) } - wsHost := "-" - wsPath := "-" - if parsed, parseErr := url.Parse(wsURL); parseErr == nil && parsed != nil { - if h := strings.TrimSpace(parsed.Host); h != "" { - wsHost = normalizeOpenAIWSLogValue(h) - } - if p := strings.TrimSpace(parsed.Path); p != "" { - wsPath = normalizeOpenAIWSLogValue(p) - } - } + wsHost, wsPath := openAIWSHostPathForLogFromURL(wsURL) logOpenAIWSModeDebug( "dial_target account_id=%d account_type=%s ws_host=%s ws_path=%s", account.ID, @@ -1738,7 +635,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( stateStore := s.getOpenAIWSStateStore() groupID := getOpenAIGroupIDFromContext(c) - sessionHash := s.GenerateSessionHash(c, nil) + sessionHash := s.GenerateSessionHashWithFallback(c, nil, openAIWSIngressFallbackSessionSeedFromContext(c)) if sessionHash == "" { var legacySessionHash string sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey) @@ -1751,9 +648,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( } preferredConnID := "" if stateStore != nil && previousResponseID != "" { - if connID, ok := stateStore.GetResponseConn(previousResponseID); ok { - preferredConnID = connID - } + preferredConnID = openAIWSPreferredConnIDFromResponse(stateStore, previousResponseID) } storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account) if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" { @@ -1800,18 +695,35 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout()) defer acquireCancel() - lease, err := s.getOpenAIWSConnPool().Acquire(acquireCtx, openAIWSAcquireRequest{ - Account: account, - WSURL: wsURL, - Headers: wsHeaders, - PreferredConnID: preferredConnID, - ForceNewConn: forceNewConn, + ingressCtxPool := s.getOpenAIWSIngressContextPool() + if ingressCtxPool == nil { + return nil, wrapOpenAIWSFallback("ctx_pool_unavailable", errors.New("openai ws ingress context pool is nil")) + } + sessionHashForCtx := strings.TrimSpace(sessionHash) + if sessionHashForCtx == "" { + sessionHashForCtx = fmt.Sprintf("httpws:%d:%d", account.ID, startTime.UnixNano()) + } + if forceNewConn { + sessionHashForCtx = fmt.Sprintf("%s:retry:%d", sessionHashForCtx, attempt) + } + ownerID := fmt.Sprintf("httpws_%d_%d", account.ID, attempt) + lease, err := ingressCtxPool.Acquire(acquireCtx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: groupID, + SessionHash: sessionHashForCtx, + OwnerID: ownerID, + WSURL: wsURL, + Headers: cloneHeader(wsHeaders), ProxyURL: func() string { if account.ProxyID != nil && account.Proxy != nil { return account.Proxy.URL() } return "" }(), + Turn: 1, + HasPreviousResponseID: previousResponseID != "", + StrictAffinity: previousResponseID != "", + StoreDisabled: storeDisabled, }) if err != nil { dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err) @@ -1971,6 +883,11 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( } clientDisconnected := false + var downstreamWriteErr error + var requestCtx context.Context + if c != nil && c.Request != nil { + requestCtx = c.Request.Context() + } flushBatchSize := s.openAIWSEventFlushBatchSize() flushInterval := s.openAIWSEventFlushInterval() pendingFlushEvents := 0 @@ -1988,23 +905,41 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( pendingFlushEvents = 0 lastFlushAt = time.Now() } + var sseFrameBuf []byte emitStreamMessage := func(message []byte, forceFlush bool) { - if clientDisconnected { + if clientDisconnected || downstreamWriteErr != nil { return } - frame := make([]byte, 0, len(message)+8) - frame = append(frame, "data: "...) - frame = append(frame, message...) - frame = append(frame, '\n', '\n') - _, wErr := c.Writer.Write(frame) + sseFrameBuf = sseFrameBuf[:0] + sseFrameBuf = append(sseFrameBuf, "data: "...) + sseFrameBuf = append(sseFrameBuf, message...) + sseFrameBuf = append(sseFrameBuf, '\n', '\n') + _, wErr := c.Writer.Write(sseFrameBuf) if wErr == nil { wroteDownstream = true pendingFlushEvents++ flushStreamWriter(forceFlush) return } - clientDisconnected = true - logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d", account.ID) + if isOpenAIWSStreamWriteDisconnectError(wErr, requestCtx) { + clientDisconnected = true + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d conn_id=%s", + account.ID, + connID, + ) + return + } + downstreamWriteErr = wErr + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(wErr.Error()), "") + logOpenAIWSModeInfo( + "stream_write_fail account_id=%d conn_id=%s wrote_downstream=%v cause=%s", + account.ID, + connID, + wroteDownstream, + truncateOpenAIWSLogValue(wErr.Error(), openAIWSLogValueMaxLen), + ) } flushBufferedStreamEvents := func(reason string) { if len(bufferedStreamEvents) == 0 { @@ -2013,6 +948,9 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( flushed := len(bufferedStreamEvents) for _, buffered := range bufferedStreamEvents { emitStreamMessage(buffered, false) + if downstreamWriteErr != nil { + break + } } bufferedStreamEvents = bufferedStreamEvents[:0] flushStreamWriter(true) @@ -2170,8 +1108,16 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw) setOpsUpstreamError(c, statusCode, errMsg, "") if reqStream && !clientDisconnected { - flushBufferedStreamEvents("error_event") + if shouldFlushOpenAIWSBufferedEventsOnError(reqStream, wroteDownstream, clientDisconnected) { + flushBufferedStreamEvents("error_event") + } else { + bufferedStreamEvents = bufferedStreamEvents[:0] + } emitStreamMessage(message, true) + if downstreamWriteErr != nil { + lease.MarkBroken() + return nil, fmt.Errorf("openai ws stream write: %w", downstreamWriteErr) + } } if !reqStream { c.JSON(statusCode, gin.H{ @@ -2207,10 +1153,14 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( } else { flushBufferedStreamEvents(eventType) emitStreamMessage(message, isTerminalEvent) + if downstreamWriteErr != nil { + lease.MarkBroken() + return nil, fmt.Errorf("openai ws stream write: %w", downstreamWriteErr) + } } } else { if responseField.Exists() && responseField.Type == gjson.JSON { - finalResponse = []byte(responseField.Raw) + finalResponse = cloneOpenAIWSJSONRawString(responseField.Raw) } } @@ -2253,7 +1203,14 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( if responseID != "" && stateStore != nil { ttl := s.openAIWSResponseStickyTTL() logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) - stateStore.BindResponseConn(responseID, lease.ConnID(), ttl) + if connID, ok := normalizeOpenAIWSPreferredConnID(lease.ConnID()); ok { + stateStore.BindResponseConn(responseID, connID, ttl) + } + if sessionHash != "" && shouldPersistOpenAIWSLastResponseID(lastEventType) { + stateStore.BindSessionLastResponseID(groupID, sessionHash, responseID, s.openAIWSSessionStickyTTL()) + } else if sessionHash != "" { + stateStore.DeleteSessionLastResponseID(groupID, sessionHash) + } } if stateStore != nil && storeDisabled && sessionHash != "" { stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL()) @@ -2282,14 +1239,16 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2( ) return &OpenAIForwardResult{ - RequestID: responseID, - Usage: *usage, - Model: originalModel, - ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), - Stream: reqStream, - OpenAIWSMode: true, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, + RequestID: responseID, + Usage: *usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + WSIngressMode: OpenAIWSIngressModeCtxPool, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + TerminalEventType: strings.TrimSpace(lastEventType), }, nil } @@ -2301,9 +1260,27 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( clientConn *coderws.Conn, account *Account, token string, + firstClientMessageType coderws.MessageType, firstClientMessage []byte, hooks *OpenAIWSIngressHooks, -) error { +) (err error) { + defer func() { + if recovered := recover(); recovered != nil { + const panicCloseReason = "internal websocket proxy panic" + logger.LegacyPrintf( + "service.openai_ws_forwarder", + "[OpenAIWS] recovered panic in ProxyResponsesWebSocketFromClient: panic=%v stack=%s", + recovered, + string(debug.Stack()), + ) + err = NewOpenAIWSClientCloseError( + coderws.StatusInternalError, + panicCloseReason, + fmt.Errorf("panic recovered: %v", recovered), + ) + } + }() + if s == nil { return errors.New("service is nil") } @@ -2322,33 +1299,78 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled - ingressMode := OpenAIWSIngressModeShared - if modeRouterV2Enabled { - ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) - if ingressMode == OpenAIWSIngressModeOff { - return NewOpenAIWSClientCloseError( - coderws.StatusPolicyViolation, - "websocket mode is disabled for this account", - nil, - ) - } + if !modeRouterV2Enabled { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode requires mode_router_v2 with ctx_pool/passthrough", + nil, + ) + } + ingressMode := account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) + logOpenAIWSModeInfo( + "ingress_ws_validate account_id=%d ingress_mode=%s transport=%s", + account.ID, + normalizeOpenAIWSLogValue(string(ingressMode)), + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + ) + if ingressMode == OpenAIWSIngressModeOff { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode is disabled for this account", + nil, + ) + } + if ingressMode != OpenAIWSIngressModeCtxPool && ingressMode != OpenAIWSIngressModePassthrough { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode only supports ctx_pool/passthrough", + nil, + ) } if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) } - dedicatedMode := modeRouterV2Enabled && ingressMode == OpenAIWSIngressModeDedicated - + firstModel, firstPreviousResponseID, firstPreviousResponseIDKind := ResolveOpenAIWSFirstMessageMeta(c, firstClientMessage) + if firstModel == "" { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "model is required in first response.create payload", + nil, + ) + } + if ingressMode == OpenAIWSIngressModeCtxPool && + firstPreviousResponseID != "" && + firstPreviousResponseIDKind == OpenAIPreviousResponseIDKindMessageID { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "previous_response_id must be a response.id (resp_*), not a message id", + nil, + ) + } + if ingressMode == OpenAIWSIngressModePassthrough { + return s.proxyResponsesWebSocketV2Passthrough(ctx, c, clientConn, account, token, firstClientMessageType, firstClientMessage, hooks, wsDecision) + } + // Ingress ws_v2 请求天然是 Codex 会话语义,ctx_pool 是否启用仅由账号 mode 决定。 + ctxPoolMode := ingressMode == OpenAIWSIngressModeCtxPool + ctxPoolSessionScope := "" + if ctxPoolMode { + ctxPoolSessionScope = openAIWSIngressSessionScopeFromContext(c) + } wsURL, err := s.buildOpenAIResponsesWSURL(account) if err != nil { return fmt.Errorf("build ws url: %w", err) } - wsHost := "-" - wsPath := "-" - if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil { - wsHost = normalizeOpenAIWSLogValue(parsedURL.Host) - wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) - } + wsHost, wsPath := openAIWSHostPathForLogFromURL(wsURL) debugEnabled := isOpenAIWSModeDebugEnabled() + logOpenAIWSModeInfo( + "ingress_ws_session_init account_id=%d ws_host=%s ws_path=%s ctx_pool=%v session_scope=%s debug=%v", + account.ID, + wsHost, + wsPath, + ctxPoolMode, + truncateOpenAIWSLogValue(ctxPoolSessionScope, openAIWSIDValueMaxLen), + debugEnabled, + ) type openAIWSClientPayload struct { payloadRaw []byte @@ -2475,32 +1497,64 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) stateStore := s.getOpenAIWSStateStore() groupID := getOpenAIGroupIDFromContext(c) - sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash) - if turnState == "" && stateStore != nil && sessionHash != "" { + fallbackSessionSeed := openAIWSIngressFallbackSessionSeedFromContext(c) + legacySessionHash := strings.TrimSpace(s.GenerateSessionHashWithFallback(c, firstPayload.rawForHash, fallbackSessionSeed)) + sessionHash := legacySessionHash + if ctxPoolMode { + sessionHash = openAIWSApplySessionScope(legacySessionHash, ctxPoolSessionScope) + } + resolveSessionTurnState := func() (string, bool) { + if stateStore == nil || sessionHash == "" { + return "", false + } if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + return savedTurnState, true + } + if !ctxPoolMode || legacySessionHash == "" || legacySessionHash == sessionHash { + return "", false + } + return stateStore.GetSessionTurnState(groupID, legacySessionHash) + } + resolveSessionLastResponseID := func() (string, bool) { + if stateStore == nil || sessionHash == "" { + return "", false + } + if savedResponseID, ok := stateStore.GetSessionLastResponseID(groupID, sessionHash); ok { + return strings.TrimSpace(savedResponseID), true + } + if !ctxPoolMode || legacySessionHash == "" || legacySessionHash == sessionHash { + return "", false + } + savedResponseID, ok := stateStore.GetSessionLastResponseID(groupID, legacySessionHash) + return strings.TrimSpace(savedResponseID), ok + } + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := resolveSessionTurnState(); ok { turnState = savedTurnState } } + sessionLastResponseID := "" + if stateStore != nil && sessionHash != "" { + if savedResponseID, ok := resolveSessionLastResponseID(); ok { + sessionLastResponseID = savedResponseID + } + } preferredConnID := "" if stateStore != nil && firstPayload.previousResponseID != "" { - if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok { - preferredConnID = connID - } + preferredConnID = openAIWSPreferredConnIDFromResponse(stateStore, firstPayload.previousResponseID) } storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account) storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() - if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" { - if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { - preferredConnID = connID - } - } isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) - baseAcquireReq := openAIWSAcquireRequest{ - Account: account, + baseAcquireReq := struct { + WSURL string + Headers http.Header + ProxyURL string + }{ WSURL: wsURL, Headers: wsHeaders, ProxyURL: func() string { @@ -2509,21 +1563,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } return "" }(), - ForceNewConn: false, } - pool := s.getOpenAIWSConnPool() - if pool == nil { - return errors.New("openai ws conn pool is nil") + + ingressCtxPool := s.getOpenAIWSIngressContextPool() + if ingressCtxPool == nil { + return errors.New("openai ws ingress context pool is nil") } logOpenAIWSModeInfo( - "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s store_disabled=%v has_session_hash=%v has_previous_response_id=%v", + "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_session_hash=%v has_previous_response_id=%v", account.ID, account.Type, normalizeOpenAIWSLogValue(string(wsDecision.Transport)), wsHost, wsPath, normalizeOpenAIWSLogValue(ingressMode), + ctxPoolMode, storeDisabled, sessionHash != "", firstPayload.previousResponseID != "", @@ -2531,7 +1586,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( if debugEnabled { logOpenAIWSModeDebug( - "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v", + "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v ctx_pool_mode=%v", account.ID, account.Type, normalizeOpenAIWSLogValue(string(wsDecision.Transport)), @@ -2540,6 +1595,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( sessionHash != "", firstPayload.previousResponseID != "", storeDisabled, + ctxPoolMode, ) } if firstPayload.previousResponseID != "" { @@ -2566,15 +1622,37 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( acquireTimeout = 30 * time.Second } - acquireTurnLease := func(turn int, preferred string, forcePreferredConn bool) (*openAIWSConnLease, error) { - req := cloneOpenAIWSAcquireRequest(baseAcquireReq) - req.PreferredConnID = strings.TrimSpace(preferred) - req.ForcePreferredConn = forcePreferredConn - // dedicated 模式下每次获取均新建连接,避免跨会话复用残留上下文。 - req.ForceNewConn = dedicatedMode + ownerID := fmt.Sprintf("cliws_%p", clientConn) + acquireTurnLease := func( + turn int, + preferred string, + forcePreferredConn bool, + hasPreviousResponseID bool, + ) (openAIWSIngressUpstreamLease, error) { acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout) - lease, acquireErr := pool.Acquire(acquireCtx, req) - acquireCancel() + defer acquireCancel() + + var ( + lease openAIWSIngressUpstreamLease + acquireErr error + ) + sessionHashForCtx := strings.TrimSpace(sessionHash) + if sessionHashForCtx == "" { + sessionHashForCtx = fmt.Sprintf("conn:%s", ownerID) + } + lease, acquireErr = ingressCtxPool.Acquire(acquireCtx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: groupID, + SessionHash: sessionHashForCtx, + OwnerID: ownerID, + WSURL: baseAcquireReq.WSURL, + Headers: cloneHeader(baseAcquireReq.Headers), + ProxyURL: baseAcquireReq.ProxyURL, + Turn: turn, + HasPreviousResponseID: hasPreviousResponseID, + StrictAffinity: forcePreferredConn, + StoreDisabled: storeDisabled, + }) if acquireErr != nil { dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr) logOpenAIWSModeInfo( @@ -2597,14 +1675,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( wsPath, account.ProxyID != nil && account.Proxy != nil, ) - if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { - return nil, NewOpenAIWSClientCloseError( - coderws.StatusPolicyViolation, - "upstream continuation connection is unavailable; please restart the conversation", - acquireErr, - ) - } - if errors.Is(acquireErr, context.DeadlineExceeded) || errors.Is(acquireErr, errOpenAIWSConnQueueFull) { + if errors.Is(acquireErr, context.DeadlineExceeded) || + errors.Is(acquireErr, errOpenAIWSConnQueueFull) || + errors.Is(acquireErr, errOpenAIWSIngressContextBusy) { return nil, NewOpenAIWSClientCloseError( coderws.StatusTryAgainLater, "upstream websocket is busy, please retry later", @@ -2627,7 +1700,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( baseAcquireReq.Headers = updatedHeaders } logOpenAIWSModeInfo( - "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s", + "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s ctx_pool_mode=%v schedule_layer=%s stickiness_level=%s migration_used=%v", account.ID, turn, truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), @@ -2635,6 +1708,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( lease.ConnPickDuration().Milliseconds(), lease.QueueWaitDuration().Milliseconds(), truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), + ctxPoolMode, + normalizeOpenAIWSLogValue(lease.ScheduleLayer()), + normalizeOpenAIWSLogValue(lease.StickinessLevel()), + lease.MigrationUsed(), ) return lease, nil } @@ -2646,8 +1723,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } readClientMessage := func() ([]byte, error) { - msgType, payload, readErr := clientConn.Read(ctx) + readCtx := ctx + if idleTimeout := s.openAIWSClientReadIdleTimeout(); idleTimeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(ctx, idleTimeout) + defer cancel() + } + msgType, payload, readErr := clientConn.Read(readCtx) if readErr != nil { + if readCtx != nil && readCtx.Err() == context.DeadlineExceeded && (ctx == nil || ctx.Err() == nil) { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "client websocket idle timeout", + readErr, + ) + } return nil, readErr } if msgType != coderws.MessageText && msgType != coderws.MessageBinary { @@ -2660,12 +1750,73 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return payload, nil } - sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) { + // 持久客户端读取 goroutine:将客户端消息推送到 channel, + // 使 sendAndRelay 可以通过 select 同时监听上游事件和客户端新请求。 + var nextClientPreemptedPayload []byte + var pendingClientReadErr error + clientMsgCh := make(chan []byte, 1) + clientReadErrCh := make(chan error, 1) + go func() { + defer close(clientMsgCh) + for { + msg, err := readClientMessage() + if err != nil { + select { + case clientReadErrCh <- err: + case <-ctx.Done(): + } + return + } + select { + case clientMsgCh <- msg: + case <-ctx.Done(): + return + } + } + }() + + sendAndRelay := func(turn int, lease openAIWSIngressUpstreamLease, payload []byte, payloadBytes int, originalModel string, expectedPreviousResponseID string) (*OpenAIForwardResult, error) { if lease == nil { return nil, errors.New("upstream websocket lease is nil") } turnStart := time.Now() wroteDownstream := false + reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true) + turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID) + turnExpectedPreviousResponseID := strings.TrimSpace(expectedPreviousResponseID) + turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key") + turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account) + turnFunctionCallOutputCallIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload) + turnHasFunctionCallOutput := len(turnFunctionCallOutputCallIDs) > 0 + turnHasToolOutputContext := openAIWSHasToolCallContextInPayload(payload) || + openAIWSHasItemReferenceForAllFunctionCallOutputsInPayload(payload, turnFunctionCallOutputCallIDs) + + // 预防性检测:必须在发往上游之前执行,避免“先写上游再本地失败”造成 turn 状态错位。 + // 在 store_disabled 模式下,若 function_call_output 既没有 previous_response_id, + // 也没有可关联的 tool_call/item_reference 上下文,则必然会触发 tool_output_not_found。 + // 提前返回可恢复错误,由外层 recoverIngressPrevResponseNotFound 执行 context replay 重试。 + if shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID( + turnStoreDisabled, + turnHasFunctionCallOutput, + turnPreviousResponseID, + turnHasToolOutputContext, + ) { + logOpenAIWSModeInfo( + "ingress_ws_proactive_tool_output_reject account_id=%d turn=%d conn_id=%s has_tool_output_context=%v previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + turnHasToolOutputContext, + truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), + ) + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + openAIWSIngressStageToolOutputNotFound, + errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"), + false, + nil, + ) + } if err := lease.WriteJSONWithContextTimeout(ctx, json.RawMessage(payload), s.openAIWSWriteTimeout()); err != nil { return nil, wrapOpenAIWSIngressTurnError( "write_upstream", @@ -2686,12 +1837,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( responseID := "" usage := OpenAIUsage{} var firstTokenMs *int - reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true) - turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") - turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID) - turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key") - turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account) - turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + + turnPendingFunctionCallIDSet := make(map[string]struct{}, 4) eventCount := 0 tokenEventCount := 0 terminalEventCount := 0 @@ -2699,8 +1846,31 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( lastEventType := "" needModelReplace := false clientDisconnected := false + clientDisconnectDrainDeadline := time.Time{} + terminateOnErrorEvent := false + terminateOnErrorMessage := "" mappedModel := "" var mappedModelBytes []byte + buildPartialResult := func(terminalEventType string) *OpenAIForwardResult { + if usage.InputTokens <= 0 && + usage.OutputTokens <= 0 && + usage.CacheCreationInputTokens <= 0 && + usage.CacheReadInputTokens <= 0 { + return nil + } + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + WSIngressMode: OpenAIWSIngressModeCtxPool, + Duration: time.Since(turnStart), + FirstTokenMs: firstTokenMs, + TerminalEventType: strings.TrimSpace(terminalEventType), + } + } if originalModel != "" { mappedModel = account.GetMappedModel(originalModel) if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { @@ -2711,18 +1881,190 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( mappedModelBytes = []byte(mappedModel) } } + // 启动上游事件读取泵:解耦上游读取和客户端写入,允许二者并发执行。 + // 读取 goroutine 将上游事件推送到缓冲 channel,主 goroutine 从 channel 消费并处理/转发。 + // 缓冲 channel 允许上游在客户端写入阻塞时继续读取后续事件,降低端到端延迟。 + pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize) + pumpCtx, pumpCancel := context.WithCancel(ctx) + defer pumpCancel() + pumpStartedAt := time.Now() + go func() { + defer func() { + close(pumpEventCh) + if pumpCtx.Err() == nil { + return + } + pumpAlive := time.Since(pumpStartedAt) + if pumpAlive >= openAIWSUpstreamPumpInfoMinAlive { + logOpenAIWSModeInfo( + "ingress_ws_upstream_pump_exit account_id=%d turn=%d conn_id=%s reason=context_cancelled pump_alive_ms=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + pumpAlive.Milliseconds(), + ) + return + } + logOpenAIWSModeDebug( + "ingress_ws_upstream_pump_exit account_id=%d turn=%d conn_id=%s reason=context_cancelled pump_alive_ms=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + pumpAlive.Milliseconds(), + ) + }() + for { + msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, s.openAIWSReadTimeout()) + select { + case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}: + case <-pumpCtx.Done(): + return + } + if readErr != nil { + return + } + // 检测终端/错误事件以终止读取泵。 + evtType, _ := parseOpenAIWSEventType(msg) + if isOpenAIWSTerminalEvent(evtType) || evtType == "error" { + return + } + } + }() + var drainTimer *time.Timer + defer func() { + if drainTimer != nil { + drainTimer.Stop() + } + }() for { - upstreamMessage, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) - if readErr != nil { + var evt openAIWSUpstreamPumpEvent + var evtOk bool + select { + case evt, evtOk = <-pumpEventCh: + if !evtOk { + goto pumpClosed + } + case preemptMsg, ok := <-clientMsgCh: + if !ok { + // 客户端读取 goroutine 退出,置空 channel 防止再次 select + clientMsgCh = nil + continue + } + // 客户端抢占:暂存新请求,取消上游转发,返回让外层切换到下一 turn + nextClientPreemptedPayload = preemptMsg + logOpenAIWSModeInfo( + "ingress_ws_client_preempt account_id=%d turn=%d conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + ) + pumpCancel() + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "client_preempted", errOpenAIWSClientPreempted, + wroteDownstream, buildPartialResult("client_preempted")) + case readErr := <-clientReadErrCh: + // 客户端断连:立即取消上游 pump 并释放连接。 + // Codex CLI 在 ESC 取消后会关闭旧 WebSocket 并新建连接发送下一条消息, + // 继续排水只会延迟新连接获取上游 lease,因此这里直接终止。 + if isOpenAIWSClientDisconnectError(readErr) { + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "ingress_ws_client_disconnected_immediate_cancel account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + pumpCancel() + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "client_disconnected_immediate", + fmt.Errorf("client disconnected (read): %w", readErr), + wroteDownstream, + buildPartialResult("client_disconnected"), + ) + } + pendingClientReadErr = readErr + cause := "-" + if readErr != nil { + cause = truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen) + } + logOpenAIWSModeInfo( + "ingress_ws_client_read_error_deferred account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + cause, + ) + clientMsgCh = nil + clientReadErrCh = nil + continue + } + // 排水超时检查:客户端已断连且排水截止时间已过,终止读取。 + if clientDisconnected && !clientDisconnectDrainDeadline.IsZero() && time.Now().After(clientDisconnectDrainDeadline) { + pumpCancel() + logOpenAIWSModeInfo( + "ingress_ws_client_disconnected_drain_timeout account_id=%d turn=%d conn_id=%s timeout_ms=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + openAIWSIngressClientDisconnectDrainTimeout.Milliseconds(), + ) + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "client_disconnected_drain_timeout", + openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout), + wroteDownstream, + buildPartialResult("client_disconnected_drain_timeout"), + ) + } + upstreamMessage := evt.message + if evt.err != nil { + readErr := evt.err + if clientDisconnected { + // 排水期间读取失败(上游关闭或读取泵被取消),按排水超时处理。 + logOpenAIWSModeInfo( + "ingress_ws_client_disconnected_drain_timeout account_id=%d turn=%d conn_id=%s timeout_ms=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + openAIWSIngressClientDisconnectDrainTimeout.Milliseconds(), + ) + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "client_disconnected_drain_timeout", + openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout), + wroteDownstream, + buildPartialResult("client_disconnected_drain_timeout"), + ) + } + readErrClass := classifyOpenAIWSIngressReadErrorClass(readErr) + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "ingress_ws_upstream_read_error account_id=%d turn=%d conn_id=%s class=%s close_status=%s close_reason=%s events_received=%d wrote_downstream=%v response_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(readErrClass), + closeStatus, + closeReason, + eventCount, + wroteDownstream, + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + ) lease.MarkBroken() - return nil, wrapOpenAIWSIngressTurnError( + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( "read_upstream", fmt.Errorf("read upstream websocket event: %w", readErr), wroteDownstream, + buildPartialResult("read_upstream"), ) } - eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage) + eventType, eventResponseID := parseOpenAIWSEventType(upstreamMessage) if responseID == "" && eventResponseID != "" { responseID = eventResponseID } @@ -2737,15 +2079,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + recoveryEnabled := s.openAIWSIngressPreviousResponseRecoveryEnabled() recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound && + recoveryEnabled && + (turnPreviousResponseID != "" || (turnHasFunctionCallOutput && turnExpectedPreviousResponseID != "")) && + !wroteDownstream + // tool_output_not_found: previous_response_id 指向的 response 包含未完成的 function_call + // (用户在 Codex CLI 按 ESC 取消后重新发送消息),需要移除 previous_response_id 后重放。 + recoverableToolOutputNotFound := fallbackReason == openAIWSIngressStageToolOutputNotFound && + recoveryEnabled && turnPreviousResponseID != "" && - !turnHasFunctionCallOutput && - s.openAIWSIngressPreviousResponseRecoveryEnabled() && !wroteDownstream - if recoverablePrevNotFound { + recoverableContextMismatch := recoverablePrevNotFound || recoverableToolOutputNotFound + if recoverableContextMismatch { // 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。 logOpenAIWSModeInfo( - "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", + "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_prompt_cache_key=%v has_function_call_output=%v recovery_enabled=%v wrote_downstream=%v", account.ID, turn, truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), @@ -2757,12 +2106,17 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ingressMode), + ctxPoolMode, turnStoreDisabled, turnPromptCacheKey != "", + turnHasFunctionCallOutput, + recoveryEnabled, + wroteDownstream, ) } else { logOpenAIWSModeInfo( - "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", + "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_prompt_cache_key=%v has_function_call_output=%v recovery_enabled=%v wrote_downstream=%v", account.ID, turn, truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), @@ -2774,24 +2128,39 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ingressMode), + ctxPoolMode, turnStoreDisabled, turnPromptCacheKey != "", + turnHasFunctionCallOutput, + recoveryEnabled, + wroteDownstream, ) } - // previous_response_not_found 在 ingress 模式支持单次恢复重试: + // previous_response_not_found / tool_output_not_found 在 ingress 模式支持单次恢复重试: // 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。 - if recoverablePrevNotFound { + if recoverableContextMismatch { lease.MarkBroken() errMsg := strings.TrimSpace(errMsgRaw) if errMsg == "" { - errMsg = "previous response not found" + if fallbackReason == openAIWSIngressStageToolOutputNotFound { + errMsg = "no tool output found for function call" + } else { + errMsg = "previous response not found" + } } - return nil, wrapOpenAIWSIngressTurnError( - openAIWSIngressStagePreviousResponseNotFound, + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + fallbackReason, errors.New(errMsg), false, + buildPartialResult(fallbackReason), ) } + terminateOnErrorEvent = true + terminateOnErrorMessage = strings.TrimSpace(errMsgRaw) + if terminateOnErrorMessage == "" { + terminateOnErrorMessage = "upstream websocket error" + } } isTokenEvent := isOpenAIWSTokenEvent(eventType) if isTokenEvent { @@ -2808,6 +2177,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( if openAIWSEventShouldParseUsage(eventType) { parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage) } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) { + for _, callID := range openAIWSExtractPendingFunctionCallIDsFromEvent(upstreamMessage) { + turnPendingFunctionCallIDSet[callID] = struct{}{} + } + } if !clientDisconnected { if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) { @@ -2820,27 +2194,47 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } if err := writeClientMessage(upstreamMessage); err != nil { if isOpenAIWSClientDisconnectError(err) { - clientDisconnected = true + // 客户端断连:立即取消上游 pump 并释放连接。 + // 不再排水等待,以便新连接能尽快获取上游 lease。 closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err) logOpenAIWSModeInfo( - "ingress_ws_client_disconnected_drain account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s", + "ingress_ws_client_disconnected_immediate_cancel account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s", account.ID, turn, truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), closeStatus, truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), ) + pumpCancel() + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "client_disconnected_immediate", + fmt.Errorf("client disconnected (write): %w", err), + wroteDownstream, + buildPartialResult("client_disconnected"), + ) } else { - return nil, wrapOpenAIWSIngressTurnError( + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( "write_client", fmt.Errorf("write client websocket event: %w", err), wroteDownstream, + buildPartialResult("write_client"), ) } } else { wroteDownstream = true } } + if terminateOnErrorEvent { + // WS ingress 中的 error 事件应立即终止当前 turn,避免继续阻塞在下一次上游 read。 + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnErrorWithPartial( + "upstream_error_event", + errors.New(terminateOnErrorMessage), + wroteDownstream, + buildPartialResult("upstream_error_event"), + ) + } if isTerminalEvent { // 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。 if clientDisconnected { @@ -2852,7 +2246,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } if debugEnabled { logOpenAIWSModeDebug( - "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v", + "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v has_function_call_output=%v pending_function_call_ids=%d", account.ID, turn, truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), @@ -2865,20 +2259,47 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), firstTokenMsValue, clientDisconnected, + turnHasFunctionCallOutput, + len(turnPendingFunctionCallIDSet), ) } + pendingFunctionCallIDs := make([]string, 0, len(turnPendingFunctionCallIDSet)) + for callID := range turnPendingFunctionCallIDSet { + pendingFunctionCallIDs = append(pendingFunctionCallIDs, callID) + } + sort.Strings(pendingFunctionCallIDs) return &OpenAIForwardResult{ - RequestID: responseID, - Usage: usage, - Model: originalModel, - ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), - Stream: reqStream, - OpenAIWSMode: true, - Duration: time.Since(turnStart), - FirstTokenMs: firstTokenMs, + RequestID: responseID, + Usage: usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + WSIngressMode: OpenAIWSIngressModeCtxPool, + Duration: time.Since(turnStart), + FirstTokenMs: firstTokenMs, + TerminalEventType: strings.TrimSpace(eventType), + PendingFunctionCallIDs: pendingFunctionCallIDs, }, nil } } + pumpClosed: + // 读取泵 channel 关闭但未收到终端事件: + // - 客户端已断连:按排水超时收尾,避免误判为 read_upstream。 + // - 其他场景:按上游读取异常处理。 + lease.MarkBroken() + if clientDisconnected { + return nil, openAIWSIngressPumpClosedTurnError( + true, + wroteDownstream, + buildPartialResult("client_disconnected_drain_timeout"), + ) + } + return nil, openAIWSIngressPumpClosedTurnError( + false, + wroteDownstream, + buildPartialResult("read_upstream"), + ) } currentPayload := firstPayload.payloadRaw @@ -2890,50 +2311,39 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != "" } - var sessionLease *openAIWSConnLease + var sessionLease openAIWSIngressUpstreamLease sessionConnID := "" - pinnedSessionConnID := "" - unpinSessionConn := func(connID string) { - connID = strings.TrimSpace(connID) - if connID == "" || pinnedSessionConnID != connID { - return - } - pool.UnpinConn(account.ID, connID) - pinnedSessionConnID = "" - } - pinSessionConn := func(connID string) { - if !storeDisabled { - return - } - connID = strings.TrimSpace(connID) - if connID == "" || pinnedSessionConnID == connID { + unpinSessionConn := func(_ string) {} + pinSessionConn := func(_ string) {} + releaseSessionLease := func() { + if sessionLease == nil { return } - if pinnedSessionConnID != "" { - pool.UnpinConn(account.ID, pinnedSessionConnID) - pinnedSessionConnID = "" - } - if pool.PinConn(account.ID, connID) { - pinnedSessionConnID = connID + unpinSessionConn(sessionConnID) + sessionLease.Release() + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_upstream_released account_id=%d conn_id=%s", + account.ID, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + ) } } - releaseSessionLease := func() { + yieldSessionLease := func() { if sessionLease == nil { return } - if dedicatedMode { - // dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。 - sessionLease.MarkBroken() - } unpinSessionConn(sessionConnID) - sessionLease.Release() + sessionLease.Yield() if debugEnabled { logOpenAIWSModeDebug( - "ingress_ws_upstream_released account_id=%d conn_id=%s", + "ingress_ws_upstream_yielded account_id=%d conn_id=%s", account.ID, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), ) } + sessionLease = nil + sessionConnID = "" } defer releaseSessionLease() @@ -2941,7 +2351,17 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( turnRetry := 0 turnPrevRecoveryTried := false lastTurnFinishedAt := time.Time{} - lastTurnResponseID := "" + lastTurnResponseID := sessionLastResponseID + clearSessionLastResponseID := func() { + lastTurnResponseID = "" + if stateStore == nil || sessionHash == "" { + return + } + stateStore.DeleteSessionLastResponseID(groupID, sessionHash) + if ctxPoolMode && legacySessionHash != "" && legacySessionHash != sessionHash { + stateStore.DeleteSessionLastResponseID(groupID, legacySessionHash) + } + } lastTurnPayload := []byte(nil) var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState lastTurnReplayInput := []json.RawMessage(nil) @@ -2953,6 +2373,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( if sessionLease == nil { return } + resetStart := time.Now() + resetConnID := sessionConnID if markBroken { sessionLease.MarkBroken() } @@ -2960,12 +2382,194 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( sessionLease = nil sessionConnID = "" preferredConnID = "" + if elapsed := time.Since(resetStart); elapsed > 100*time.Millisecond { + logOpenAIWSModeInfo( + "ingress_ws_reset_session_lease_slow account_id=%d conn_id=%s mark_broken=%v elapsed_ms=%d", + account.ID, + truncateOpenAIWSLogValue(resetConnID, openAIWSIDValueMaxLen), + markBroken, + elapsed.Milliseconds(), + ) + } } recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool { - if !isOpenAIWSIngressPreviousResponseNotFound(relayErr) { + isPrevNotFound := isOpenAIWSIngressPreviousResponseNotFound(relayErr) + isToolOutputMissing := isOpenAIWSIngressToolOutputNotFound(relayErr) + if !isPrevNotFound && !isToolOutputMissing { return false } if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() { + skipReason := "already_tried" + if !s.openAIWSIngressPreviousResponseRecoveryEnabled() { + skipReason = "recovery_disabled" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skipped account_id=%d turn=%d conn_id=%s reason=%s is_prev_not_found=%v is_tool_output_missing=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(skipReason), + isPrevNotFound, + isToolOutputMissing, + ) + return false + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + // tool_output_not_found: previous_response_id 指向的 response 包含未完成的 function_call + // (用户在 Codex CLI 按 ESC 取消了 function_call 后重新发送消息)。 + // 对齐/保持 previous_response_id 无法解决问题,直接跳到 drop 分支移除后重放。 + if isToolOutputMissing { + logOpenAIWSModeInfo( + "ingress_ws_tool_output_not_found_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + turnPrevRecoveryTried = true + updatedPayload := currentPayload + if currentPreviousResponseID != "" { + dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_tool_output_not_found_recovery_skip account_id=%d turn=%d conn_id=%s reason=drop_error", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + return false + } + if removed { + updatedPayload = dropped + } + } + // previous_response_id 已不存在或已移除,继续执行 setOpenAIWSPayloadInputSequence + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_tool_output_not_found_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + ) + return false + } + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + clearSessionLastResponseID() + resetSessionLease(true) + skipBeforeTurn = true + return true + } + hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() + if hasFunctionCallOutput { + turnPrevRecoveryTried = true + expectedPrev := strings.TrimSpace(lastTurnResponseID) + if currentPreviousResponseID == "" && expectedPrev != "" { + updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev) + if setPrevErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s previous_response_id=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=set_previous_response_id_retry previous_response_id=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + } + } + alignedPayload, aligned, alignErr := alignStoreDisabledPreviousResponseID(currentPayload, expectedPrev) + if alignErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=align_previous_response_id_error cause=%s previous_response_id=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(alignErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else if aligned { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + alignedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s previous_response_id=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=align_previous_response_id_retry previous_response_id=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + } + // function_call_output 与 previous_response_id 语义绑定: + // function_call_output 引用了前一个 response 中的 call_id, + // 移除 previous_response_id 但保留 function_call_output 会导致上游报错 + // "No tool call found for function call output with call_id ..."。 + // 此场景在网关层不可恢复,返回 false 走 abort 路径通知客户端, + // 客户端收到错误后会重置并发送完整请求(不带 previous_response_id)。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=abort_function_call_unrecoverable previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) return false } if isStrictAffinityTurn(currentPayload) { @@ -3018,12 +2622,26 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ) currentPayload = updatedWithInput currentPayloadBytes = len(updatedWithInput) + clearSessionLastResponseID() resetSessionLease(true) skipBeforeTurn = true return true } retryIngressTurn := func(relayErr error, turn int, connID string) bool { if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 { + retrySkipReason := "not_retryable" + if turnRetry >= 1 { + retrySkipReason = "retry_exhausted" + } + logOpenAIWSModeInfo( + "ingress_ws_turn_retry_skipped account_id=%d turn=%d conn_id=%s reason=%s retry_count=%d err_stage=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(retrySkipReason), + turnRetry, + truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen), + ) return false } if isStrictAffinityTurn(currentPayload) { @@ -3048,6 +2666,160 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( skipBeforeTurn = true return true } + advanceToNextClientTurn := func(turn int, connID string) (bool, error) { + logOpenAIWSModeInfo( + "ingress_ws_advance_wait_client account_id=%d turn=%d conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + var nextClientMessage []byte + if nextClientPreemptedPayload != nil { + nextClientMessage = nextClientPreemptedPayload + nextClientPreemptedPayload = nil + logOpenAIWSModeInfo( + "ingress_ws_advance_use_preempted_payload account_id=%d turn=%d conn_id=%s payload_bytes=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + len(nextClientMessage), + ) + } else { + if pendingReadErr := openAIWSAdvanceConsumePendingClientReadErr(&pendingClientReadErr); pendingReadErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_advance_read_fail account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(pendingReadErr.Error(), openAIWSLogValueMaxLen), + ) + return false, pendingReadErr + } + if openAIWSAdvanceClientReadUnavailable(clientMsgCh, clientReadErrCh) { + logOpenAIWSModeInfo( + "ingress_ws_advance_read_unavailable account_id=%d turn=%d conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + return false, fmt.Errorf("read client websocket request: %w", errOpenAIWSAdvanceClientReadUnavailable) + } + select { + case msg, ok := <-clientMsgCh: + if !ok { + // 客户端读取 goroutine 已退出 + return true, nil + } + nextClientMessage = msg + case readErr := <-clientReadErrCh: + if isOpenAIWSClientDisconnectError(readErr) { + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + return true, nil + } + logOpenAIWSModeInfo( + "ingress_ws_advance_read_fail account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + ) + return false, fmt.Errorf("read client websocket request: %w", readErr) + } + } + + nextPayload, parseErr := parseClientPayload(nextClientMessage) + if parseErr != nil { + return false, parseErr + } + nextStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(nextPayload.payloadRaw, account) + nextLegacySessionHash := strings.TrimSpace(s.GenerateSessionHashWithFallback(c, nextPayload.rawForHash, fallbackSessionSeed)) + nextSessionHash := nextLegacySessionHash + if ctxPoolMode { + nextSessionHash = openAIWSApplySessionScope(nextLegacySessionHash, ctxPoolSessionScope) + } + if sessionHash == "" && nextSessionHash != "" { + sessionHash = nextSessionHash + legacySessionHash = nextLegacySessionHash + if stateStore != nil { + if turnState == "" { + if savedTurnState, ok := resolveSessionTurnState(); ok { + turnState = savedTurnState + } + } + if lastTurnResponseID == "" { + if savedResponseID, ok := resolveSessionLastResponseID(); ok { + lastTurnResponseID = savedResponseID + } + } + } + logOpenAIWSModeInfo( + "ingress_ws_session_hash_backfill account_id=%d turn=%d next_turn=%d conn_id=%s session_hash=%s has_turn_state=%v has_last_response_id=%v store_disabled=%v", + account.ID, + turn, + turn+1, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(sessionHash, 12), + turnState != "", + strings.TrimSpace(lastTurnResponseID) != "", + nextStoreDisabled, + ) + } + if nextPayload.promptCacheKey != "" { + // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接; + // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。 + updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey) + baseAcquireReq.Headers = updatedHeaders + } + if nextPayload.previousResponseID != "" { + expectedPrev := strings.TrimSpace(lastTurnResponseID) + chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev + nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v", + account.ID, + turn, + turn+1, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(nextPreviousResponseIDKind), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + chainedFromLast, + nextPayload.promptCacheKey != "", + storeDisabled, + ) + } + if stateStore != nil && nextPayload.previousResponseID != "" { + if stickyConnID := openAIWSPreferredConnIDFromResponse(stateStore, nextPayload.previousResponseID); stickyConnID != "" { + if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID { + logOpenAIWSModeInfo( + "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + ) + } else { + preferredConnID = stickyConnID + } + } + } + currentPayload = nextPayload.payloadRaw + currentOriginalModel = nextPayload.originalModel + currentPayloadBytes = nextPayload.payloadBytes + storeDisabled = nextStoreDisabled + if !storeDisabled { + unpinSessionConn(sessionConnID) + } + return false, nil + } for { if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil { if err := hooks.BeforeTurn(turn); err != nil { @@ -3057,39 +2829,78 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( skipBeforeTurn = false currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") expectedPrev := strings.TrimSpace(lastTurnResponseID) - hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() - // store=false + function_call_output 场景必须有续链锚点。 - // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。 - if shouldInferIngressFunctionCallOutputPreviousResponseID( - storeDisabled, + if expectedPrev == "" && stateStore != nil && sessionHash != "" { + if savedResponseID, ok := resolveSessionLastResponseID(); ok { + expectedPrev = savedResponseID + if expectedPrev != "" { + lastTurnResponseID = expectedPrev + } + } + } + logOpenAIWSModeInfo( + "ingress_ws_turn_begin account_id=%d turn=%d conn_id=%s previous_response_id=%s expected_previous_response_id=%s store_disabled=%v has_session_lease=%v", + account.ID, turn, - hasFunctionCallOutput, - currentPreviousResponseID, - expectedPrev, - ) { - updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev) - if setPrevErr != nil { + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + storeDisabled, + sessionLease != nil, + ) + pendingExpectedCallIDs := []string(nil) + if storeDisabled && expectedPrev != "" && stateStore != nil { + if pendingCallIDs, ok := stateStore.GetResponsePendingToolCalls(groupID, expectedPrev); ok { + pendingExpectedCallIDs = openAIWSNormalizeCallIDs(pendingCallIDs) + } + } + normalized := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: account.ID, + turn: turn, + connID: sessionConnID, + currentPayload: currentPayload, + currentPayloadBytes: currentPayloadBytes, + currentPreviousResponseID: currentPreviousResponseID, + expectedPreviousResponse: expectedPrev, + pendingExpectedCallIDs: pendingExpectedCallIDs, + }) + currentPayload = normalized.currentPayload + currentPayloadBytes = normalized.currentPayloadBytes + currentPreviousResponseID = normalized.currentPreviousResponseID + expectedPrev = normalized.expectedPreviousResponseID + pendingExpectedCallIDs = normalized.pendingExpectedCallIDs + currentFunctionCallOutputCallIDs := normalized.functionCallOutputCallIDs + hasFunctionCallOutput := normalized.hasFunctionCallOutputCallID + + // 当客户端发送 function_call_output 但未携带 previous_response_id 时, + // 主动注入 Gateway 跟踪的 lastTurnResponseID。 + // 在 store_disabled 模式下,上游需要 previous_response_id 来关联 function_call_output 与 response, + // 否则会返回 "No tool call found for function call output" 错误。 + if shouldInferIngressFunctionCallOutputPreviousResponseID(storeDisabled, turn, hasFunctionCallOutput, currentPreviousResponseID, expectedPrev) { + injectedPayload, injectErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev) + if injectErr != nil { logOpenAIWSModeInfo( - "ingress_ws_function_call_output_prev_infer_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s", + "ingress_ws_inject_prev_response_id_fail account_id=%d turn=%d conn_id=%s cause=%s expected_previous_response_id=%s", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), - truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(injectErr.Error(), openAIWSLogValueMaxLen), truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), ) } else { - currentPayload = updatedPayload - currentPayloadBytes = len(updatedPayload) - currentPreviousResponseID = expectedPrev logOpenAIWSModeInfo( - "ingress_ws_function_call_output_prev_infer account_id=%d turn=%d conn_id=%s action=set_previous_response_id previous_response_id=%s", + "ingress_ws_inject_prev_response_id account_id=%d turn=%d conn_id=%s injected_previous_response_id=%s has_function_call_output=%v", account.ID, turn, truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, ) + currentPayload = injectedPayload + currentPayloadBytes = len(injectedPayload) + currentPreviousResponseID = expectedPrev } } + nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence( lastTurnReplayInput, lastTurnReplayInputExists, @@ -3120,6 +2931,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( currentPayload, lastTurnResponseID, hasFunctionCallOutput, + pendingExpectedCallIDs, + currentFunctionCallOutputCallIDs, ) } else { shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID( @@ -3127,6 +2940,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( currentPayload, lastTurnResponseID, hasFunctionCallOutput, + pendingExpectedCallIDs, + currentFunctionCallOutputCallIDs, ) } if strictErr != nil { @@ -3196,9 +3011,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( } } forcePreferredConn := isStrictAffinityTurn(currentPayload) + hasPreviousResponseIDForAcquire := currentPreviousResponseID != "" if sessionLease == nil { - acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) + acquiredLease, acquireErr := acquireTurnLease( + turn, + preferredConnID, + forcePreferredConn, + hasPreviousResponseIDForAcquire, + ) if acquireErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_acquire_lease_fail account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen), + ) return fmt.Errorf("acquire upstream websocket: %w", acquireErr) } sessionLease = acquiredLease @@ -3224,64 +3052,14 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen), ) - if forcePreferredConn { - if !turnPrevRecoveryTried && currentPreviousResponseID != "" { - updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) - if dropErr != nil || !removed { - reason := "not_removed" - if dropErr != nil { - reason = "drop_error" - } - logOpenAIWSModeInfo( - "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s previous_response_id=%s", - account.ID, - turn, - truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), - normalizeOpenAIWSLogValue(reason), - truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), - ) - } else { - updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( - updatedPayload, - currentTurnReplayInput, - currentTurnReplayInputExists, - ) - if setInputErr != nil { - logOpenAIWSModeInfo( - "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error previous_response_id=%s cause=%s", - account.ID, - turn, - truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), - truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), - truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), - ) - } else { - logOpenAIWSModeInfo( - "ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s", - account.ID, - turn, - truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), - truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), - ) - turnPrevRecoveryTried = true - currentPayload = updatedWithInput - currentPayloadBytes = len(updatedWithInput) - resetSessionLease(true) - skipBeforeTurn = true - continue - } - } - } - resetSessionLease(true) - return NewOpenAIWSClientCloseError( - coderws.StatusPolicyViolation, - "upstream continuation connection is unavailable; please restart the conversation", - pingErr, - ) - } + // preflight ping 失败:直接重连,不修改 payload resetSessionLease(true) - - acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) + acquiredLease, acquireErr := acquireTurnLease( + turn, + preferredConnID, + forcePreferredConn, + currentPreviousResponseID != "", + ) if acquireErr != nil { return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr) } @@ -3315,7 +3093,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( ) } - result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) + result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, expectedPrev) if relayErr != nil { if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { continue @@ -3327,11 +3105,108 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( if unwrapped := errors.Unwrap(relayErr); unwrapped != nil { finalErr = unwrapped } + abortReason, abortExpected := classifyOpenAIWSIngressTurnAbortReason(relayErr) + s.recordOpenAIWSTurnAbort(abortReason, abortExpected) + logOpenAIWSIngressTurnAbort(account.ID, turn, connID, abortReason, abortExpected, finalErr) if hooks != nil && hooks.AfterTurn != nil { hooks.AfterTurn(turn, nil, finalErr) } - sessionLease.MarkBroken() - return finalErr + switch openAIWSIngressTurnAbortDispositionForReason(abortReason) { + case openAIWSIngressTurnAbortDispositionContinueTurn: + switch abortReason { + case openAIWSIngressTurnAbortReasonClientPreempted: + // 客户端抢占:不通知 error(客户端已发出新请求,不需要旧 turn 的错误事件), + // 保留上一轮 response_id(被抢占的 turn 未完成,上一轮 response_id 仍有效供新 turn 续链)。 + preemptRecoverStart := time.Now() + resetSessionLease(true) + logOpenAIWSModeInfo( + "ingress_ws_client_preempt_recover account_id=%d turn=%d conn_id=%s reset_elapsed_ms=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + time.Since(preemptRecoverStart).Milliseconds(), + ) + case openAIWSIngressTurnAbortReasonUpstreamRestart: + // 上游重启(1012/1013):连接级关闭,客户端未收到任何终端事件, + // 始终补发 error 事件(无论 wroteDownstream 状态),避免客户端永远等待响应。 + abortMessage := "upstream service restarting, please retry" + if finalErr != nil { + abortMessage = finalErr.Error() + } + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`) + if writeErr := writeClientMessage(errorEvent); writeErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_upstream_restart_notify_failed account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(writeErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_upstream_restart_notified account_id=%d turn=%d conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + } + resetSessionLease(true) + clearSessionLastResponseID() + default: + // turn 级终止:当前 turn 失败,但客户端 WS 会话继续。 + // 这样可与 Codex 客户端语义对齐:后续 turn 允许在新上游连接上继续进行。 + // + // 关键修复:若未向客户端写入过任何数据 (wroteDownstream=false), + // 必须补发 error 事件通知客户端本轮失败,否则客户端会一直等待响应, + // 而服务端在 advanceToNextClientTurn 中等待客户端下一条消息 → 双向死锁。 + if !openAIWSIngressTurnWroteDownstream(relayErr) { + abortMessage := "turn failed: " + string(abortReason) + if finalErr != nil { + abortMessage = finalErr.Error() + } + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`) + if writeErr := writeClientMessage(errorEvent); writeErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_turn_abort_notify_failed account_id=%d turn=%d conn_id=%s reason=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(string(abortReason)), + truncateOpenAIWSLogValue(writeErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_turn_abort_notified account_id=%d turn=%d conn_id=%s reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(string(abortReason)), + ) + } + } + resetSessionLease(true) + clearSessionLastResponseID() + } + turnRetry = 0 + turnPrevRecoveryTried = false + exit, advanceErr := advanceToNextClientTurn(turn, connID) + if advanceErr != nil { + return advanceErr + } + if exit { + return nil + } + s.recordOpenAIWSTurnAbortRecovered() + turn++ + continue + case openAIWSIngressTurnAbortDispositionCloseGracefully: + resetSessionLease(true) + clearSessionLastResponseID() + return nil + case openAIWSIngressTurnAbortDispositionFailRequest: + sessionLease.MarkBroken() + return finalErr + } } turnRetry = 0 turnPrevRecoveryTried = false @@ -3343,7 +3218,27 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( return errors.New("websocket turn result is nil") } responseID := strings.TrimSpace(result.RequestID) - lastTurnResponseID = responseID + persistLastResponseID := responseID != "" && shouldPersistOpenAIWSLastResponseID(result.TerminalEventType) + logOpenAIWSModeInfo( + "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d persist_response_id=%v has_function_call_output=%v pending_function_calls=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + result.Duration.Milliseconds(), + persistLastResponseID, + hasFunctionCallOutput, + len(result.PendingFunctionCallIDs), + ) + if persistLastResponseID { + lastTurnResponseID = responseID + } else if responseID != "" && len(result.PendingFunctionCallIDs) > 0 { + // response 未 completed/done(如 incomplete/failed/cancelled),但包含未完成的 function_call。 + // 保留 response_id 以便下一个 turn 的 function_call_output 能够关联。 + lastTurnResponseID = responseID + } else { + clearSessionLastResponseID() + } lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload) lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput) lastTurnReplayInputExists = currentTurnReplayInputExists @@ -3361,84 +3256,39 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( lastTurnStrictState = nextStrictState } + if stateStore != nil && + expectedPrev != "" && + currentPreviousResponseID == expectedPrev && + (hasFunctionCallOutput || len(pendingExpectedCallIDs) > 0) { + stateStore.DeleteResponsePendingToolCalls(groupID, expectedPrev) + } + if responseID != "" && stateStore != nil { ttl := s.openAIWSResponseStickyTTL() logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) - stateStore.BindResponseConn(responseID, connID, ttl) - } - if stateStore != nil && storeDisabled && sessionHash != "" { - stateStore.BindSessionConn(groupID, sessionHash, connID, s.openAIWSSessionStickyTTL()) + if poolConnID, ok := normalizeOpenAIWSPreferredConnID(connID); ok { + stateStore.BindResponseConn(responseID, poolConnID, ttl) + } + if pendingFunctionCallIDs := openAIWSNormalizeCallIDs(result.PendingFunctionCallIDs); len(pendingFunctionCallIDs) > 0 { + stateStore.BindResponsePendingToolCalls(groupID, responseID, pendingFunctionCallIDs, ttl) + } else { + stateStore.DeleteResponsePendingToolCalls(groupID, responseID) + } + if sessionHash != "" && persistLastResponseID { + stateStore.BindSessionLastResponseID(groupID, sessionHash, responseID, s.openAIWSSessionStickyTTL()) + } } if connID != "" { preferredConnID = connID } + yieldSessionLease() - nextClientMessage, readErr := readClientMessage() - if readErr != nil { - if isOpenAIWSClientDisconnectError(readErr) { - closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) - logOpenAIWSModeInfo( - "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s", - account.ID, - truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), - closeStatus, - truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), - ) - return nil - } - return fmt.Errorf("read client websocket request: %w", readErr) - } - - nextPayload, parseErr := parseClientPayload(nextClientMessage) - if parseErr != nil { - return parseErr - } - if nextPayload.promptCacheKey != "" { - // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接; - // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。 - updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey) - baseAcquireReq.Headers = updatedHeaders - } - if nextPayload.previousResponseID != "" { - expectedPrev := strings.TrimSpace(lastTurnResponseID) - chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev - nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID) - logOpenAIWSModeInfo( - "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v", - account.ID, - turn, - turn+1, - truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), - truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), - normalizeOpenAIWSLogValue(nextPreviousResponseIDKind), - truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), - chainedFromLast, - nextPayload.promptCacheKey != "", - storeDisabled, - ) - } - if stateStore != nil && nextPayload.previousResponseID != "" { - if stickyConnID, ok := stateStore.GetResponseConn(nextPayload.previousResponseID); ok { - if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID { - logOpenAIWSModeInfo( - "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s", - account.ID, - turn, - truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), - truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen), - truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), - ) - } else { - preferredConnID = stickyConnID - } - } + exit, advanceErr := advanceToNextClientTurn(turn, connID) + if advanceErr != nil { + return advanceErr } - currentPayload = nextPayload.payloadRaw - currentOriginalModel = nextPayload.originalModel - currentPayloadBytes = nextPayload.payloadBytes - storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account) - if !storeDisabled { - unpinSessionConn(sessionConnID) + if exit { + return nil } turn++ } @@ -3452,7 +3302,7 @@ func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool { // 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。 func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( ctx context.Context, - lease *openAIWSConnLease, + lease openAIWSIngressUpstreamLease, decision OpenAIWSProtocolDecision, payload map[string]any, previousResponseID string, @@ -3540,7 +3390,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr) } - eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message) + eventType, eventResponseID := parseOpenAIWSEventType(message) if eventType == "" { continue } @@ -3595,7 +3445,9 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( if prewarmResponseID != "" && stateStore != nil { ttl := s.openAIWSResponseStickyTTL() logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl)) - stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl) + if connID, ok := normalizeOpenAIWSPreferredConnID(lease.ConnID()); ok { + stateStore.BindResponseConn(prewarmResponseID, connID, ttl) + } } logOpenAIWSModeInfo( "prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d", @@ -3609,111 +3461,6 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( return nil } -func payloadAsJSON(payload map[string]any) string { - return string(payloadAsJSONBytes(payload)) -} - -func payloadAsJSONBytes(payload map[string]any) []byte { - if len(payload) == 0 { - return []byte("{}") - } - body, err := json.Marshal(payload) - if err != nil { - return []byte("{}") - } - return body -} - -func isOpenAIWSTerminalEvent(eventType string) bool { - switch strings.TrimSpace(eventType) { - case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": - return true - default: - return false - } -} - -func isOpenAIWSTokenEvent(eventType string) bool { - eventType = strings.TrimSpace(eventType) - if eventType == "" { - return false - } - switch eventType { - case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": - return false - } - if strings.Contains(eventType, ".delta") { - return true - } - if strings.HasPrefix(eventType, "response.output_text") { - return true - } - if strings.HasPrefix(eventType, "response.output") { - return true - } - return eventType == "response.completed" || eventType == "response.done" -} - -func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte { - if len(message) == 0 { - return message - } - if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel { - return message - } - if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) { - return message - } - modelValues := gjson.GetManyBytes(message, "model", "response.model") - replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel - replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel - if !replaceModel && !replaceResponseModel { - return message - } - updated := message - if replaceModel { - if next, err := sjson.SetBytes(updated, "model", toModel); err == nil { - updated = next - } - } - if replaceResponseModel { - if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil { - updated = next - } - } - return updated -} - -func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) { - if usage == nil || len(body) == 0 { - return - } - values := gjson.GetManyBytes( - body, - "usage.input_tokens", - "usage.output_tokens", - "usage.input_tokens_details.cached_tokens", - ) - usage.InputTokens = int(values[0].Int()) - usage.OutputTokens = int(values[1].Int()) - usage.CacheReadInputTokens = int(values[2].Int()) -} - -func getOpenAIGroupIDFromContext(c *gin.Context) int64 { - if c == nil { - return 0 - } - value, exists := c.Get("api_key") - if !exists { - return 0 - } - apiKey, ok := value.(*APIKey) - if !ok || apiKey == nil || apiKey.GroupID == nil { - return 0 - } - return *apiKey.GroupID -} - // SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。 // 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。 func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( @@ -3792,164 +3539,3 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( } return nil, nil } - -func classifyOpenAIWSAcquireError(err error) string { - if err == nil { - return "acquire_conn" - } - var dialErr *openAIWSDialError - if errors.As(err, &dialErr) { - switch dialErr.StatusCode { - case 426: - return "upgrade_required" - case 401, 403: - return "auth_failed" - case 429: - return "upstream_rate_limited" - } - if dialErr.StatusCode >= 500 { - return "upstream_5xx" - } - return "dial_failed" - } - if errors.Is(err, errOpenAIWSConnQueueFull) { - return "conn_queue_full" - } - if errors.Is(err, errOpenAIWSPreferredConnUnavailable) { - return "preferred_conn_unavailable" - } - if errors.Is(err, context.DeadlineExceeded) { - return "acquire_timeout" - } - return "acquire_conn" -} - -func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { - code := strings.ToLower(strings.TrimSpace(codeRaw)) - errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) - msg := strings.ToLower(strings.TrimSpace(msgRaw)) - - switch code { - case "upgrade_required": - return "upgrade_required", true - case "websocket_not_supported", "websocket_unsupported": - return "ws_unsupported", true - case "websocket_connection_limit_reached": - return "ws_connection_limit_reached", true - case "previous_response_not_found": - return "previous_response_not_found", true - } - if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { - return "upgrade_required", true - } - if strings.Contains(errType, "upgrade") { - return "upgrade_required", true - } - if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") { - return "ws_unsupported", true - } - if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { - return "ws_connection_limit_reached", true - } - if strings.Contains(msg, "previous_response_not_found") || - (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { - return "previous_response_not_found", true - } - if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") { - return "upstream_error_event", true - } - return "event_error", false -} - -func classifyOpenAIWSErrorEvent(message []byte) (string, bool) { - if len(message) == 0 { - return "event_error", false - } - return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message)) -} - -func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { - code := strings.ToLower(strings.TrimSpace(codeRaw)) - errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) - switch { - case strings.Contains(errType, "invalid_request"), - strings.Contains(code, "invalid_request"), - strings.Contains(code, "bad_request"), - code == "previous_response_not_found": - return http.StatusBadRequest - case strings.Contains(errType, "authentication"), - strings.Contains(code, "invalid_api_key"), - strings.Contains(code, "unauthorized"): - return http.StatusUnauthorized - case strings.Contains(errType, "permission"), - strings.Contains(code, "forbidden"): - return http.StatusForbidden - case strings.Contains(errType, "rate_limit"), - strings.Contains(code, "rate_limit"), - strings.Contains(code, "insufficient_quota"): - return http.StatusTooManyRequests - default: - return http.StatusBadGateway - } -} - -func openAIWSErrorHTTPStatus(message []byte) int { - if len(message) == 0 { - return http.StatusBadGateway - } - codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message) - return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) -} - -func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration { - if s == nil || s.cfg == nil { - return 30 * time.Second - } - seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds - if seconds <= 0 { - return 0 - } - return time.Duration(seconds) * time.Second -} - -func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool { - if s == nil || accountID <= 0 { - return false - } - cooldown := s.openAIWSFallbackCooldown() - if cooldown <= 0 { - return false - } - rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID) - if !ok || rawUntil == nil { - return false - } - until, ok := rawUntil.(time.Time) - if !ok || until.IsZero() { - s.openaiWSFallbackUntil.Delete(accountID) - return false - } - if time.Now().Before(until) { - return true - } - s.openaiWSFallbackUntil.Delete(accountID) - return false -} - -func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) { - if s == nil || accountID <= 0 { - return - } - cooldown := s.openAIWSFallbackCooldown() - if cooldown <= 0 { - return - } - s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown)) -} - -func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) { - if s == nil || accountID <= 0 { - return - } - s.openaiWSFallbackUntil.Delete(accountID) -} diff --git a/backend/internal/service/openai_ws_forwarder_benchmark_test.go b/backend/internal/service/openai_ws_forwarder_benchmark_test.go index bd03ab5a6..0bc2114fc 100644 --- a/backend/internal/service/openai_ws_forwarder_benchmark_test.go +++ b/backend/internal/service/openai_ws_forwarder_benchmark_test.go @@ -1,8 +1,10 @@ package service import ( + "context" "fmt" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" ) @@ -125,3 +127,146 @@ func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) { benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model") } } + +// --- Optimization benchmarks --- + +var benchmarkOpenAIWSConnSink openAIWSClientConn + +func BenchmarkTouchLease_Full(b *testing.B) { + ctx := &openAIWSIngressContext{} + ttl := 10 * time.Minute + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ctx.touchLease(time.Now(), ttl) + } +} + +func BenchmarkMaybeTouchLease_Throttled(b *testing.B) { + ctx := &openAIWSIngressContext{} + ttl := 10 * time.Minute + ctx.touchLease(time.Now(), ttl) // seed the initial touch + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ctx.maybeTouchLease(ttl) + } +} + +func BenchmarkActiveConn_CachedPath(b *testing.B) { + conn := &benchmarkOpenAIWSNoopConn{} + ctx := &openAIWSIngressContext{ownerID: "bench_owner", upstream: conn} + lease := &openAIWSIngressContextLease{context: ctx, ownerID: "bench_owner"} + // Prime the cache + _, _ = lease.activeConn() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSConnSink, _ = lease.activeConn() + } +} + +func BenchmarkActiveConn_MutexPath(b *testing.B) { + conn := &benchmarkOpenAIWSNoopConn{} + ctx := &openAIWSIngressContext{ownerID: "bench_owner", upstream: conn} + lease := &openAIWSIngressContextLease{context: ctx, ownerID: "bench_owner"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + lease.cachedConn = nil // force mutex path each iteration + benchmarkOpenAIWSConnSink, _ = lease.activeConn() + } +} + +func BenchmarkParseOpenAIWSEventType_Lightweight(b *testing.B) { + event := []byte(`{"type":"response.output_text.delta","delta":"hello","response":{"id":"resp_1","model":"gpt-5.1"}}`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + et, rid := parseOpenAIWSEventType(event) + benchmarkOpenAIWSStringSink = et + benchmarkOpenAIWSStringSink = rid + } +} + +func BenchmarkParseOpenAIWSEventEnvelope_Full(b *testing.B) { + event := []byte(`{"type":"response.output_text.delta","delta":"hello","response":{"id":"resp_1","model":"gpt-5.1"}}`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + et, rid, resp := parseOpenAIWSEventEnvelope(event) + benchmarkOpenAIWSStringSink = et + benchmarkOpenAIWSStringSink = rid + benchmarkOpenAIWSBoolSink = resp.Exists() + } +} + +func BenchmarkSessionTurnStateKey_Strconv(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSStringSink = openAIWSSessionTurnStateKey(int64(i%1000+1), "session_hash_bench") + } +} + +func BenchmarkResponseAccountCacheKey_XXHash(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSStringSink = openAIWSResponseAccountCacheKey(fmt.Sprintf("resp_bench_%d", i%1000)) + } +} + +func BenchmarkIsOpenAIWSTerminalEvent(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBoolSink = isOpenAIWSTerminalEvent("response.completed") + } +} + +func BenchmarkIsOpenAIWSTokenEvent(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBoolSink = isOpenAIWSTokenEvent("response.output_text.delta") + } +} + +func BenchmarkStateStore_ShardedBindGet(b *testing.B) { + storeAny := NewOpenAIWSStateStore(nil) + store, ok := storeAny.(*defaultOpenAIWSStateStore) + if !ok { + b.Fatal("expected *defaultOpenAIWSStateStore") + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + key := fmt.Sprintf("resp_%d", i%1000) + store.BindResponseConn(key, "conn_bench", time.Minute) + benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = store.GetResponseConn(key) + } +} + +func BenchmarkDeriveOpenAISessionHash(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSStringSink = deriveOpenAISessionHash("session-id-benchmark-value") + } +} + +func BenchmarkDeriveOpenAILegacySessionHash(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSStringSink = deriveOpenAILegacySessionHash("session-id-benchmark-value") + } +} + +type benchmarkOpenAIWSNoopConn struct{} + +func (c *benchmarkOpenAIWSNoopConn) WriteJSON(_ context.Context, _ any) error { return nil } +func (c *benchmarkOpenAIWSNoopConn) ReadMessage(_ context.Context) ([]byte, error) { return nil, nil } +func (c *benchmarkOpenAIWSNoopConn) Ping(_ context.Context) error { return nil } +func (c *benchmarkOpenAIWSNoopConn) Close() error { return nil } diff --git a/backend/internal/service/openai_ws_forwarder_headers_test.go b/backend/internal/service/openai_ws_forwarder_headers_test.go new file mode 100644 index 000000000..30b7216fe --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_headers_test.go @@ -0,0 +1,119 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestBuildOpenAIWSHeaders_OAuthNormalization(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "chatgpt_account_id": "chatgpt_acc_1", + }, + } + + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/openai/v1/responses", nil) + c.Request.Header.Set("accept-language", "zh-CN") + c.Request.Header.Set("User-Agent", "custom-client/1.0") + c.Request.Header.Set("session_id", "sess_hdr") + c.Request.Header.Set("conversation_id", "conv_hdr") + + headers, resolution := svc.buildOpenAIWSHeaders( + c, + account, + "test_token", + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + false, + "turn_state_1", + "turn_meta_1", + "prompt_cache_fallback", + ) + + require.Equal(t, "Bearer test_token", headers.Get("authorization")) + require.Equal(t, "zh-CN", headers.Get("accept-language")) + require.Equal(t, "sess_hdr", headers.Get("session_id")) + require.Equal(t, "conv_hdr", headers.Get("conversation_id")) + require.Equal(t, "turn_state_1", headers.Get(openAIWSTurnStateHeader)) + require.Equal(t, "turn_meta_1", headers.Get(openAIWSTurnMetadataHeader)) + require.Equal(t, "chatgpt_acc_1", headers.Get("chatgpt-account-id")) + require.Equal(t, openAIWSBetaV2Value, headers.Get("OpenAI-Beta")) + require.Equal(t, codexCLIUserAgent, headers.Get("user-agent")) + require.Equal(t, "codex_cli_rs", headers.Get("originator")) + + require.Equal(t, "sess_hdr", resolution.SessionID) + require.Equal(t, "conv_hdr", resolution.ConversationID) + require.Equal(t, "header_session_id", resolution.SessionSource) + require.Equal(t, "header_conversation_id", resolution.ConversationSource) +} + +func TestBuildOpenAIWSHeaders_APIKeyForceCodexAndV1(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Gateway.ForceCodexCLI = true + svc := &OpenAIGatewayService{cfg: cfg} + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Credentials: map[string]any{ + "user_agent": "my-custom-ua/1.0", + }, + } + + headers, resolution := svc.buildOpenAIWSHeaders( + nil, + account, + "token_apikey", + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocket}, + false, + "", + "", + "pcache_1", + ) + + require.Equal(t, "Bearer token_apikey", headers.Get("authorization")) + require.Equal(t, openAIWSBetaV1Value, headers.Get("OpenAI-Beta")) + require.Equal(t, codexCLIUserAgent, headers.Get("user-agent")) + require.Equal(t, "", headers.Get("originator")) + require.Equal(t, "pcache_1", headers.Get("session_id")) + require.Equal(t, "pcache_1", resolution.SessionID) + require.Equal(t, "prompt_cache_key", resolution.SessionSource) +} + +func TestBuildOpenAIWSHeaders_OAuthCodexCLIInputKeepsOriginator(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", codexCLIUserAgent) + + headers, _ := svc.buildOpenAIWSHeaders( + c, + account, + "token_oauth", + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + true, + "", + "", + "", + ) + + require.Equal(t, codexCLIUserAgent, headers.Get("user-agent")) + require.Equal(t, "codex_cli_rs", headers.Get("originator")) +} diff --git a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go index 761676038..807e97255 100644 --- a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go +++ b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go @@ -2,9 +2,12 @@ package service import ( "net/http" + "net/http/httptest" "testing" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func TestParseOpenAIWSEventEnvelope(t *testing.T) { @@ -31,6 +34,15 @@ func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) { require.Equal(t, 3, usage.CacheReadInputTokens) } +func TestOpenAIWSEventShouldParseUsage_TerminalEvents(t *testing.T) { + require.True(t, openAIWSEventShouldParseUsage("response.completed")) + require.True(t, openAIWSEventShouldParseUsage("response.done")) + require.True(t, openAIWSEventShouldParseUsage("response.failed")) + // After removing TrimSpace, callers must provide pre-trimmed input. + require.False(t, openAIWSEventShouldParseUsage(" response.done ")) + require.False(t, openAIWSEventShouldParseUsage("response.in_progress")) +} + func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) { message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`) codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) @@ -53,11 +65,62 @@ func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) { } func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) { + require.False(t, openAIWSMessageLikelyContainsToolCalls(nil)) require.False(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_text.delta","delta":"hello"}`))) require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"tool_calls":[{"id":"tc1"}]}}`))) + require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"tool_call"}}`))) require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`))) } +func TestOpenAIWSExtractPendingFunctionCallIDsFromEvent(t *testing.T) { + callIDs := openAIWSExtractPendingFunctionCallIDsFromEvent([]byte(`{ + "type":"response.output_item.added", + "response":{"id":"resp_1"}, + "item":{"type":"function_call","call_id":"call_a"} + }`)) + require.Equal(t, []string{"call_a"}, callIDs) + + callIDs = openAIWSExtractPendingFunctionCallIDsFromEvent([]byte(`{ + "type":"response.completed", + "response":{ + "id":"resp_2", + "output":[ + {"type":"function_call","call_id":"call_b"}, + {"type":"message","content":[{"type":"output_text","text":"ok"}]}, + {"type":"function_call","call_id":"call_c"} + ] + } + }`)) + require.Equal(t, []string{"call_b", "call_c"}, callIDs) +} + +func TestOpenAIWSExtractFunctionCallOutputCallIDsFromPayload(t *testing.T) { + callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload([]byte(`{ + "input":[ + {"type":"input_text","text":"hi"}, + {"type":"function_call_output","call_id":"call_2","output":"ok"}, + {"type":"function_call_output","call_id":"call_1","output":"ok"}, + {"type":"function_call_output","call_id":"call_2","output":"dup"} + ] + }`)) + require.Equal(t, []string{"call_1", "call_2"}, callIDs) +} + +func TestOpenAIWSInjectFunctionCallOutputItems(t *testing.T) { + updatedPayload, injected, err := openAIWSInjectFunctionCallOutputItems( + []byte(`{"type":"response.create","input":[{"type":"input_text","text":"hello"}]}`), + []string{"call_1", "call_2", "call_1"}, + openAIWSAutoAbortedToolOutputValue, + ) + require.NoError(t, err) + require.Equal(t, 2, injected) + require.Equal(t, "input_text", gjson.GetBytes(updatedPayload, "input.0.type").String()) + require.Equal(t, "function_call_output", gjson.GetBytes(updatedPayload, "input.1.type").String()) + require.Equal(t, "call_1", gjson.GetBytes(updatedPayload, "input.1.call_id").String()) + require.Equal(t, openAIWSAutoAbortedToolOutputValue, gjson.GetBytes(updatedPayload, "input.1.output").String()) + require.Equal(t, "call_2", gjson.GetBytes(updatedPayload, "input.2.call_id").String()) +} + func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) { noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`) require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model"))) @@ -71,3 +134,45 @@ func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) { both := []byte(`{"model":"gpt-5.1","response":{"model":"gpt-5.1"}}`) require.Equal(t, `{"model":"custom-model","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(both, "gpt-5.1", "custom-model"))) } + +func TestResolveOpenAIWSFirstMessageMeta_ContextPreferred(t *testing.T) { + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + SetOpenAIWSFirstMessageMeta(c, "ctx_model", "resp_ctx", OpenAIPreviousResponseIDKindResponseID) + + model, prevID, prevKind := ResolveOpenAIWSFirstMessageMeta( + c, + []byte(`{"model":"payload_model","previous_response_id":"resp_payload"}`), + ) + require.Equal(t, "ctx_model", model) + require.Equal(t, "resp_ctx", prevID) + require.Equal(t, OpenAIPreviousResponseIDKindResponseID, prevKind) + + require.NotPanics(t, func() { + SetOpenAIWSFirstMessageMeta(nil, "m", "resp_x", OpenAIPreviousResponseIDKindResponseID) + }) +} + +func TestResolveOpenAIWSFirstMessageMeta_FallbackParse(t *testing.T) { + model, prevID, prevKind := ResolveOpenAIWSFirstMessageMeta( + nil, + []byte(`{"model":"payload_model","previous_response_id":"resp_payload"}`), + ) + require.Equal(t, "payload_model", model) + require.Equal(t, "resp_payload", prevID) + require.Equal(t, OpenAIPreviousResponseIDKindResponseID, prevKind) +} + +func TestOpenAIWSHostPathForLogFromURL(t *testing.T) { + host, path := openAIWSHostPathForLogFromURL("wss://api.openai.com/v1/responses?stream=true") + require.Equal(t, "api.openai.com", host) + require.Equal(t, "/v1/responses", path) + + host, path = openAIWSHostPathForLogFromURL("api.openai.com/v1/responses") + require.Equal(t, "api.openai.com", host) + require.Equal(t, "/v1/responses", path) + + host, path = openAIWSHostPathForLogFromURL(" ") + require.Equal(t, "-", host) + require.Equal(t, "-", path) +} diff --git a/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go b/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go new file mode 100644 index 000000000..149c5765b --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go @@ -0,0 +1,384 @@ +package service + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func runIngressProxyWithFirstPayload( + t *testing.T, + svc *OpenAIGatewayService, + account *Account, + firstPayload string, +) error { + t.Helper() + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, message, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", msgType, message, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(firstPayload)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + return serverErr + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + return nil + } +} + +func buildIngressPolicyTestConfig() *config.Config { + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool + return cfg +} + +func buildIngressPolicyTestService(cfg *config.Config) *OpenAIGatewayService { + return &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } +} + +func buildIngressPolicyTestAccount(extra map[string]any) *Account { + return &Account{ + ID: 442, + Name: "openai-ingress-policy", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: extra, + } +} + +type openAIWSPassthroughProbeStateStore struct { + mu sync.Mutex + calls []string +} + +func newOpenAIWSPassthroughProbeStateStore() *openAIWSPassthroughProbeStateStore { + return &openAIWSPassthroughProbeStateStore{ + calls: make([]string, 0, 4), + } +} + +func (s *openAIWSPassthroughProbeStateStore) record(method string) { + s.mu.Lock() + s.calls = append(s.calls, method) + s.mu.Unlock() +} + +func (s *openAIWSPassthroughProbeStateStore) unexpectedErr(method string) error { + s.record(method) + return errors.New("passthrough must not call OpenAIWSStateStore." + method) +} + +func (s *openAIWSPassthroughProbeStateStore) Calls() []string { + s.mu.Lock() + defer s.mu.Unlock() + out := make([]string, len(s.calls)) + copy(out, s.calls) + return out +} + +func (s *openAIWSPassthroughProbeStateStore) BindResponseAccount(context.Context, int64, string, int64, time.Duration) error { + return s.unexpectedErr("BindResponseAccount") +} + +func (s *openAIWSPassthroughProbeStateStore) GetResponseAccount(context.Context, int64, string) (int64, error) { + return 0, s.unexpectedErr("GetResponseAccount") +} + +func (s *openAIWSPassthroughProbeStateStore) DeleteResponseAccount(context.Context, int64, string) error { + return s.unexpectedErr("DeleteResponseAccount") +} + +func (s *openAIWSPassthroughProbeStateStore) BindResponseConn(string, string, time.Duration) { + s.record("BindResponseConn") +} + +func (s *openAIWSPassthroughProbeStateStore) GetResponseConn(string) (string, bool) { + s.record("GetResponseConn") + return "", false +} + +func (s *openAIWSPassthroughProbeStateStore) DeleteResponseConn(string) { + s.record("DeleteResponseConn") +} + +func (s *openAIWSPassthroughProbeStateStore) BindResponsePendingToolCalls(int64, string, []string, time.Duration) { + s.record("BindResponsePendingToolCalls") +} + +func (s *openAIWSPassthroughProbeStateStore) GetResponsePendingToolCalls(int64, string) ([]string, bool) { + s.record("GetResponsePendingToolCalls") + return nil, false +} + +func (s *openAIWSPassthroughProbeStateStore) DeleteResponsePendingToolCalls(int64, string) { + s.record("DeleteResponsePendingToolCalls") +} + +func (s *openAIWSPassthroughProbeStateStore) BindSessionTurnState(int64, string, string, time.Duration) { + s.record("BindSessionTurnState") +} + +func (s *openAIWSPassthroughProbeStateStore) GetSessionTurnState(int64, string) (string, bool) { + s.record("GetSessionTurnState") + return "", false +} + +func (s *openAIWSPassthroughProbeStateStore) DeleteSessionTurnState(int64, string) { + s.record("DeleteSessionTurnState") +} + +func (s *openAIWSPassthroughProbeStateStore) BindSessionLastResponseID(int64, string, string, time.Duration) { + s.record("BindSessionLastResponseID") +} + +func (s *openAIWSPassthroughProbeStateStore) GetSessionLastResponseID(int64, string) (string, bool) { + s.record("GetSessionLastResponseID") + return "", false +} + +func (s *openAIWSPassthroughProbeStateStore) DeleteSessionLastResponseID(int64, string) { + s.record("DeleteSessionLastResponseID") +} + +func (s *openAIWSPassthroughProbeStateStore) BindSessionConn(int64, string, string, time.Duration) { + s.record("BindSessionConn") +} + +func (s *openAIWSPassthroughProbeStateStore) GetSessionConn(int64, string) (string, bool) { + s.record("GetSessionConn") + return "", false +} + +func (s *openAIWSPassthroughProbeStateStore) DeleteSessionConn(int64, string) { + s.record("DeleteSessionConn") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + svc := buildIngressPolicyTestService(cfg) + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeOff, + }) + + serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Equal(t, "websocket mode is disabled for this account", closeErr.Reason()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeRouterDisabledReturnsPolicyViolation(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = false + svc := buildIngressPolicyTestService(cfg) + account := buildIngressPolicyTestAccount(map[string]any{ + "responses_websockets_v2_enabled": true, + }) + + serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Equal(t, "websocket mode requires mode_router_v2 with ctx_pool/passthrough", closeErr.Reason()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_CtxPoolRejectsMessageIDPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + svc := buildIngressPolicyTestService(cfg) + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, + }) + + serverErr := runIngressProxyWithFirstPayload( + t, + svc, + account, + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`, + ) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Contains(t, closeErr.Reason(), "previous_response_id must be a response.id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughDoesNotRejectMessageIDPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + svc := buildIngressPolicyTestService(cfg) + dialer := &openAIWSAlwaysFailDialer{} + svc.openaiWSPassthroughDialer = dialer + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + serverErr := runIngressProxyWithFirstPayload( + t, + svc, + account, + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`, + ) + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "openai ws passthrough dial") + require.NotContains(t, serverErr.Error(), "previous_response_id must be a response.id") + require.Equal(t, 1, dialer.DialCount()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughEmptyModelFailsBeforeDial(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + svc := buildIngressPolicyTestService(cfg) + dialer := &openAIWSAlwaysFailDialer{} + svc.openaiWSPassthroughDialer = dialer + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + serverErr := runIngressProxyWithFirstPayload( + t, + svc, + account, + `{"type":"response.create","stream":false,"input":[]}`, + ) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Contains(t, closeErr.Reason(), "model is required") + require.Equal(t, 0, dialer.DialCount()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughFunctionCallOutputNoRecoveryReject(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + svc := buildIngressPolicyTestService(cfg) + dialer := &openAIWSAlwaysFailDialer{} + svc.openaiWSPassthroughDialer = dialer + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + serverErr := runIngressProxyWithFirstPayload( + t, + svc, + account, + `{"type":"response.create","model":"gpt-5.1","stream":false,"input":[{"type":"function_call_output","call_id":"call_abc","output":"ok"}]}`, + ) + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "openai ws passthrough dial") + require.NotContains(t, serverErr.Error(), "tool_output_not_found") + require.NotContains(t, serverErr.Error(), "previous_response_not_found") + require.Equal(t, 1, dialer.DialCount()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughDoesNotTouchStateStore(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + svc := buildIngressPolicyTestService(cfg) + dialer := &openAIWSAlwaysFailDialer{} + svc.openaiWSPassthroughDialer = dialer + storeProbe := newOpenAIWSPassthroughProbeStateStore() + svc.openaiWSStateStore = storeProbe + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + serverErr := runIngressProxyWithFirstPayload( + t, + svc, + account, + `{"type":"response.create","model":"gpt-5.1","stream":false}`, + ) + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "openai ws passthrough dial") + require.Equal(t, 1, dialer.DialCount()) + require.Empty(t, storeProbe.Calls(), "passthrough 路径不应访问 StateStore") +} diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go deleted file mode 100644 index 5a3c12c39..000000000 --- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go +++ /dev/null @@ -1,2483 +0,0 @@ -package service - -import ( - "context" - "errors" - "io" - "net/http" - "net/http/httptest" - "strings" - "sync" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - coderws "github.com/coder/websocket" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" -) - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossTurns(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - captureConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 114, - Name: "openai-ingress-session-lease", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - turnWSModeCh := make(chan bool, 2) - hooks := &OpenAIWSIngressHooks{ - AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { - if turnErr == nil && result != nil { - turnWSModeCh <- result.OpenAIWSMode - } - }, - } - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) - firstTurnEvent := readMessage() - require.Equal(t, "response.completed", gjson.GetBytes(firstTurnEvent, "type").String()) - require.Equal(t, "resp_ingress_turn_1", gjson.GetBytes(firstTurnEvent, "response.id").String()) - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_ingress_turn_1"}`) - secondTurnEvent := readMessage() - require.Equal(t, "response.completed", gjson.GetBytes(secondTurnEvent, "type").String()) - require.Equal(t, "resp_ingress_turn_2", gjson.GetBytes(secondTurnEvent, "response.id").String()) - require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式") - require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式") - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } - - metrics := svc.SnapshotOpenAIWSPoolMetrics() - require.Equal(t, int64(1), metrics.AcquireTotal, "同一 ingress 会话多 turn 应只获取一次上游 lease") - require.Equal(t, 1, captureDialer.DialCount(), "同一 ingress 会话应保持同一上游连接") - require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create") -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true - cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - upstreamConn1 := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - upstreamConn2 := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - dialer := &openAIWSQueueDialer{ - conns: []openAIWSClientConn{upstreamConn1, upstreamConn2}, - } - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(dialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 441, - Name: "openai-ingress-dedicated", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, - }, - } - - serverErrCh := make(chan error, 2) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - runSingleTurnSession := func(expectedResponseID string) { - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) - err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) - cancelWrite() - require.NoError(t, err) - - readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) - msgType, event, readErr := clientConn.Read(readCtx) - cancelRead() - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - require.Equal(t, expectedResponseID, gjson.GetBytes(event, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } - } - - runSingleTurnSession("resp_dedicated_1") - runSingleTurnSession("resp_dedicated_2") - - require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接") -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true - cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: newOpenAIWSConnPool(cfg), - } - - account := &Account{ - ID: 442, - Name: "openai-ingress-off", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeOff, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) - err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) - cancelWrite() - require.NoError(t, err) - - select { - case serverErr := <-serverErrCh: - var closeErr *OpenAIWSClientCloseError - require.ErrorAs(t, serverErr, &closeErr) - require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) - require.Equal(t, "websocket mode is disabled for this account", closeErr.Reason()) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropToFullCreate(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - captureConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 140, - Name: "openai-ingress-prev-preflight-rewrite", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) - firstTurn := readMessage() - require.Equal(t, "resp_preflight_rewrite_1", gjson.GetBytes(firstTurn, "response.id").String()) - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`) - secondTurn := readMessage() - require.Equal(t, "resp_preflight_rewrite_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } - - require.Equal(t, 1, captureDialer.DialCount(), "严格增量不成立时应在同一连接内降级为 full create") - require.Len(t, captureConn.writes, 2) - secondWrite := requestToJSONString(captureConn.writes[1]) - require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格增量不成立时应移除 previous_response_id,改为 full create") - require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "严格降级为 full create 时应重放完整 input 上下文") - require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) - require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropBeforePreflightPingFailReconnects(t *testing.T) { - gin.SetMode(gin.TestMode) - prevPreflightPingIdle := openAIWSIngressPreflightPingIdle - openAIWSIngressPreflightPingIdle = 0 - defer func() { - openAIWSIngressPreflightPingIdle = prevPreflightPingIdle - }() - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - firstConn := &openAIWSPreflightFailConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - secondConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - dialer := &openAIWSQueueDialer{ - conns: []openAIWSClientConn{firstConn, secondConn}, - } - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(dialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 142, - Name: "openai-ingress-prev-strict-drop-before-ping", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) - firstTurn := readMessage() - require.Equal(t, "resp_turn_ping_drop_1", gjson.GetBytes(firstTurn, "response.id").String()) - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`) - secondTurn := readMessage() - require.Equal(t, "resp_turn_ping_drop_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 严格降级后预检换连超时") - } - - require.Equal(t, 2, dialer.DialCount(), "严格降级为 full create 后,预检 ping 失败应允许换连") - require.Equal(t, 1, firstConn.WriteCount(), "首连接在预检失败后不应继续发送第二轮") - require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping") - secondConn.mu.Lock() - secondWrites := append([]map[string]any(nil), secondConn.writes...) - secondConn.mu.Unlock() - require.Len(t, secondWrites, 1) - secondWrite := requestToJSONString(secondWrites[0]) - require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格降级后重试应移除 previous_response_id") - require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array())) - require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) - require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreEnabledSkipsStrictPrevResponseEval(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - captureConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 143, - Name: "openai-ingress-store-enabled-skip-strict", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true}`) - firstTurn := readMessage() - require.Equal(t, "resp_store_enabled_1", gjson.GetBytes(firstTurn, "response.id").String()) - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true,"previous_response_id":"resp_stale_external"}`) - secondTurn := readMessage() - require.Equal(t, "resp_store_enabled_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 store=true 场景 websocket 结束超时") - } - - require.Equal(t, 1, captureDialer.DialCount()) - require.Len(t, captureConn.writes, 2) - require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "store=true 场景不应触发 store-disabled strict 规则") -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponsePreflightSkipForFunctionCallOutput(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - captureConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 141, - Name: "openai-ingress-prev-preflight-skip-fco", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false}`) - firstTurn := readMessage() - require.Equal(t, "resp_preflight_skip_1", gjson.GetBytes(firstTurn, "response.id").String()) - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) - secondTurn := readMessage() - require.Equal(t, "resp_preflight_skip_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } - - require.Equal(t, 1, captureDialer.DialCount()) - require.Len(t, captureConn.writes, 2) - require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 场景不应预改写 previous_response_id") -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachPreviousResponseID(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - captureConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 143, - Name: "openai-ingress-fco-auto-prev", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) - firstTurn := readMessage() - require.Equal(t, "resp_auto_prev_1", gjson.GetBytes(firstTurn, "response.id").String()) - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_1","output":"ok"}]}`) - secondTurn := readMessage() - require.Equal(t, "resp_auto_prev_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } - - require.Equal(t, 1, captureDialer.DialCount()) - require.Len(t, captureConn.writes, 2) - require.Equal(t, "resp_auto_prev_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 缺失 previous_response_id 时应回填上一轮响应 ID") -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenLastResponseIDMissing(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - captureConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 144, - Name: "openai-ingress-fco-auto-prev-skip", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) - firstTurn := readMessage() - require.Equal(t, "response.completed", gjson.GetBytes(firstTurn, "type").String()) - require.Empty(t, gjson.GetBytes(firstTurn, "response.id").String(), "首轮响应不返回 response.id,模拟无法推导续链锚点") - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_skip_1","output":"ok"}]}`) - secondTurn := readMessage() - require.Equal(t, "resp_auto_prev_skip_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } - - require.Equal(t, 1, captureDialer.DialCount()) - require.Len(t, captureConn.writes, 2) - require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id") -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) { - gin.SetMode(gin.TestMode) - prevPreflightPingIdle := openAIWSIngressPreflightPingIdle - openAIWSIngressPreflightPingIdle = 0 - defer func() { - openAIWSIngressPreflightPingIdle = prevPreflightPingIdle - }() - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - firstConn := &openAIWSPreflightFailConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - secondConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - dialer := &openAIWSQueueDialer{ - conns: []openAIWSClientConn{firstConn, secondConn}, - } - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(dialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 116, - Name: "openai-ingress-preflight-ping", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) - firstTurn := readMessage() - require.Equal(t, "resp_turn_ping_1", gjson.GetBytes(firstTurn, "response.id").String()) - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_ping_1"}`) - secondTurn := readMessage() - require.Equal(t, "resp_turn_ping_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } - require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 前 ping 失败应触发换连") - require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续向旧连接发送第二轮 turn") - require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应对旧连接执行 preflight ping") -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreflightPingFailAutoRecoveryReconnects(t *testing.T) { - gin.SetMode(gin.TestMode) - prevPreflightPingIdle := openAIWSIngressPreflightPingIdle - openAIWSIngressPreflightPingIdle = 0 - defer func() { - openAIWSIngressPreflightPingIdle = prevPreflightPingIdle - }() - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - firstConn := &openAIWSPreflightFailConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - secondConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - dialer := &openAIWSQueueDialer{ - conns: []openAIWSClientConn{firstConn, secondConn}, - } - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(dialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 121, - Name: "openai-ingress-preflight-ping-strict-affinity", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) - firstTurn := readMessage() - require.Equal(t, "resp_turn_ping_strict_1", gjson.GetBytes(firstTurn, "response.id").String()) - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_strict_1","input":[{"type":"input_text","text":"world"}]}`) - secondTurn := readMessage() - require.Equal(t, "resp_turn_ping_strict_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 严格亲和自动恢复后结束超时") - } - - require.Equal(t, 2, dialer.DialCount(), "严格亲和 preflight ping 失败后应自动降级并换连重放") - require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续在旧连接写第二轮") - require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping") - secondConn.mu.Lock() - secondWrites := append([]map[string]any(nil), secondConn.writes...) - secondConn.mu.Unlock() - require.Len(t, secondWrites, 1) - secondWrite := requestToJSONString(secondWrites[0]) - require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "自动恢复重放应移除 previous_response_id") - require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "自动恢复重放应使用完整 input 上下文") - require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) - require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - firstConn := &openAIWSWriteFailAfterFirstTurnConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - secondConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - dialer := &openAIWSQueueDialer{ - conns: []openAIWSClientConn{firstConn, secondConn}, - } - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(dialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 117, - Name: "openai-ingress-write-retry", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - var hooksMu sync.Mutex - beforeTurnCalls := make(map[int]int) - afterTurnCalls := make(map[int]int) - hooks := &OpenAIWSIngressHooks{ - BeforeTurn: func(turn int) error { - hooksMu.Lock() - beforeTurnCalls[turn]++ - hooksMu.Unlock() - return nil - }, - AfterTurn: func(turn int, _ *OpenAIForwardResult, _ error) { - hooksMu.Lock() - afterTurnCalls[turn]++ - hooksMu.Unlock() - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) - firstTurn := readMessage() - require.Equal(t, "resp_turn_write_retry_1", gjson.GetBytes(firstTurn, "response.id").String()) - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_write_retry_1"}`) - secondTurn := readMessage() - require.Equal(t, "resp_turn_write_retry_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } - require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 上游写失败且未写下游时应自动重试并换连") - hooksMu.Lock() - beforeTurn1 := beforeTurnCalls[1] - beforeTurn2 := beforeTurnCalls[2] - afterTurn1 := afterTurnCalls[1] - afterTurn2 := afterTurnCalls[2] - hooksMu.Unlock() - require.Equal(t, 1, beforeTurn1, "首轮 turn BeforeTurn 应执行一次") - require.Equal(t, 1, beforeTurn2, "同一 turn 重试不应重复触发 BeforeTurn") - require.Equal(t, 1, afterTurn1, "首轮 turn AfterTurn 应执行一次") - require.Equal(t, 1, afterTurn2, "第二轮 turn AfterTurn 应执行一次") -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoversByDroppingPrevID(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - firstConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":""}}`), - }, - } - secondConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - dialer := &openAIWSQueueDialer{ - conns: []openAIWSClientConn{firstConn, secondConn}, - } - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(dialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 118, - Name: "openai-ingress-prev-recovery", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_seed_anchor"}`) - firstTurn := readMessage() - require.Equal(t, "resp_turn_prev_recover_1", gjson.GetBytes(firstTurn, "response.id").String()) - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_recover_1"}`) - secondTurn := readMessage() - require.Equal(t, "response.completed", gjson.GetBytes(secondTurn, "type").String()) - require.Equal(t, "resp_turn_prev_recover_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } - - require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应触发换连重试") - - firstConn.mu.Lock() - firstWrites := append([]map[string]any(nil), firstConn.writes...) - firstConn.mu.Unlock() - require.Len(t, firstWrites, 2, "首个连接应处理首轮与失败的第二轮请求") - require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists(), "失败轮次首发请求应包含 previous_response_id") - - secondConn.mu.Lock() - secondWrites := append([]map[string]any(nil), secondConn.writes...) - secondConn.mu.Unlock() - require.Len(t, secondWrites, 1, "恢复重试应在第二个连接发送一次请求") - require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreviousResponseNotFoundLayer2Recovery(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - firstConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"missing strict anchor"}}`), - }, - } - secondConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - dialer := &openAIWSQueueDialer{ - conns: []openAIWSClientConn{firstConn, secondConn}, - } - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(dialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 122, - Name: "openai-ingress-prev-strict-layer2", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","input":[{"type":"input_text","text":"hello"}]}`) - firstTurn := readMessage() - require.Equal(t, "resp_turn_prev_strict_recover_1", gjson.GetBytes(firstTurn, "response.id").String()) - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","previous_response_id":"resp_turn_prev_strict_recover_1","input":[{"type":"input_text","text":"world"}]}`) - secondTurn := readMessage() - require.Equal(t, "resp_turn_prev_strict_recover_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 严格亲和 Layer2 恢复结束超时") - } - - require.Equal(t, 2, dialer.DialCount(), "严格亲和链路命中 previous_response_not_found 应触发 Layer2 恢复重试") - - firstConn.mu.Lock() - firstWrites := append([]map[string]any(nil), firstConn.writes...) - firstConn.mu.Unlock() - require.Len(t, firstWrites, 2, "首连接应收到首轮请求和失败的续链请求") - require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists()) - - secondConn.mu.Lock() - secondWrites := append([]map[string]any(nil), secondConn.writes...) - secondConn.mu.Unlock() - require.Len(t, secondWrites, 1, "Layer2 恢复应仅重放一次") - secondWrite := requestToJSONString(secondWrites[0]) - require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "Layer2 恢复重放应移除 previous_response_id") - require.True(t, gjson.Get(secondWrite, "store").Exists(), "Layer2 恢复不应改变 store 标志") - require.False(t, gjson.Get(secondWrite, "store").Bool()) - require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "Layer2 恢复应重放完整 input 上下文") - require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) - require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoveryRemovesDuplicatePrevID(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - firstConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"first missing"}}`), - }, - } - secondConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - dialer := &openAIWSQueueDialer{ - conns: []openAIWSClientConn{firstConn, secondConn}, - } - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(dialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 120, - Name: "openai-ingress-prev-recovery-once", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeMessage := func(payload string) { - writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) - } - readMessage := func() []byte { - readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() - msgType, message, readErr := clientConn.Read(readCtx) - require.NoError(t, readErr) - require.Equal(t, coderws.MessageText, msgType) - return message - } - - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) - firstTurn := readMessage() - require.Equal(t, "resp_turn_prev_once_1", gjson.GetBytes(firstTurn, "response.id").String()) - - // duplicate previous_response_id: 恢复重试时应删除所有重复键,避免再次 previous_response_not_found。 - writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_once_1","input":[],"previous_response_id":"resp_turn_prev_duplicate"}`) - secondTurn := readMessage() - require.Equal(t, "resp_turn_prev_once_2", gjson.GetBytes(secondTurn, "response.id").String()) - - require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr) - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } - - require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应只重试一次") - - firstConn.mu.Lock() - firstWrites := append([]map[string]any(nil), firstConn.writes...) - firstConn.mu.Unlock() - require.Len(t, firstWrites, 2) - require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists()) - - secondConn.mu.Lock() - secondWrites := append([]map[string]any(nil), secondConn.writes...) - secondConn.mu.Unlock() - require.Len(t, secondWrites, 1) - require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "重复键场景恢复重试后不应保留 previous_response_id") -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_RejectsMessageIDAsPreviousResponseID(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - } - - account := &Account{ - ID: 119, - Name: "openai-ingress-prev-validation", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - defer func() { - _ = clientConn.CloseNow() - }() - - writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) - err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456"}`)) - cancelWrite() - require.NoError(t, err) - - select { - case serverErr := <-serverErrCh: - require.Error(t, serverErr) - var closeErr *OpenAIWSClientCloseError - require.ErrorAs(t, serverErr, &closeErr) - require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) - require.Contains(t, closeErr.Reason(), "previous_response_id must be a response.id") - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } -} - -type openAIWSQueueDialer struct { - mu sync.Mutex - conns []openAIWSClientConn - dialCount int -} - -func (d *openAIWSQueueDialer) Dial( - ctx context.Context, - wsURL string, - headers http.Header, - proxyURL string, -) (openAIWSClientConn, int, http.Header, error) { - _ = ctx - _ = wsURL - _ = headers - _ = proxyURL - d.mu.Lock() - defer d.mu.Unlock() - d.dialCount++ - if len(d.conns) == 0 { - return nil, 503, nil, errors.New("no test conn") - } - conn := d.conns[0] - if len(d.conns) > 1 { - d.conns = d.conns[1:] - } - return conn, 0, nil, nil -} - -func (d *openAIWSQueueDialer) DialCount() int { - d.mu.Lock() - defer d.mu.Unlock() - return d.dialCount -} - -type openAIWSPreflightFailConn struct { - mu sync.Mutex - events [][]byte - pingFails bool - writeCount int - pingCount int -} - -func (c *openAIWSPreflightFailConn) WriteJSON(context.Context, any) error { - c.mu.Lock() - c.writeCount++ - c.mu.Unlock() - return nil -} - -func (c *openAIWSPreflightFailConn) ReadMessage(context.Context) ([]byte, error) { - c.mu.Lock() - defer c.mu.Unlock() - if len(c.events) == 0 { - return nil, io.EOF - } - event := c.events[0] - c.events = c.events[1:] - if len(c.events) == 0 { - c.pingFails = true - } - return event, nil -} - -func (c *openAIWSPreflightFailConn) Ping(context.Context) error { - c.mu.Lock() - defer c.mu.Unlock() - c.pingCount++ - if c.pingFails { - return errors.New("preflight ping failed") - } - return nil -} - -func (c *openAIWSPreflightFailConn) Close() error { - return nil -} - -func (c *openAIWSPreflightFailConn) WriteCount() int { - c.mu.Lock() - defer c.mu.Unlock() - return c.writeCount -} - -func (c *openAIWSPreflightFailConn) PingCount() int { - c.mu.Lock() - defer c.mu.Unlock() - return c.pingCount -} - -type openAIWSWriteFailAfterFirstTurnConn struct { - mu sync.Mutex - events [][]byte - failOnWrite bool -} - -func (c *openAIWSWriteFailAfterFirstTurnConn) WriteJSON(context.Context, any) error { - c.mu.Lock() - defer c.mu.Unlock() - if c.failOnWrite { - return errors.New("write failed on stale conn") - } - return nil -} - -func (c *openAIWSWriteFailAfterFirstTurnConn) ReadMessage(context.Context) ([]byte, error) { - c.mu.Lock() - defer c.mu.Unlock() - if len(c.events) == 0 { - return nil, io.EOF - } - event := c.events[0] - c.events = c.events[1:] - if len(c.events) == 0 { - c.failOnWrite = true - } - return event, nil -} - -func (c *openAIWSWriteFailAfterFirstTurnConn) Ping(context.Context) error { - return nil -} - -func (c *openAIWSWriteFailAfterFirstTurnConn) Close() error { - return nil -} - -func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnectStillDrainsUpstream(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - // 多个上游事件:前几个为非 terminal 事件,最后一个为 terminal。 - // 第一个事件延迟 250ms 让客户端 RST 有时间传播,使 writeClientMessage 可靠失败。 - captureConn := &openAIWSCaptureConn{ - readDelays: []time.Duration{250 * time.Millisecond, 0, 0}, - events: [][]byte{ - []byte(`{"type":"response.created","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1"}}`), - []byte(`{"type":"response.output_item.added","response":{"id":"resp_ingress_disconnect"}}`), - []byte(`{"type":"response.completed","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 115, - Name: "openai-ingress-client-disconnect", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - "model_mapping": map[string]any{ - "custom-original-model": "gpt-5.1", - }, - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - serverErrCh := make(chan error, 1) - resultCh := make(chan *OpenAIForwardResult, 1) - hooks := &OpenAIWSIngressHooks{ - AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { - if turnErr == nil && result != nil { - resultCh <- result - } - }, - } - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ - CompressionMode: coderws.CompressionContextTakeover, - }) - if err != nil { - serverErrCh <- err - return - } - defer func() { - _ = conn.CloseNow() - }() - - rec := httptest.NewRecorder() - ginCtx, _ := gin.CreateTestContext(rec) - req := r.Clone(r.Context()) - req.Header = req.Header.Clone() - req.Header.Set("User-Agent", "unit-test-agent/1.0") - ginCtx.Request = req - - readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) - msgType, firstMessage, readErr := conn.Read(readCtx) - cancel() - if readErr != nil { - serverErrCh <- readErr - return - } - if msgType != coderws.MessageText && msgType != coderws.MessageBinary { - serverErrCh <- errors.New("unsupported websocket client message type") - return - } - - serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) - })) - defer wsServer.Close() - - dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) - clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) - cancelDial() - require.NoError(t, err) - - writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) - err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`)) - cancelWrite() - require.NoError(t, err) - // 立即关闭客户端,模拟客户端在 relay 期间断连。 - require.NoError(t, clientConn.CloseNow(), "模拟 ingress 客户端提前断连") - - select { - case serverErr := <-serverErrCh: - require.NoError(t, serverErr, "客户端断连后应继续 drain 上游直到 terminal 或正常结束") - case <-time.After(5 * time.Second): - t.Fatal("等待 ingress websocket 结束超时") - } - - select { - case result := <-resultCh: - require.Equal(t, "resp_ingress_disconnect", result.RequestID) - require.Equal(t, 2, result.Usage.InputTokens) - require.Equal(t, 1, result.Usage.OutputTokens) - case <-time.After(2 * time.Second): - t.Fatal("未收到断连后的 turn 结果回调") - } -} diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go index ff35cb01d..3129dae27 100644 --- a/backend/internal/service/openai_ws_forwarder_ingress_test.go +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -6,10 +6,13 @@ import ( "errors" "io" "net" + "strings" "testing" + "time" "github.com/Wei-Shaw/sub2api/internal/config" coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" ) @@ -34,6 +37,8 @@ func TestIsOpenAIWSClientDisconnectError(t *testing.T) { {name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true}, {name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true}, {name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true}, + {name: "blank_message", err: errors.New(" "), want: false}, + {name: "unmatched_message", err: errors.New("tls handshake timeout"), want: false}, } for _, tt := range tests { @@ -45,6 +50,423 @@ func TestIsOpenAIWSClientDisconnectError(t *testing.T) { } } +func TestOpenAIWSIngressFallbackSessionSeedFromContext(t *testing.T) { + t.Parallel() + + require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(nil)) + + gin.SetMode(gin.TestMode) + c, _ := gin.CreateTestContext(nil) + require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(c)) + + c.Set("api_key", "not_api_key") + require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(c)) + + groupID := int64(99) + c.Set("api_key", &APIKey{ + ID: 101, + GroupID: &groupID, + User: &User{ID: 202}, + }) + require.Equal(t, "openai_ws_ingress:99:202:101", openAIWSIngressFallbackSessionSeedFromContext(c)) + + c.Set("api_key", &APIKey{ + ID: 303, + User: nil, + }) + require.Equal(t, "openai_ws_ingress:0:0:303", openAIWSIngressFallbackSessionSeedFromContext(c)) +} + +func TestClassifyOpenAIWSIngressTurnAbortReason(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + wantReason openAIWSIngressTurnAbortReason + wantExpected bool + }{ + { + name: "nil", + err: nil, + wantReason: openAIWSIngressTurnAbortReasonUnknown, + wantExpected: false, + }, + { + name: "context canceled", + err: context.Canceled, + wantReason: openAIWSIngressTurnAbortReasonContextCanceled, + wantExpected: true, + }, + { + name: "context deadline", + err: context.DeadlineExceeded, + wantReason: openAIWSIngressTurnAbortReasonContextDeadline, + wantExpected: false, + }, + { + name: "client close", + err: coderws.CloseError{Code: coderws.StatusNormalClosure}, + wantReason: openAIWSIngressTurnAbortReasonClientClosed, + wantExpected: true, + }, + { + name: "client close by eof", + err: io.EOF, + wantReason: openAIWSIngressTurnAbortReasonClientClosed, + wantExpected: true, + }, + { + name: "previous response not found", + err: wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonPreviousResponse, + wantExpected: true, + }, + { + name: "tool output not found", + err: wrapOpenAIWSIngressTurnError( + openAIWSIngressStageToolOutputNotFound, + errors.New("no tool output found"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonToolOutput, + wantExpected: true, + }, + { + name: "upstream error event", + err: wrapOpenAIWSIngressTurnError( + "upstream_error_event", + errors.New("upstream error event"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonUpstreamError, + wantExpected: true, + }, + { + name: "write upstream", + err: wrapOpenAIWSIngressTurnError( + "write_upstream", + errors.New("write upstream fail"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonWriteUpstream, + wantExpected: false, + }, + { + name: "read upstream", + err: wrapOpenAIWSIngressTurnError( + "read_upstream", + errors.New("read upstream fail"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonReadUpstream, + wantExpected: false, + }, + { + name: "idle timeout stage", + err: wrapOpenAIWSIngressTurnError( + "idle_timeout", + errors.New("relay idle timeout"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonContextDeadline, + wantExpected: false, + }, + { + name: "write client", + err: wrapOpenAIWSIngressTurnError( + "write_client", + errors.New("write client fail"), + true, + ), + wantReason: openAIWSIngressTurnAbortReasonWriteClient, + wantExpected: false, + }, + { + name: "unknown turn stage", + err: wrapOpenAIWSIngressTurnError( + "some_unknown_stage", + errors.New("unknown stage fail"), + false, + ), + wantReason: openAIWSIngressTurnAbortReasonUnknown, + wantExpected: false, + }, + { + name: "continuation unavailable close", + err: NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + openAIWSContinuationUnavailableReason, + nil, + ), + wantReason: openAIWSIngressTurnAbortReasonContinuationUnavailable, + wantExpected: true, + }, + { + name: "upstream restart 1012", + err: wrapOpenAIWSIngressTurnError( + "read_upstream", + coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"}, + false, + ), + wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart, + wantExpected: true, + }, + { + name: "upstream try again later 1013", + err: wrapOpenAIWSIngressTurnError( + "read_upstream", + coderws.CloseError{Code: coderws.StatusTryAgainLater, Reason: "try again later"}, + false, + ), + wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart, + wantExpected: true, + }, + { + name: "upstream restart 1012 with wroteDownstream", + err: wrapOpenAIWSIngressTurnError( + "read_upstream", + coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"}, + true, + ), + wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart, + wantExpected: true, + }, + { + name: "1012 on non-read_upstream stage should not match", + err: wrapOpenAIWSIngressTurnError( + "write_upstream", + coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"}, + false, + ), + wantReason: openAIWSIngressTurnAbortReasonWriteUpstream, + wantExpected: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + reason, expected := classifyOpenAIWSIngressTurnAbortReason(tt.err) + require.Equal(t, tt.wantReason, reason) + require.Equal(t, tt.wantExpected, expected) + }) + } +} + +func TestClassifyOpenAIWSIngressTurnAbortReason_ClientDisconnectedDrainTimeout(t *testing.T) { + t.Parallel() + + err := wrapOpenAIWSIngressTurnError( + "client_disconnected_drain_timeout", + openAIWSIngressClientDisconnectedDrainTimeoutError(2*time.Second), + true, + ) + reason, expected := classifyOpenAIWSIngressTurnAbortReason(err) + require.Equal(t, openAIWSIngressTurnAbortReasonContextCanceled, reason) + require.True(t, expected) + require.Equal(t, openAIWSIngressTurnAbortDispositionCloseGracefully, openAIWSIngressTurnAbortDispositionForReason(reason)) +} + +func TestOpenAIWSIngressPumpClosedTurnError_ClientDisconnected(t *testing.T) { + t.Parallel() + + partial := &OpenAIForwardResult{ + RequestID: "resp_partial", + Usage: OpenAIUsage{ + InputTokens: 12, + }, + } + err := openAIWSIngressPumpClosedTurnError(true, true, partial) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + + var turnErr *openAIWSIngressTurnError + require.ErrorAs(t, err, &turnErr) + require.Equal(t, "client_disconnected_drain_timeout", turnErr.stage) + require.True(t, turnErr.wroteDownstream) + require.NotNil(t, turnErr.partialResult) + require.Equal(t, partial.RequestID, turnErr.partialResult.RequestID) +} + +func TestOpenAIWSIngressPumpClosedTurnError_ReadUpstream(t *testing.T) { + t.Parallel() + + err := openAIWSIngressPumpClosedTurnError(false, false, nil) + require.Error(t, err) + + var turnErr *openAIWSIngressTurnError + require.ErrorAs(t, err, &turnErr) + require.Equal(t, "read_upstream", turnErr.stage) + require.False(t, turnErr.wroteDownstream) + require.Nil(t, turnErr.partialResult) + reason, expected := classifyOpenAIWSIngressTurnAbortReason(err) + require.Equal(t, openAIWSIngressTurnAbortReasonReadUpstream, reason) + require.False(t, expected) +} + +func TestOpenAIWSIngressPumpClosedTurnError_ClonesPartialResult(t *testing.T) { + t.Parallel() + + partial := &OpenAIForwardResult{ + RequestID: "resp_original", + PendingFunctionCallIDs: []string{"call_a"}, + } + err := openAIWSIngressPumpClosedTurnError(true, true, partial) + require.Error(t, err) + + partial.RequestID = "resp_mutated" + partial.PendingFunctionCallIDs[0] = "call_b" + + var turnErr *openAIWSIngressTurnError + require.ErrorAs(t, err, &turnErr) + require.NotNil(t, turnErr.partialResult) + require.Equal(t, "resp_original", turnErr.partialResult.RequestID) + require.Equal(t, []string{"call_a"}, turnErr.partialResult.PendingFunctionCallIDs) +} + +func TestOpenAIWSIngressClientDisconnectedDrainTimeoutError_DefaultTimeout(t *testing.T) { + t.Parallel() + + err := openAIWSIngressClientDisconnectedDrainTimeoutError(0) + require.Error(t, err) + require.Contains(t, err.Error(), openAIWSIngressClientDisconnectDrainTimeout.String()) + require.ErrorIs(t, err, context.Canceled) +} + +func TestOpenAIWSIngressResolveDrainReadTimeout(t *testing.T) { + t.Parallel() + + now := time.Now() + tests := []struct { + name string + base time.Duration + deadline time.Time + want time.Duration + wantExpire bool + }{ + { + name: "no_deadline_uses_base", + base: 15 * time.Second, + deadline: time.Time{}, + want: 15 * time.Second, + wantExpire: false, + }, + { + name: "remaining_shorter_than_base", + base: 10 * time.Second, + deadline: now.Add(3 * time.Second), + want: 3 * time.Second, + wantExpire: false, + }, + { + name: "base_shorter_than_remaining", + base: 2 * time.Second, + deadline: now.Add(8 * time.Second), + want: 2 * time.Second, + wantExpire: false, + }, + { + name: "base_zero_uses_remaining", + base: 0, + deadline: now.Add(5 * time.Second), + want: 5 * time.Second, + wantExpire: false, + }, + { + name: "expired_deadline", + base: 10 * time.Second, + deadline: now.Add(-time.Millisecond), + want: 0, + wantExpire: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, expired := openAIWSIngressResolveDrainReadTimeout(tt.base, tt.deadline, now) + require.Equal(t, tt.want, got) + require.Equal(t, tt.wantExpire, expired) + }) + } +} + +func TestOpenAIWSIngressTurnAbortDispositionForReason(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + in openAIWSIngressTurnAbortReason + want openAIWSIngressTurnAbortDisposition + }{ + { + name: "continue turn on previous response mismatch", + in: openAIWSIngressTurnAbortReasonPreviousResponse, + want: openAIWSIngressTurnAbortDispositionContinueTurn, + }, + { + name: "continue turn on tool output mismatch", + in: openAIWSIngressTurnAbortReasonToolOutput, + want: openAIWSIngressTurnAbortDispositionContinueTurn, + }, + { + name: "continue turn on upstream error event", + in: openAIWSIngressTurnAbortReasonUpstreamError, + want: openAIWSIngressTurnAbortDispositionContinueTurn, + }, + { + name: "close gracefully on context canceled", + in: openAIWSIngressTurnAbortReasonContextCanceled, + want: openAIWSIngressTurnAbortDispositionCloseGracefully, + }, + { + name: "close gracefully on client closed", + in: openAIWSIngressTurnAbortReasonClientClosed, + want: openAIWSIngressTurnAbortDispositionCloseGracefully, + }, + { + name: "default fail request on unknown reason", + in: openAIWSIngressTurnAbortReasonUnknown, + want: openAIWSIngressTurnAbortDispositionFailRequest, + }, + { + name: "default fail request on write upstream reason", + in: openAIWSIngressTurnAbortReasonWriteUpstream, + want: openAIWSIngressTurnAbortDispositionFailRequest, + }, + { + name: "default fail request on read upstream reason", + in: openAIWSIngressTurnAbortReasonReadUpstream, + want: openAIWSIngressTurnAbortDispositionFailRequest, + }, + { + name: "default fail request on write client reason", + in: openAIWSIngressTurnAbortReasonWriteClient, + want: openAIWSIngressTurnAbortDispositionFailRequest, + }, + { + name: "continue turn on upstream restart", + in: openAIWSIngressTurnAbortReasonUpstreamRestart, + want: openAIWSIngressTurnAbortDispositionContinueTurn, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, openAIWSIngressTurnAbortDispositionForReason(tt.in)) + }) + } +} + func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) { t.Parallel() @@ -254,12 +676,12 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { want: false, }, { - name: "skip_on_first_turn", + name: "infer_on_first_turn_when_expected_previous_exists", storeDisabled: true, turn: 1, hasFunctionCallOutput: true, expectedPrevious: "resp_1", - want: false, + want: true, }, { name: "skip_without_function_call_output", @@ -312,6 +734,145 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { } } +func TestShouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storeDisabled bool + hasFunctionCallOutput bool + previousResponseID string + hasToolOutputContext bool + want bool + }{ + { + name: "reject_when_store_disabled_and_missing_prev_without_context", + storeDisabled: true, + hasFunctionCallOutput: true, + previousResponseID: "", + hasToolOutputContext: false, + want: true, + }, + { + name: "skip_when_store_enabled", + storeDisabled: false, + hasFunctionCallOutput: true, + previousResponseID: "", + hasToolOutputContext: false, + want: false, + }, + { + name: "skip_when_previous_response_id_exists", + storeDisabled: true, + hasFunctionCallOutput: true, + previousResponseID: "resp_1", + hasToolOutputContext: false, + want: false, + }, + { + name: "skip_when_has_tool_output_context", + storeDisabled: true, + hasFunctionCallOutput: true, + previousResponseID: "", + hasToolOutputContext: true, + want: false, + }, + { + name: "skip_when_no_function_call_output", + storeDisabled: true, + hasFunctionCallOutput: false, + previousResponseID: "", + hasToolOutputContext: false, + want: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID( + tt.storeDisabled, + tt.hasFunctionCallOutput, + tt.previousResponseID, + tt.hasToolOutputContext, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func TestOpenAIWSHasToolOutputContextInPayload(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + payload []byte + expectedCallIDs []string + wantHasToolCall bool + wantHasItemReference bool + }{ + { + name: "empty_payload", + payload: nil, + expectedCallIDs: []string{"call_1"}, + wantHasToolCall: false, + wantHasItemReference: false, + }, + { + name: "has_tool_call_context", + payload: []byte(`{"input":[{"type":"tool_call","call_id":"call_1"},{"type":"function_call_output","call_id":"call_1"}]}`), + expectedCallIDs: []string{"call_1"}, + wantHasToolCall: true, + wantHasItemReference: false, + }, + { + name: "has_function_call_context", + payload: []byte(`{"input":[{"type":"function_call","call_id":"call_1"},{"type":"function_call_output","call_id":"call_1"}]}`), + expectedCallIDs: []string{"call_1"}, + wantHasToolCall: true, + wantHasItemReference: false, + }, + { + name: "tool_call_without_call_id_is_not_context", + payload: []byte(`{"input":[{"type":"tool_call"},{"type":"function_call_output","call_id":"call_1"}]}`), + expectedCallIDs: []string{"call_1"}, + wantHasToolCall: false, + wantHasItemReference: false, + }, + { + name: "has_item_reference_for_all_function_call_outputs", + payload: []byte(`{"input":[{"type":"item_reference","id":"call_1"},{"type":"item_reference","id":"call_2"},{"type":"function_call_output","call_id":"call_1"},{"type":"function_call_output","call_id":"call_2"}]}`), + expectedCallIDs: []string{"call_1", "call_2"}, + wantHasToolCall: false, + wantHasItemReference: true, + }, + { + name: "missing_item_reference_for_some_call_ids", + payload: []byte(`{"input":[{"type":"item_reference","id":"call_1"},{"type":"function_call_output","call_id":"call_1"},{"type":"function_call_output","call_id":"call_2"}]}`), + expectedCallIDs: []string{"call_1", "call_2"}, + wantHasToolCall: false, + wantHasItemReference: false, + }, + { + name: "ignores_non_array_input", + payload: []byte(`{"input":"bad"}`), + expectedCallIDs: []string{"call_1"}, + wantHasToolCall: false, + wantHasItemReference: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.wantHasToolCall, openAIWSHasToolCallContextInPayload(tt.payload)) + require.Equal(t, tt.wantHasItemReference, openAIWSHasItemReferenceForAllFunctionCallOutputsInPayload(tt.payload, tt.expectedCallIDs)) + }) + } +} + func TestOpenAIWSInputIsPrefixExtended(t *testing.T) { t.Parallel() @@ -529,7 +1090,7 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) { }`) t.Run("strict_incremental_keep", func(t *testing.T) { - keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false, nil, nil) require.NoError(t, err) require.True(t, keep) require.Equal(t, "strict_incremental_ok", reason) @@ -537,28 +1098,28 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) { t.Run("missing_previous_response_id", func(t *testing.T) { payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) - keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil) require.NoError(t, err) require.False(t, keep) require.Equal(t, "missing_previous_response_id", reason) }) t.Run("missing_last_turn_response_id", func(t *testing.T) { - keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false, nil, nil) require.NoError(t, err) require.False(t, keep) require.Equal(t, "missing_last_turn_response_id", reason) }) t.Run("previous_response_id_mismatch", func(t *testing.T) { - keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false, nil, nil) require.NoError(t, err) require.False(t, keep) require.Equal(t, "previous_response_id_mismatch", reason) }) t.Run("missing_previous_turn_payload", func(t *testing.T) { - keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false) + keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false, nil, nil) require.NoError(t, err) require.False(t, keep) require.Equal(t, "missing_previous_turn_payload", reason) @@ -573,7 +1134,7 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) { "previous_response_id":"resp_turn_1", "input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}] }`) - keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil) require.NoError(t, err) require.False(t, keep) require.Equal(t, "non_input_changed", reason) @@ -588,7 +1149,7 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) { "previous_response_id":"resp_turn_1", "input":[{"type":"input_text","text":"different"}] }`) - keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil) require.NoError(t, err) require.True(t, keep) require.Equal(t, "strict_incremental_ok", reason) @@ -602,21 +1163,63 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) { "previous_response_id":"resp_external", "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] }`) - keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true, nil, []string{"call_1"}) require.NoError(t, err) require.True(t, keep) require.Equal(t, "has_function_call_output", reason) }) + t.Run("function_call_output_pending_call_id_match_keeps_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "previous_response_id":"resp_turn_1", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID( + previousPayload, + payload, + "resp_turn_1", + true, + []string{"call_1"}, + []string{"call_1"}, + ) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "function_call_output_call_id_match", reason) + }) + + t.Run("function_call_output_pending_call_id_mismatch_drops_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "previous_response_id":"resp_turn_1", + "input":[{"type":"function_call_output","call_id":"call_other","output":"ok"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID( + previousPayload, + payload, + "resp_turn_1", + true, + []string{"call_1"}, + []string{"call_other"}, + ) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "function_call_output_call_id_mismatch", reason) + }) + t.Run("non_input_compare_error", func(t *testing.T) { - keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false) + keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false, nil, nil) require.Error(t, err) require.False(t, keep) require.Equal(t, "non_input_compare_error", reason) }) t.Run("current_payload_compare_error", func(t *testing.T) { - keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false, nil, nil) require.Error(t, err) require.False(t, keep) require.Equal(t, "non_input_compare_error", reason) @@ -670,6 +1273,171 @@ func TestBuildOpenAIWSReplayInputSequence(t *testing.T) { require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String()) require.Equal(t, "world", gjson.GetBytes(items[1], "text").String()) }) + + t.Run("replay_input_limited_by_bytes_keeps_newest_items", func(t *testing.T) { + makeItem := func(text string) json.RawMessage { + raw, err := json.Marshal(map[string]any{ + "type": "input_text", + "text": text, + }) + require.NoError(t, err) + return json.RawMessage(raw) + } + largeA := strings.Repeat("a", openAIWSIngressReplayInputMaxBytes/2) + largeB := strings.Repeat("b", openAIWSIngressReplayInputMaxBytes/2) + largeC := strings.Repeat("c", openAIWSIngressReplayInputMaxBytes/2) + previousLarge := []json.RawMessage{ + makeItem(largeA), + makeItem(largeB), + } + currentPayload, err := json.Marshal(map[string]any{ + "previous_response_id": "resp_1", + "input": []map[string]any{ + {"type": "input_text", "text": largeC}, + }, + }) + require.NoError(t, err) + + items, exists, err := buildOpenAIWSReplayInputSequence( + previousLarge, + true, + currentPayload, + true, + ) + require.NoError(t, err) + require.True(t, exists) + require.GreaterOrEqual(t, len(items), 1) + require.Equal(t, largeC, gjson.GetBytes(items[len(items)-1], "text").String(), "latest item should always be preserved") + require.Less(t, len(items), 3, "oversized replay input should be truncated") + }) + + t.Run("replay_input_limited_by_bytes_still_keeps_single_oversized_latest_item", func(t *testing.T) { + tooLargeText := strings.Repeat("z", openAIWSIngressReplayInputMaxBytes+1024) + currentPayload, err := json.Marshal(map[string]any{ + "input": []map[string]any{ + {"type": "input_text", "text": tooLargeText}, + }, + }) + require.NoError(t, err) + + items, exists, err := buildOpenAIWSReplayInputSequence( + nil, + false, + currentPayload, + false, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, tooLargeText, gjson.GetBytes(items[0], "text").String()) + }) +} + +func TestOpenAIWSInputAppearsEditedFromPreviousFullInput(t *testing.T) { + t.Parallel() + + makeItems := func(values ...string) []json.RawMessage { + items := make([]json.RawMessage, 0, len(values)) + for _, v := range values { + raw, err := json.Marshal(map[string]any{ + "type": "input_text", + "text": v, + }) + require.NoError(t, err) + items = append(items, json.RawMessage(raw)) + } + return items + } + + previous := makeItems("hello", "world") + + t.Run("skip_when_no_previous_response_id", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`), + false, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("skip_when_previous_full_input_missing", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + nil, + false, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("error_when_current_payload_invalid", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"previous_response_id":"resp_1","input":[}`), + true, + ) + require.Error(t, err) + require.False(t, edited) + }) + + t.Run("skip_when_current_input_missing", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"previous_response_id":"resp_1"}`), + true, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("skip_when_previous_len_lt_2", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + makeItems("hello"), + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"}]}`), + true, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("skip_when_current_shorter_than_previous", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("skip_when_current_has_previous_prefix", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"},{"type":"input_text","text":"new"}]}`), + true, + ) + require.NoError(t, err) + require.False(t, edited) + }) + + t.Run("detect_when_current_is_full_snapshot_edit", func(t *testing.T) { + edited, err := openAIWSInputAppearsEditedFromPreviousFullInput( + previous, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.True(t, edited) + }) } func TestSetOpenAIWSPayloadInputSequence(t *testing.T) { @@ -712,3 +1480,147 @@ func TestCloneOpenAIWSRawMessages(t *testing.T) { require.Len(t, cloned, 0) }) } + +// --------------------------------------------------------------------------- +// TestInjectPreviousResponseIDForFunctionCallOutput +// 端到端测试:当客户端发送 function_call_output 但未携带 previous_response_id 时, +// Gateway 应主动注入 lastTurnResponseID,避免上游返回 tool_output_not_found 错误。 +// --------------------------------------------------------------------------- + +func TestInjectPreviousResponseIDForFunctionCallOutput(t *testing.T) { + t.Parallel() + + // 辅助函数:模拟 forwarder 中的注入逻辑 + // 返回 (注入后的 payload, 注入后的 previousResponseID, 是否执行了注入) + simulateInject := func( + storeDisabled bool, + turn int, + payload []byte, + expectedPrev string, + ) ([]byte, string, bool) { + currentPreviousResponseID := "" + prev := gjson.GetBytes(payload, "previous_response_id") + if prev.Exists() { + currentPreviousResponseID = strings.TrimSpace(prev.String()) + } + hasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + + if shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled, turn, hasFunctionCallOutput, currentPreviousResponseID, expectedPrev, + ) { + injected, err := setPreviousResponseIDToRawPayload(payload, expectedPrev) + if err != nil { + return payload, currentPreviousResponseID, false + } + return injected, expectedPrev, true + } + return payload, currentPreviousResponseID, false + } + + t.Run("inject_when_function_call_output_without_prev_id", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"function_call_output","call_id":"call_abc123","output":"result"}]}`) + updated, prevID, injected := simulateInject(true, 2, payload, "resp_last_turn") + + require.True(t, injected, "应该执行注入") + require.Equal(t, "resp_last_turn", prevID) + require.Equal(t, "resp_last_turn", gjson.GetBytes(updated, "previous_response_id").String()) + // 验证原始 input 保持不变 + require.Equal(t, "call_abc123", gjson.GetBytes(updated, `input.0.call_id`).String()) + require.Equal(t, "function_call_output", gjson.GetBytes(updated, `input.0.type`).String()) + require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String()) + }) + + t.Run("skip_when_prev_id_already_present", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","previous_response_id":"resp_client","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, prevID, injected := simulateInject(true, 2, payload, "resp_last_turn") + + require.False(t, injected, "客户端已携带 previous_response_id,不应注入") + require.Equal(t, "resp_client", prevID) + }) + + t.Run("skip_when_store_enabled", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, _, injected := simulateInject(false, 2, payload, "resp_last_turn") + + require.False(t, injected, "store 未禁用时不应注入") + }) + + t.Run("skip_when_no_function_call_output", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"input_text","text":"hello"}]}`) + _, _, injected := simulateInject(true, 2, payload, "resp_last_turn") + + require.False(t, injected, "没有 function_call_output 时不应注入") + }) + + t.Run("skip_when_expected_prev_empty", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, _, injected := simulateInject(true, 2, payload, "") + + require.False(t, injected, "没有 expectedPrev 时不应注入") + }) + + t.Run("inject_preserves_multiple_function_call_outputs", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"a"},{"type":"function_call_output","call_id":"call_2","output":"b"}]}`) + updated, prevID, injected := simulateInject(true, 5, payload, "resp_multi") + + require.True(t, injected) + require.Equal(t, "resp_multi", prevID) + require.Equal(t, "resp_multi", gjson.GetBytes(updated, "previous_response_id").String()) + outputs := gjson.GetBytes(updated, `input.#(type=="function_call_output")#.call_id`).Array() + require.Len(t, outputs, 2) + require.Equal(t, "call_1", outputs[0].String()) + require.Equal(t, "call_2", outputs[1].String()) + }) + + t.Run("inject_on_first_turn_with_expected_prev", func(t *testing.T) { + t.Parallel() + // turn=1 但有 expectedPrev(可能来自 session state store 恢复),应注入 + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, _, injected := simulateInject(true, 1, payload, "resp_restored") + + require.True(t, injected, "turn=1 且有 expectedPrev 时应注入") + }) + + t.Run("inject_updates_payload_bytes_correctly", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + updated, _, injected := simulateInject(true, 3, payload, "resp_check_size") + + require.True(t, injected) + // 注入后 payload 长度应增加(包含了新的 previous_response_id 字段) + require.Greater(t, len(updated), len(payload)) + // 验证 JSON 合法性 + require.True(t, json.Valid(updated), "注入后的 payload 应为合法 JSON") + }) + + t.Run("skip_when_turn_zero", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, _, injected := simulateInject(true, 0, payload, "resp_1") + + require.False(t, injected, "turn=0 时不应注入") + }) + + t.Run("inject_with_whitespace_expected_prev", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + // shouldInfer 内部会 trim,所以带空格的 expectedPrev 仍然有效 + _, _, injected := simulateInject(true, 2, payload, " resp_trimmed ") + + require.True(t, injected, "trim 后非空的 expectedPrev 应触发注入") + }) + + t.Run("skip_when_prev_id_is_whitespace_only", func(t *testing.T) { + t.Parallel() + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + _, _, injected := simulateInject(true, 2, payload, " ") + + require.False(t, injected, "纯空白的 expectedPrev 不应触发注入") + }) +} diff --git a/backend/internal/service/openai_ws_forwarder_panic_test.go b/backend/internal/service/openai_ws_forwarder_panic_test.go new file mode 100644 index 000000000..0ceacb187 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_panic_test.go @@ -0,0 +1,107 @@ +package service + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type openAIWSPanicResolver struct{} + +func (openAIWSPanicResolver) Resolve(account *Account) OpenAIWSProtocolDecision { + panic("resolver panic") +} + +type openAIWSPanicStateStore struct { + OpenAIWSStateStore +} + +func (openAIWSPanicStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) { + panic("state_store panic") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PanicRecovered(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + svc := buildIngressPolicyTestService(cfg) + svc.openaiWSResolver = openAIWSPanicResolver{} + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, + }) + + serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusInternalError, closeErr.StatusCode()) + require.Equal(t, "internal websocket proxy panic", closeErr.Reason()) +} + +func TestOpenAIGatewayService_ForwardOpenAIWSV2_PanicRecovered(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + openaiWSStateStore: openAIWSPanicStateStore{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + cache: &stubGatewayCache{}, + } + + account := &Account{ + ID: 445, + Name: "openai-forwarder-panic", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + } + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + req.Header.Set("session_id", "sess-panic-check") + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + _, err := svc.forwardOpenAIWSV2( + context.Background(), + ginCtx, + account, + map[string]any{ + "model": "gpt-5.1", + "input": []any{}, + }, + "sk-test", + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + true, + true, + "gpt-5.1", + "gpt-5.1", + time.Now(), + 1, + "", + ) + require.Error(t, err) + require.ErrorContains(t, err, "panic recovered") +} diff --git a/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go b/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go new file mode 100644 index 000000000..ed8539f28 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go @@ -0,0 +1,818 @@ +package service + +import ( + "encoding/json" + "errors" + "strings" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// --------------------------------------------------------------------------- +// 改动 1 测试:预防性检测 — store_disabled + function_call_output + 无 previous_response_id +// 在 sendAndRelay 中提前返回可恢复错误,避免上游先写数据再报错导致无法恢复 +// --------------------------------------------------------------------------- + +// TestProactiveToolOutputNotFound_ErrorShape 验证预防性检测生成的错误格式正确: +// stage = tool_output_not_found、wroteDownstream = false、可被恢复函数识别。 +func TestProactiveToolOutputNotFound_ErrorShape(t *testing.T) { + t.Parallel() + + err := wrapOpenAIWSIngressTurnErrorWithPartial( + openAIWSIngressStageToolOutputNotFound, + errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"), + false, + nil, + ) + + // 1. 应被识别为 tool_output_not_found + require.True(t, isOpenAIWSIngressToolOutputNotFound(err), + "预防性检测错误应被 isOpenAIWSIngressToolOutputNotFound 识别") + + // 2. wroteDownstream 应为 false(尚未发送任何数据) + require.False(t, openAIWSIngressTurnWroteDownstream(err), + "预防性检测在发送前触发,wroteDownstream 必须为 false") + + // 3. 应被 classifyOpenAIWSIngressTurnAbortReason 归类为 ToolOutput + reason, expected := classifyOpenAIWSIngressTurnAbortReason(err) + require.Equal(t, openAIWSIngressTurnAbortReasonToolOutput, reason) + require.True(t, expected, "tool_output_not_found 是预期错误") + + // 4. 处置方式应为 ContinueTurn(允许恢复重试) + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition, + "tool_output_not_found 应为 ContinueTurn 处置,允许恢复重试") + + // 5. 部分结果应为 nil(因为尚未产生任何上游响应) + // partialResult=nil 时 OpenAIWSIngressTurnPartialResult 返回 (nil, false) + partial, ok := OpenAIWSIngressTurnPartialResult(err) + require.False(t, ok, "预防性检测传入 partialResult=nil,ok 应为 false") + require.Nil(t, partial, "预防性检测不应有部分结果") +} + +// TestProactiveToolOutputNotFound_NotPreviousResponseNotFound +// 验证预防性检测生成的错误不会被误识别为 previous_response_not_found。 +func TestProactiveToolOutputNotFound_NotPreviousResponseNotFound(t *testing.T) { + t.Parallel() + + err := wrapOpenAIWSIngressTurnErrorWithPartial( + openAIWSIngressStageToolOutputNotFound, + errors.New("proactive tool_output_not_found"), + false, + nil, + ) + + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(err), + "tool_output_not_found 不应被误识别为 previous_response_not_found") +} + +// TestProactiveDetection_ConditionMatrix 验证预防性检测的三个触发条件的所有组合。 +// 只有 store_disabled + function_call_output + 无 previous_response_id + 无可关联上下文 才触发。 +func TestProactiveDetection_ConditionMatrix(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storeDisabled bool + hasFunctionCallOutput bool + previousResponseID string + hasToolOutputContext bool + shouldTrigger bool + }{ + { + name: "all_conditions_met_should_trigger", + storeDisabled: true, + hasFunctionCallOutput: true, + previousResponseID: "", + hasToolOutputContext: false, + shouldTrigger: true, + }, + { + name: "whitespace_only_previous_response_id_should_trigger", + storeDisabled: true, + hasFunctionCallOutput: true, + previousResponseID: " ", + hasToolOutputContext: false, + shouldTrigger: true, + }, + { + name: "store_enabled_should_not_trigger", + storeDisabled: false, + hasFunctionCallOutput: true, + previousResponseID: "", + hasToolOutputContext: false, + shouldTrigger: false, + }, + { + name: "no_function_call_output_should_not_trigger", + storeDisabled: true, + hasFunctionCallOutput: false, + previousResponseID: "", + hasToolOutputContext: false, + shouldTrigger: false, + }, + { + name: "has_previous_response_id_should_not_trigger", + storeDisabled: true, + hasFunctionCallOutput: true, + previousResponseID: "resp_abc", + hasToolOutputContext: false, + shouldTrigger: false, + }, + { + name: "all_false_should_not_trigger", + storeDisabled: false, + hasFunctionCallOutput: false, + previousResponseID: "resp_abc", + hasToolOutputContext: false, + shouldTrigger: false, + }, + { + name: "store_disabled_no_fco_has_prev_should_not_trigger", + storeDisabled: true, + hasFunctionCallOutput: false, + previousResponseID: "resp_abc", + hasToolOutputContext: false, + shouldTrigger: false, + }, + { + name: "store_enabled_has_fco_has_prev_should_not_trigger", + storeDisabled: false, + hasFunctionCallOutput: true, + previousResponseID: "resp_abc", + hasToolOutputContext: false, + shouldTrigger: false, + }, + { + name: "has_tool_output_context_should_not_trigger", + storeDisabled: true, + hasFunctionCallOutput: true, + previousResponseID: "", + hasToolOutputContext: true, + shouldTrigger: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + // 模拟 sendAndRelay 中的检测逻辑 + triggered := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID( + tt.storeDisabled, + tt.hasFunctionCallOutput, + tt.previousResponseID, + tt.hasToolOutputContext, + ) + require.Equal(t, tt.shouldTrigger, triggered) + }) + } +} + +// TestProactiveDetection_PayloadExtraction 验证从真实 payload 中提取的条件参数 +// 与预防性检测逻辑的配合。确保 payload 解析结果能正确触发或跳过检测。 +func TestProactiveDetection_PayloadExtraction(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + payload string + wantHasFCO bool + wantPrevID string + wantHasContext bool + shouldTrigger bool // 假设 storeDisabled=true + }{ + { + name: "fco_without_previous_response_id", + payload: `{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`, + wantHasFCO: true, + wantPrevID: "", + wantHasContext: false, + shouldTrigger: true, + }, + { + name: "fco_with_previous_response_id", + payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`, + wantHasFCO: true, + wantPrevID: "resp_1", + wantHasContext: false, + shouldTrigger: false, + }, + { + name: "no_fco_without_previous_response_id", + payload: `{"type":"response.create","input":[{"type":"input_text","text":"hello"}]}`, + wantHasFCO: false, + wantPrevID: "", + wantHasContext: false, + shouldTrigger: false, + }, + { + name: "multiple_fco_without_previous_response_id", + payload: `{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"r1"},{"type":"function_call_output","call_id":"call_2","output":"r2"}]}`, + wantHasFCO: true, + wantPrevID: "", + wantHasContext: false, + shouldTrigger: true, + }, + { + name: "fco_with_tool_call_context", + payload: `{"type":"response.create","input":[{"type":"tool_call","call_id":"call_1"},{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`, + wantHasFCO: true, + wantPrevID: "", + wantHasContext: true, + shouldTrigger: false, + }, + { + name: "fco_with_item_reference_context", + payload: `{"type":"response.create","input":[{"type":"item_reference","id":"call_1"},{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`, + wantHasFCO: true, + wantPrevID: "", + wantHasContext: true, + shouldTrigger: false, + }, + { + name: "empty_input_without_previous_response_id", + payload: `{"type":"response.create","input":[]}`, + wantHasFCO: false, + wantPrevID: "", + wantHasContext: false, + shouldTrigger: false, + }, + { + name: "no_input_field", + payload: `{"type":"response.create","model":"gpt-5.1"}`, + wantHasFCO: false, + wantPrevID: "", + wantHasContext: false, + shouldTrigger: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + payload := []byte(tt.payload) + + callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload) + hasFCO := len(callIDs) > 0 + prevID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + hasContext := openAIWSHasToolCallContextInPayload(payload) || + openAIWSHasItemReferenceForAllFunctionCallOutputsInPayload(payload, callIDs) + + require.Equal(t, tt.wantHasFCO, hasFCO, "hasFunctionCallOutput 不匹配") + require.Equal(t, tt.wantPrevID, prevID, "previousResponseID 不匹配") + require.Equal(t, tt.wantHasContext, hasContext, "tool output context 检测结果不匹配") + + // 模拟 storeDisabled=true 时的检测 + triggered := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID( + true, + hasFCO, + prevID, + hasContext, + ) + require.Equal(t, tt.shouldTrigger, triggered, "检测触发结果不匹配") + }) + } +} + +// TestProactiveDetection_RecoveryChainIntegration 验证预防性检测错误能被 +// recoverIngressPrevResponseNotFound 的恢复逻辑识别并处理。 +// 这是两个改动之间的集成测试。 +func TestProactiveDetection_RecoveryChainIntegration(t *testing.T) { + t.Parallel() + + // 模拟预防性检测返回的错误 + proactiveErr := wrapOpenAIWSIngressTurnErrorWithPartial( + openAIWSIngressStageToolOutputNotFound, + errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"), + false, + nil, + ) + + // 1. 恢复函数入口检测:isToolOutputMissing 应为 true + require.True(t, isOpenAIWSIngressToolOutputNotFound(proactiveErr), + "预防性检测错误应通过 isToolOutputMissing 检查") + + // 2. isPrevNotFound 应为 false(确保不走 previous_response_not_found 分支) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(proactiveErr), + "预防性检测错误不应走 previous_response_not_found 分支") + + // 3. wroteDownstream=false 确保可以安全重试 + require.False(t, openAIWSIngressTurnWroteDownstream(proactiveErr), + "预防性检测在发送前触发,wroteDownstream 必须为 false") +} + +// --------------------------------------------------------------------------- +// 改动 2 测试:恢复逻辑修复 — previous_response_id 已缺失时跳过 drop 步骤 +// --------------------------------------------------------------------------- + +// TestToolOutputRecovery_PreviousResponseIDAlreadyEmpty 验证核心修复: +// 当 payload 中本就不存在 previous_response_id 时,跳过 drop 步骤, +// 直接进入 setOpenAIWSPayloadInputSequence。 +func TestToolOutputRecovery_PreviousResponseIDAlreadyEmpty(t *testing.T) { + t.Parallel() + + // 场景:预防性检测触发后,payload 中不存在 previous_response_id + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) + require.Empty(t, currentPreviousResponseID, "payload 中不应存在 previous_response_id") + + // 修复后的逻辑:currentPreviousResponseID 为空时跳过 drop + // 直接使用原始 payload 作为 updatedPayload + updatedPayload := payload + if currentPreviousResponseID != "" { + // 此分支不应执行 + t.Fatal("不应进入 drop 分支") + } + + // 验证跳过 drop 后,payload 仍然有效且可以继续处理 + require.True(t, gjson.ValidBytes(updatedPayload), "payload 应为有效 JSON") + require.True(t, gjson.GetBytes(updatedPayload, `input.#(type=="function_call_output")`).Exists(), + "function_call_output 应保持不变") + + // 模拟 setOpenAIWSPayloadInputSequence:使用 replay input 替换 + replayInput := []json.RawMessage{ + json.RawMessage(`{"type":"function_call_output","call_id":"call_1","output":"ok"}`), + } + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + replayInput, + true, + ) + require.NoError(t, setInputErr, "setOpenAIWSPayloadInputSequence 应成功") + require.True(t, gjson.ValidBytes(updatedWithInput), "更新后的 payload 应为有效 JSON") + require.True(t, gjson.GetBytes(updatedWithInput, `input.#(type=="function_call_output")`).Exists(), + "replay input 中的 function_call_output 应存在") +} + +// TestToolOutputRecovery_PreviousResponseIDExists 验证修复不影响原有逻辑: +// 当 payload 中存在 previous_response_id 时,仍然执行 drop。 +func TestToolOutputRecovery_PreviousResponseIDExists(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_stale", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) + require.NotEmpty(t, currentPreviousResponseID, "payload 中应存在 previous_response_id") + + // 修复后的逻辑:currentPreviousResponseID 不为空时执行 drop + updatedPayload := payload + if currentPreviousResponseID != "" { + dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, dropErr, "drop 操作不应出错") + require.True(t, removed, "previous_response_id 应被成功移除") + updatedPayload = dropped + } + + // 验证 drop 后 previous_response_id 已移除 + require.False(t, gjson.GetBytes(updatedPayload, "previous_response_id").Exists(), + "previous_response_id 应被移除") + + // 验证 function_call_output 仍然存在 + require.True(t, gjson.GetBytes(updatedPayload, `input.#(type=="function_call_output")`).Exists(), + "function_call_output 应保持不变") + + // 模拟 setOpenAIWSPayloadInputSequence + replayInput := []json.RawMessage{ + json.RawMessage(`{"type":"function_call_output","call_id":"call_1","output":"ok"}`), + } + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + replayInput, + true, + ) + require.NoError(t, setInputErr, "setOpenAIWSPayloadInputSequence 应成功") + require.True(t, gjson.ValidBytes(updatedWithInput), "更新后的 payload 应为有效 JSON") +} + +// TestToolOutputRecovery_DropError 验证当 drop 操作出错时返回 false。 +func TestToolOutputRecovery_DropError(t *testing.T) { + t.Parallel() + + // 使用非法 JSON 模拟 drop 错误 + currentPreviousResponseID := "resp_stale" + + if currentPreviousResponseID != "" { + // 使用含有 previous_response_id 但格式异常的 payload 触发 drop 错误 + badPayload := []byte(`not-valid-json`) + _, _, dropErr := dropPreviousResponseIDFromRawPayload(badPayload) + // 无论 drop 是否报错,验证错误分支的行为 + if dropErr != nil { + // drop 出错时应返回 false(跳过恢复) + t.Log("drop 出错场景:恢复应返回 false") + } + } +} + +// TestToolOutputRecovery_DropNotRemoved 验证当 drop 操作返回 removed=false 时的行为。 +// 修复后:如果 currentPreviousResponseID 不为空但 drop 未能移除(removed=false), +// updatedPayload 仍使用原始 payload(不更新),继续执行后续流程。 +func TestToolOutputRecovery_DropNotRemoved(t *testing.T) { + t.Parallel() + + // 构造一个含有 previous_response_id 的 payload + payload := []byte(`{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) + require.NotEmpty(t, currentPreviousResponseID) + + // 使用自定义 delete 函数模拟 removed=false 的场景 + updatedPayload := payload + if currentPreviousResponseID != "" { + // 正常 drop 应该成功 + dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, dropErr) + if removed { + updatedPayload = dropped + } + // 无论 removed 与否,后续逻辑都应继续执行(不再提前 return false) + } + + // 验证可以继续执行 setOpenAIWSPayloadInputSequence + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + nil, + false, // fullInputExists=false 时直接返回原 payload + ) + require.NoError(t, setInputErr) + require.Equal(t, string(updatedPayload), string(updatedWithInput)) +} + +// TestToolOutputRecovery_WithDeleteFn_Removed_False 使用 WithDeleteFn 接口模拟 +// drop 函数返回 removed=false 的边界情况。 +func TestToolOutputRecovery_WithDeleteFn_Removed_False(t *testing.T) { + t.Parallel() + + payload := []byte(`{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + currentPreviousResponseID := "resp_1" + + // 使用 noop delete 函数模拟 removed=false + noopDelete := func(data []byte, _ string) ([]byte, error) { + return data, nil // 不修改 payload,sjson 会返回原样 + } + + updatedPayload := payload + if currentPreviousResponseID != "" { + dropped, removed, dropErr := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, noopDelete) + require.NoError(t, dropErr) + // noop delete 不移除字段,但 previous_response_id 仍存在 + // removed 取决于 payload 比较 + if removed { + updatedPayload = dropped + } + } + + // 无论 removed 与否,都应继续(不提前 return false) + require.True(t, gjson.ValidBytes(updatedPayload)) +} + +// --------------------------------------------------------------------------- +// 端到端场景测试:预防性检测 → 恢复逻辑 完整链路 +// --------------------------------------------------------------------------- + +// TestEndToEnd_ProactiveDetection_RecoveryWithEmptyPreviousResponseID +// 完整模拟:store_disabled 模式下,客户端发送 function_call_output 但 +// 未携带 previous_response_id → 预防性检测触发 → 恢复逻辑跳过 drop → 重放。 +func TestEndToEnd_ProactiveDetection_RecoveryWithEmptyPreviousResponseID(t *testing.T) { + t.Parallel() + + // Step 1: 构造客户端 payload(无 previous_response_id) + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[ + {"type":"function_call_output","call_id":"call_abc","output":"{\"result\":\"ok\"}"} + ] + }`) + + // Step 2: 提取条件参数(模拟 sendAndRelay 中的变量赋值) + turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + turnFunctionCallOutputCallIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload) + turnHasFunctionCallOutput := len(turnFunctionCallOutputCallIDs) > 0 + turnStoreDisabled := true // 模拟 store_disabled 模式 + + require.Empty(t, turnPreviousResponseID) + require.True(t, turnHasFunctionCallOutput) + require.Equal(t, []string{"call_abc"}, turnFunctionCallOutputCallIDs) + + // Step 3: 预防性检测触发 + shouldTrigger := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID( + turnStoreDisabled, + turnHasFunctionCallOutput, + turnPreviousResponseID, + false, + ) + require.True(t, shouldTrigger, "预防性检测应触发") + + // Step 4: 构造预防性检测错误 + proactiveErr := wrapOpenAIWSIngressTurnErrorWithPartial( + openAIWSIngressStageToolOutputNotFound, + errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"), + false, + nil, + ) + + // Step 5: 验证恢复入口条件 + isToolOutputMissing := isOpenAIWSIngressToolOutputNotFound(proactiveErr) + require.True(t, isToolOutputMissing, "恢复入口应识别 tool_output_not_found") + + // Step 6: 模拟恢复逻辑(改动 2 的核心路径) + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) + require.Empty(t, currentPreviousResponseID, "payload 中无 previous_response_id") + + // 改动 2:跳过 drop 步骤 + updatedPayload := payload + if currentPreviousResponseID != "" { + t.Fatal("不应进入 drop 分支") + } + + // Step 7: 执行 setOpenAIWSPayloadInputSequence + replayInput := []json.RawMessage{ + json.RawMessage(`{"type":"function_call_output","call_id":"call_abc","output":"{\"result\":\"ok\"}"}`), + } + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + replayInput, + true, + ) + require.NoError(t, setInputErr, "setOpenAIWSPayloadInputSequence 应成功") + require.True(t, gjson.ValidBytes(updatedWithInput), "结果应为有效 JSON") + + // Step 8: 验证最终 payload + require.False(t, gjson.GetBytes(updatedWithInput, "previous_response_id").Exists(), + "最终 payload 不应含 previous_response_id") + require.True(t, gjson.GetBytes(updatedWithInput, `input.#(type=="function_call_output")`).Exists(), + "最终 payload 应含 function_call_output") + require.Equal(t, "call_abc", + gjson.GetBytes(updatedWithInput, `input.#(type=="function_call_output").call_id`).String(), + "call_id 应保持不变") +} + +// TestEndToEnd_ProactiveDetection_StoreEnabledBypasses 对照测试: +// store 未禁用时,即使 function_call_output 无 previous_response_id,也不触发预防性检测。 +func TestEndToEnd_ProactiveDetection_StoreEnabledBypasses(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + turnHasFunctionCallOutput := len(openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)) > 0 + turnStoreDisabled := false // store 未禁用 + + shouldTrigger := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID( + turnStoreDisabled, + turnHasFunctionCallOutput, + turnPreviousResponseID, + false, + ) + require.False(t, shouldTrigger, "store 未禁用时不应触发预防性检测") +} + +// TestEndToEnd_ProactiveDetection_WithPreviousResponseIDBypasses 对照测试: +// store_disabled 但 payload 含有 previous_response_id 时,不触发预防性检测。 +func TestEndToEnd_ProactiveDetection_WithPreviousResponseIDBypasses(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "previous_response_id":"resp_valid", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + turnHasFunctionCallOutput := len(openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)) > 0 + turnStoreDisabled := true + + shouldTrigger := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID( + turnStoreDisabled, + turnHasFunctionCallOutput, + turnPreviousResponseID, + false, + ) + require.False(t, shouldTrigger, "有 previous_response_id 时不应触发预防性检测") +} + +// TestEndToEnd_NormalTextInput_NeverTriggersProactive 对照测试: +// 普通文本输入(无 function_call_output)在任何模式下都不触发预防性检测。 +func TestEndToEnd_NormalTextInput_NeverTriggersProactive(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[{"type":"input_text","text":"hello world"}] + }`) + + turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + turnHasFunctionCallOutput := len(openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)) > 0 + turnStoreDisabled := true + + shouldTrigger := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID( + turnStoreDisabled, + turnHasFunctionCallOutput, + turnPreviousResponseID, + false, + ) + require.False(t, shouldTrigger, "普通文本输入不应触发预防性检测") +} + +// --------------------------------------------------------------------------- +// 改动 2 与原有 drop 路径的兼容性测试 +// --------------------------------------------------------------------------- + +// TestToolOutputRecovery_OriginalFlow_WithPreviousResponseID_StillWorks +// 验证改动 2 不破坏原有的 tool_output_not_found 恢复流程: +// 当 previous_response_id 存在时,仍然执行 drop + setInputSequence。 +func TestToolOutputRecovery_OriginalFlow_WithPreviousResponseID_StillWorks(t *testing.T) { + t.Parallel() + + // 原有场景:用户按 ESC 取消 function_call 后重新发送消息, + // payload 含有过时的 previous_response_id + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_stale", + "input":[ + {"type":"function_call_output","call_id":"call_canceled","output":"canceled"}, + {"type":"input_text","text":"new message"} + ] + }`) + + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) + require.Equal(t, "resp_stale", currentPreviousResponseID) + + // 改动后的逻辑 + updatedPayload := payload + if currentPreviousResponseID != "" { + dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, dropErr) + require.True(t, removed) + updatedPayload = dropped + } + + // 验证 drop 成功 + require.False(t, gjson.GetBytes(updatedPayload, "previous_response_id").Exists()) + + // 验证 input 保持不变 + inputCount := gjson.GetBytes(updatedPayload, "input.#").Int() + require.Equal(t, int64(2), inputCount, "input 数组应保持不变") + + // setOpenAIWSPayloadInputSequence 使用 replay input + replayInput := []json.RawMessage{ + json.RawMessage(`{"type":"input_text","text":"new message"}`), + } + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + replayInput, + true, + ) + require.NoError(t, setInputErr) + require.True(t, gjson.ValidBytes(updatedWithInput)) + require.Equal(t, int64(1), gjson.GetBytes(updatedWithInput, "input.#").Int(), + "replay input 应替换原始 input") +} + +// TestToolOutputRecovery_SetInputSequence_NoReplayInput 验证当没有 replay input 时, +// setOpenAIWSPayloadInputSequence 直接返回原 payload。 +func TestToolOutputRecovery_SetInputSequence_NoReplayInput(t *testing.T) { + t.Parallel() + + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + + // fullInputExists=false 时直接返回原 payload + result, err := setOpenAIWSPayloadInputSequence(payload, nil, false) + require.NoError(t, err) + require.Equal(t, string(payload), string(result)) +} + +// TestToolOutputRecovery_SetInputSequence_EmptyReplayInput 验证 replay input 为空数组时 +// 仍然能正确设置(替换为空数组)。 +func TestToolOutputRecovery_SetInputSequence_EmptyReplayInput(t *testing.T) { + t.Parallel() + + payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + + result, err := setOpenAIWSPayloadInputSequence(payload, []json.RawMessage{}, true) + require.NoError(t, err) + require.True(t, gjson.ValidBytes(result)) + require.Equal(t, int64(0), gjson.GetBytes(result, "input.#").Int(), + "空 replay input 应替换为空数组") +} + +// --------------------------------------------------------------------------- +// 边界条件与防御性测试 +// --------------------------------------------------------------------------- + +// TestProactiveDetection_EmptyPayload 验证空 payload 的行为。 +func TestProactiveDetection_EmptyPayload(t *testing.T) { + t.Parallel() + + emptyPayloads := [][]byte{ + nil, + {}, + []byte(""), + } + + for i, payload := range emptyPayloads { + callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload) + require.Empty(t, callIDs, "空 payload[%d] 不应提取到 call_id", i) + + prevID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + require.Empty(t, prevID, "空 payload[%d] 不应有 previous_response_id", i) + } +} + +// TestProactiveDetection_MalformedJSON 验证非法 JSON 不会导致 panic。 +func TestProactiveDetection_MalformedJSON(t *testing.T) { + t.Parallel() + + badPayloads := [][]byte{ + []byte(`{invalid json`), + []byte(`{"type":"response.create","input":"not_array"}`), + []byte(`{"type":"response.create","input":123}`), + } + + for i, payload := range badPayloads { + // 不应 panic + callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload) + _ = callIDs + prevID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + _ = prevID + + // 模拟检测逻辑不应 panic + hasFCO := len(callIDs) > 0 + hasContext := openAIWSHasToolCallContextInPayload(payload) || + openAIWSHasItemReferenceForAllFunctionCallOutputsInPayload(payload, callIDs) + triggered := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID( + true, + hasFCO, + prevID, + hasContext, + ) + _ = triggered + t.Logf("badPayload[%d]: hasFCO=%v, triggered=%v", i, hasFCO, triggered) + } +} + +// TestToolOutputRecovery_OldCodeWouldFail_Regression 回归测试: +// 验证旧代码在 previous_response_id 为空时会因 removed=false 而失败。 +// 新代码跳过 drop 步骤后应成功。 +func TestToolOutputRecovery_OldCodeWouldFail_Regression(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + // 旧代码行为:直接调用 dropPreviousResponseIDFromRawPayload + _, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, dropErr) + require.False(t, removed, "payload 中无 previous_response_id 时 drop 返回 removed=false") + + // 旧代码:!removed → return false(恢复失败) + oldCodeResult := dropErr == nil && removed // 旧条件:dropErr != nil || !removed + require.False(t, oldCodeResult, "旧代码在此场景会失败(return false)") + + // 新代码行为:先检查 currentPreviousResponseID,为空时跳过 drop + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) + updatedPayload := payload + newCodeSkippedDrop := false + if currentPreviousResponseID != "" { + dropped, removedNew, dropErrNew := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, dropErrNew) + if removedNew { + updatedPayload = dropped + } + } else { + newCodeSkippedDrop = true + } + + require.True(t, newCodeSkippedDrop, "新代码应跳过 drop 步骤") + require.Equal(t, string(payload), string(updatedPayload), "payload 应保持不变") + + // 新代码继续执行 setOpenAIWSPayloadInputSequence + replayInput := []json.RawMessage{ + json.RawMessage(`{"type":"function_call_output","call_id":"call_1","output":"ok"}`), + } + result, setErr := setOpenAIWSPayloadInputSequence(updatedPayload, replayInput, true) + require.NoError(t, setErr, "新代码应能继续执行 setInputSequence") + require.True(t, gjson.ValidBytes(result)) +} diff --git a/backend/internal/service/openai_ws_forwarder_recovery_test.go b/backend/internal/service/openai_ws_forwarder_recovery_test.go new file mode 100644 index 000000000..86abd88f0 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_recovery_test.go @@ -0,0 +1,692 @@ +package service + +import ( + "encoding/json" + "errors" + "strconv" + "testing" + + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +// --------------------------------------------------------------------------- +// openAIWSIngressTurnWroteDownstream 辅助函数测试 +// --------------------------------------------------------------------------- + +func TestOpenAIWSIngressTurnWroteDownstream(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + { + name: "nil_error_returns_false", + err: nil, + want: false, + }, + { + name: "plain_error_returns_false", + err: errors.New("some random error"), + want: false, + }, + { + name: "turn_error_wrote_downstream_false", + err: wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + false, + ), + want: false, + }, + { + name: "turn_error_wrote_downstream_true", + err: wrapOpenAIWSIngressTurnError( + "upstream_error_event", + errors.New("upstream error"), + true, + ), + want: true, + }, + { + name: "turn_error_with_partial_result_wrote_downstream_true", + err: wrapOpenAIWSIngressTurnErrorWithPartial( + "read_upstream", + errors.New("connection reset"), + true, + &OpenAIForwardResult{RequestID: "resp_partial"}, + ), + want: true, + }, + { + name: "turn_error_with_partial_result_wrote_downstream_false", + err: wrapOpenAIWSIngressTurnErrorWithPartial( + "read_upstream", + errors.New("connection reset"), + false, + &OpenAIForwardResult{RequestID: "resp_partial"}, + ), + want: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, openAIWSIngressTurnWroteDownstream(tt.err)) + }) + } +} + +// --------------------------------------------------------------------------- +// previous_response_not_found 错误与 ContinueTurn 处置测试 +// --------------------------------------------------------------------------- + +func TestPreviousResponseNotFound_ClassifiesAsContinueTurn(t *testing.T) { + t.Parallel() + + // previous_response_not_found(wroteDownstream=false)应被归类为 ContinueTurn + err := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + false, + ) + + reason, expected := classifyOpenAIWSIngressTurnAbortReason(err) + require.Equal(t, openAIWSIngressTurnAbortReasonPreviousResponse, reason) + require.True(t, expected) + + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) +} + +func TestToolOutputNotFound_ClassifiesAsContinueTurn(t *testing.T) { + t.Parallel() + + // tool_output_not_found(wroteDownstream=false)应被归类为 ContinueTurn + err := wrapOpenAIWSIngressTurnError( + openAIWSIngressStageToolOutputNotFound, + errors.New("no tool call found for function call output"), + false, + ) + + reason, expected := classifyOpenAIWSIngressTurnAbortReason(err) + require.Equal(t, openAIWSIngressTurnAbortReasonToolOutput, reason) + require.True(t, expected) + + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) +} + +// --------------------------------------------------------------------------- +// function_call_output 与 previous_response_id 语义绑定测试 +// 验证核心修复:带 function_call_output 时不能 drop previous_response_id +// --------------------------------------------------------------------------- + +func TestFunctionCallOutputPayload_HasFunctionCallOutput(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + payload string + want bool + }{ + { + name: "payload_with_function_call_output", + payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_abc","output":"ok"}]}`, + want: true, + }, + { + name: "payload_with_mixed_input_including_function_call_output", + payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"function_call_output","call_id":"call_abc","output":"ok"}]}`, + want: true, + }, + { + name: "payload_without_function_call_output", + payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"}]}`, + want: false, + }, + { + name: "payload_with_empty_input", + payload: `{"type":"response.create","previous_response_id":"resp_1","input":[]}`, + want: false, + }, + { + name: "payload_without_input", + payload: `{"type":"response.create","model":"gpt-5.1"}`, + want: false, + }, + { + name: "multiple_function_call_outputs", + payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"r1"},{"type":"function_call_output","call_id":"call_2","output":"r2"}]}`, + want: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := gjson.GetBytes([]byte(tt.payload), `input.#(type=="function_call_output")`).Exists() + require.Equal(t, tt.want, got) + }) + } +} + +func TestDropPreviousResponseID_BreaksFunctionCallOutput(t *testing.T) { + t.Parallel() + + // 核心回归测试:验证 drop previous_response_id 后 function_call_output 会变成孤立引用 + // + // 场景:客户端发送 {previous_response_id: "resp_1", input: [{type: "function_call_output", call_id: "call_abc"}]} + // 如果 drop 了 previous_response_id,上游会创建全新上下文,找不到 call_abc 对应的 tool call + // 结果:上游报 "No tool call found for function call output with call_id call_abc" + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_stale_or_lost", + "input":[ + {"type":"function_call_output","call_id":"call_abc","output":"{\"result\":\"ok\"}"}, + {"type":"function_call_output","call_id":"call_def","output":"{\"result\":\"done\"}"} + ] + }`) + + // 1. 验证原始 payload 有 previous_response_id 和 function_call_output + require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists()) + require.True(t, gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()) + + // 2. drop previous_response_id + dropped, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + + // 3. 验证 drop 后的状态:previous_response_id 被移除但 function_call_output 仍然存在 + require.False(t, gjson.GetBytes(dropped, "previous_response_id").Exists(), + "previous_response_id 应该被移除") + require.True(t, gjson.GetBytes(dropped, `input.#(type=="function_call_output")`).Exists(), + "function_call_output 仍然存在,但此时它引用的 call_id 没有了上下文 — 这就是 bug 的根因") + + // 4. 验证 call_id 仍然在 payload 中(说明 drop 不会清理 function_call_output) + callIDs := make([]string, 0) + gjson.GetBytes(dropped, "input").ForEach(func(_, item gjson.Result) bool { + if item.Get("type").String() == "function_call_output" { + callIDs = append(callIDs, item.Get("call_id").String()) + } + return true + }) + require.ElementsMatch(t, []string{"call_abc", "call_def"}, callIDs, + "function_call_output 的 call_id 未被清除,但上游已无法匹配") +} + +func TestRecoveryStrategy_FunctionCallOutput_ShouldNotDrop(t *testing.T) { + t.Parallel() + + // 此测试验证修复的核心逻辑: + // 当 hasFunctionCallOutput=true 且 set/align 策略均失败时, + // 正确行为是放弃恢复(return false),而非 drop previous_response_id + // + // 因为:function_call_output 语义绑定 previous_response_id + // drop previous_response_id 但保留 function_call_output → 上游报错 + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_lost", + "input":[{"type":"function_call_output","call_id":"call_JDKR","output":"ok"}] + }`) + + hasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + require.True(t, hasFunctionCallOutput, "payload 必须包含 function_call_output") + + // 模拟 set 策略失败(currentPreviousResponseID 不为空,不满足 set 条件) + currentPreviousResponseID := "resp_lost" + expectedPrev := "resp_expected" + require.NotEmpty(t, currentPreviousResponseID, "set 策略需要 currentPreviousResponseID 为空") + + // 模拟 align 策略失败 + _, aligned, alignErr := alignStoreDisabledPreviousResponseID(payload, expectedPrev) + if alignErr == nil && aligned { + // align 成功了,更新 payload 中的 previous_response_id + t.Log("align 策略成功,此场景不触发 abort 路径") + } + // 注意:align 通常会成功(替换 resp_lost → resp_expected)。 + // 但在真实场景中,如果 align 后的 previous_response_id 仍然在上游不存在, + // 上游会再次返回 previous_response_not_found,此时二次进入恢复函数, + // 但 turnPrevRecoveryTried=true 会阻止二次恢复,直接走 abort。 + + // 验证关键断言:即使 drop 技术上可行,也不应该执行 + // 因为这会导致 "No tool call found for function call output" 错误 + droppedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, dropErr) + require.True(t, removed, "drop 操作本身可以成功") + + // 但 drop 后的 payload 仍有 function_call_output —— 这就是为什么不能 drop + hasFCOAfterDrop := gjson.GetBytes(droppedPayload, `input.#(type=="function_call_output")`).Exists() + require.True(t, hasFCOAfterDrop, + "drop previous_response_id 不会移除 function_call_output,"+ + "导致上游报 'No tool call found for function call output'") +} + +// --------------------------------------------------------------------------- +// ContinueTurn abort 路径错误通知测试 +// --------------------------------------------------------------------------- + +func TestContinueTurnAbort_ErrorEventFormat(t *testing.T) { + t.Parallel() + + // 验证 ContinueTurn abort 时生成的 error 事件格式正确 + abortReason := openAIWSIngressTurnAbortReasonPreviousResponse + abortMessage := "turn failed: " + string(abortReason) + + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`) + + // 验证 JSON 格式有效 + var parsed map[string]any + err := json.Unmarshal(errorEvent, &parsed) + require.NoError(t, err, "error 事件应为有效 JSON") + + // 验证事件结构 + require.Equal(t, "error", parsed["type"]) + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "server_error", errorObj["type"]) + require.Equal(t, string(abortReason), errorObj["code"]) + require.Contains(t, errorObj["message"], string(abortReason)) +} + +func TestContinueTurnAbort_ErrorEventWithSpecialChars(t *testing.T) { + t.Parallel() + + // 验证包含特殊字符的错误消息不会破坏 JSON 格式 + specialMessages := []string{ + `No tool call found for function call output with call_id call_JDKR0SzNTARIsGb0L3hofFWd.`, + `error with "quotes" and \backslash`, + "error with\nnewline", + `error with & entities`, + "", // 空消息 + } + + for i, msg := range specialMessages { + msg := msg + t.Run("special_message_"+strconv.Itoa(i), func(t *testing.T) { + t.Parallel() + abortReason := openAIWSIngressTurnAbortReasonToolOutput + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + + string(abortReason) + `","message":` + strconv.Quote(msg) + `}}`) + + var parsed map[string]any + err := json.Unmarshal(errorEvent, &parsed) + require.NoError(t, err, "error event with special chars should be valid JSON: %q", msg) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, msg, errorObj["message"]) + }) + } +} + +func TestContinueTurnAbort_WroteDownstreamDeterminesNotification(t *testing.T) { + t.Parallel() + + // 验证 wroteDownstream 标志如何影响错误通知策略 + tests := []struct { + name string + wroteDownstream bool + shouldSendErrorToClient bool + }{ + { + name: "not_wrote_downstream_should_send_error", + wroteDownstream: false, + shouldSendErrorToClient: true, + }, + { + name: "wrote_downstream_should_not_send_error", + wroteDownstream: true, + shouldSendErrorToClient: false, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + err := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + tt.wroteDownstream, + ) + wroteDownstream := openAIWSIngressTurnWroteDownstream(err) + require.Equal(t, tt.wroteDownstream, wroteDownstream) + + // 只有当 wroteDownstream=false 时才需要补发 error 事件 + shouldNotify := !wroteDownstream + require.Equal(t, tt.shouldSendErrorToClient, shouldNotify) + }) + } +} + +// --------------------------------------------------------------------------- +// previous_response_id 恢复策略:set / align / abort 完整流程测试 +// --------------------------------------------------------------------------- + +func TestRecoveryStrategy_SetPreviousResponseID(t *testing.T) { + t.Parallel() + + // 场景:客户端未发送 previous_response_id,但 session 中有记录 + // 此时应该通过 set 策略注入 previous_response_id + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + expectedPrev := "resp_expected" + + // set 策略:当 currentPreviousResponseID 为空时,注入 expectedPrev + updated, err := setPreviousResponseIDToRawPayload(payload, expectedPrev) + require.NoError(t, err) + require.Equal(t, expectedPrev, gjson.GetBytes(updated, "previous_response_id").String()) + + // function_call_output 保持不变 + require.True(t, gjson.GetBytes(updated, `input.#(type=="function_call_output")`).Exists()) + require.Equal(t, "call_1", gjson.GetBytes(updated, `input.#(type=="function_call_output").call_id`).String()) +} + +func TestRecoveryStrategy_AlignPreviousResponseID(t *testing.T) { + t.Parallel() + + // 场景:客户端发送了过时的 previous_response_id,需要 align 到最新 + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_stale", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + expectedPrev := "resp_latest" + + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, expectedPrev) + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, expectedPrev, gjson.GetBytes(updated, "previous_response_id").String()) + + // function_call_output 保持不变 + require.True(t, gjson.GetBytes(updated, `input.#(type=="function_call_output")`).Exists()) +} + +func TestRecoveryStrategy_AlignFailsWhenNoExpectedPrev(t *testing.T) { + t.Parallel() + + // 场景:没有预期的 previous_response_id,align 无法执行 + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_stale", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "") + require.NoError(t, err) + require.False(t, changed, "align 应该在 expectedPrev 为空时不执行") + require.Equal(t, string(payload), string(updated)) +} + +// --------------------------------------------------------------------------- +// isOpenAIWSIngressPreviousResponseNotFound 边界条件测试 +// --------------------------------------------------------------------------- + +func TestIsOpenAIWSIngressPreviousResponseNotFound_WroteDownstreamBlocks(t *testing.T) { + t.Parallel() + + // wroteDownstream=true 时,即使 stage 是 previous_response_not_found, + // 也不应被识别为可恢复的 previous_response_not_found + // (因为已经向客户端写入了数据,无法安全重试) + err := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + true, // wroteDownstream = true + ) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(err), + "wroteDownstream=true 时不应识别为可恢复的 previous_response_not_found") +} + +func TestIsOpenAIWSIngressPreviousResponseNotFound_DifferentStageReturns_False(t *testing.T) { + t.Parallel() + + stages := []string{ + "read_upstream", + "write_upstream", + "upstream_error_event", + openAIWSIngressStageToolOutputNotFound, + "unknown", + "", + } + + for _, stage := range stages { + stage := stage + t.Run("stage_"+stage, func(t *testing.T) { + t.Parallel() + err := wrapOpenAIWSIngressTurnError(stage, errors.New("some error"), false) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(err), + "stage=%q 不应被识别为 previous_response_not_found", stage) + }) + } +} + +// --------------------------------------------------------------------------- +// 端到端场景测试:function_call_output 恢复链路 +// --------------------------------------------------------------------------- + +func TestEndToEnd_FunctionCallOutputRecoveryChain(t *testing.T) { + t.Parallel() + + // 完整场景: + // 1. 客户端发送带 function_call_output 的请求 + // 2. 上游返回 previous_response_not_found + // 3. 恢复策略尝试 set/align + // 4. 如果都失败,应该 abort(而非 drop previous_response_id) + // 5. 客户端收到 error 事件 + // 6. 客户端重置并发送完整请求 + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_lost", + "input":[ + {"type":"function_call_output","call_id":"call_JDKR0SzNTARIsGb0L3hofFWd","output":"{\"ok\":true}"} + ] + }`) + + // Step 1: 检测 function_call_output + hasFCO := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + require.True(t, hasFCO, "payload 包含 function_call_output") + + // Step 2: 模拟 previous_response_not_found 错误 + turnErr := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + false, + ) + require.True(t, isOpenAIWSIngressPreviousResponseNotFound(turnErr)) + + // Step 3: 验证 ContinueTurn 处置 + reason, _ := classifyOpenAIWSIngressTurnAbortReason(turnErr) + disposition := openAIWSIngressTurnAbortDispositionForReason(reason) + require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition) + + // Step 4: set 策略 — 失败(currentPreviousResponseID 不为空) + currentPrevID := gjson.GetBytes(payload, "previous_response_id").String() + require.NotEmpty(t, currentPrevID, "set 策略前提条件不满足(需要 currentPreviousResponseID 为空)") + + // Step 5: align 策略 — 假设 expectedPrev 为空(session 中无记录) + expectedPrev := "" + _, aligned, alignErr := alignStoreDisabledPreviousResponseID(payload, expectedPrev) + require.NoError(t, alignErr) + require.False(t, aligned, "expectedPrev 为空时 align 应失败") + + // Step 6: 此时应该 abort(return false)而非 drop + // 验证:如果错误地执行 drop,会导致 function_call_output 成为孤立引用 + dropped, removed, _ := dropPreviousResponseIDFromRawPayload(payload) + if removed { + hasFCOAfterDrop := gjson.GetBytes(dropped, `input.#(type=="function_call_output")`).Exists() + require.True(t, hasFCOAfterDrop, + "drop 后 function_call_output 仍存在,上游会报 'No tool call found'") + } + + // Step 7: 正确行为——abort 后生成 error 事件通知客户端 + wroteDownstream := openAIWSIngressTurnWroteDownstream(turnErr) + require.False(t, wroteDownstream, "abort 前未向客户端写入数据") + + abortMessage := "turn failed: " + string(reason) + errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + + string(reason) + `","message":` + strconv.Quote(abortMessage) + `}}`) + + var parsed map[string]any + require.NoError(t, json.Unmarshal(errorEvent, &parsed)) + require.Equal(t, "error", parsed["type"]) +} + +func TestEndToEnd_NonFunctionCallOutput_CanDrop(t *testing.T) { + t.Parallel() + + // 对照场景:没有 function_call_output 的 payload 可以安全 drop previous_response_id + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_old", + "input":[{"type":"input_text","text":"hello"}] + }`) + + hasFCO := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + require.False(t, hasFCO, "此 payload 不包含 function_call_output") + + // drop 是安全的 + dropped, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(dropped, "previous_response_id").Exists()) + + // input 仍然有效(input_text 不依赖 previous_response_id) + require.Equal(t, "hello", gjson.GetBytes(dropped, "input.0.text").String()) +} + +// --------------------------------------------------------------------------- +// shouldKeepIngressPreviousResponseID 与 function_call_output 的交互测试 +// --------------------------------------------------------------------------- + +func TestShouldKeepIngressPreviousResponseID_FunctionCallOutputCallIDMatch(t *testing.T) { + t.Parallel() + + // 当 function_call_output 的 call_id 与 pending call_id 匹配时,应保留 previous_response_id + previousPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + currentPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_1", + "input":[{"type":"function_call_output","call_id":"call_match","output":"ok"}] + }`) + + keep, reason, err := shouldKeepIngressPreviousResponseID( + previousPayload, + currentPayload, + "resp_1", + true, // hasFunctionCallOutput + []string{"call_match"}, // pendingCallIDs + []string{"call_match"}, // requestCallIDs + ) + require.NoError(t, err) + require.True(t, keep, "call_id 匹配时应保留 previous_response_id") + require.Equal(t, "function_call_output_call_id_match", reason) +} + +func TestShouldKeepIngressPreviousResponseID_FunctionCallOutputCallIDMismatch(t *testing.T) { + t.Parallel() + + // 当 function_call_output 的 call_id 与 pending call_id 不匹配时, + // 应放弃 previous_response_id + previousPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + currentPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_1", + "input":[{"type":"function_call_output","call_id":"call_wrong","output":"ok"}] + }`) + + keep, reason, err := shouldKeepIngressPreviousResponseID( + previousPayload, + currentPayload, + "resp_1", + true, // hasFunctionCallOutput + []string{"call_real"}, // pendingCallIDs + []string{"call_wrong"}, // requestCallIDs + ) + require.NoError(t, err) + require.False(t, keep, "call_id 不匹配时应放弃 previous_response_id") + require.Equal(t, "function_call_output_call_id_mismatch", reason) +} + +// --------------------------------------------------------------------------- +// isOpenAIWSIngressTurnRetryable 与 function_call_output 场景的交互 +// --------------------------------------------------------------------------- + +func TestIsOpenAIWSIngressTurnRetryable_PreviousResponseNotFound(t *testing.T) { + t.Parallel() + + // previous_response_not_found 不应被标记为 retryable(因为有专门的恢复路径) + err := wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New("previous response not found"), + false, + ) + require.False(t, isOpenAIWSIngressTurnRetryable(err), + "previous_response_not_found 有专门的恢复逻辑,不走通用重试") +} + +func TestIsOpenAIWSIngressTurnRetryable_WroteDownstreamBlocksRetry(t *testing.T) { + t.Parallel() + + // wroteDownstream=true 时,任何 stage 都不应 retryable + err := wrapOpenAIWSIngressTurnError( + "write_upstream", + errors.New("write failed"), + true, // wroteDownstream + ) + require.False(t, isOpenAIWSIngressTurnRetryable(err), + "wroteDownstream=true 时不应重试") +} + +// --------------------------------------------------------------------------- +// normalizeOpenAIWSIngressPayloadBeforeSend 与恢复的集成测试 +// --------------------------------------------------------------------------- + +func TestNormalizePayload_FunctionCallOutputPassthrough(t *testing.T) { + t.Parallel() + + // 透传模式:normalizer 不再注入 previous_response_id,原样传递 + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 1, + turn: 2, + connID: "conn_test", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "", + expectedPreviousResponse: "resp_expected", + pendingExpectedCallIDs: []string{"call_1"}, + }) + + // 透传模式:previous_response_id 保持客户端原值(空),由下游 recovery 处理 + require.Empty(t, out.currentPreviousResponseID, + "透传模式不应注入 previous_response_id") + require.True(t, out.hasFunctionCallOutputCallID) + require.Equal(t, []string{"call_1"}, out.functionCallOutputCallIDs) +} diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go deleted file mode 100644 index 592801f66..000000000 --- a/backend/internal/service/openai_ws_forwarder_success_test.go +++ /dev/null @@ -1,1306 +0,0 @@ -package service - -import ( - "context" - "encoding/json" - "errors" - "io" - "net/http" - "net/http/httptest" - "strconv" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/gin-gonic/gin" - "github.com/gorilla/websocket" - "github.com/stretchr/testify/require" - "github.com/tidwall/gjson" -) - -func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) { - gin.SetMode(gin.TestMode) - - type receivedPayload struct { - Type string - PreviousResponseID string - StreamExists bool - Stream bool - } - receivedCh := make(chan receivedPayload, 1) - - upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - t.Errorf("upgrade websocket failed: %v", err) - return - } - defer func() { - _ = conn.Close() - }() - - var request map[string]any - if err := conn.ReadJSON(&request); err != nil { - t.Errorf("read ws request failed: %v", err) - return - } - requestJSON := requestToJSONString(request) - receivedCh <- receivedPayload{ - Type: strings.TrimSpace(gjson.Get(requestJSON, "type").String()), - PreviousResponseID: strings.TrimSpace(gjson.Get(requestJSON, "previous_response_id").String()), - StreamExists: gjson.Get(requestJSON, "stream").Exists(), - Stream: gjson.Get(requestJSON, "stream").Bool(), - } - - if err := conn.WriteJSON(map[string]any{ - "type": "response.created", - "response": map[string]any{ - "id": "resp_new_1", - "model": "gpt-5.1", - }, - }); err != nil { - t.Errorf("write response.created failed: %v", err) - return - } - if err := conn.WriteJSON(map[string]any{ - "type": "response.completed", - "response": map[string]any{ - "id": "resp_new_1", - "model": "gpt-5.1", - "usage": map[string]any{ - "input_tokens": 12, - "output_tokens": 7, - "input_tokens_details": map[string]any{ - "cached_tokens": 3, - }, - }, - }, - }); err != nil { - t.Errorf("write response.completed failed: %v", err) - return - } - })) - defer wsServer.Close() - - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") - groupID := int64(1001) - c.Set("api_key", &APIKey{GroupID: &groupID}) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10 - cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 - - upstream := &httpUpstreamRecorder{ - resp: &http.Response{ - StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/json"}}, - Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), - }, - } - - cache := &stubGatewayCache{} - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: upstream, - cache: cache, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - } - - account := &Account{ - ID: 9, - Name: "openai-ws", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 2, - Credentials: map[string]any{ - "api_key": "sk-test", - "base_url": wsServer.URL, - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_1","input":[{"type":"input_text","text":"hello"}]}`) - result, err := svc.Forward(context.Background(), c, account, body) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, 12, result.Usage.InputTokens) - require.Equal(t, 7, result.Usage.OutputTokens) - require.Equal(t, 3, result.Usage.CacheReadInputTokens) - require.Equal(t, "resp_new_1", result.RequestID) - require.True(t, result.OpenAIWSMode) - require.False(t, gjson.GetBytes(upstream.lastBody, "model").Exists(), "WSv2 成功时不应回落 HTTP 上游") - - received := <-receivedCh - require.Equal(t, "response.create", received.Type) - require.Equal(t, "resp_prev_1", received.PreviousResponseID) - require.True(t, received.StreamExists, "WS 请求应携带 stream 字段") - require.False(t, received.Stream, "应保持客户端 stream=false 的原始语义") - - store := svc.getOpenAIWSStateStore() - mappedAccountID, getErr := store.GetResponseAccount(context.Background(), groupID, "resp_new_1") - require.NoError(t, getErr) - require.Equal(t, account.ID, mappedAccountID) - connID, ok := store.GetResponseConn("resp_new_1") - require.True(t, ok) - require.NotEmpty(t, connID) - - responseBody := rec.Body.Bytes() - require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String()) -} - -func requestToJSONString(payload map[string]any) string { - if len(payload) == 0 { - return "{}" - } - b, err := json.Marshal(payload) - if err != nil { - return "{}" - } - return string(b) -} - -func TestLogOpenAIWSBindResponseAccountWarn(t *testing.T) { - require.NotPanics(t, func() { - logOpenAIWSBindResponseAccountWarn(1, 2, "resp_ok", nil) - }) - require.NotPanics(t, func() { - logOpenAIWSBindResponseAccountWarn(1, 2, "resp_err", errors.New("bind failed")) - }) -} - -func TestOpenAIGatewayService_Forward_WSv2_RewriteModelAndToolCallsOnCompletedEvent(t *testing.T) { - gin.SetMode(gin.TestMode) - - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") - groupID := int64(3001) - c.Set("api_key", &APIKey{GroupID: &groupID}) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - captureConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_model_tool_1","model":"gpt-5.1","tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}],"usage":{"input_tokens":2,"output_tokens":1}},"tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}]}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 1301, - Name: "openai-rewrite", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - "model_mapping": map[string]any{ - "custom-original-model": "gpt-5.1", - }, - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - body := []byte(`{"model":"custom-original-model","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) - result, err := svc.Forward(context.Background(), c, account, body) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "resp_model_tool_1", result.RequestID) - require.Equal(t, "custom-original-model", gjson.GetBytes(rec.Body.Bytes(), "model").String(), "响应模型应回写为原始请求模型") - require.Equal(t, "edit", gjson.GetBytes(rec.Body.Bytes(), "tool_calls.0.function.name").String(), "工具名称应被修正为 OpenCode 规范") -} - -func TestOpenAIWSPayloadString_OnlyAcceptsStringValues(t *testing.T) { - payload := map[string]any{ - "type": nil, - "model": 123, - "prompt_cache_key": " cache-key ", - "previous_response_id": []byte(" resp_1 "), - } - - require.Equal(t, "", openAIWSPayloadString(payload, "type")) - require.Equal(t, "", openAIWSPayloadString(payload, "model")) - require.Equal(t, "cache-key", openAIWSPayloadString(payload, "prompt_cache_key")) - require.Equal(t, "resp_1", openAIWSPayloadString(payload, "previous_response_id")) -} - -func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) { - gin.SetMode(gin.TestMode) - - var upgradeCount atomic.Int64 - var sequence atomic.Int64 - upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upgradeCount.Add(1) - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - t.Errorf("upgrade websocket failed: %v", err) - return - } - defer func() { - _ = conn.Close() - }() - - for { - var request map[string]any - if err := conn.ReadJSON(&request); err != nil { - return - } - idx := sequence.Add(1) - responseID := "resp_reuse_" + strconv.FormatInt(idx, 10) - if err := conn.WriteJSON(map[string]any{ - "type": "response.created", - "response": map[string]any{ - "id": responseID, - "model": "gpt-5.1", - }, - }); err != nil { - return - } - if err := conn.WriteJSON(map[string]any{ - "type": "response.completed", - "response": map[string]any{ - "id": responseID, - "model": "gpt-5.1", - "usage": map[string]any{ - "input_tokens": 2, - "output_tokens": 1, - }, - }, - }); err != nil { - return - } - } - })) - defer wsServer.Close() - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10 - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - } - account := &Account{ - ID: 19, - Name: "openai-ws", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 2, - Credentials: map[string]any{ - "api_key": "sk-test", - "base_url": wsServer.URL, - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - for i := 0; i < 2; i++ { - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") - groupID := int64(2001) - c.Set("api_key", &APIKey{GroupID: &groupID}) - - body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_reuse","input":[{"type":"input_text","text":"hello"}]}`) - result, err := svc.Forward(context.Background(), c, account, body) - require.NoError(t, err) - require.NotNil(t, result) - require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_")) - } - - require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链") - metrics := svc.SnapshotOpenAIWSPoolMetrics() - require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1)) - require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1)) -} - -func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T) { - gin.SetMode(gin.TestMode) - - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") - c.Request.Header.Set("session_id", "sess-oauth-1") - c.Request.Header.Set("conversation_id", "conv-oauth-1") - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.AllowStoreRecovery = false - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - - captureConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_oauth_1","model":"gpt-5.1","usage":{"input_tokens":3,"output_tokens":2}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - account := &Account{ - ID: 29, - Name: "openai-oauth", - Platform: PlatformOpenAI, - Type: AccountTypeOAuth, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "oauth-token-1", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - body := []byte(`{"model":"gpt-5.1","stream":false,"store":true,"input":[{"type":"input_text","text":"hello"}]}`) - result, err := svc.Forward(context.Background(), c, account, body) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "resp_oauth_1", result.RequestID) - - require.NotNil(t, captureConn.lastWrite) - requestJSON := requestToJSONString(captureConn.lastWrite) - require.True(t, gjson.Get(requestJSON, "store").Exists(), "OAuth WSv2 应显式写入 store 字段") - require.False(t, gjson.Get(requestJSON, "store").Bool(), "默认策略应将 OAuth store 置为 false") - require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段") - require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true") - require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta")) - require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id")) - require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id")) -} - -func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) { - gin.SetMode(gin.TestMode) - - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - - captureConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_prompt_cache_key","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - account := &Account{ - ID: 31, - Name: "openai-oauth", - Platform: PlatformOpenAI, - Type: AccountTypeOAuth, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "access_token": "oauth-token-1", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - body := []byte(`{"model":"gpt-5.1","stream":true,"prompt_cache_key":"pcache_123","input":[{"type":"input_text","text":"hi"}]}`) - result, err := svc.Forward(context.Background(), c, account, body) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "resp_prompt_cache_key", result.RequestID) - - require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id")) - require.Empty(t, captureDialer.lastHeaders.Get("conversation_id")) - require.NotNil(t, captureConn.lastWrite) - require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists()) -} - -func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) { - gin.SetMode(gin.TestMode) - - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsockets = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false - - upstream := &httpUpstreamRecorder{ - resp: &http.Response{ - StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/json"}}, - Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), - }, - } - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: upstream, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - } - - account := &Account{ - ID: 39, - Name: "openai-ws-v1", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - "base_url": "https://api.openai.com/v1/responses", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_v1","input":[{"type":"input_text","text":"hello"}]}`) - result, err := svc.Forward(context.Background(), c, account, body) - require.Error(t, err) - require.Nil(t, result) - require.Contains(t, err.Error(), "ws v1") - require.Equal(t, http.StatusBadRequest, rec.Code) - require.Contains(t, rec.Body.String(), "WSv1") - require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求") -} - -func TestOpenAIGatewayService_Forward_WSv2_TurnStateAndMetadataReplayOnReconnect(t *testing.T) { - gin.SetMode(gin.TestMode) - - var connIndex atomic.Int64 - headersCh := make(chan http.Header, 4) - upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - idx := connIndex.Add(1) - headersCh <- cloneHeader(r.Header) - - respHeader := http.Header{} - if idx == 1 { - respHeader.Set("x-codex-turn-state", "turn_state_first") - } - conn, err := upgrader.Upgrade(w, r, respHeader) - if err != nil { - t.Errorf("upgrade websocket failed: %v", err) - return - } - defer func() { - _ = conn.Close() - }() - - var request map[string]any - if err := conn.ReadJSON(&request); err != nil { - t.Errorf("read ws request failed: %v", err) - return - } - responseID := "resp_turn_" + strconv.FormatInt(idx, 10) - if err := conn.WriteJSON(map[string]any{ - "type": "response.completed", - "response": map[string]any{ - "id": responseID, - "model": "gpt-5.1", - "usage": map[string]any{ - "input_tokens": 2, - "output_tokens": 1, - }, - }, - }); err != nil { - t.Errorf("write response.completed failed: %v", err) - return - } - })) - defer wsServer.Close() - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0 - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - } - - account := &Account{ - ID: 49, - Name: "openai-turn-state", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - "base_url": wsServer.URL, - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - reqBody := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) - rec1 := httptest.NewRecorder() - c1, _ := gin.CreateTestContext(rec1) - c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c1.Request.Header.Set("session_id", "session_turn_state") - c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_1") - result1, err := svc.Forward(context.Background(), c1, account, reqBody) - require.NoError(t, err) - require.NotNil(t, result1) - - sessionHash := svc.GenerateSessionHash(c1, reqBody) - store := svc.getOpenAIWSStateStore() - turnState, ok := store.GetSessionTurnState(0, sessionHash) - require.True(t, ok) - require.Equal(t, "turn_state_first", turnState) - - // 主动淘汰连接,模拟下一次请求发生重连。 - connID, hasConn := store.GetResponseConn(result1.RequestID) - require.True(t, hasConn) - svc.getOpenAIWSConnPool().evictConn(account.ID, connID) - - rec2 := httptest.NewRecorder() - c2, _ := gin.CreateTestContext(rec2) - c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c2.Request.Header.Set("session_id", "session_turn_state") - c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_2") - result2, err := svc.Forward(context.Background(), c2, account, reqBody) - require.NoError(t, err) - require.NotNil(t, result2) - - firstHandshakeHeaders := <-headersCh - secondHandshakeHeaders := <-headersCh - require.Equal(t, "turn_meta_1", firstHandshakeHeaders.Get("X-Codex-Turn-Metadata")) - require.Equal(t, "turn_meta_2", secondHandshakeHeaders.Get("X-Codex-Turn-Metadata")) - require.Equal(t, "turn_state_first", secondHandshakeHeaders.Get("X-Codex-Turn-State")) -} - -func TestOpenAIGatewayService_Forward_WSv2_GeneratePrewarm(t *testing.T) { - gin.SetMode(gin.TestMode) - - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c.Request.Header.Set("session_id", "session-prewarm") - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - - captureConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_prewarm_1","model":"gpt-5.1","usage":{"input_tokens":0,"output_tokens":0}}}`), - []byte(`{"type":"response.completed","response":{"id":"resp_main_1","model":"gpt-5.1","usage":{"input_tokens":4,"output_tokens":2}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 59, - Name: "openai-prewarm", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) - result, err := svc.Forward(context.Background(), c, account, body) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "resp_main_1", result.RequestID) - - require.Len(t, captureConn.writes, 2, "开启 generate=false 预热后应发送两次 WS 请求") - firstWrite := requestToJSONString(captureConn.writes[0]) - secondWrite := requestToJSONString(captureConn.writes[1]) - require.True(t, gjson.Get(firstWrite, "generate").Exists()) - require.False(t, gjson.Get(firstWrite, "generate").Bool()) - require.False(t, gjson.Get(secondWrite, "generate").Exists()) -} - -func TestOpenAIGatewayService_PrewarmReadHonorsParentContext(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - svc := &OpenAIGatewayService{ - cfg: cfg, - toolCorrector: NewCodexToolCorrector(), - } - account := &Account{ - ID: 601, - Name: "openai-prewarm-timeout", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - } - conn := newOpenAIWSConn("prewarm_ctx_conn", account.ID, &openAIWSBlockingConn{ - readDelay: 200 * time.Millisecond, - }, nil) - lease := &openAIWSConnLease{ - accountID: account.ID, - conn: conn, - } - payload := map[string]any{ - "type": "response.create", - "model": "gpt-5.1", - } - - ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) - defer cancel() - start := time.Now() - err := svc.performOpenAIWSGeneratePrewarm( - ctx, - lease, - OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, - payload, - "", - map[string]any{"model": "gpt-5.1"}, - account, - nil, - 0, - ) - elapsed := time.Since(start) - require.Error(t, err) - require.Contains(t, err.Error(), "prewarm_read_event") - require.Less(t, elapsed, 180*time.Millisecond, "预热读取应受父 context 取消控制,不应阻塞到 read_timeout") -} - -func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *testing.T) { - gin.SetMode(gin.TestMode) - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - - captureConn := &openAIWSCaptureConn{ - events: [][]byte{ - []byte(`{"type":"response.completed","response":{"id":"resp_meta_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - []byte(`{"type":"response.completed","response":{"id":"resp_meta_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 69, - Name: "openai-turn-metadata", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) - - rec1 := httptest.NewRecorder() - c1, _ := gin.CreateTestContext(rec1) - c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c1.Request.Header.Set("session_id", "session-metadata-reuse") - c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_1") - result1, err := svc.Forward(context.Background(), c1, account, body) - require.NoError(t, err) - require.NotNil(t, result1) - require.Equal(t, "resp_meta_1", result1.RequestID) - - rec2 := httptest.NewRecorder() - c2, _ := gin.CreateTestContext(rec2) - c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c2.Request.Header.Set("session_id", "session-metadata-reuse") - c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_2") - result2, err := svc.Forward(context.Background(), c2, account, body) - require.NoError(t, err) - require.NotNil(t, result2) - require.Equal(t, "resp_meta_2", result2.RequestID) - - require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接") - require.Len(t, captureConn.writes, 2) - - firstWrite := requestToJSONString(captureConn.writes[0]) - secondWrite := requestToJSONString(captureConn.writes[1]) - require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String()) - require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String()) -} - -func TestOpenAIGatewayService_Forward_WSv2StoreFalseSessionConnIsolation(t *testing.T) { - gin.SetMode(gin.TestMode) - - var upgradeCount atomic.Int64 - var sequence atomic.Int64 - upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upgradeCount.Add(1) - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - t.Errorf("upgrade websocket failed: %v", err) - return - } - defer func() { - _ = conn.Close() - }() - - for { - var request map[string]any - if err := conn.ReadJSON(&request); err != nil { - return - } - responseID := "resp_store_false_" + strconv.FormatInt(sequence.Add(1), 10) - if err := conn.WriteJSON(map[string]any{ - "type": "response.completed", - "response": map[string]any{ - "id": responseID, - "model": "gpt-5.1", - "usage": map[string]any{ - "input_tokens": 1, - "output_tokens": 1, - }, - }, - }); err != nil { - return - } - } - })) - defer wsServer.Close() - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4 - cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - } - - account := &Account{ - ID: 79, - Name: "openai-store-false", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 2, - Credentials: map[string]any{ - "api_key": "sk-test", - "base_url": wsServer.URL, - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) - - rec1 := httptest.NewRecorder() - c1, _ := gin.CreateTestContext(rec1) - c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c1.Request.Header.Set("session_id", "session_store_false_a") - result1, err := svc.Forward(context.Background(), c1, account, body) - require.NoError(t, err) - require.NotNil(t, result1) - require.Equal(t, int64(1), upgradeCount.Load()) - - rec2 := httptest.NewRecorder() - c2, _ := gin.CreateTestContext(rec2) - c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c2.Request.Header.Set("session_id", "session_store_false_a") - result2, err := svc.Forward(context.Background(), c2, account, body) - require.NoError(t, err) - require.NotNil(t, result2) - require.Equal(t, int64(1), upgradeCount.Load(), "同一 session(store=false) 应复用同一 WS 连接") - - rec3 := httptest.NewRecorder() - c3, _ := gin.CreateTestContext(rec3) - c3.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c3.Request.Header.Set("session_id", "session_store_false_b") - result3, err := svc.Forward(context.Background(), c3, account, body) - require.NoError(t, err) - require.NotNil(t, result3) - require.Equal(t, int64(2), upgradeCount.Load(), "不同 session(store=false) 应隔离连接,避免续链状态互相覆盖") -} - -func TestOpenAIGatewayService_Forward_WSv2StoreFalseDisableForceNewConnAllowsReuse(t *testing.T) { - gin.SetMode(gin.TestMode) - - var upgradeCount atomic.Int64 - var sequence atomic.Int64 - upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} - wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - upgradeCount.Add(1) - conn, err := upgrader.Upgrade(w, r, nil) - if err != nil { - t.Errorf("upgrade websocket failed: %v", err) - return - } - defer func() { - _ = conn.Close() - }() - - for { - var request map[string]any - if err := conn.ReadJSON(&request); err != nil { - return - } - responseID := "resp_store_false_reuse_" + strconv.FormatInt(sequence.Add(1), 10) - if err := conn.WriteJSON(map[string]any{ - "type": "response.completed", - "response": map[string]any{ - "id": responseID, - "model": "gpt-5.1", - "usage": map[string]any{ - "input_tokens": 1, - "output_tokens": 1, - }, - }, - }); err != nil { - return - } - } - })) - defer wsServer.Close() - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: &httpUpstreamRecorder{}, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - } - - account := &Account{ - ID: 80, - Name: "openai-store-false-reuse", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 2, - Credentials: map[string]any{ - "api_key": "sk-test", - "base_url": wsServer.URL, - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) - - rec1 := httptest.NewRecorder() - c1, _ := gin.CreateTestContext(rec1) - c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c1.Request.Header.Set("session_id", "session_store_false_reuse_a") - result1, err := svc.Forward(context.Background(), c1, account, body) - require.NoError(t, err) - require.NotNil(t, result1) - require.Equal(t, int64(1), upgradeCount.Load()) - - rec2 := httptest.NewRecorder() - c2, _ := gin.CreateTestContext(rec2) - c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c2.Request.Header.Set("session_id", "session_store_false_reuse_b") - result2, err := svc.Forward(context.Background(), c2, account, body) - require.NoError(t, err) - require.NotNil(t, result2) - require.Equal(t, int64(1), upgradeCount.Load(), "关闭强制新连后,不同 session(store=false) 可复用连接") -} - -func TestOpenAIGatewayService_Forward_WSv2ReadTimeoutAppliesPerRead(t *testing.T) { - gin.SetMode(gin.TestMode) - - rec := httptest.NewRecorder() - c, _ := gin.CreateTestContext(rec) - c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) - c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") - - cfg := &config.Config{} - cfg.Security.URLAllowlist.Enabled = false - cfg.Security.URLAllowlist.AllowInsecureHTTP = true - cfg.Gateway.OpenAIWS.Enabled = true - cfg.Gateway.OpenAIWS.OAuthEnabled = true - cfg.Gateway.OpenAIWS.APIKeyEnabled = true - cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 - cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 1 - cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 - - captureConn := &openAIWSCaptureConn{ - readDelays: []time.Duration{ - 700 * time.Millisecond, - 700 * time.Millisecond, - }, - events: [][]byte{ - []byte(`{"type":"response.created","response":{"id":"resp_timeout_ok","model":"gpt-5.1"}}`), - []byte(`{"type":"response.completed","response":{"id":"resp_timeout_ok","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), - }, - } - captureDialer := &openAIWSCaptureDialer{conn: captureConn} - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(captureDialer) - - upstream := &httpUpstreamRecorder{ - resp: &http.Response{ - StatusCode: http.StatusOK, - Header: http.Header{"Content-Type": []string{"application/json"}}, - Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)), - }, - } - - svc := &OpenAIGatewayService{ - cfg: cfg, - httpUpstream: upstream, - cache: &stubGatewayCache{}, - openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), - toolCorrector: NewCodexToolCorrector(), - openaiWSPool: pool, - } - - account := &Account{ - ID: 81, - Name: "openai-read-timeout", - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Status: StatusActive, - Schedulable: true, - Concurrency: 1, - Credentials: map[string]any{ - "api_key": "sk-test", - }, - Extra: map[string]any{ - "responses_websockets_v2_enabled": true, - }, - } - - body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) - result, err := svc.Forward(context.Background(), c, account, body) - require.NoError(t, err) - require.NotNil(t, result) - require.Equal(t, "resp_timeout_ok", result.RequestID) - require.Nil(t, upstream.lastReq, "每次 Read 都应独立应用超时;总时长超过 read_timeout 不应误回退 HTTP") -} - -type openAIWSCaptureDialer struct { - mu sync.Mutex - conn *openAIWSCaptureConn - lastHeaders http.Header - handshake http.Header - dialCount int -} - -func (d *openAIWSCaptureDialer) Dial( - ctx context.Context, - wsURL string, - headers http.Header, - proxyURL string, -) (openAIWSClientConn, int, http.Header, error) { - _ = ctx - _ = wsURL - _ = proxyURL - d.mu.Lock() - d.lastHeaders = cloneHeader(headers) - d.dialCount++ - respHeaders := cloneHeader(d.handshake) - d.mu.Unlock() - return d.conn, 0, respHeaders, nil -} - -func (d *openAIWSCaptureDialer) DialCount() int { - d.mu.Lock() - defer d.mu.Unlock() - return d.dialCount -} - -type openAIWSCaptureConn struct { - mu sync.Mutex - readDelays []time.Duration - events [][]byte - lastWrite map[string]any - writes []map[string]any - closed bool -} - -func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error { - _ = ctx - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return errOpenAIWSConnClosed - } - switch payload := value.(type) { - case map[string]any: - c.lastWrite = cloneMapStringAny(payload) - c.writes = append(c.writes, cloneMapStringAny(payload)) - case json.RawMessage: - var parsed map[string]any - if err := json.Unmarshal(payload, &parsed); err == nil { - c.lastWrite = cloneMapStringAny(parsed) - c.writes = append(c.writes, cloneMapStringAny(parsed)) - } - case []byte: - var parsed map[string]any - if err := json.Unmarshal(payload, &parsed); err == nil { - c.lastWrite = cloneMapStringAny(parsed) - c.writes = append(c.writes, cloneMapStringAny(parsed)) - } - } - return nil -} - -func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) { - if ctx == nil { - ctx = context.Background() - } - c.mu.Lock() - if c.closed { - c.mu.Unlock() - return nil, errOpenAIWSConnClosed - } - if len(c.events) == 0 { - c.mu.Unlock() - return nil, io.EOF - } - delay := time.Duration(0) - if len(c.readDelays) > 0 { - delay = c.readDelays[0] - c.readDelays = c.readDelays[1:] - } - event := c.events[0] - c.events = c.events[1:] - c.mu.Unlock() - if delay > 0 { - timer := time.NewTimer(delay) - defer timer.Stop() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-timer.C: - } - } - return event, nil -} - -func (c *openAIWSCaptureConn) Ping(ctx context.Context) error { - _ = ctx - return nil -} - -func (c *openAIWSCaptureConn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - c.closed = true - return nil -} - -func cloneMapStringAny(src map[string]any) map[string]any { - if src == nil { - return nil - } - dst := make(map[string]any, len(src)) - for k, v := range src { - dst[k] = v - } - return dst -} diff --git a/backend/internal/service/openai_ws_forwarder_turn_error_test.go b/backend/internal/service/openai_ws_forwarder_turn_error_test.go new file mode 100644 index 000000000..ea7ea8ccb --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_turn_error_test.go @@ -0,0 +1,62 @@ +package service + +import ( + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSIngressTurnPartialResult_NotTurnError(t *testing.T) { + result, ok := OpenAIWSIngressTurnPartialResult(errors.New("plain error")) + require.False(t, ok) + require.Nil(t, result) +} + +func TestOpenAIWSIngressTurnPartialResult_DeepCopy(t *testing.T) { + partial := &OpenAIForwardResult{ + RequestID: "resp_partial", + Usage: OpenAIUsage{ + InputTokens: 12, + OutputTokens: 34, + }, + PendingFunctionCallIDs: []string{"call_1", "call_2"}, + } + err := wrapOpenAIWSIngressTurnErrorWithPartial("read_upstream", errors.New("boom"), false, partial) + + got, ok := OpenAIWSIngressTurnPartialResult(err) + require.True(t, ok) + require.NotNil(t, got) + require.Equal(t, partial.RequestID, got.RequestID) + require.Equal(t, partial.Usage, got.Usage) + require.Equal(t, partial.PendingFunctionCallIDs, got.PendingFunctionCallIDs) + + // mutate returned copy should not affect stored partial result + got.PendingFunctionCallIDs[0] = "changed" + again, ok := OpenAIWSIngressTurnPartialResult(err) + require.True(t, ok) + require.Equal(t, "call_1", again.PendingFunctionCallIDs[0]) +} + +func TestOpenAIWSClientReadIdleTimeout_DefaultAndConfig(t *testing.T) { + svc := &OpenAIGatewayService{} + require.Equal(t, 30*time.Minute, svc.openAIWSClientReadIdleTimeout()) + + svc.cfg = &config.Config{} + svc.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 1800 + require.Equal(t, 30*time.Minute, svc.openAIWSClientReadIdleTimeout()) + + svc.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 120 + require.Equal(t, 120*time.Second, svc.openAIWSClientReadIdleTimeout()) +} + +func TestOpenAIWSPassthroughIdleTimeout_DefaultAndConfig(t *testing.T) { + svc := &OpenAIGatewayService{} + require.Equal(t, time.Hour, svc.openAIWSPassthroughIdleTimeout()) + + svc.cfg = &config.Config{} + svc.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 120 + require.Equal(t, 120*time.Second, svc.openAIWSPassthroughIdleTimeout()) +} diff --git a/backend/internal/service/openai_ws_hotpath_perf_test.go b/backend/internal/service/openai_ws_hotpath_perf_test.go new file mode 100644 index 000000000..6339fc3d9 --- /dev/null +++ b/backend/internal/service/openai_ws_hotpath_perf_test.go @@ -0,0 +1,938 @@ +package service + +import ( + "context" + "fmt" + "strings" + "sync" + "testing" + "time" + "unsafe" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// Mock conn for hotpath performance tests +// --------------------------------------------------------------------------- + +type openAIWSNoopConn struct{} + +func (c *openAIWSNoopConn) WriteJSON(context.Context, any) error { return nil } +func (c *openAIWSNoopConn) ReadMessage(context.Context) ([]byte, error) { return nil, nil } +func (c *openAIWSNoopConn) Ping(context.Context) error { return nil } +func (c *openAIWSNoopConn) Close() error { return nil } + +// openAIWSIdentityConn is a distinct conn instance used to verify pointer identity. +type openAIWSIdentityConn struct{} + +func (c *openAIWSIdentityConn) WriteJSON(context.Context, any) error { return nil } +func (c *openAIWSIdentityConn) ReadMessage(context.Context) ([]byte, error) { return nil, nil } +func (c *openAIWSIdentityConn) Ping(context.Context) error { return nil } +func (c *openAIWSIdentityConn) Close() error { return nil } + +func mustDefaultOpenAIWSStateStore(t *testing.T, raw OpenAIWSStateStore) *defaultOpenAIWSStateStore { + t.Helper() + store, ok := raw.(*defaultOpenAIWSStateStore) + require.True(t, ok) + return store +} + +// =================================================================== +// 1. maybeTouchLease throttle +// =================================================================== + +func TestMaybeTouchLease_NilReceiverDoesNotPanic(t *testing.T) { + var c *openAIWSIngressContext + require.NotPanics(t, func() { + c.maybeTouchLease(time.Minute) + }) +} + +func TestMaybeTouchLease_FirstCallAlwaysTouches(t *testing.T) { + c := &openAIWSIngressContext{} + require.Zero(t, c.lastTouchUnixNano.Load(), "precondition: lastTouchUnixNano should be zero") + + c.maybeTouchLease(5 * time.Minute) + + require.NotZero(t, c.lastTouchUnixNano.Load(), "first maybeTouchLease must update lastTouchUnixNano") + require.False(t, c.expiresAt().IsZero(), "first maybeTouchLease must set expiresAt") +} + +func TestMaybeTouchLease_SecondCallWithin1sIsSkipped(t *testing.T) { + c := &openAIWSIngressContext{} + + // First touch + c.maybeTouchLease(5 * time.Minute) + firstExpiry := c.expiresAt() + firstTouch := c.lastTouchUnixNano.Load() + require.NotZero(t, firstTouch) + + // Second touch immediately -- within 1s, should be skipped + c.maybeTouchLease(10 * time.Minute) + secondExpiry := c.expiresAt() + secondTouch := c.lastTouchUnixNano.Load() + + require.Equal(t, firstTouch, secondTouch, "lastTouchUnixNano should NOT change within 1s") + require.Equal(t, firstExpiry, secondExpiry, "expiresAt should NOT change within 1s") +} + +func TestMaybeTouchLease_CallAfter1sActuallyTouches(t *testing.T) { + c := &openAIWSIngressContext{} + + c.maybeTouchLease(5 * time.Minute) + firstExpiry := c.expiresAt() + + // Simulate 1s+ passing by backdating the lastTouchUnixNano + backdated := time.Now().Add(-2 * time.Second).UnixNano() + c.lastTouchUnixNano.Store(backdated) + // Also backdate expiresAt so we can observe the change + c.setExpiresAt(time.Now().Add(-time.Minute)) + expiryAfterBackdate := c.expiresAt() + require.True(t, expiryAfterBackdate.Before(firstExpiry), "precondition: expiresAt should be backdated") + + c.maybeTouchLease(5 * time.Minute) + touchAfter := c.lastTouchUnixNano.Load() + secondExpiry := c.expiresAt() + + require.Greater(t, touchAfter, backdated, "lastTouchUnixNano should advance past the backdated value") + require.True(t, secondExpiry.After(expiryAfterBackdate), "expiresAt should advance after 1s+ gap") +} + +func TestTouchLease_NilReceiverDoesNotPanic(t *testing.T) { + var c *openAIWSIngressContext + require.NotPanics(t, func() { + c.touchLease(time.Now(), 5*time.Minute) + }) +} + +func TestTouchLease_AlwaysUpdatesLastTouchUnixNano(t *testing.T) { + c := &openAIWSIngressContext{} + + now := time.Now() + c.touchLease(now, 5*time.Minute) + first := c.lastTouchUnixNano.Load() + require.NotZero(t, first) + + // touchLease (non-throttled) always updates, even if called again immediately. + time.Sleep(time.Millisecond) // ensure clock moves forward + now2 := time.Now() + c.touchLease(now2, 5*time.Minute) + second := c.lastTouchUnixNano.Load() + require.Greater(t, second, first, "touchLease must always update lastTouchUnixNano") +} + +// =================================================================== +// 2. activeConn cached connection +// =================================================================== + +func TestActiveConn_NilLeaseReturnsError(t *testing.T) { + var lease *openAIWSIngressContextLease + conn, err := lease.activeConn() + require.Nil(t, conn) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestActiveConn_NilContextReturnsError(t *testing.T) { + lease := &openAIWSIngressContextLease{context: nil} + conn, err := lease.activeConn() + require.Nil(t, conn) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestActiveConn_ReleasedLeaseReturnsError(t *testing.T) { + ctx := &openAIWSIngressContext{ + ownerID: "owner", + upstream: &openAIWSNoopConn{}, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner", + } + lease.released.Store(true) + + conn, err := lease.activeConn() + require.Nil(t, conn) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestActiveConn_FirstCallPopulatesCachedConn(t *testing.T) { + upstream := &openAIWSIdentityConn{} + ctx := &openAIWSIngressContext{ + ownerID: "owner_1", + upstream: upstream, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner_1", + } + + require.Nil(t, lease.cachedConn, "precondition: cachedConn should be nil") + + conn, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream, conn, "should return the upstream conn") + require.Equal(t, upstream, lease.cachedConn, "should populate cachedConn") +} + +func TestActiveConn_SecondCallReturnsCachedDirectly(t *testing.T) { + upstream1 := &openAIWSIdentityConn{} + ctx := &openAIWSIngressContext{ + ownerID: "owner_cache", + upstream: upstream1, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner_cache", + } + + // First call populates cache + conn1, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream1, conn1) + + // Swap the upstream -- cached path should NOT see the swap + upstream2 := &openAIWSIdentityConn{} + ctx.mu.Lock() + ctx.upstream = upstream2 + ctx.mu.Unlock() + + conn2, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream1, conn2, "second call should return cachedConn, not the swapped upstream") +} + +func TestActiveConn_OwnerMismatchReturnsError(t *testing.T) { + ctx := &openAIWSIngressContext{ + ownerID: "other_owner", + upstream: &openAIWSNoopConn{}, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "my_owner", + } + + conn, err := lease.activeConn() + require.Nil(t, conn) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestActiveConn_NilUpstreamReturnsError(t *testing.T) { + ctx := &openAIWSIngressContext{ + ownerID: "owner", + upstream: nil, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner", + } + + conn, err := lease.activeConn() + require.Nil(t, conn) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestActiveConn_MarkBrokenClearsCachedConn(t *testing.T) { + upstream := &openAIWSNoopConn{} + ctx := &openAIWSIngressContext{ + ownerID: "owner_mb", + upstream: upstream, + } + pool := &openAIWSIngressContextPool{ + idleTTL: 10 * time.Minute, + } + lease := &openAIWSIngressContextLease{ + pool: pool, + context: ctx, + ownerID: "owner_mb", + } + + // Populate cache + conn, err := lease.activeConn() + require.NoError(t, err) + require.NotNil(t, conn) + require.NotNil(t, lease.cachedConn) + + lease.MarkBroken() + require.Nil(t, lease.cachedConn, "MarkBroken must clear cachedConn") +} + +func TestActiveConn_ReleaseClearsCachedConn(t *testing.T) { + upstream := &openAIWSNoopConn{} + ctx := &openAIWSIngressContext{ + ownerID: "owner_rel", + upstream: upstream, + } + pool := &openAIWSIngressContextPool{ + idleTTL: 10 * time.Minute, + } + lease := &openAIWSIngressContextLease{ + pool: pool, + context: ctx, + ownerID: "owner_rel", + } + + // Populate cache + conn, err := lease.activeConn() + require.NoError(t, err) + require.NotNil(t, conn) + require.NotNil(t, lease.cachedConn) + + lease.Release() + require.Nil(t, lease.cachedConn, "Release must clear cachedConn") +} + +func TestActiveConn_AfterClearCachedConn_ReacquiresViaMutex(t *testing.T) { + upstream1 := &openAIWSIdentityConn{} + ctx := &openAIWSIngressContext{ + ownerID: "owner_reacq", + upstream: upstream1, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner_reacq", + } + + // Populate cache with upstream1 + conn, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream1, conn) + + // Simulate a cleared cache (e.g., after recovery) + lease.cachedConn = nil + + // Swap upstream + upstream2 := &openAIWSIdentityConn{} + ctx.mu.Lock() + ctx.upstream = upstream2 + ctx.mu.Unlock() + + // Should now re-acquire via mutex and return upstream2 + conn2, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream2, conn2, "after clearing cachedConn, next call must re-acquire via mutex") + require.Equal(t, upstream2, lease.cachedConn, "should re-populate cachedConn with new upstream") +} + +// =================================================================== +// 3. Event type TrimSpace-free functions +// =================================================================== + +func TestIsOpenAIWSTerminalEvent(t *testing.T) { + tests := []struct { + eventType string + want bool + }{ + {"response.completed", true}, + {"response.done", true}, + {"response.failed", true}, + {"response.incomplete", true}, + {"response.cancelled", true}, + {"response.canceled", true}, + {"response.created", false}, + {"response.in_progress", false}, + {"response.output_text.delta", false}, + {"", false}, + {"unknown_event", false}, + } + for _, tt := range tests { + t.Run(tt.eventType, func(t *testing.T) { + require.Equal(t, tt.want, isOpenAIWSTerminalEvent(tt.eventType)) + }) + } +} + +func TestShouldPersistOpenAIWSLastResponseID_HotpathPerf(t *testing.T) { + tests := []struct { + eventType string + want bool + }{ + {"response.completed", true}, + {"response.done", true}, + {"response.failed", false}, + {"response.incomplete", false}, + {"response.cancelled", false}, + {"", false}, + {"unknown_event", false}, + } + for _, tt := range tests { + t.Run(tt.eventType, func(t *testing.T) { + require.Equal(t, tt.want, shouldPersistOpenAIWSLastResponseID(tt.eventType)) + }) + } +} + +func TestIsOpenAIWSTokenEvent(t *testing.T) { + tests := []struct { + eventType string + want bool + }{ + // Known false: structural events + {"response.created", false}, + {"response.in_progress", false}, + {"response.output_item.added", false}, + {"response.output_item.done", false}, + // Delta events + {"response.output_text.delta", true}, + {"response.content_part.delta", true}, + {"response.audio.delta", true}, + {"response.function_call_arguments.delta", true}, + // output_text prefix + {"response.output_text.done", true}, + {"response.output_text.annotation.added", true}, + // output prefix (but not output_item) + {"response.output.done", true}, + // Terminal events that are also token events + {"response.completed", true}, + {"response.done", true}, + // Empty and unknown + {"", false}, + {"unknown_event", false}, + {"session.created", false}, + {"session.updated", false}, + } + for _, tt := range tests { + t.Run(tt.eventType, func(t *testing.T) { + require.Equal(t, tt.want, isOpenAIWSTokenEvent(tt.eventType)) + }) + } +} + +func TestOpenAIWSEventShouldParseUsage(t *testing.T) { + tests := []struct { + eventType string + want bool + }{ + {"response.completed", true}, + {"response.done", true}, + {"response.failed", true}, + {"", false}, + {"unknown", false}, + } + for _, tt := range tests { + t.Run(tt.eventType, func(t *testing.T) { + require.Equal(t, tt.want, openAIWSEventShouldParseUsage(tt.eventType)) + }) + } +} + +func TestOpenAIWSEventMayContainToolCalls(t *testing.T) { + tests := []struct { + eventType string + want bool + }{ + // Explicit function_call / tool_call in name + {"response.function_call_arguments.delta", true}, + {"response.function_call_arguments.done", true}, + {"response.tool_call.delta", true}, + // Structural events that may contain tool output items + {"response.output_item.added", true}, + {"response.output_item.done", true}, + {"response.completed", true}, + {"response.done", true}, + // Non-tool events + {"response.output_text.delta", false}, + {"response.created", false}, + {"response.in_progress", false}, + {"", false}, + {"unknown", false}, + } + for _, tt := range tests { + t.Run(tt.eventType, func(t *testing.T) { + require.Equal(t, tt.want, openAIWSEventMayContainToolCalls(tt.eventType)) + }) + } +} + +// =================================================================== +// 4. parseOpenAIWSEventType (lightweight version) +// =================================================================== + +func TestParseOpenAIWSEventType_EmptyMessage(t *testing.T) { + eventType, responseID := parseOpenAIWSEventType(nil) + require.Empty(t, eventType) + require.Empty(t, responseID) + + eventType, responseID = parseOpenAIWSEventType([]byte{}) + require.Empty(t, eventType) + require.Empty(t, responseID) +} + +func TestParseOpenAIWSEventType_ResponseIDExtracted(t *testing.T) { + msg := []byte(`{"type":"response.completed","response":{"id":"resp_abc123"}}`) + eventType, responseID := parseOpenAIWSEventType(msg) + require.Equal(t, "response.completed", eventType) + require.Equal(t, "resp_abc123", responseID) +} + +func TestParseOpenAIWSEventType_FallbackToID(t *testing.T) { + msg := []byte(`{"type":"response.output_text.delta","id":"evt_fallback_id"}`) + eventType, responseID := parseOpenAIWSEventType(msg) + require.Equal(t, "response.output_text.delta", eventType) + require.Equal(t, "evt_fallback_id", responseID) +} + +func TestParseOpenAIWSEventType_ConsistentWithEnvelope(t *testing.T) { + testMessages := [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`), + []byte(`{"type":"response.output_text.delta","id":"evt_2"}`), + []byte(`{"type":"response.created","response":{"id":"resp_3"}}`), + []byte(`{"type":"error","error":{"message":"bad request"}}`), + []byte(`{"type":"response.done","id":"resp_4","response":{"id":"resp_4_inner"}}`), + []byte(`{}`), + []byte(`{"type":"session.created"}`), + } + for i, msg := range testMessages { + t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) { + typeLight, idLight := parseOpenAIWSEventType(msg) + typeEnv, idEnv, _ := parseOpenAIWSEventEnvelope(msg) + require.Equal(t, typeEnv, typeLight, "eventType must match parseOpenAIWSEventEnvelope") + require.Equal(t, idEnv, idLight, "responseID must match parseOpenAIWSEventEnvelope") + }) + } +} + +// =================================================================== +// 6. openAIWSResponseAccountCacheKey (xxhash, v2 prefix) +// =================================================================== + +func TestOpenAIWSResponseAccountCacheKey_Deterministic(t *testing.T) { + key1 := openAIWSResponseAccountCacheKey("resp_deterministic_test") + key2 := openAIWSResponseAccountCacheKey("resp_deterministic_test") + require.Equal(t, key1, key2, "same responseID must produce the same key") +} + +func TestOpenAIWSResponseAccountCacheKey_DifferentIDsDifferentKeys(t *testing.T) { + key1 := openAIWSResponseAccountCacheKey("resp_alpha") + key2 := openAIWSResponseAccountCacheKey("resp_beta") + require.NotEqual(t, key1, key2, "different responseIDs must produce different keys") +} + +func TestOpenAIWSResponseAccountCacheKey_V2Prefix(t *testing.T) { + key := openAIWSResponseAccountCacheKey("resp_v2_check") + require.True(t, strings.Contains(key, "v2:"), "key must contain v2: prefix for version compatibility") +} + +func TestOpenAIWSResponseAccountCacheKey_StartsWithCachePrefix(t *testing.T) { + key := openAIWSResponseAccountCacheKey("resp_prefix_check") + require.True(t, strings.HasPrefix(key, openAIWSResponseAccountCachePrefix), + "key must start with the standard cache prefix %q, got %q", openAIWSResponseAccountCachePrefix, key) +} + +func TestOpenAIWSResponseAccountCacheKey_HexLength(t *testing.T) { + key := openAIWSResponseAccountCacheKey("resp_hex_length") + // Expected format: "openai:response:v2:<16 hex chars>" + prefix := openAIWSResponseAccountCachePrefix + "v2:" + require.True(t, strings.HasPrefix(key, prefix)) + hexPart := strings.TrimPrefix(key, prefix) + require.Len(t, hexPart, 16, "xxhash hex digest should be zero-padded to 16 chars, got %q", hexPart) +} + +func TestOpenAIWSResponseAccountCacheKey_ManyInputs_AllPaddedTo16(t *testing.T) { + // Verify that all inputs produce exactly 16-char hex, testing many variations. + prefix := openAIWSResponseAccountCachePrefix + "v2:" + for i := 0; i < 1000; i++ { + responseID := fmt.Sprintf("resp_%d", i) + key := openAIWSResponseAccountCacheKey(responseID) + hexPart := strings.TrimPrefix(key, prefix) + require.Len(t, hexPart, 16, "responseID=%q produced hex %q (len %d)", responseID, hexPart, len(hexPart)) + } +} + +// =================================================================== +// 7. openAIWSSessionTurnStateKey uses strconv +// =================================================================== + +func TestOpenAIWSSessionTurnStateKey_NormalCase(t *testing.T) { + key := openAIWSSessionTurnStateKey(123, "abc_hash") + require.Equal(t, "123:abc_hash", key) +} + +func TestOpenAIWSSessionTurnStateKey_EmptySessionHash(t *testing.T) { + key := openAIWSSessionTurnStateKey(123, "") + require.Equal(t, "", key) +} + +func TestOpenAIWSSessionTurnStateKey_WhitespaceOnlySessionHash(t *testing.T) { + key := openAIWSSessionTurnStateKey(123, " ") + require.Equal(t, "", key) +} + +func TestOpenAIWSSessionTurnStateKey_NegativeGroupID(t *testing.T) { + key := openAIWSSessionTurnStateKey(-1, "hash") + require.Equal(t, "-1:hash", key) +} + +func TestOpenAIWSSessionTurnStateKey_ZeroGroupID(t *testing.T) { + key := openAIWSSessionTurnStateKey(0, "hash") + require.Equal(t, "0:hash", key) +} + +// =================================================================== +// 8. openAIWSIngressContextSessionKey uses strconv +// =================================================================== + +func TestOpenAIWSIngressContextSessionKey_NormalCase(t *testing.T) { + key := openAIWSIngressContextSessionKey(456, "session_xyz") + require.Equal(t, "456:session_xyz", key) +} + +func TestOpenAIWSIngressContextSessionKey_EmptySessionHash(t *testing.T) { + key := openAIWSIngressContextSessionKey(456, "") + require.Equal(t, "", key) +} + +func TestOpenAIWSIngressContextSessionKey_WhitespaceOnlySessionHash(t *testing.T) { + key := openAIWSIngressContextSessionKey(456, " \t ") + require.Equal(t, "", key) +} + +func TestOpenAIWSIngressContextSessionKey_LargeGroupID(t *testing.T) { + key := openAIWSIngressContextSessionKey(9223372036854775807, "h") + require.Equal(t, "9223372036854775807:h", key) +} + +// =================================================================== +// 9. deriveOpenAISessionHash and deriveOpenAILegacySessionHash +// =================================================================== + +func TestDeriveOpenAISessionHash_EmptyReturnsEmpty(t *testing.T) { + require.Equal(t, "", deriveOpenAISessionHash("")) + require.Equal(t, "", deriveOpenAISessionHash(" ")) +} + +func TestDeriveOpenAISessionHash_ProducesXXHash16Chars(t *testing.T) { + hash := deriveOpenAISessionHash("test_session_id") + require.Len(t, hash, 16, "xxhash hex should be exactly 16 chars, got %q", hash) + // Verify it's valid hex + for _, ch := range hash { + require.True(t, (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f'), + "hash should be lowercase hex, got char %c in %q", ch, hash) + } +} + +func TestDeriveOpenAISessionHash_Deterministic(t *testing.T) { + h1 := deriveOpenAISessionHash("session_abc") + h2 := deriveOpenAISessionHash("session_abc") + require.Equal(t, h1, h2) +} + +func TestDeriveOpenAILegacySessionHash_EmptyReturnsEmpty(t *testing.T) { + require.Equal(t, "", deriveOpenAILegacySessionHash("")) + require.Equal(t, "", deriveOpenAILegacySessionHash(" ")) +} + +func TestDeriveOpenAILegacySessionHash_ProducesSHA256_64Chars(t *testing.T) { + hash := deriveOpenAILegacySessionHash("test_session_id") + require.Len(t, hash, 64, "SHA-256 hex should be exactly 64 chars, got %q", hash) + for _, ch := range hash { + require.True(t, (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f'), + "hash should be lowercase hex, got char %c in %q", ch, hash) + } +} + +func TestDeriveOpenAILegacySessionHash_Deterministic(t *testing.T) { + h1 := deriveOpenAILegacySessionHash("session_xyz") + h2 := deriveOpenAILegacySessionHash("session_xyz") + require.Equal(t, h1, h2) +} + +func TestDeriveOpenAISessionHashes_MatchesIndividualFunctions(t *testing.T) { + sessionID := "test_combined_session" + currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) + + require.Equal(t, deriveOpenAISessionHash(sessionID), currentHash) + require.Equal(t, deriveOpenAILegacySessionHash(sessionID), legacyHash) +} + +func TestDeriveOpenAISessionHashes_EmptyReturnsEmpty(t *testing.T) { + currentHash, legacyHash := deriveOpenAISessionHashes("") + require.Equal(t, "", currentHash) + require.Equal(t, "", legacyHash) +} + +func TestDeriveOpenAISessionHashes_DifferentInputsDifferentOutputs(t *testing.T) { + h1Current, h1Legacy := deriveOpenAISessionHashes("session_A") + h2Current, h2Legacy := deriveOpenAISessionHashes("session_B") + require.NotEqual(t, h1Current, h2Current) + require.NotEqual(t, h1Legacy, h2Legacy) +} + +func TestDeriveOpenAISessionHash_DifferentFromLegacy(t *testing.T) { + // xxhash and SHA-256 produce completely different outputs for the same input + currentHash := deriveOpenAISessionHash("same_input") + legacyHash := deriveOpenAILegacySessionHash("same_input") + require.NotEqual(t, currentHash, legacyHash, "xxhash and SHA-256 should produce different results") + require.Len(t, currentHash, 16) + require.Len(t, legacyHash, 64) +} + +// =================================================================== +// 10. State store sharded lock (responseToConn) +// =================================================================== + +func TestConnShard_DistributesAcrossShards(t *testing.T) { + store := mustDefaultOpenAIWSStateStore(t, NewOpenAIWSStateStore(nil)) + + shardHits := make(map[int]int) + for i := 0; i < 256; i++ { + key := fmt.Sprintf("resp_%d", i) + shard := store.connShard(key) + // Find which shard index this is + for j := 0; j < openAIWSStateStoreConnShards; j++ { + if shard == &store.responseToConnShards[j] { + shardHits[j]++ + break + } + } + } + + // With 256 keys and 16 shards, each shard should get some keys. + // We don't require perfect uniformity, just that keys aren't all in one shard. + require.Greater(t, len(shardHits), 1, "keys must be distributed across multiple shards, got %d shards used", len(shardHits)) + require.GreaterOrEqual(t, len(shardHits), openAIWSStateStoreConnShards/2, + "keys should hit at least half the shards for reasonable distribution") +} + +func TestStateStore_ShardedBindGetDelete(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + + store.BindResponseConn("resp_shard_1", "conn_a", time.Minute) + store.BindResponseConn("resp_shard_2", "conn_b", time.Minute) + + conn1, ok1 := store.GetResponseConn("resp_shard_1") + require.True(t, ok1) + require.Equal(t, "conn_a", conn1) + + conn2, ok2 := store.GetResponseConn("resp_shard_2") + require.True(t, ok2) + require.Equal(t, "conn_b", conn2) + + store.DeleteResponseConn("resp_shard_1") + _, ok1After := store.GetResponseConn("resp_shard_1") + require.False(t, ok1After) + + // resp_shard_2 should still be accessible + conn2After, ok2After := store.GetResponseConn("resp_shard_2") + require.True(t, ok2After) + require.Equal(t, "conn_b", conn2After) +} + +func TestStateStore_ShardedConcurrentAccessNoRace(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + const goroutines = 32 + const opsPerGoroutine = 200 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for g := 0; g < goroutines; g++ { + g := g + go func() { + defer wg.Done() + for i := 0; i < opsPerGoroutine; i++ { + key := fmt.Sprintf("resp_conc_%d_%d", g, i) + connID := fmt.Sprintf("conn_%d_%d", g, i) + + store.BindResponseConn(key, connID, time.Minute) + got, ok := store.GetResponseConn(key) + if ok { + _ = got + } + store.DeleteResponseConn(key) + } + }() + } + + wg.Wait() +} + +// =================================================================== +// 11. State store: Get paths don't call maybeCleanup +// =================================================================== + +func TestStateStore_GetPaths_DoNotTriggerCleanup(t *testing.T) { + raw := NewOpenAIWSStateStore(nil) + store := mustDefaultOpenAIWSStateStore(t, raw) + + // Seed some data so Get paths have something to read + store.BindResponseConn("resp_get_noclean", "conn_1", time.Minute) + store.BindResponsePendingToolCalls(0, "resp_get_noclean", []string{"call_1"}, time.Minute) + store.BindSessionTurnState(1, "session_get_noclean", "state_1", time.Minute) + store.BindSessionConn(1, "session_get_noclean", "conn_1", time.Minute) + + // Record the lastCleanupUnixNano after the Bind calls + cleanupBefore := store.lastCleanupUnixNano.Load() + + // Set lastCleanup to the future to ensure no cleanup triggers from Binds + store.lastCleanupUnixNano.Store(time.Now().Add(time.Hour).UnixNano()) + cleanupFrozen := store.lastCleanupUnixNano.Load() + + // Perform many Get calls + for i := 0; i < 100; i++ { + store.GetResponseConn("resp_get_noclean") + store.GetResponsePendingToolCalls(0, "resp_get_noclean") + store.GetSessionTurnState(1, "session_get_noclean") + store.GetSessionConn(1, "session_get_noclean") + } + + cleanupAfterGets := store.lastCleanupUnixNano.Load() + require.Equal(t, cleanupFrozen, cleanupAfterGets, + "Get paths must NOT change lastCleanupUnixNano (was %d before, %d after)", cleanupBefore, cleanupAfterGets) +} + +func TestStateStore_MaybeCleanup_NilReceiverDoesNotPanic(t *testing.T) { + var store *defaultOpenAIWSStateStore + require.NotPanics(t, func() { + store.maybeCleanup() + }) +} + +func TestStateStore_BindPaths_MayTriggerCleanup(t *testing.T) { + raw := NewOpenAIWSStateStore(nil) + store := mustDefaultOpenAIWSStateStore(t, raw) + + // Set lastCleanup to long ago to ensure cleanup triggers on next Bind + pastNano := time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano() + store.lastCleanupUnixNano.Store(pastNano) + + store.BindResponseConn("resp_bind_trigger", "conn_trigger", time.Minute) + + cleanupAfterBind := store.lastCleanupUnixNano.Load() + require.NotEqual(t, pastNano, cleanupAfterBind, + "Bind paths should trigger maybeCleanup when interval has elapsed") +} + +// =================================================================== +// 12. GetResponsePendingToolCalls returns internal slice directly +// =================================================================== + +func TestGetResponsePendingToolCalls_ReturnsInternalSlice(t *testing.T) { + raw := NewOpenAIWSStateStore(nil) + store := mustDefaultOpenAIWSStateStore(t, raw) + + store.BindResponsePendingToolCalls(0, "resp_slice_identity", []string{"call_x", "call_y"}, time.Minute) + + callIDs, ok := store.GetResponsePendingToolCalls(0, "resp_slice_identity") + require.True(t, ok) + require.Equal(t, []string{"call_x", "call_y"}, callIDs) + + // Verify it's the same underlying slice as stored in the binding (pointer equality). + // The binding stores callIDs as a copied slice at bind time, but Get returns it directly. + id := openAIWSResponsePendingToolCallsBindingKey(0, "resp_slice_identity") + store.responsePendingToolMu.RLock() + binding, exists := store.responsePendingTool[id] + store.responsePendingToolMu.RUnlock() + require.True(t, exists) + + // Check pointer equality of the underlying array via unsafe + gotHeader := (*[3]uintptr)(unsafe.Pointer(&callIDs)) + internalHeader := (*[3]uintptr)(unsafe.Pointer(&binding.callIDs)) + require.Equal(t, gotHeader[0], internalHeader[0], + "returned slice should share the same underlying array pointer as the internal binding (zero-copy)") +} + +// =================================================================== +// Additional edge-case tests for completeness +// =================================================================== + +func TestParseOpenAIWSEventType_MalformedJSON(t *testing.T) { + // Should not panic on malformed JSON + eventType, responseID := parseOpenAIWSEventType([]byte(`{not valid json`)) + // gjson returns empty for invalid JSON + _ = eventType + _ = responseID +} + +func TestOpenAIWSResponseAccountCacheKey_EmptyInput(t *testing.T) { + // Even empty string should produce a valid key + key := openAIWSResponseAccountCacheKey("") + require.True(t, strings.HasPrefix(key, openAIWSResponseAccountCachePrefix+"v2:")) + hexPart := strings.TrimPrefix(key, openAIWSResponseAccountCachePrefix+"v2:") + require.Len(t, hexPart, 16) +} + +func TestConnShard_SameKeyAlwaysSameShard(t *testing.T) { + store := mustDefaultOpenAIWSStateStore(t, NewOpenAIWSStateStore(nil)) + shard1 := store.connShard("resp_stable_key") + shard2 := store.connShard("resp_stable_key") + require.Equal(t, shard1, shard2, "same key must always map to the same shard") +} + +func TestMaybeTouchLease_ConcurrentSafe(t *testing.T) { + c := &openAIWSIngressContext{} + var wg sync.WaitGroup + const goroutines = 16 + + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + c.maybeTouchLease(5 * time.Minute) + } + }() + } + wg.Wait() + + require.NotZero(t, c.lastTouchUnixNano.Load()) + require.False(t, c.expiresAt().IsZero()) +} + +func TestActiveConn_SingleOwnerSequentialAccess(t *testing.T) { + // activeConn uses a non-synchronized cachedConn field by design. + // A lease is only used by a single goroutine (the forwarding loop). + // This test verifies sequential repeated calls from the same goroutine + // always return the same cached conn without error. + upstream := &openAIWSNoopConn{} + ctx := &openAIWSIngressContext{ + ownerID: "owner_seq", + upstream: upstream, + } + lease := &openAIWSIngressContextLease{ + context: ctx, + ownerID: "owner_seq", + } + + for i := 0; i < 1000; i++ { + conn, err := lease.activeConn() + require.NoError(t, err) + require.Equal(t, upstream, conn) + } +} + +func TestOpenAIWSIngressContextSessionKey_ConsistentWithTurnStateKey(t *testing.T) { + // Both functions use the same pattern: strconv.FormatInt(groupID, 10) + ":" + hash + groupID := int64(42) + sessionHash := "test_hash" + + sessionKey := openAIWSIngressContextSessionKey(groupID, sessionHash) + turnStateKey := openAIWSSessionTurnStateKey(groupID, sessionHash) + + require.Equal(t, sessionKey, turnStateKey, + "openAIWSIngressContextSessionKey and openAIWSSessionTurnStateKey should produce identical keys for the same inputs") +} + +func TestStateStore_ShardedBindOverwrite(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + + store.BindResponseConn("resp_overwrite", "conn_old", time.Minute) + store.BindResponseConn("resp_overwrite", "conn_new", time.Minute) + + conn, ok := store.GetResponseConn("resp_overwrite") + require.True(t, ok) + require.Equal(t, "conn_new", conn, "later bind should overwrite earlier bind") +} + +func TestStateStore_ShardedTTLExpiry(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + + store.BindResponseConn("resp_ttl_shard", "conn_ttl", 30*time.Millisecond) + conn, ok := store.GetResponseConn("resp_ttl_shard") + require.True(t, ok) + require.Equal(t, "conn_ttl", conn) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetResponseConn("resp_ttl_shard") + require.False(t, ok, "entry should be expired after TTL") +} diff --git a/backend/internal/service/openai_ws_ingress_context_pool.go b/backend/internal/service/openai_ws_ingress_context_pool.go new file mode 100644 index 000000000..daa0ebb02 --- /dev/null +++ b/backend/internal/service/openai_ws_ingress_context_pool.go @@ -0,0 +1,1594 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +var ( + errOpenAIWSIngressContextBusy = errors.New("openai ws ingress context is busy") +) + +const ( + openAIWSIngressScheduleLayerExact = "l0_exact" + openAIWSIngressScheduleLayerNew = "l1_new_context" + openAIWSIngressScheduleLayerMigration = "l2_migration" + openAIWSIngressAcquireMaxWaitRetries = 4096 + openAIWSIngressAcquireMaxQueueWait = 30 * time.Minute + + // openAIWSUpstreamConnMaxAge 是上游 WebSocket 连接的默认最大存活时间。 + // OpenAI 在 60 分钟后强制关闭连接,此处默认 55 分钟主动轮换以避免中途断连。 + openAIWSUpstreamConnMaxAge = 55 * time.Minute + + // openAIWSIngressDelayedPingAfterYield 是 yield 后延迟 Ping 探测的等待时间。 + // 在会话暂时空闲后提前发现死连接,避免下次 Acquire 时才发现。 + openAIWSIngressDelayedPingAfterYield = 5 * time.Second + + // openAIWSIngressPingTimeout 是后台 Ping 探测的超时时间。 + openAIWSIngressPingTimeout = 5 * time.Second +) + +const ( + openAIWSIngressStickinessWeak = "weak" + openAIWSIngressStickinessBalanced = "balanced" + openAIWSIngressStickinessStrong = "strong" +) + +type openAIWSIngressContextAcquireRequest struct { + Account *Account + GroupID int64 + SessionHash string + OwnerID string + WSURL string + Headers http.Header + ProxyURL string + Turn int + + HasPreviousResponseID bool + StrictAffinity bool + StoreDisabled bool +} + +type openAIWSIngressContextPool struct { + cfg *config.Config + dialer openAIWSClientDialer + + idleTTL time.Duration + sweepInterval time.Duration + upstreamMaxAge time.Duration + + seq atomic.Uint64 + + // schedulerStats provides load-aware signals (error rate, circuit breaker + // state) for migration scoring. When nil, scoring falls back to the + // existing idle-time + failure-streak heuristic. + schedulerStats *openAIAccountRuntimeStats + + mu sync.Mutex + accounts map[int64]*openAIWSIngressAccountPool + + stopCh chan struct{} + stopOnce sync.Once + workerWg sync.WaitGroup + closeOnce sync.Once +} + +type openAIWSIngressAccountPool struct { + mu sync.Mutex + + refs atomic.Int64 + + // dynamicCap 动态容量:初始 1,按需增长(L1 新建时 +1),空闲超时后缩减。 + // 实际容量为 min(dynamicCap, effectiveContextCapacity)。 + dynamicCap atomic.Int32 + + contexts map[string]*openAIWSIngressContext + bySession map[string]string +} + +type openAIWSIngressContext struct { + id string + groupID int64 + accountID int64 + sessionHash string + sessionKey string + + mu sync.Mutex + dialing bool + dialDone chan struct{} + releaseDone chan struct{} // ownerID 释放时发送单信号,唤醒一个等待者 + ownerID string + lastUsedAtUnix atomic.Int64 + expiresAtUnix atomic.Int64 + lastTouchUnixNano atomic.Int64 // throttle: skip touchLease if < 1s since last + broken bool + failureStreak int + lastFailureAt time.Time + migrationCount int + lastMigrationAt time.Time + upstream openAIWSClientConn + upstreamConnID string + upstreamConnCreatedAt atomic.Int64 // UnixNano; 0 表示未设置 + handshakeHeaders http.Header + prewarmed atomic.Bool + pendingPingTimer *time.Timer // 延迟 Ping 去重:同一 context 仅保留一个 pending ping +} + +type openAIWSIngressContextLease struct { + pool *openAIWSIngressContextPool + context *openAIWSIngressContext + ownerID string + queueWait time.Duration + connPick time.Duration + reused bool + scheduleLayer string + stickiness string + migrationUsed bool + released atomic.Bool + cachedConnMu sync.RWMutex + cachedConn openAIWSClientConn // fast path: avoid mutex on every activeConn() call +} + +func openAIWSTimeToUnixNano(ts time.Time) int64 { + if ts.IsZero() { + return 0 + } + return ts.UnixNano() +} + +func openAIWSUnixNanoToTime(ns int64) time.Time { + if ns <= 0 { + return time.Time{} + } + return time.Unix(0, ns) +} + +func (c *openAIWSIngressContext) setLastUsedAt(ts time.Time) { + if c == nil { + return + } + c.lastUsedAtUnix.Store(openAIWSTimeToUnixNano(ts)) +} + +func (c *openAIWSIngressContext) lastUsedAt() time.Time { + if c == nil { + return time.Time{} + } + return openAIWSUnixNanoToTime(c.lastUsedAtUnix.Load()) +} + +func (c *openAIWSIngressContext) setExpiresAt(ts time.Time) { + if c == nil { + return + } + c.expiresAtUnix.Store(openAIWSTimeToUnixNano(ts)) +} + +func (c *openAIWSIngressContext) expiresAt() time.Time { + if c == nil { + return time.Time{} + } + return openAIWSUnixNanoToTime(c.expiresAtUnix.Load()) +} + +// upstreamConnAge 返回上游连接已存活的时长。 +// 若 createdAt 未设置(零值)或 now 早于 createdAt(时钟回拨),返回 0。 +func (c *openAIWSIngressContext) upstreamConnAge(now time.Time) time.Duration { + if c == nil { + return 0 + } + ns := c.upstreamConnCreatedAt.Load() + if ns <= 0 { + return 0 + } + age := now.Sub(time.Unix(0, ns)) + if age < 0 { + return 0 + } + return age +} + +func (c *openAIWSIngressContext) touchLease(now time.Time, ttl time.Duration) { + if c == nil { + return + } + nowUnix := openAIWSTimeToUnixNano(now) + c.lastUsedAtUnix.Store(nowUnix) + c.expiresAtUnix.Store(openAIWSTimeToUnixNano(now.Add(ttl))) + c.lastTouchUnixNano.Store(nowUnix) +} + +// maybeTouchLease is a throttled version of touchLease. +// It skips the update if less than 1 second has passed since the last touch, +// avoiding redundant time.Now() + atomic stores on every hot-path message. +func (c *openAIWSIngressContext) maybeTouchLease(ttl time.Duration) { + if c == nil { + return + } + now := time.Now() + nowNano := now.UnixNano() + lastNano := c.lastTouchUnixNano.Load() + if lastNano > 0 && nowNano-lastNano < int64(time.Second) { + return + } + c.touchLease(now, ttl) +} + +func newOpenAIWSIngressContextPool(cfg *config.Config) *openAIWSIngressContextPool { + pool := &openAIWSIngressContextPool{ + cfg: cfg, + dialer: newDefaultOpenAIWSClientDialer(), + idleTTL: 10 * time.Minute, + sweepInterval: 30 * time.Second, + upstreamMaxAge: openAIWSUpstreamConnMaxAge, + accounts: make(map[int64]*openAIWSIngressAccountPool), + stopCh: make(chan struct{}), + } + if cfg != nil && cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + pool.idleTTL = time.Duration(cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + if cfg != nil && cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds >= 0 { + // 配置语义:0 表示禁用超龄轮换。 + pool.upstreamMaxAge = time.Duration(cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds) * time.Second + } + pool.startWorker() + return pool +} + +func (p *openAIWSIngressContextPool) setClientDialerForTest(dialer openAIWSClientDialer) { + if p == nil || dialer == nil { + return + } + p.dialer = dialer +} + +func (p *openAIWSIngressContextPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { + if p == nil { + return OpenAIWSTransportMetricsSnapshot{} + } + if dialer, ok := p.dialer.(openAIWSTransportMetricsDialer); ok { + return dialer.SnapshotTransportMetrics() + } + return OpenAIWSTransportMetricsSnapshot{} +} + +func (p *openAIWSIngressContextPool) maxConnsHardCap() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 { + return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount + } + return 8 +} + +func (p *openAIWSIngressContextPool) effectiveContextCapacity(account *Account) int { + if account == nil || account.Concurrency <= 0 { + return 0 + } + capacity := account.Concurrency + hardCap := p.maxConnsHardCap() + if hardCap > 0 && capacity > hardCap { + return hardCap + } + return capacity +} + +func (p *openAIWSIngressContextPool) Close() { + if p == nil { + return + } + p.closeOnce.Do(func() { + p.stopOnce.Do(func() { + close(p.stopCh) + }) + p.workerWg.Wait() + + var toClose []openAIWSClientConn + p.mu.Lock() + accountPools := make([]*openAIWSIngressAccountPool, 0, len(p.accounts)) + for _, ap := range p.accounts { + if ap != nil { + accountPools = append(accountPools, ap) + } + } + p.accounts = make(map[int64]*openAIWSIngressAccountPool) + p.mu.Unlock() + + for _, ap := range accountPools { + ap.mu.Lock() + for _, ctx := range ap.contexts { + if ctx == nil { + continue + } + ctx.mu.Lock() + if ctx.upstream != nil { + toClose = append(toClose, ctx.upstream) + } + ctx.upstream = nil + ctx.upstreamConnCreatedAt.Store(0) + ctx.broken = true + ctx.ownerID = "" + ctx.handshakeHeaders = nil + ctx.mu.Unlock() + } + ap.contexts = make(map[string]*openAIWSIngressContext) + ap.bySession = make(map[string]string) + ap.mu.Unlock() + } + + for _, conn := range toClose { + if conn != nil { + _ = conn.Close() + } + } + }) +} + +func (p *openAIWSIngressContextPool) startWorker() { + if p == nil { + return + } + interval := p.sweepInterval + if interval <= 0 { + interval = 30 * time.Second + } + p.workerWg.Add(1) + go func() { + defer p.workerWg.Done() + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-p.stopCh: + return + case <-ticker.C: + p.sweepExpiredIdleContexts() + } + } + }() +} + +func (p *openAIWSIngressContextPool) Acquire( + ctx context.Context, + req openAIWSIngressContextAcquireRequest, +) (*openAIWSIngressContextLease, error) { + if p == nil { + return nil, errors.New("openai ws ingress context pool is nil") + } + if req.Account == nil || req.Account.ID <= 0 { + return nil, errors.New("invalid account in ingress context acquire request") + } + ownerID := strings.TrimSpace(req.OwnerID) + if ownerID == "" { + return nil, errors.New("owner id is empty") + } + if strings.TrimSpace(req.WSURL) == "" { + return nil, errors.New("ws url is empty") + } + capacity := p.effectiveContextCapacity(req.Account) + if capacity <= 0 { + return nil, errOpenAIWSConnQueueFull + } + + sessionHash := strings.TrimSpace(req.SessionHash) + if sessionHash == "" { + // 无会话信号时退化为连接级上下文,避免跨连接串会话。 + sessionHash = "conn:" + ownerID + } + sessionKey := openAIWSIngressContextSessionKey(req.GroupID, sessionHash) + accountID := req.Account.ID + + start := time.Now() + queueWait := time.Duration(0) + waitRetries := 0 + + p.mu.Lock() + ap := p.getOrCreateAccountPoolLocked(accountID) + ap.refs.Add(1) + p.mu.Unlock() + defer ap.refs.Add(-1) + + calcConnPick := func() time.Duration { + connPick := time.Since(start) - queueWait + if connPick < 0 { + return 0 + } + return connPick + } + + for { + now := time.Now() + var ( + selected *openAIWSIngressContext + reusedContext bool + newlyCreated bool + ownerAssigned bool + migrationUsed bool + scheduleLayer string + oldUpstream openAIWSClientConn + deferredClose []openAIWSClientConn + ) + + ap.mu.Lock() + + stickiness := p.resolveStickinessLevelLocked(ap, sessionKey, req, now) + allowMigration, minMigrationScore := openAIWSIngressMigrationPolicyByStickiness(stickiness) + + if existingID, ok := ap.bySession[sessionKey]; ok { + if existing := ap.contexts[existingID]; existing != nil { + existing.mu.Lock() + switch existing.ownerID { + case "": + if existing.releaseDone != nil { + select { + case <-existing.releaseDone: + default: + } + } + existing.ownerID = ownerID + ownerAssigned = true + existing.touchLease(now, p.idleTTL) + selected = existing + reusedContext = true + scheduleLayer = openAIWSIngressScheduleLayerExact + case ownerID: + existing.touchLease(now, p.idleTTL) + selected = existing + reusedContext = true + scheduleLayer = openAIWSIngressScheduleLayerExact + default: + // 当前 context 被其他 owner 占用,等待其释放后重试(循环重试替代递归)。 + blockedByOwner := existing.ownerID + if existing.releaseDone == nil { + existing.releaseDone = make(chan struct{}, 1) + } + releaseDone := existing.releaseDone + existing.mu.Unlock() + ap.mu.Unlock() + closeOpenAIWSClientConns(deferredClose) + + logOpenAIWSModeInfo( + "ctx_pool_owner_wait_begin account_id=%d ctx_id=%s owner_id=%s blocked_by=%s retry=%d", + accountID, existing.id, ownerID, + truncateOpenAIWSLogValue(blockedByOwner, openAIWSIDValueMaxLen), + waitRetries, + ) + waitStart := time.Now() + select { + case <-releaseDone: + queueWait += time.Since(waitStart) + waitRetries++ + if waitRetries >= openAIWSIngressAcquireMaxWaitRetries || queueWait >= openAIWSIngressAcquireMaxQueueWait { + logOpenAIWSModeInfo( + "ctx_pool_owner_wait_exhausted account_id=%d ctx_id=%s owner_id=%s wait_retries=%d queue_wait_ms=%d", + accountID, existing.id, ownerID, waitRetries, queueWait.Milliseconds(), + ) + return nil, errOpenAIWSIngressContextBusy + } + continue + case <-ctx.Done(): + queueWait += time.Since(waitStart) + logOpenAIWSModeInfo( + "ctx_pool_owner_wait_canceled account_id=%d ctx_id=%s owner_id=%s wait_retries=%d queue_wait_ms=%d", + accountID, existing.id, ownerID, waitRetries, queueWait.Milliseconds(), + ) + return nil, errOpenAIWSIngressContextBusy + } + } + existing.mu.Unlock() + } + } + + if selected == nil { + dynCap := p.effectiveDynamicCapacity(ap, capacity) + if len(ap.contexts) >= dynCap { + deferredClose = append(deferredClose, p.evictExpiredIdleLocked(ap, now)...) + } + if len(ap.contexts) >= dynCap { + if dynCap < capacity { + // 动态扩容:尚未达到硬上限,增长 1 后创建新 context + ap.dynamicCap.Add(1) + } else if !allowMigration { + ap.mu.Unlock() + closeOpenAIWSClientConns(deferredClose) + logOpenAIWSModeInfo( + "ctx_pool_full_no_migration account_id=%d capacity=%d stickiness=%s", + accountID, capacity, stickiness, + ) + return nil, errOpenAIWSConnQueueFull + } else { + recycle := p.pickMigrationCandidateLocked(ap, minMigrationScore, now) + if recycle == nil { + ap.mu.Unlock() + closeOpenAIWSClientConns(deferredClose) + logOpenAIWSModeInfo( + "ctx_pool_no_migration_candidate account_id=%d capacity=%d min_score=%.1f", + accountID, capacity, minMigrationScore, + ) + return nil, errOpenAIWSConnQueueFull + } + recycle.mu.Lock() + oldSessionKey := recycle.sessionKey + oldUpstream = recycle.upstream + recycle.sessionHash = sessionHash + recycle.sessionKey = sessionKey + recycle.groupID = req.GroupID + if recycle.releaseDone != nil { + select { + case <-recycle.releaseDone: + default: + } + } + recycle.ownerID = ownerID + recycle.touchLease(now, p.idleTTL) + // 会话被回收复用时关闭旧上游,避免跨会话污染。 + recycle.upstream = nil + recycle.upstreamConnID = "" + recycle.upstreamConnCreatedAt.Store(0) + recycle.handshakeHeaders = nil + recycle.broken = false + recycle.migrationCount++ + recycle.lastMigrationAt = now + recycle.mu.Unlock() + + if oldSessionKey != "" { + if mapped, ok := ap.bySession[oldSessionKey]; ok && mapped == recycle.id { + delete(ap.bySession, oldSessionKey) + } + } + ap.bySession[sessionKey] = recycle.id + selected = recycle + reusedContext = true + migrationUsed = true + scheduleLayer = openAIWSIngressScheduleLayerMigration + ap.mu.Unlock() + closeOpenAIWSClientConns(deferredClose) + if oldUpstream != nil { + _ = oldUpstream.Close() + } + reusedConn, ensureErr := p.ensureContextUpstream(ctx, selected, req) + if ensureErr != nil { + p.releaseContext(selected, ownerID) + return nil, ensureErr + } + logOpenAIWSModeInfo( + "ctx_pool_migration account_id=%d ctx_id=%s old_session=%s new_session=%s group_id=%d session_hash=%s migration_count=%d", + accountID, selected.id, truncateOpenAIWSLogValue(oldSessionKey, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(sessionKey, openAIWSIDValueMaxLen), selected.groupID, + truncateOpenAIWSLogValue(selected.sessionHash, openAIWSIDValueMaxLen), selected.migrationCount, + ) + return &openAIWSIngressContextLease{ + pool: p, + context: selected, + ownerID: ownerID, + queueWait: queueWait, + connPick: calcConnPick(), + reused: reusedContext && reusedConn, + scheduleLayer: scheduleLayer, + stickiness: stickiness, + migrationUsed: migrationUsed, + }, nil + } + } + + ctxID := fmt.Sprintf("ctx_%d_%d", accountID, p.seq.Add(1)) + created := &openAIWSIngressContext{ + id: ctxID, + groupID: req.GroupID, + accountID: accountID, + sessionHash: sessionHash, + sessionKey: sessionKey, + ownerID: ownerID, + } + created.touchLease(now, p.idleTTL) + ap.contexts[ctxID] = created + ap.bySession[sessionKey] = ctxID + selected = created + newlyCreated = true + ownerAssigned = true + scheduleLayer = openAIWSIngressScheduleLayerNew + } + ap.mu.Unlock() + closeOpenAIWSClientConns(deferredClose) + + reusedConn, ensureErr := p.ensureContextUpstream(ctx, selected, req) + if ensureErr != nil { + if newlyCreated { + ap.mu.Lock() + delete(ap.contexts, selected.id) + if mapped, ok := ap.bySession[sessionKey]; ok && mapped == selected.id { + delete(ap.bySession, sessionKey) + } + ap.mu.Unlock() + } else if ownerAssigned { + p.releaseContext(selected, ownerID) + } + return nil, ensureErr + } + + return &openAIWSIngressContextLease{ + pool: p, + context: selected, + ownerID: ownerID, + queueWait: queueWait, + connPick: calcConnPick(), + reused: reusedContext && reusedConn, + scheduleLayer: scheduleLayer, + stickiness: stickiness, + migrationUsed: migrationUsed, + }, nil + } +} + +func (p *openAIWSIngressContextPool) resolveStickinessLevelLocked( + ap *openAIWSIngressAccountPool, + sessionKey string, + req openAIWSIngressContextAcquireRequest, + now time.Time, +) string { + if req.StrictAffinity { + return openAIWSIngressStickinessStrong + } + + level := openAIWSIngressStickinessWeak + switch { + case req.HasPreviousResponseID: + level = openAIWSIngressStickinessStrong + case req.StoreDisabled || req.Turn > 1: + level = openAIWSIngressStickinessBalanced + } + + if ap == nil { + return level + } + ctxID, ok := ap.bySession[sessionKey] + if !ok { + return level + } + existing := ap.contexts[ctxID] + if existing == nil { + return level + } + + existing.mu.Lock() + broken := existing.broken + failureStreak := existing.failureStreak + lastFailureAt := existing.lastFailureAt + lastUsedAt := existing.lastUsedAt() + existing.mu.Unlock() + + recentFailure := failureStreak > 0 && !lastFailureAt.IsZero() && now.Sub(lastFailureAt) <= 2*time.Minute + if broken || recentFailure { + return openAIWSIngressStickinessDowngrade(level) + } + if failureStreak == 0 && !lastUsedAt.IsZero() && now.Sub(lastUsedAt) <= 20*time.Second { + return openAIWSIngressStickinessUpgrade(level) + } + return level +} + +func openAIWSIngressMigrationPolicyByStickiness(stickiness string) (bool, float64) { + switch stickiness { + case openAIWSIngressStickinessStrong: + return false, 80 // was 85; lowered to allow migration away from degraded accounts + case openAIWSIngressStickinessBalanced: + return true, 65 // was 68; lowered to allow more aggressive migration to healthier accounts + default: + return true, 40 // was 45; lowered for weak stickiness + } +} + +func openAIWSIngressStickinessDowngrade(level string) string { + switch level { + case openAIWSIngressStickinessStrong: + return openAIWSIngressStickinessBalanced + case openAIWSIngressStickinessBalanced: + return openAIWSIngressStickinessWeak + default: + return openAIWSIngressStickinessWeak + } +} + +func openAIWSIngressStickinessUpgrade(level string) string { + switch level { + case openAIWSIngressStickinessWeak: + return openAIWSIngressStickinessBalanced + case openAIWSIngressStickinessBalanced: + return openAIWSIngressStickinessStrong + default: + return openAIWSIngressStickinessStrong + } +} + +func (p *openAIWSIngressContextPool) pickMigrationCandidateLocked( + ap *openAIWSIngressAccountPool, + minScore float64, + now time.Time, +) *openAIWSIngressContext { + if ap == nil { + return nil + } + var ( + selected *openAIWSIngressContext + selectedScore float64 + selectedAt time.Time + hasSelected bool + ) + + for _, ctx := range ap.contexts { + if ctx == nil { + continue + } + score, lastUsed, ok := scoreOpenAIWSIngressMigrationCandidate(ctx, now, p.schedulerStats) + if !ok || score < minScore { + continue + } + if !hasSelected || score > selectedScore || (score == selectedScore && lastUsed.Before(selectedAt)) { + selected = ctx + selectedScore = score + selectedAt = lastUsed + hasSelected = true + } + } + return selected +} + +func scoreOpenAIWSIngressMigrationCandidate(c *openAIWSIngressContext, now time.Time, stats *openAIAccountRuntimeStats) (float64, time.Time, bool) { + if c == nil { + return 0, time.Time{}, false + } + c.mu.Lock() + defer c.mu.Unlock() + if strings.TrimSpace(c.ownerID) != "" { + return 0, time.Time{}, false + } + + score := 100.0 + if c.broken { + score -= 30 + } + if c.failureStreak > 0 { + score -= float64(minInt(c.failureStreak*12, 40)) + } + if !c.lastFailureAt.IsZero() && now.Sub(c.lastFailureAt) <= 2*time.Minute { + score -= 18 + } + if !c.lastMigrationAt.IsZero() && now.Sub(c.lastMigrationAt) <= time.Minute { + score -= 10 + } + if c.migrationCount > 0 { + score -= float64(minInt(c.migrationCount*4, 20)) + } + + lastUsedAt := c.lastUsedAt() + idleDuration := now.Sub(lastUsedAt) + switch { + case idleDuration <= 15*time.Second: + score -= 15 + case idleDuration >= 3*time.Minute: + score += 16 + default: + score += idleDuration.Seconds() / 12.0 + } + + // Load-aware factors: penalize contexts bound to accounts that the + // scheduler has flagged as degraded or circuit-open. When stats is nil + // (e.g. during tests or before scheduler init), these adjustments are + // silently skipped so existing behaviour is preserved. + if stats != nil && c.accountID > 0 { + errorRate, _, _ := stats.snapshot(c.accountID) + // errorRate is in [0,1]; a fully-erroring account subtracts up to 30 + // points, making it significantly harder for a migration to land on + // an unhealthy account. + score -= errorRate * 30 + + // Circuit-open accounts receive a harsh penalty (-50) that in + // practice drops the score below any reasonable minimum threshold, + // effectively blocking migration to that account. + if stats.isCircuitOpen(c.accountID) { + score -= 50 + } + } + + return score, lastUsedAt, true +} + +func minInt(a, b int) int { + if a <= b { + return a + } + return b +} + +func closeOpenAIWSClientConns(conns []openAIWSClientConn) { + for _, conn := range conns { + if conn != nil { + _ = conn.Close() + } + } +} + +func (p *openAIWSIngressContextPool) ensureContextUpstream( + ctx context.Context, + c *openAIWSIngressContext, + req openAIWSIngressContextAcquireRequest, +) (bool, error) { + if p == nil || c == nil { + return false, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + for { + c.mu.Lock() + if c.upstream != nil && !c.broken { + now := time.Now() + connAge := c.upstreamConnAge(now) + if p.upstreamMaxAge > 0 && connAge > 0 && connAge >= p.upstreamMaxAge { + // 主动轮换:关闭旧连接,不设 broken、不增 failureStreak + oldUpstream, oldConnID := c.upstream, c.upstreamConnID + c.upstream = nil + c.upstreamConnID = "" + c.upstreamConnCreatedAt.Store(0) + c.handshakeHeaders = nil + c.prewarmed.Store(false) + c.mu.Unlock() + _ = oldUpstream.Close() + logOpenAIWSModeInfo( + "ctx_pool_upstream_max_age_rotate account_id=%d ctx_id=%s conn_id=%s conn_age_min=%.1f max_age_min=%.1f", + c.accountID, c.id, oldConnID, + connAge.Minutes(), p.upstreamMaxAge.Minutes(), + ) + continue // 回到 for 循环走 dialing 路径 + } + c.touchLease(now, p.idleTTL) + c.mu.Unlock() + return true, nil + } + if c.dialing { + dialDone := c.dialDone + c.mu.Unlock() + if dialDone == nil { + if err := ctx.Err(); err != nil { + return false, err + } + continue + } + select { + case <-dialDone: + continue + case <-ctx.Done(): + return false, ctx.Err() + } + } + oldUpstream := c.upstream + c.upstream = nil + c.upstreamConnCreatedAt.Store(0) + c.handshakeHeaders = nil + c.upstreamConnID = "" + c.prewarmed.Store(false) + c.broken = false + c.dialing = true + dialDone := make(chan struct{}) + c.dialDone = dialDone + c.mu.Unlock() + + if oldUpstream != nil { + _ = oldUpstream.Close() + } + + dialer := p.dialer + if dialer == nil { + c.mu.Lock() + c.broken = true + c.failureStreak++ + c.lastFailureAt = time.Now() + c.dialing = false + if c.dialDone == dialDone { + c.dialDone = nil + } + close(dialDone) + c.mu.Unlock() + return false, errors.New("openai ws ingress context dialer is nil") + } + conn, statusCode, handshakeHeaders, err := dialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL) + if err != nil { + wrappedErr := err + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) { + wrappedErr = &openAIWSDialError{ + StatusCode: statusCode, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: err, + } + } + c.mu.Lock() + c.broken = true + c.failureStreak++ + c.lastFailureAt = time.Now() + c.dialing = false + if c.dialDone == dialDone { + c.dialDone = nil + } + close(dialDone) + failureStreak := c.failureStreak + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_dial_fail account_id=%d ctx_id=%s status_code=%d failure_streak=%d cause=%s", + c.accountID, c.id, statusCode, failureStreak, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return false, wrappedErr + } + + c.mu.Lock() + now := time.Now() + c.upstream = conn + c.upstreamConnID = fmt.Sprintf("ctxws_%d_%d", c.accountID, p.seq.Add(1)) + c.upstreamConnCreatedAt.Store(now.UnixNano()) + c.handshakeHeaders = cloneHeader(handshakeHeaders) + c.prewarmed.Store(false) + c.touchLease(now, p.idleTTL) + c.broken = false + c.failureStreak = 0 + c.lastFailureAt = time.Time{} + c.dialing = false + if c.dialDone == dialDone { + c.dialDone = nil + } + close(dialDone) + connID := c.upstreamConnID + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_dial_ok account_id=%d ctx_id=%s conn_id=%s", + c.accountID, c.id, connID, + ) + return false, nil + } +} + +func (p *openAIWSIngressContextPool) yieldContext(c *openAIWSIngressContext, ownerID string) { + p.releaseContextWithPolicy(c, ownerID, false) + // yield 后延迟 Ping,提前发现死连接 + p.scheduleDelayedPing(c, openAIWSIngressDelayedPingAfterYield) +} + +func (p *openAIWSIngressContextPool) releaseContext(c *openAIWSIngressContext, ownerID string) { + p.releaseContextWithPolicy(c, ownerID, true) +} + +func (p *openAIWSIngressContextPool) releaseContextWithPolicy( + c *openAIWSIngressContext, + ownerID string, + closeUpstream bool, +) { + if p == nil || c == nil { + return + } + var upstream openAIWSClientConn + c.mu.Lock() + if c.ownerID == ownerID { + if closeUpstream { + // 会话结束或链路损坏时销毁上游连接,避免污染后续请求。 + upstream = c.upstream + c.upstream = nil + c.upstreamConnCreatedAt.Store(0) + c.handshakeHeaders = nil + c.upstreamConnID = "" + c.prewarmed.Store(false) + } + c.ownerID = "" + // 通知一个等待中的 Acquire 请求,避免 close 广播导致惊群。 + if c.releaseDone != nil { + select { + case c.releaseDone <- struct{}{}: + default: + } + } + now := time.Now() + c.setLastUsedAt(now) + c.setExpiresAt(now.Add(p.idleTTL)) + c.broken = false + } + c.mu.Unlock() + if upstream != nil { + _ = upstream.Close() + } +} + +func (p *openAIWSIngressContextPool) markContextBroken(c *openAIWSIngressContext) { + if c == nil { + return + } + c.mu.Lock() + upstream := c.upstream + c.upstream = nil + c.upstreamConnCreatedAt.Store(0) + c.handshakeHeaders = nil + c.upstreamConnID = "" + c.prewarmed.Store(false) + c.broken = true + c.failureStreak++ + c.lastFailureAt = time.Now() + // 注意:此处不发送 releaseDone 信号。ownerID 仍被占用,等待者被唤醒后 + // 会发现 owner 未释放而重新阻塞,造成信号浪费。实际释放由 Release() 完成。 + failureStreak := c.failureStreak + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_mark_broken account_id=%d ctx_id=%s failure_streak=%d", + c.accountID, c.id, failureStreak, + ) + if upstream != nil { + _ = upstream.Close() + } +} + +// markContextBrokenIfConnMatch 仅在连接代次(connID)匹配时标记 broken。 +// 后台 Ping 在解锁期间执行,期间连接可能已被重建为新连接; +// 若 connID 已变则说明旧连接已被替换,放弃标记以避免误杀新连接。 +func (p *openAIWSIngressContextPool) markContextBrokenIfConnMatch(c *openAIWSIngressContext, expectedConnID string) { + if c == nil { + return + } + c.mu.Lock() + actualConnID := c.upstreamConnID + if actualConnID != expectedConnID { + // 连接已被重建(connID 变了),放弃标记 + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_bg_ping_skip_stale account_id=%d ctx_id=%s expected_conn_id=%s actual_conn_id=%s", + c.accountID, c.id, expectedConnID, actualConnID, + ) + return + } + ownerID := c.ownerID + dialing := c.dialing + if ownerID != "" || dialing { + // Ping 期间 context 可能被重新占用或进入建连流程,不应由后台探测路径误杀活跃连接。 + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_bg_ping_skip_busy account_id=%d ctx_id=%s conn_id=%s owner_id=%s dialing=%v", + c.accountID, + c.id, + actualConnID, + truncateOpenAIWSLogValue(ownerID, openAIWSIDValueMaxLen), + dialing, + ) + return + } + upstream := c.upstream + c.upstream = nil + c.upstreamConnCreatedAt.Store(0) + c.handshakeHeaders = nil + c.upstreamConnID = "" + c.prewarmed.Store(false) + c.broken = true + c.failureStreak++ + c.lastFailureAt = time.Now() + failureStreak := c.failureStreak + c.mu.Unlock() + logOpenAIWSModeInfo( + "ctx_pool_mark_broken account_id=%d ctx_id=%s failure_streak=%d", + c.accountID, c.id, failureStreak, + ) + if upstream != nil { + _ = upstream.Close() + } +} + +func (p *openAIWSIngressContextPool) getOrCreateAccountPoolLocked(accountID int64) *openAIWSIngressAccountPool { + if ap, ok := p.accounts[accountID]; ok && ap != nil { + return ap + } + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + ap.dynamicCap.Store(1) + p.accounts[accountID] = ap + return ap +} + +// effectiveDynamicCapacity 返回 min(dynamicCap, hardCap)。 +// dynamicCap 从 1 开始,按需增长,空闲时缩减;hardCap 由账户并发度和全局上限决定。 +func (p *openAIWSIngressContextPool) effectiveDynamicCapacity(ap *openAIWSIngressAccountPool, hardCap int) int { + if ap == nil || hardCap <= 0 { + return hardCap + } + dynCap := int(ap.dynamicCap.Load()) + if dynCap < 1 { + dynCap = 1 + ap.dynamicCap.Store(1) + } + if dynCap > hardCap { + return hardCap + } + return dynCap +} + +func (p *openAIWSIngressContextPool) evictExpiredIdleLocked( + ap *openAIWSIngressAccountPool, + now time.Time, +) []openAIWSClientConn { + if ap == nil { + return nil + } + var toClose []openAIWSClientConn + for id, ctx := range ap.contexts { + if ctx == nil { + delete(ap.contexts, id) + continue + } + ctx.mu.Lock() + expiresAt := ctx.expiresAt() + expired := ctx.ownerID == "" && !expiresAt.IsZero() && now.After(expiresAt) + upstream := ctx.upstream + if expired { + ctx.upstream = nil + ctx.upstreamConnCreatedAt.Store(0) + ctx.handshakeHeaders = nil + ctx.upstreamConnID = "" + } + ctx.mu.Unlock() + if !expired { + continue + } + delete(ap.contexts, id) + if mappedID, ok := ap.bySession[ctx.sessionKey]; ok && mappedID == id { + delete(ap.bySession, ctx.sessionKey) + } + if upstream != nil { + toClose = append(toClose, upstream) + } + } + return toClose +} + +func (p *openAIWSIngressContextPool) pickOldestIdleContextLocked(ap *openAIWSIngressAccountPool) *openAIWSIngressContext { + if ap == nil { + return nil + } + var ( + selected *openAIWSIngressContext + selectedAt time.Time + ) + for _, ctx := range ap.contexts { + if ctx == nil { + continue + } + ctx.mu.Lock() + idle := strings.TrimSpace(ctx.ownerID) == "" + lastUsed := ctx.lastUsedAt() + ctx.mu.Unlock() + if !idle { + continue + } + if selected == nil || lastUsed.Before(selectedAt) { + selected = ctx + selectedAt = lastUsed + } + } + return selected +} + +// closeAgedIdleUpstreamsLocked 关闭空闲且超龄的上游连接。 +// 只清理 upstream,保留 context 槽位(不删 context、不清 bySession)。 +// 不设 broken、不增 failureStreak。 +// 调用方必须持有 ap.mu。 +func (p *openAIWSIngressContextPool) closeAgedIdleUpstreamsLocked( + ap *openAIWSIngressAccountPool, + now time.Time, +) []openAIWSClientConn { + if ap == nil || p.upstreamMaxAge <= 0 { + return nil + } + var toClose []openAIWSClientConn + for _, ctx := range ap.contexts { + if ctx == nil { + continue + } + ctx.mu.Lock() + idle := ctx.ownerID == "" + hasUpstream := ctx.upstream != nil + connAge := ctx.upstreamConnAge(now) + aged := connAge > 0 && connAge >= p.upstreamMaxAge + if idle && hasUpstream && aged { + toClose = append(toClose, ctx.upstream) + ctx.upstream = nil + ctx.upstreamConnCreatedAt.Store(0) + ctx.upstreamConnID = "" + ctx.handshakeHeaders = nil + ctx.prewarmed.Store(false) + } + ctx.mu.Unlock() + } + return toClose +} + +// pingContextUpstream 对空闲 context 的上游连接发送 Ping 探测。 +// 若 Ping 失败则标记 context 为 broken,让后续 Acquire 走重建路径。 +// 调用方不需要持有任何锁。 +// +// 使用 connID 代次守卫:Ping 期间连接可能被重建,仅当 connID 未变时才标记 broken, +// 避免旧连接 Ping 失败误杀新连接。 +func (p *openAIWSIngressContextPool) pingContextUpstream(c *openAIWSIngressContext) { + if p == nil || c == nil { + return + } + c.mu.Lock() + idle := c.ownerID == "" + hasUpstream := c.upstream != nil + broken := c.broken + dialing := c.dialing + upstream := c.upstream + connID := c.upstreamConnID // 快照连接代次 + c.mu.Unlock() + if !idle || !hasUpstream || broken || dialing || upstream == nil { + return + } + + pingCtx, cancel := context.WithTimeout(context.Background(), openAIWSIngressPingTimeout) + defer cancel() + if err := upstream.Ping(pingCtx); err != nil { + p.markContextBrokenIfConnMatch(c, connID) + logOpenAIWSModeInfo( + "ctx_pool_bg_ping_fail account_id=%d ctx_id=%s conn_id=%s cause=%s", + c.accountID, c.id, connID, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + } +} + +// pingIdleUpstreams 对账户池内所有空闲且有上游连接的 context 发起 Ping 探测。 +// 先在锁内收集候选列表,再在锁外逐个 Ping,避免阻塞其他操作。 +func (p *openAIWSIngressContextPool) pingIdleUpstreams(ap *openAIWSIngressAccountPool) { + if p == nil || ap == nil { + return + } + ap.mu.Lock() + targets := make([]*openAIWSIngressContext, 0, len(ap.contexts)) + for _, ctx := range ap.contexts { + if ctx == nil { + continue + } + ctx.mu.Lock() + eligible := ctx.ownerID == "" && ctx.upstream != nil && !ctx.broken && !ctx.dialing + ctx.mu.Unlock() + if eligible { + targets = append(targets, ctx) + } + } + ap.mu.Unlock() + + for _, ctx := range targets { + p.pingContextUpstream(ctx) + } +} + +// scheduleDelayedPing 在 yield 后延迟一段时间对 context 发送 Ping 探测。 +// 通过 pendingPingTimer 去重:同一 context 同时只保留一个延迟 Ping, +// 连续 yield 只 Reset timer 而不创建新 goroutine,避免高并发下 goroutine 堆积。 +func (p *openAIWSIngressContextPool) scheduleDelayedPing(c *openAIWSIngressContext, delay time.Duration) { + if p == nil || c == nil || delay <= 0 { + return + } + c.mu.Lock() + if c.pendingPingTimer != nil { + // 已有 pending ping,只需 Reset timer 延迟窗口 + c.pendingPingTimer.Reset(delay) + c.mu.Unlock() + return + } + timer := time.NewTimer(delay) + c.pendingPingTimer = timer + c.mu.Unlock() + + go func() { + select { + case <-p.stopCh: + timer.Stop() + case <-timer.C: + p.pingContextUpstream(c) + } + c.mu.Lock() + if c.pendingPingTimer == timer { + c.pendingPingTimer = nil + } + c.mu.Unlock() + }() +} + +func (p *openAIWSIngressContextPool) sweepExpiredIdleContexts() { + if p == nil { + return + } + now := time.Now() + + type accountSnapshot struct { + accountID int64 + ap *openAIWSIngressAccountPool + } + + snapshots := make([]accountSnapshot, 0, len(p.accounts)) + p.mu.Lock() + for accountID, ap := range p.accounts { + if ap == nil { + delete(p.accounts, accountID) + continue + } + snapshots = append(snapshots, accountSnapshot{accountID: accountID, ap: ap}) + } + p.mu.Unlock() + + removable := make([]accountSnapshot, 0) + for _, item := range snapshots { + ap := item.ap + ap.mu.Lock() + toClose := p.evictExpiredIdleLocked(ap, now) + agedClose := p.closeAgedIdleUpstreamsLocked(ap, now) + empty := len(ap.contexts) == 0 + // 动态缩容:将 dynamicCap 收缩到 max(1, 当前 context 数量) + shrinkTarget := int32(len(ap.contexts)) + if shrinkTarget < 1 { + shrinkTarget = 1 + } + if ap.dynamicCap.Load() > shrinkTarget { + ap.dynamicCap.Store(shrinkTarget) + } + ap.mu.Unlock() + closeOpenAIWSClientConns(toClose) + closeOpenAIWSClientConns(agedClose) + // 后台 Ping 探测:对剩余空闲连接发送 Ping,及时剔除死连接 + if !empty { + p.pingIdleUpstreams(ap) + } + if empty && ap.refs.Load() == 0 { + removable = append(removable, item) + } + } + + if len(removable) == 0 { + return + } + + p.mu.Lock() + for _, item := range removable { + existing := p.accounts[item.accountID] + if existing != item.ap || existing == nil { + continue + } + if existing.refs.Load() != 0 { + continue + } + delete(p.accounts, item.accountID) + } + p.mu.Unlock() +} + +func openAIWSIngressContextSessionKey(groupID int64, sessionHash string) string { + hash := strings.TrimSpace(sessionHash) + if hash == "" { + return "" + } + return strconv.FormatInt(groupID, 10) + ":" + hash +} + +func (l *openAIWSIngressContextLease) ConnID() string { + if l == nil || l.context == nil { + return "" + } + l.context.mu.Lock() + defer l.context.mu.Unlock() + return strings.TrimSpace(l.context.upstreamConnID) +} + +func (l *openAIWSIngressContextLease) QueueWaitDuration() time.Duration { + if l == nil { + return 0 + } + return l.queueWait +} + +func (l *openAIWSIngressContextLease) ConnPickDuration() time.Duration { + if l == nil { + return 0 + } + return l.connPick +} + +func (l *openAIWSIngressContextLease) Reused() bool { + if l == nil { + return false + } + return l.reused +} + +func (l *openAIWSIngressContextLease) ScheduleLayer() string { + if l == nil { + return "" + } + return strings.TrimSpace(l.scheduleLayer) +} + +func (l *openAIWSIngressContextLease) StickinessLevel() string { + if l == nil { + return "" + } + return strings.TrimSpace(l.stickiness) +} + +func (l *openAIWSIngressContextLease) MigrationUsed() bool { + if l == nil { + return false + } + return l.migrationUsed +} + +func (l *openAIWSIngressContextLease) HandshakeHeader(name string) string { + if l == nil || l.context == nil { + return "" + } + l.context.mu.Lock() + defer l.context.mu.Unlock() + if l.context.handshakeHeaders == nil { + return "" + } + return strings.TrimSpace(l.context.handshakeHeaders.Get(strings.TrimSpace(name))) +} + +func (l *openAIWSIngressContextLease) IsPrewarmed() bool { + if l == nil || l.context == nil { + return false + } + return l.context.prewarmed.Load() +} + +func (l *openAIWSIngressContextLease) MarkPrewarmed() { + if l == nil || l.context == nil { + return + } + l.context.prewarmed.Store(true) +} + +func (l *openAIWSIngressContextLease) activeConn() (openAIWSClientConn, error) { + if l == nil || l.context == nil || l.released.Load() { + return nil, errOpenAIWSConnClosed + } + // Fast path: return cached conn without mutex if lease is still valid. + l.cachedConnMu.RLock() + cc := l.cachedConn + l.cachedConnMu.RUnlock() + if cc != nil { + return cc, nil + } + // Slow path: acquire mutex, validate ownership, cache result. + l.context.mu.Lock() + defer l.context.mu.Unlock() + if l.context.ownerID != l.ownerID { + return nil, errOpenAIWSConnClosed + } + if l.context.upstream == nil { + return nil, errOpenAIWSConnClosed + } + l.cachedConnMu.Lock() + l.cachedConn = l.context.upstream + l.cachedConnMu.Unlock() + return l.context.upstream, nil +} + +func (l *openAIWSIngressContextLease) invalidateCachedConnOnIOError(err error) { + if l == nil || err == nil { + return + } + l.cachedConnMu.Lock() + l.cachedConn = nil + l.cachedConnMu.Unlock() + if l.pool != nil && l.context != nil && isOpenAIWSClientDisconnectError(err) { + l.pool.markContextBroken(l.context) + } +} + +func (l *openAIWSIngressContextLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + writeCtx := ctx + if writeCtx == nil { + writeCtx = context.Background() + } + if timeout > 0 { + var cancel context.CancelFunc + writeCtx, cancel = context.WithTimeout(writeCtx, timeout) + defer cancel() + } + if err := conn.WriteJSON(writeCtx, value); err != nil { + l.invalidateCachedConnOnIOError(err) + return err + } + l.context.maybeTouchLease(l.pool.idleTTL) + return nil +} + +func (l *openAIWSIngressContextLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) { + conn, err := l.activeConn() + if err != nil { + return nil, err + } + readCtx := ctx + if readCtx == nil { + readCtx = context.Background() + } + if timeout > 0 { + var cancel context.CancelFunc + readCtx, cancel = context.WithTimeout(readCtx, timeout) + defer cancel() + } + payload, err := conn.ReadMessage(readCtx) + if err != nil { + l.invalidateCachedConnOnIOError(err) + return nil, err + } + l.context.maybeTouchLease(l.pool.idleTTL) + return payload, nil +} + +func (l *openAIWSIngressContextLease) PingWithTimeout(timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + pingTimeout := timeout + if pingTimeout <= 0 { + pingTimeout = openAIWSConnHealthCheckTO + } + pingCtx, cancel := context.WithTimeout(context.Background(), pingTimeout) + defer cancel() + if err := conn.Ping(pingCtx); err != nil { + l.invalidateCachedConnOnIOError(err) + return err + } + l.context.maybeTouchLease(l.pool.idleTTL) + return nil +} + +func (l *openAIWSIngressContextLease) MarkBroken() { + if l == nil || l.pool == nil || l.context == nil || l.released.Load() { + return + } + l.cachedConnMu.Lock() + l.cachedConn = nil + l.cachedConnMu.Unlock() + l.pool.markContextBroken(l.context) +} + +func (l *openAIWSIngressContextLease) Release() { + if l == nil || l.context == nil || l.pool == nil { + return + } + if !l.released.CompareAndSwap(false, true) { + return + } + l.cachedConnMu.Lock() + l.cachedConn = nil + l.cachedConnMu.Unlock() + l.pool.releaseContext(l.context, l.ownerID) +} + +func (l *openAIWSIngressContextLease) Yield() { + if l == nil || l.context == nil || l.pool == nil { + return + } + if !l.released.CompareAndSwap(false, true) { + return + } + l.cachedConnMu.Lock() + l.cachedConn = nil + l.cachedConnMu.Unlock() + l.pool.yieldContext(l.context, l.ownerID) +} diff --git a/backend/internal/service/openai_ws_ingress_context_pool_test.go b/backend/internal/service/openai_ws_ingress_context_pool_test.go new file mode 100644 index 000000000..f5a6802c5 --- /dev/null +++ b/backend/internal/service/openai_ws_ingress_context_pool_test.go @@ -0,0 +1,2496 @@ +package service + +import ( + "context" + "errors" + "net" + "net/http" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSIngressContextPool_Acquire_HardCapacityEqualsConcurrency(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 801, Concurrency: 1} + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 2, + SessionHash: "session_a", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + + _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 2, + SessionHash: "session_b", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "并发=1 时第二个并发 owner 不应获取到 context") + + lease1.Release() + + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 2, + SessionHash: "session_b", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, openAIWSIngressScheduleLayerMigration, lease2.ScheduleLayer()) + require.Equal(t, openAIWSIngressStickinessWeak, lease2.StickinessLevel()) + require.True(t, lease2.MigrationUsed()) + lease2.Release() + + require.Equal(t, 2, dialer.DialCount(), "会话回收复用 context 后应重建上游连接,避免跨会话污染") +} + +func TestOpenAIWSIngressContextPool_Acquire_RespectsGlobalHardCap(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 802, Concurrency: 10} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 3, + SessionHash: "session_a", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + HasPreviousResponseID: true, + }) + require.NoError(t, err) + require.NotNil(t, lease1) + + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 3, + SessionHash: "session_b", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + HasPreviousResponseID: true, + }) + require.NoError(t, err) + require.NotNil(t, lease2) + + _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 3, + SessionHash: "session_c", + OwnerID: "owner_c", + WSURL: "ws://test-upstream", + HasPreviousResponseID: true, + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "账号并发高于全局硬上限时,context pool 仍应被硬上限约束") + + lease1.Release() + lease2.Release() + require.Equal(t, 2, dialer.DialCount()) +} + +func TestOpenAIWSIngressContextPool_Acquire_DoesNotCrossAccount(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + accountA := &Account{ID: 901, Concurrency: 1} + accountB := &Account{ID: 902, Concurrency: 1} + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + leaseA, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: accountA, + GroupID: 5, + SessionHash: "same_session_hash", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, leaseA) + leaseA.Release() + + leaseB, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: accountB, + GroupID: 5, + SessionHash: "same_session_hash", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, leaseB) + leaseB.Release() + + require.Equal(t, 2, dialer.DialCount(), "相同 session_hash 在不同账号下必须使用不同 context,不允许跨账号复用") +} + +func TestOpenAIWSIngressContextPool_Acquire_StrongStickinessDisablesMigration(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 1001, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 9, + SessionHash: "session_keep_strong_a", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 9, + SessionHash: "session_keep_strong_b", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + HasPreviousResponseID: true, + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "strong 粘连不应迁移其它 session 的 context") +} + +func TestOpenAIWSIngressContextPool_Acquire_AdaptiveStickinessDowngradesAfterFailure(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 1002, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 11, + SessionHash: "session_adaptive", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.MarkBroken() + lease1.Release() + + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 11, + SessionHash: "session_adaptive", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + HasPreviousResponseID: true, + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, openAIWSIngressScheduleLayerExact, lease2.ScheduleLayer()) + require.Equal(t, openAIWSIngressStickinessBalanced, lease2.StickinessLevel(), "失败后应从 strong 自适应降级到 balanced") + lease2.Release() + require.Equal(t, 2, dialer.DialCount(), "故障后重连同一 context 应重新建立上游连接") +} + +func TestOpenAIWSIngressContextPool_Acquire_EnsureFailureReleasesOwner(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + initialDialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(initialDialer) + + account := &Account{ID: 1101, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 12, + SessionHash: "session_owner_release", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + failDialer := &openAIWSAlwaysFailDialer{} + pool.setClientDialerForTest(failDialer) + _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 12, + SessionHash: "session_owner_release", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.Error(t, err) + require.NotErrorIs(t, err, errOpenAIWSIngressContextBusy, "ensure 上游失败后不应遗留 owner 导致 context 长时间 busy") + + recoverDialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(recoverDialer) + + lease3, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 12, + SessionHash: "session_owner_release", + OwnerID: "owner_c", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err, "owner 回滚后应允许后续会话重新获取同一 context") + require.NotNil(t, lease3) + lease3.Release() + require.Equal(t, 1, failDialer.DialCount()) + require.Equal(t, 1, recoverDialer.DialCount()) +} + +func TestOpenAIWSIngressContextPool_Release_ClosesUpstreamAndForcesRedial(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstreamConn1 := &openAIWSCaptureConn{} + upstreamConn2 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + upstreamConn1, + upstreamConn2, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 1102, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 13, + SessionHash: "session_same", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + connID1 := lease1.ConnID() + require.NotEmpty(t, connID1) + lease1.Release() + + upstreamConn1.mu.Lock() + closed1 := upstreamConn1.closed + upstreamConn1.mu.Unlock() + require.True(t, closed1, "客户端会话结束后应关闭对应上游连接,防止复用污染") + + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 13, + SessionHash: "session_same", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + connID2 := lease2.ConnID() + require.NotEmpty(t, connID2) + require.NotEqual(t, connID1, connID2, "下一次会话必须重新建立上游连接") + lease2.Release() + + upstreamConn2.mu.Lock() + closed2 := upstreamConn2.closed + upstreamConn2.mu.Unlock() + require.True(t, closed2) + require.Equal(t, 2, dialer.DialCount()) +} + +func TestOpenAIWSIngressContextPool_Yield_ReleasesOwnerKeepsUpstream(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstreamConn := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{upstreamConn}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 1103, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 14, + SessionHash: "session_yield", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + connID1 := lease1.ConnID() + require.NotEmpty(t, connID1) + + lease1.Yield() + upstreamConn.mu.Lock() + closedAfterYield := upstreamConn.closed + upstreamConn.mu.Unlock() + require.False(t, closedAfterYield, "yield 只应释放 owner,不应关闭上游连接") + + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 14, + SessionHash: "session_yield", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, connID1, lease2.ConnID(), "yield 后应复用同一上游连接") + require.Equal(t, 1, dialer.DialCount(), "yield 后重新获取不应触发重拨号") + + lease2.Release() + upstreamConn.mu.Lock() + closedAfterRelease := upstreamConn.closed + upstreamConn.mu.Unlock() + require.True(t, closedAfterRelease, "release 仍需关闭上游连接") +} + +func TestOpenAIWSIngressContextPool_EvictExpiredIdleLocked_ClosesUpstream(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstreamConn := &openAIWSCaptureConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + expiredCtx := &openAIWSIngressContext{ + id: "ctx_expired_1", + groupID: 21, + accountID: 1201, + sessionHash: "session_expired", + sessionKey: openAIWSIngressContextSessionKey(21, "session_expired"), + upstream: upstreamConn, + upstreamConnID: "ctxws_1201_1", + handshakeHeaders: map[string][]string{"x-test": {"ok"}}, + } + expiredCtx.setExpiresAt(time.Now().Add(-2 * time.Second)) + ap.contexts[expiredCtx.id] = expiredCtx + ap.bySession[expiredCtx.sessionKey] = expiredCtx.id + + ap.mu.Lock() + toClose := pool.evictExpiredIdleLocked(ap, time.Now()) + ap.mu.Unlock() + closeOpenAIWSClientConns(toClose) + + require.Empty(t, ap.contexts, "过期 idle context 应被清理") + require.Empty(t, ap.bySession, "过期 context 的 session 索引应同步清理") + upstreamConn.mu.Lock() + closed := upstreamConn.closed + upstreamConn.mu.Unlock() + require.True(t, closed, "清理过期 context 时应关闭残留上游连接,避免泄漏") +} + +func TestOpenAIWSIngressContextPool_ScoreAndStickinessHelpers(t *testing.T) { + now := time.Now() + + require.Equal(t, 1, minInt(1, 2)) + require.Equal(t, 2, minInt(3, 2)) + + require.Equal(t, openAIWSIngressStickinessBalanced, openAIWSIngressStickinessDowngrade(openAIWSIngressStickinessStrong)) + require.Equal(t, openAIWSIngressStickinessWeak, openAIWSIngressStickinessDowngrade(openAIWSIngressStickinessBalanced)) + require.Equal(t, openAIWSIngressStickinessWeak, openAIWSIngressStickinessDowngrade("unknown")) + + require.Equal(t, openAIWSIngressStickinessBalanced, openAIWSIngressStickinessUpgrade(openAIWSIngressStickinessWeak)) + require.Equal(t, openAIWSIngressStickinessStrong, openAIWSIngressStickinessUpgrade(openAIWSIngressStickinessBalanced)) + require.Equal(t, openAIWSIngressStickinessStrong, openAIWSIngressStickinessUpgrade("unknown")) + + allowStrong, scoreStrong := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessStrong) + require.False(t, allowStrong) + require.Equal(t, 80.0, scoreStrong) + allowBalanced, scoreBalanced := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessBalanced) + require.True(t, allowBalanced) + require.Equal(t, 65.0, scoreBalanced) + allowWeak, scoreWeak := openAIWSIngressMigrationPolicyByStickiness("weak_or_unknown") + require.True(t, allowWeak) + require.Equal(t, 40.0, scoreWeak) + + busyCtx := &openAIWSIngressContext{ownerID: "owner_busy"} + _, _, ok := scoreOpenAIWSIngressMigrationCandidate(busyCtx, now, nil) + require.False(t, ok, "owner 占用中的 context 不应作为迁移候选") + + oldIdle := &openAIWSIngressContext{} + oldIdle.setLastUsedAt(now.Add(-5 * time.Minute)) + recentIdle := &openAIWSIngressContext{} + recentIdle.setLastUsedAt(now.Add(-10 * time.Second)) + scoreOld, _, okOld := scoreOpenAIWSIngressMigrationCandidate(oldIdle, now, nil) + scoreRecent, _, okRecent := scoreOpenAIWSIngressMigrationCandidate(recentIdle, now, nil) + require.True(t, okOld) + require.True(t, okRecent) + require.Greater(t, scoreOld, scoreRecent, "更久未使用的空闲 context 应该更易被迁移") + + penalized := &openAIWSIngressContext{ + broken: true, + failureStreak: 2, + lastFailureAt: now.Add(-30 * time.Second), + migrationCount: 2, + lastMigrationAt: now.Add(-10 * time.Second), + } + penalized.setLastUsedAt(now.Add(-5 * time.Minute)) + scorePenalized, _, okPenalized := scoreOpenAIWSIngressMigrationCandidate(penalized, now, nil) + require.True(t, okPenalized) + require.Less(t, scorePenalized, scoreOld, "近期失败和频繁迁移应降低迁移分数") +} + +func TestOpenAIWSIngressContextPool_EvictPickAndSweep(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + now := time.Now() + expiredConn := &openAIWSCaptureConn{} + expiredCtx := &openAIWSIngressContext{ + id: "ctx_expired", + sessionKey: "1:expired", + upstream: expiredConn, + upstreamConnID: "ctxws_expired", + } + expiredCtx.setLastUsedAt(now.Add(-20 * time.Minute)) + expiredCtx.setExpiresAt(now.Add(-time.Minute)) + + idleNewCtx := &openAIWSIngressContext{ + id: "ctx_idle_new", + sessionKey: "1:idle_new", + } + idleNewCtx.setLastUsedAt(now.Add(-30 * time.Second)) + idleNewCtx.setExpiresAt(now.Add(time.Minute)) + + busyCtx := &openAIWSIngressContext{ + id: "ctx_busy", + sessionKey: "1:busy", + ownerID: "active_owner", + } + busyCtx.setLastUsedAt(now.Add(-40 * time.Minute)) + busyCtx.setExpiresAt(now.Add(-time.Minute)) + + ap := &openAIWSIngressAccountPool{ + contexts: map[string]*openAIWSIngressContext{ + "ctx_expired": expiredCtx, + "ctx_idle_new": idleNewCtx, + "ctx_busy": busyCtx, + }, + bySession: map[string]string{ + "1:expired": "ctx_expired", + "1:idle_new": "ctx_idle_new", + "1:busy": "ctx_busy", + }, + } + + ap.mu.Lock() + oldestIdle := pool.pickOldestIdleContextLocked(ap) + ap.mu.Unlock() + require.NotNil(t, oldestIdle) + require.Equal(t, "ctx_expired", oldestIdle.id, "应选择最旧的空闲 context") + + ap.mu.Lock() + toClose := pool.evictExpiredIdleLocked(ap, now) + ap.mu.Unlock() + closeOpenAIWSClientConns(toClose) + require.NotContains(t, ap.contexts, "ctx_expired") + require.NotContains(t, ap.bySession, "1:expired") + require.Contains(t, ap.contexts, "ctx_idle_new", "未过期空闲 context 应保留") + require.Contains(t, ap.contexts, "ctx_busy", "有 owner 的 context 不应被 idle 过期清理") + expiredConn.mu.Lock() + expiredClosed := expiredConn.closed + expiredConn.mu.Unlock() + require.True(t, expiredClosed, "清理过期 idle context 时应关闭上游连接") + + expiredInPoolConn := &openAIWSCaptureConn{} + pool.mu.Lock() + pool.accounts[5001] = ap + poolExpiredCtx := &openAIWSIngressContext{ + id: "ctx_pool_expired", + sessionKey: "2:expired", + upstream: expiredInPoolConn, + } + poolExpiredCtx.setExpiresAt(now.Add(-time.Minute)) + pool.accounts[5002] = &openAIWSIngressAccountPool{ + contexts: map[string]*openAIWSIngressContext{ + "ctx_pool_expired": poolExpiredCtx, + }, + bySession: map[string]string{ + "2:expired": "ctx_pool_expired", + }, + } + pool.mu.Unlock() + + pool.sweepExpiredIdleContexts() + + pool.mu.Lock() + _, account2Exists := pool.accounts[5002] + account1 := pool.accounts[5001] + pool.mu.Unlock() + require.False(t, account2Exists, "sweep 后空账号应被移除") + require.NotNil(t, account1, "非空账号应保留") + expiredInPoolConn.mu.Lock() + sweptClosed := expiredInPoolConn.closed + expiredInPoolConn.mu.Unlock() + require.True(t, sweptClosed) +} + +func TestOpenAIWSIngressContextLease_AccessorsAndPingGuards(t *testing.T) { + var nilLease *openAIWSIngressContextLease + require.Equal(t, "", nilLease.ConnID()) + require.Zero(t, nilLease.QueueWaitDuration()) + require.Zero(t, nilLease.ConnPickDuration()) + require.False(t, nilLease.Reused()) + require.Equal(t, "", nilLease.ScheduleLayer()) + require.Equal(t, "", nilLease.StickinessLevel()) + require.False(t, nilLease.MigrationUsed()) + require.Equal(t, "", nilLease.HandshakeHeader("x-test")) + require.ErrorIs(t, nilLease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed) + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + ctxItem := &openAIWSIngressContext{ + id: "ctx_lease", + ownerID: "owner_ok", + upstream: &openAIWSFakeConn{}, + handshakeHeaders: http.Header{"X-Test": []string{"ok"}}, + } + lease := &openAIWSIngressContextLease{ + pool: pool, + context: ctxItem, + ownerID: "owner_ok", + queueWait: 5 * time.Millisecond, + connPick: 8 * time.Millisecond, + reused: true, + scheduleLayer: openAIWSIngressScheduleLayerExact, + stickiness: openAIWSIngressStickinessBalanced, + migrationUsed: true, + } + + require.Equal(t, "ok", lease.HandshakeHeader("x-test")) + require.Equal(t, 5*time.Millisecond, lease.QueueWaitDuration()) + require.Equal(t, 8*time.Millisecond, lease.ConnPickDuration()) + require.True(t, lease.Reused()) + require.Equal(t, openAIWSIngressScheduleLayerExact, lease.ScheduleLayer()) + require.Equal(t, openAIWSIngressStickinessBalanced, lease.StickinessLevel()) + require.True(t, lease.MigrationUsed()) + require.NoError(t, lease.PingWithTimeout(0), "timeout=0 应回退默认 ping 超时") + + lease.released.Store(true) + require.ErrorIs(t, lease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed) + lease.released.Store(false) + + ctxItem.mu.Lock() + ctxItem.ownerID = "other_owner" + ctxItem.mu.Unlock() + lease.cachedConn = nil // clear cache to force re-validation (simulates migration) + require.ErrorIs(t, lease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed, "owner 不匹配时应拒绝访问") + + ctxItem.mu.Lock() + ctxItem.ownerID = "owner_ok" + ctxItem.upstream = &openAIWSPingFailConn{} + ctxItem.mu.Unlock() + lease.cachedConn = nil // clear cache to pick up new upstream + require.Error(t, lease.PingWithTimeout(time.Millisecond), "上游 ping 失败应透传错误") + + lease.Release() + lease.Release() + require.Equal(t, "", lease.context.ownerID, "重复 Release 应幂等且不会 panic") +} + +func TestOpenAIWSIngressContextPool_EnsureContextUpstreamBranches(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + ctxItem := &openAIWSIngressContext{ + id: "ctx_ensure", + accountID: 1, + ownerID: "owner", + upstream: &openAIWSFakeConn{}, + } + + reused, err := pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{ + WSURL: "ws://test", + }) + require.NoError(t, err) + require.True(t, reused, "已有可用 upstream 时应直接复用") + + pool.dialer = nil + ctxItem.mu.Lock() + ctxItem.broken = true + ctxItem.mu.Unlock() + _, err = pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{ + WSURL: "ws://test", + }) + require.ErrorContains(t, err, "dialer is nil") + + failDialer := &openAIWSAlwaysFailDialer{} + pool.setClientDialerForTest(failDialer) + _, err = pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{ + WSURL: "ws://test", + }) + require.Error(t, err) + var dialErr *openAIWSDialError + require.ErrorAs(t, err, &dialErr, "dial 失败应包装为 openAIWSDialError") + require.Equal(t, 503, dialErr.StatusCode) + ctxItem.mu.Lock() + broken := ctxItem.broken + failureStreak := ctxItem.failureStreak + ctxItem.mu.Unlock() + require.True(t, broken) + require.GreaterOrEqual(t, failureStreak, 1, "dial 失败后应累计 failure_streak") +} + +func TestOpenAIWSIngressContextPool_MarkBrokenDoesNotSignalWaiterBeforeRelease(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstream := &openAIWSCaptureConn{} + ctxItem := &openAIWSIngressContext{ + id: "ctx_mark_broken", + ownerID: "owner_broken", + releaseDone: make(chan struct{}, 1), + upstream: upstream, + } + + pool.markContextBroken(ctxItem) + + select { + case <-ctxItem.releaseDone: + t.Fatal("markContextBroken should not wake waiters before owner is released") + default: + } + + ctxItem.mu.Lock() + require.True(t, ctxItem.broken) + require.Equal(t, "owner_broken", ctxItem.ownerID) + require.Nil(t, ctxItem.upstream) + ctxItem.mu.Unlock() + + upstream.mu.Lock() + require.True(t, upstream.closed, "mark broken should close current upstream connection") + upstream.mu.Unlock() + + pool.releaseContext(ctxItem, "owner_broken") + + select { + case <-ctxItem.releaseDone: + case <-time.After(200 * time.Millisecond): + t.Fatal("releaseContext should signal one waiting acquire after owner is released") + } + + ctxItem.mu.Lock() + require.Equal(t, "", ctxItem.ownerID) + require.False(t, ctxItem.broken) + ctxItem.mu.Unlock() +} + +type openAIWSWriteDisconnectConn struct{} + +func (c *openAIWSWriteDisconnectConn) WriteJSON(context.Context, any) error { + return net.ErrClosed +} + +func (c *openAIWSWriteDisconnectConn) ReadMessage(context.Context) ([]byte, error) { + return nil, net.ErrClosed +} + +func (c *openAIWSWriteDisconnectConn) Ping(context.Context) error { + return net.ErrClosed +} + +func (c *openAIWSWriteDisconnectConn) Close() error { + return nil +} + +type openAIWSWriteGenericFailConn struct{} + +func (c *openAIWSWriteGenericFailConn) WriteJSON(context.Context, any) error { + return errors.New("writer failed") +} + +func (c *openAIWSWriteGenericFailConn) ReadMessage(context.Context) ([]byte, error) { + return nil, errors.New("reader failed") +} + +func (c *openAIWSWriteGenericFailConn) Ping(context.Context) error { + return errors.New("ping failed") +} + +func (c *openAIWSWriteGenericFailConn) Close() error { + return nil +} + +func TestOpenAIWSIngressContextLease_IOErrorInvalidatesCachedConn(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstream := &openAIWSWriteDisconnectConn{} + ctxItem := &openAIWSIngressContext{ + id: "ctx_io_err", + accountID: 7, + ownerID: "owner_io", + upstream: upstream, + } + lease := &openAIWSIngressContextLease{ + pool: pool, + context: ctxItem, + ownerID: "owner_io", + } + lease.cachedConn = upstream + + err := lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second) + require.Error(t, err) + require.ErrorIs(t, err, net.ErrClosed) + + lease.cachedConnMu.RLock() + cached := lease.cachedConn + lease.cachedConnMu.RUnlock() + require.Nil(t, cached, "write failure should invalidate cachedConn") + + ctxItem.mu.Lock() + require.True(t, ctxItem.broken, "disconnect-style IO failure should mark context broken") + require.Nil(t, ctxItem.upstream, "broken context should drop upstream reference") + ctxItem.mu.Unlock() +} + +func TestOpenAIWSIngressContextLease_GenericIOErrorKeepsContextButInvalidatesCache(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + upstream := &openAIWSWriteGenericFailConn{} + ctxItem := &openAIWSIngressContext{ + id: "ctx_generic_err", + accountID: 8, + ownerID: "owner_generic", + upstream: upstream, + } + lease := &openAIWSIngressContextLease{ + pool: pool, + context: ctxItem, + ownerID: "owner_generic", + } + lease.cachedConn = upstream + + err := lease.PingWithTimeout(time.Second) + require.Error(t, err) + + lease.cachedConnMu.RLock() + cached := lease.cachedConn + lease.cachedConnMu.RUnlock() + require.Nil(t, cached, "generic IO failure should still invalidate cachedConn") + + ctxItem.mu.Lock() + require.False(t, ctxItem.broken, "non-disconnect IO failure should not force-broken context") + require.Equal(t, upstream, ctxItem.upstream, "upstream should remain for non-disconnect errors") + ctxItem.mu.Unlock() +} + +func TestOpenAIWSIngressContextPool_EnsureContextUpstream_SerializesConcurrentDial(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + releaseDial := make(chan struct{}) + blockingDialer := &openAIWSBlockingDialer{ + release: releaseDial, + dialStarted: make(chan struct{}, 4), + } + pool.setClientDialerForTest(blockingDialer) + + account := &Account{ID: 1301, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + type acquireResult struct { + lease *openAIWSIngressContextLease + err error + } + resultCh := make(chan acquireResult, 2) + acquireOnce := func() { + lease, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 23, + SessionHash: "session_same_owner", + OwnerID: "owner_same", + WSURL: "ws://test-upstream", + }) + resultCh <- acquireResult{lease: lease, err: err} + } + + go acquireOnce() + select { + case <-blockingDialer.dialStarted: + case <-time.After(500 * time.Millisecond): + t.Fatal("首个 dial 未按预期启动") + } + go acquireOnce() + + select { + case <-blockingDialer.dialStarted: + t.Fatal("同一 context 并发 acquire 不应触发第二次 dial") + case <-time.After(120 * time.Millisecond): + } + + close(releaseDial) + + results := make([]acquireResult, 0, 2) + for i := 0; i < 2; i++ { + select { + case result := <-resultCh: + require.NoError(t, result.err) + require.NotNil(t, result.lease) + results = append(results, result) + case <-time.After(2 * time.Second): + t.Fatal("等待并发 acquire 结果超时") + } + } + + for _, result := range results { + result.lease.Release() + } + require.Equal(t, 1, blockingDialer.DialCount(), "同一 context 并发获取应只发生一次上游拨号") +} + +func TestOpenAIWSIngressContextPool_EnsureContextUpstream_WaiterTimeoutDoesNotReleaseOwner(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + releaseDial := make(chan struct{}) + blockingDialer := &openAIWSBlockingDialer{ + release: releaseDial, + dialStarted: make(chan struct{}, 4), + } + pool.setClientDialerForTest(blockingDialer) + + account := &Account{ID: 1302, Concurrency: 1} + baseReq := openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 24, + SessionHash: "session_waiter_timeout", + OwnerID: "owner_same", + WSURL: "ws://test-upstream", + } + + longCtx, longCancel := context.WithTimeout(context.Background(), 3*time.Second) + defer longCancel() + type acquireResult struct { + lease *openAIWSIngressContextLease + err error + } + firstResultCh := make(chan acquireResult, 1) + go func() { + lease, err := pool.Acquire(longCtx, baseReq) + firstResultCh <- acquireResult{lease: lease, err: err} + }() + + select { + case <-blockingDialer.dialStarted: + case <-time.After(500 * time.Millisecond): + t.Fatal("首个 dial 未按预期启动") + } + + shortCtx, shortCancel := context.WithTimeout(context.Background(), 60*time.Millisecond) + defer shortCancel() + _, waiterErr := pool.Acquire(shortCtx, baseReq) + require.ErrorIs(t, waiterErr, context.DeadlineExceeded, "等待中的 acquire 超时应返回 context deadline exceeded") + + close(releaseDial) + + select { + case first := <-firstResultCh: + require.NoError(t, first.err) + require.NotNil(t, first.lease) + require.NoError(t, first.lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "ping"}, time.Second), "等待方超时不应释放已建连 owner") + first.lease.Release() + case <-time.After(2 * time.Second): + t.Fatal("等待首个 acquire 结果超时") + } + + require.Equal(t, 1, blockingDialer.DialCount()) +} + +func TestOpenAIWSIngressContextPool_Acquire_QueueWaitDurationRecorded(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 1303, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 25, + SessionHash: "session_queue_wait", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + + type acquireResult struct { + lease *openAIWSIngressContextLease + err error + } + waiterCh := make(chan acquireResult, 1) + go func() { + lease, acquireErr := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 25, + SessionHash: "session_queue_wait", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + waiterCh <- acquireResult{lease: lease, err: acquireErr} + }() + + time.Sleep(120 * time.Millisecond) + lease1.Release() + + select { + case waiter := <-waiterCh: + require.NoError(t, waiter.err) + require.NotNil(t, waiter.lease) + require.GreaterOrEqual(t, waiter.lease.QueueWaitDuration(), 100*time.Millisecond) + waiter.lease.Release() + case <-time.After(2 * time.Second): + t.Fatal("等待第二个 acquire 结果超时") + } +} + +type openAIWSBlockingDialer struct { + mu sync.Mutex + release <-chan struct{} + dialStarted chan struct{} + dialCount int +} + +func (d *openAIWSBlockingDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = wsURL + _ = headers + _ = proxyURL + if ctx == nil { + ctx = context.Background() + } + d.mu.Lock() + d.dialCount++ + d.mu.Unlock() + select { + case d.dialStarted <- struct{}{}: + default: + } + if d.release != nil { + select { + case <-d.release: + case <-ctx.Done(): + return nil, 0, nil, ctx.Err() + } + } + return &openAIWSCaptureConn{}, 0, nil, nil +} + +func (d *openAIWSBlockingDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +// --------------------------------------------------------------------------- +// Load-aware migration scoring tests +// --------------------------------------------------------------------------- + +func TestScoreOpenAIWSIngressMigrationCandidate_HighErrorRatePenalty(t *testing.T) { + now := time.Now() + stats := newOpenAIAccountRuntimeStats() + accountID := int64(9001) + + // Report a pattern of failures interspersed with occasional successes. + // This pushes the error rate high without tripping the circuit breaker + // (consecutive failures stay below the default threshold of 5). + for i := 0; i < 6; i++ { + stats.report(accountID, false, nil, "", 0) + stats.report(accountID, false, nil, "", 0) + stats.report(accountID, false, nil, "", 0) + stats.report(accountID, true, nil, "", 0) // reset consecutive fail counter + } + require.False(t, stats.isCircuitOpen(accountID), "circuit breaker should remain closed for this test") + + ctx := &openAIWSIngressContext{accountID: accountID} + ctx.setLastUsedAt(now.Add(-5 * time.Minute)) + + scoreWithStats, _, okStats := scoreOpenAIWSIngressMigrationCandidate(ctx, now, stats) + require.True(t, okStats) + + // Score the same context without stats (nil) for comparison. + scoreWithout, _, okWithout := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil) + require.True(t, okWithout) + + require.Less(t, scoreWithStats, scoreWithout, + "a context on a high-error-rate account should receive a lower migration score") + + // The error rate penalty should be approximately errorRate * 30. + // Since the circuit breaker is not open, the only load-aware penalty is + // errorRate * 30. + errorRate, _, _ := stats.snapshot(accountID) + expectedPenalty := errorRate * 30 + require.InDelta(t, expectedPenalty, scoreWithout-scoreWithStats, 1.0, + "penalty should be approximately errorRate * 30") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_CircuitOpenPenalty(t *testing.T) { + now := time.Now() + stats := newOpenAIAccountRuntimeStats() + accountID := int64(9002) + + // Trip the circuit breaker by reporting consecutive failures beyond the + // default threshold (5). + for i := 0; i < defaultCircuitBreakerFailThreshold+1; i++ { + stats.report(accountID, false, nil, "", 0) + } + require.True(t, stats.isCircuitOpen(accountID), "circuit breaker should be open after many failures") + + ctx := &openAIWSIngressContext{accountID: accountID} + ctx.setLastUsedAt(now.Add(-5 * time.Minute)) + + scoreCircuitOpen, _, ok := scoreOpenAIWSIngressMigrationCandidate(ctx, now, stats) + require.True(t, ok) + + // Score without stats for comparison. + scoreBaseline, _, okBase := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil) + require.True(t, okBase) + + // The circuit-open penalty is -50, plus errorRate*30, so the score should + // be substantially lower. + require.Less(t, scoreCircuitOpen, scoreBaseline-45, + "a context on a circuit-open account should have a very low migration score") + + // In practice, the combined penalty should bring the score below any + // reasonable minimum migration threshold (weakest = 40). + _, weakMin := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessWeak) + require.Less(t, scoreCircuitOpen, weakMin, + "circuit-open accounts should score below even the weakest migration threshold") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_NilStatsFallback(t *testing.T) { + now := time.Now() + + ctx := &openAIWSIngressContext{accountID: 9003} + ctx.setLastUsedAt(now.Add(-5 * time.Minute)) + + scoreNil, _, okNil := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil) + require.True(t, okNil) + + // Create stats but report nothing for this account -- snapshot returns 0. + emptyStats := newOpenAIAccountRuntimeStats() + scoreEmpty, _, okEmpty := scoreOpenAIWSIngressMigrationCandidate(ctx, now, emptyStats) + require.True(t, okEmpty) + + // With no data for the account, the load-aware path should add zero + // penalty, yielding the same score as nil stats. + require.InDelta(t, scoreNil, scoreEmpty, 0.001, + "when scheduler stats have no data for the account, score should match nil-stats baseline") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_NilContext(t *testing.T) { + now := time.Now() + score, _, ok := scoreOpenAIWSIngressMigrationCandidate(nil, now, nil) + require.False(t, ok) + require.Equal(t, 0.0, score) +} + +func TestScoreOpenAIWSIngressMigrationCandidate_IdleDurationBranches(t *testing.T) { + now := time.Now() + + // Very recently used (≤15s): penalty of -15 + recentCtx := &openAIWSIngressContext{} + recentCtx.setLastUsedAt(now.Add(-5 * time.Second)) + scoreRecent, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentCtx, now, nil) + require.True(t, ok) + require.InDelta(t, 100.0-15.0, scoreRecent, 0.5, "very recently used should get -15 penalty") + + // Medium idle (between 15s and 3min): bonus = idleDuration.Seconds() / 12 + mediumCtx := &openAIWSIngressContext{} + mediumCtx.setLastUsedAt(now.Add(-90 * time.Second)) // 90s idle + scoreMedium, _, ok := scoreOpenAIWSIngressMigrationCandidate(mediumCtx, now, nil) + require.True(t, ok) + expectedBonus := 90.0 / 12.0 // 7.5 + require.InDelta(t, 100.0+expectedBonus, scoreMedium, 0.5, "medium idle should get proportional bonus") + + // Long idle (≥3min): bonus of +16 + longCtx := &openAIWSIngressContext{} + longCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreLong, _, ok := scoreOpenAIWSIngressMigrationCandidate(longCtx, now, nil) + require.True(t, ok) + require.InDelta(t, 100.0+16.0, scoreLong, 0.5, "long idle should get +16 bonus") + + // Verify ordering: long > medium > recent + require.Greater(t, scoreLong, scoreMedium, "longer idle should score higher than medium") + require.Greater(t, scoreMedium, scoreRecent, "medium idle should score higher than very recent") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_BrokenAndFailures(t *testing.T) { + now := time.Now() + + // Broken context: -30 + brokenCtx := &openAIWSIngressContext{broken: true} + brokenCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreBroken, _, ok := scoreOpenAIWSIngressMigrationCandidate(brokenCtx, now, nil) + require.True(t, ok) + + cleanCtx := &openAIWSIngressContext{} + cleanCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreClean, _, ok := scoreOpenAIWSIngressMigrationCandidate(cleanCtx, now, nil) + require.True(t, ok) + require.InDelta(t, 30.0, scoreClean-scoreBroken, 0.5, "broken should subtract 30") + + // High failure streak (capped at 40) + highFailCtx := &openAIWSIngressContext{failureStreak: 5} + highFailCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreHighFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(highFailCtx, now, nil) + require.True(t, ok) + // 5*12=60 but capped at 40 + require.InDelta(t, 40.0, scoreClean-scoreHighFail, 0.5, "failure streak penalty should cap at 40") + + // Recent failure (within 2 min): -18 + recentFailCtx := &openAIWSIngressContext{lastFailureAt: now.Add(-30 * time.Second)} + recentFailCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreRecentFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentFailCtx, now, nil) + require.True(t, ok) + require.InDelta(t, 18.0, scoreClean-scoreRecentFail, 0.5, "recent failure should subtract 18") + + // Old failure (>2 min): no penalty + oldFailCtx := &openAIWSIngressContext{lastFailureAt: now.Add(-5 * time.Minute)} + oldFailCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreOldFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(oldFailCtx, now, nil) + require.True(t, ok) + require.InDelta(t, scoreClean, scoreOldFail, 0.5, "old failure should have no penalty") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_MigrationPenalties(t *testing.T) { + now := time.Now() + + cleanCtx := &openAIWSIngressContext{} + cleanCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreClean, _, _ := scoreOpenAIWSIngressMigrationCandidate(cleanCtx, now, nil) + + // Recent migration (within 1 min): -10 + recentMigCtx := &openAIWSIngressContext{lastMigrationAt: now.Add(-30 * time.Second)} + recentMigCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreRecentMig, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentMigCtx, now, nil) + require.True(t, ok) + require.InDelta(t, 10.0, scoreClean-scoreRecentMig, 0.5, "recent migration should subtract 10") + + // High migration count (capped at 20) + highMigCtx := &openAIWSIngressContext{migrationCount: 6} + highMigCtx.setLastUsedAt(now.Add(-5 * time.Minute)) + scoreHighMig, _, ok := scoreOpenAIWSIngressMigrationCandidate(highMigCtx, now, nil) + require.True(t, ok) + // 6*4=24 but capped at 20 + require.InDelta(t, 20.0, scoreClean-scoreHighMig, 0.5, "migration count penalty should cap at 20") +} + +func TestScoreOpenAIWSIngressMigrationCandidate_CombinedPenalties(t *testing.T) { + now := time.Now() + // All penalties combined: broken(-30) + failStreak 1*12(-12) + recentFail(-18) + recentMig(-10) + migCount 1*4(-4) + recentIdle(-15) + worstCtx := &openAIWSIngressContext{ + broken: true, + failureStreak: 1, + lastFailureAt: now.Add(-30 * time.Second), + migrationCount: 1, + lastMigrationAt: now.Add(-30 * time.Second), + } + worstCtx.setLastUsedAt(now.Add(-5 * time.Second)) + score, _, ok := scoreOpenAIWSIngressMigrationCandidate(worstCtx, now, nil) + require.True(t, ok) + expected := 100.0 - 30.0 - 12.0 - 18.0 - 10.0 - 4.0 - 15.0 // = 11.0 + require.InDelta(t, expected, score, 0.5, "all penalties should stack correctly") +} + +func TestOpenAIWSIngressContextPool_MigrationBlockedByCircuitBreaker(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + stats := newOpenAIAccountRuntimeStats() + pool.schedulerStats = stats + + accountID := int64(9004) + + // Trip circuit breaker for this account. + for i := 0; i < defaultCircuitBreakerFailThreshold+1; i++ { + stats.report(accountID, false, nil, "", 0) + } + require.True(t, stats.isCircuitOpen(accountID)) + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: accountID, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // Acquire the only slot. + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 30, + SessionHash: "session_cb_a", + OwnerID: "owner_a", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + // Now try a different session -- migration should fail because the only + // candidate context is on a circuit-open account, whose score will be + // below the minimum threshold. + _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 30, + SessionHash: "session_cb_b", + OwnerID: "owner_b", + WSURL: "ws://test-upstream", + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull, + "migration to a circuit-open account should be blocked") +} + +// ---------- 连接生命周期管理(超龄轮换)测试 ---------- + +func TestOpenAIWSIngressContext_UpstreamConnAge_ZeroValue(t *testing.T) { + ctx := &openAIWSIngressContext{} + // 未设置 createdAt 时,connAge 应返回 0 + require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now())) +} + +func TestOpenAIWSIngressContext_UpstreamConnAge_Normal(t *testing.T) { + ctx := &openAIWSIngressContext{} + past := time.Now().Add(-10 * time.Minute) + ctx.upstreamConnCreatedAt.Store(past.UnixNano()) + age := ctx.upstreamConnAge(time.Now()) + require.True(t, age >= 10*time.Minute-time.Second, "connAge 应约为 10 分钟,实际=%v", age) + require.True(t, age < 11*time.Minute, "connAge 不应过大,实际=%v", age) +} + +func TestOpenAIWSIngressContext_UpstreamConnAge_NilSafe(t *testing.T) { + var ctx *openAIWSIngressContext + require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now())) +} + +func TestOpenAIWSIngressContext_UpstreamConnAge_ClockSkew(t *testing.T) { + ctx := &openAIWSIngressContext{} + future := time.Now().Add(10 * time.Minute) + ctx.upstreamConnCreatedAt.Store(future.UnixNano()) + // now 早于 createdAt(时钟回拨),应返回 0 + require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now())) +} + +func TestNewOpenAIWSIngressContextPool_UpstreamMaxAge_ZeroDisablesRotation(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds = 0 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + require.Equal(t, time.Duration(0), pool.upstreamMaxAge) +} + +func TestOpenAIWSIngressContextPool_EnsureUpstream_MaxAgeRotate(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 1 * time.Second // 设置极短的 maxAge 以便测试 + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + conn2 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{conn1, conn2}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 901, Concurrency: 2} + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 第一次 Acquire:建连 + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_age", + OwnerID: "owner_age_1", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + require.Equal(t, 1, dialer.DialCount(), "首次 Acquire 应 dial 一次") + + // Yield 保留连接 + lease1.Yield() + + // 等待超过 maxAge + time.Sleep(1200 * time.Millisecond) + + // 第二次 Acquire:应触发超龄轮换 + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_age", + OwnerID: "owner_age_2", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, 2, dialer.DialCount(), "超龄轮换应触发重新 dial") + require.True(t, conn1.closed, "旧连接应被关闭") + lease2.Release() +} + +func TestOpenAIWSIngressContextPool_EnsureUpstream_YoungConnNotRotated(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 10 * time.Minute // 远大于测试时间 + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{conn1}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 902, Concurrency: 2} + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_young", + OwnerID: "owner_young_1", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + lease1.Yield() + + // 立即重新 Acquire:连接年轻,不应轮换 + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_young", + OwnerID: "owner_young_2", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, 1, dialer.DialCount(), "年轻连接不应触发重新 dial") + require.True(t, lease2.Reused(), "年轻连接应复用") + require.False(t, conn1.closed, "年轻连接不应被关闭") + lease2.Release() +} + +func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_ClosesAgedIdle(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 1 * time.Second + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + + ctx := &openAIWSIngressContext{ + id: "ctx_aged_1", + accountID: 903, + upstream: conn1, + } + // 设 createdAt 为 2 秒前 + ctx.upstreamConnCreatedAt.Store(time.Now().Add(-2 * time.Second).UnixNano()) + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_aged_1"] = ctx + + now := time.Now() + ap.mu.Lock() + toClose := pool.closeAgedIdleUpstreamsLocked(ap, now) + ap.mu.Unlock() + + require.Len(t, toClose, 1, "应关闭超龄空闲连接") + closeOpenAIWSClientConns(toClose) + require.True(t, conn1.closed) + + // upstream 应已清空 + ctx.mu.Lock() + require.Nil(t, ctx.upstream) + require.Equal(t, int64(0), ctx.upstreamConnCreatedAt.Load()) + ctx.mu.Unlock() +} + +func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_SkipsOwnedContext(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 1 * time.Second + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + + ctx := &openAIWSIngressContext{ + id: "ctx_owned_1", + accountID: 904, + ownerID: "active_owner", + upstream: conn1, + } + ctx.upstreamConnCreatedAt.Store(time.Now().Add(-2 * time.Second).UnixNano()) + ap.contexts["ctx_owned_1"] = ctx + + now := time.Now() + ap.mu.Lock() + toClose := pool.closeAgedIdleUpstreamsLocked(ap, now) + ap.mu.Unlock() + + require.Len(t, toClose, 0, "有 owner 的超龄连接不应被关闭") + require.False(t, conn1.closed) +} + +func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_SkipsYoungConn(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 10 * time.Minute + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + + ctx := &openAIWSIngressContext{ + id: "ctx_young_1", + accountID: 905, + upstream: conn1, + } + ctx.upstreamConnCreatedAt.Store(time.Now().UnixNano()) + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_young_1"] = ctx + + now := time.Now() + ap.mu.Lock() + toClose := pool.closeAgedIdleUpstreamsLocked(ap, now) + ap.mu.Unlock() + + require.Len(t, toClose, 0, "年轻连接不应被关闭") + require.False(t, conn1.closed) +} + +func TestOpenAIWSIngressContextPool_E2E_AcquireYieldAgedReconnect(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + + pool := newOpenAIWSIngressContextPool(cfg) + pool.upstreamMaxAge = 55 * time.Minute // 使用实际默认值 + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + conn2 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{conn1, conn2}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 906, Concurrency: 2} + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Acquire → Yield + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_e2e", + OwnerID: "owner_e2e_1", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Yield() + + // 手动设置 createdAt 为 56 分钟前以模拟超龄 + pool.mu.Lock() + ap := pool.accounts[account.ID] + pool.mu.Unlock() + require.NotNil(t, ap) + + ap.mu.Lock() + for _, c := range ap.contexts { + c.upstreamConnCreatedAt.Store(time.Now().Add(-56 * time.Minute).UnixNano()) + } + ap.mu.Unlock() + + // 重新 Acquire:应检测到超龄并重连 + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_e2e", + OwnerID: "owner_e2e_2", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease2) + require.Equal(t, 2, dialer.DialCount(), "超龄 56 分钟的连接应触发重连") + require.True(t, conn1.closed, "旧的超龄连接应被关闭") + lease2.Release() +} + +// 回归测试:容量满 + 存在过期 context 时,Acquire 仍能通过 evict 腾出空间后正常分配。 +func TestOpenAIWSIngressContextPool_Acquire_EvictsExpiredWhenFull(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + conn2 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{conn1, conn2}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 901, Concurrency: 1} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // 获取一个 lease 占满容量 + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 3, + SessionHash: "session_old", + OwnerID: "owner_old", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + // 手动令该 context 过期 + pool.mu.Lock() + ap := pool.accounts[account.ID] + pool.mu.Unlock() + ap.mu.Lock() + for _, c := range ap.contexts { + c.setExpiresAt(time.Now().Add(-2 * time.Second)) + } + ap.mu.Unlock() + + // 容量满(1个过期 context),新 session 的 Acquire 应通过 evict 成功 + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 3, + SessionHash: "session_new", + OwnerID: "owner_new", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err, "容量满但有过期 context 时 Acquire 应成功") + require.NotNil(t, lease2) + lease2.Release() +} + +// 回归测试:Acquire 找到已过期但仍在 bySession 映射中的 context 时,能正确取得所有权并刷新租约。 +func TestOpenAIWSIngressContextPool_Acquire_ReusesExpiredContextBySession(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60 + + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + conn1 := &openAIWSCaptureConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{conn1}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 902, Concurrency: 2} + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + // 第一次获取 context + lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 4, + SessionHash: "session_reuse", + OwnerID: "owner_1", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + // 令 context 过期(但不清理,模拟热路径不再清理的行为) + pool.mu.Lock() + ap := pool.accounts[account.ID] + pool.mu.Unlock() + ap.mu.Lock() + for _, c := range ap.contexts { + c.setExpiresAt(time.Now().Add(-1 * time.Second)) + } + ap.mu.Unlock() + + // 同 session 再次 Acquire,应能复用过期但未清理的 context + lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 4, + SessionHash: "session_reuse", + OwnerID: "owner_2", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err, "过期但未清理的 context 应被同 session 的 Acquire 复用") + require.NotNil(t, lease2) + + // 验证租约已刷新(expiresAt 应在未来) + pool.mu.Lock() + ap2 := pool.accounts[account.ID] + pool.mu.Unlock() + ap2.mu.Lock() + for _, c := range ap2.contexts { + c.mu.Lock() + ea := c.expiresAt() + c.mu.Unlock() + require.True(t, ea.After(time.Now()), "复用后租约应被刷新到未来") + } + ap2.mu.Unlock() + + lease2.Release() +} + +// === P3: 后台主动 Ping 检测测试 === + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_MarksBrokenOnPingFailure(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + ctx := &openAIWSIngressContext{ + id: "ctx_ping_fail_1", + accountID: 2001, + upstream: failConn, + } + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_ping_fail_1"] = ctx + + pool.pingIdleUpstreams(ap) + + ctx.mu.Lock() + broken := ctx.broken + streak := ctx.failureStreak + ctx.mu.Unlock() + require.True(t, broken, "Ping 失败应标记 context 为 broken") + require.Equal(t, 1, streak, "Ping 失败应增加 failureStreak") +} + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsOwnedContext(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + ctx := &openAIWSIngressContext{ + id: "ctx_ping_owned_1", + accountID: 2002, + ownerID: "active_owner", + upstream: failConn, + } + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_ping_owned_1"] = ctx + + pool.pingIdleUpstreams(ap) + + ctx.mu.Lock() + broken := ctx.broken + ctx.mu.Unlock() + require.False(t, broken, "有 owner 的 context 不应被 Ping 探测") +} + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsBrokenContext(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + ctx := &openAIWSIngressContext{ + id: "ctx_ping_broken_1", + accountID: 2003, + upstream: failConn, + broken: true, + } + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_ping_broken_1"] = ctx + + pool.pingIdleUpstreams(ap) + + ctx.mu.Lock() + streak := ctx.failureStreak + ctx.mu.Unlock() + require.Equal(t, 0, streak, "已 broken 的 context 不应被再次 Ping") +} + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_HealthyConnStaysHealthy(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + healthyConn := &openAIWSCaptureConn{} + ap := &openAIWSIngressAccountPool{ + contexts: make(map[string]*openAIWSIngressContext), + bySession: make(map[string]string), + } + ctx := &openAIWSIngressContext{ + id: "ctx_ping_healthy_1", + accountID: 2004, + upstream: healthyConn, + } + ctx.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_ping_healthy_1"] = ctx + + pool.pingIdleUpstreams(ap) + + ctx.mu.Lock() + broken := ctx.broken + hasUpstream := ctx.upstream != nil + ctx.mu.Unlock() + require.False(t, broken, "Ping 成功的 context 不应被标记 broken") + require.True(t, hasUpstream, "Ping 成功的 upstream 应保持") +} + +func TestOpenAIWSIngressContextPool_SweepTriggersPingOnIdleContexts(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + healthyConn := &openAIWSCaptureConn{} + + pool.mu.Lock() + ap := pool.getOrCreateAccountPoolLocked(3001) + pool.mu.Unlock() + + ap.mu.Lock() + ctxFail := &openAIWSIngressContext{ + id: "ctx_sweep_ping_fail", + accountID: 3001, + upstream: failConn, + } + ctxFail.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_sweep_ping_fail"] = ctxFail + + ctxOk := &openAIWSIngressContext{ + id: "ctx_sweep_ping_ok", + accountID: 3001, + upstream: healthyConn, + } + ctxOk.touchLease(time.Now(), pool.idleTTL) + ap.contexts["ctx_sweep_ping_ok"] = ctxOk + ap.dynamicCap.Store(2) + ap.mu.Unlock() + + // 手动触发 sweep + pool.sweepExpiredIdleContexts() + + ctxFail.mu.Lock() + failBroken := ctxFail.broken + ctxFail.mu.Unlock() + require.True(t, failBroken, "sweep 后 Ping 失败的空闲 context 应被标记 broken") + + ctxOk.mu.Lock() + okBroken := ctxOk.broken + ctxOk.mu.Unlock() + require.False(t, okBroken, "sweep 后 Ping 成功的空闲 context 应保持健康") +} + +func TestOpenAIWSIngressContextPool_YieldSchedulesDelayedPing(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{failConn}, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 4001, Concurrency: 2} + bCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + lease, err := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{ + Account: account, + GroupID: 1, + SessionHash: "session_yield_ping", + OwnerID: "owner_yield_ping", + WSURL: "ws://test-upstream", + }) + require.NoError(t, err) + + ingressCtx := lease.context + // Yield 触发延迟 Ping + lease.Yield() + + // 等待延迟 Ping 执行完毕(默认 5s + 余量) + time.Sleep(6 * time.Second) + + ingressCtx.mu.Lock() + broken := ingressCtx.broken + ingressCtx.mu.Unlock() + require.True(t, broken, "Yield 后延迟 Ping 应检测到 PingFailConn 并标记 broken") +} + +func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_CancelledOnPoolClose(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_delayed_cancel_1", + accountID: 5001, + upstream: failConn, + } + ctx.touchLease(time.Now(), pool.idleTTL) + + // 安排 10s 延迟(远大于测试等待时间) + pool.scheduleDelayedPing(ctx, 10*time.Second) + + // 立刻关闭 pool,应取消延迟 Ping + pool.Close() + time.Sleep(200 * time.Millisecond) + + ctx.mu.Lock() + broken := ctx.broken + ctx.mu.Unlock() + require.False(t, broken, "pool 关闭后延迟 Ping 不应执行") +} + +// === effectiveDynamicCapacity 边界测试 === + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_NilAccountPool(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + require.Equal(t, 4, pool.effectiveDynamicCapacity(nil, 4), "ap==nil 时应返回 hardCap") + require.Equal(t, 0, pool.effectiveDynamicCapacity(nil, 0), "ap==nil && hardCap==0") +} + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_ZeroHardCap(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + ap := &openAIWSIngressAccountPool{} + ap.dynamicCap.Store(3) + require.Equal(t, 0, pool.effectiveDynamicCapacity(ap, 0), "hardCap<=0 应返回 hardCap") + require.Equal(t, -1, pool.effectiveDynamicCapacity(ap, -1), "hardCap<0 应返回 hardCap") +} + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapBelowOne(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + ap := &openAIWSIngressAccountPool{} + ap.dynamicCap.Store(0) // 异常值 + result := pool.effectiveDynamicCapacity(ap, 4) + require.Equal(t, 1, result, "dynCap<1 应自动修复为 1") + require.Equal(t, int32(1), ap.dynamicCap.Load(), "dynCap 应被修复为 1") +} + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapExceedsHardCap(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + ap := &openAIWSIngressAccountPool{} + ap.dynamicCap.Store(10) + require.Equal(t, 4, pool.effectiveDynamicCapacity(ap, 4), "dynCap>hardCap 应 clamp 到 hardCap") +} + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapEqualsHardCap(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + ap := &openAIWSIngressAccountPool{} + ap.dynamicCap.Store(4) + require.Equal(t, 4, pool.effectiveDynamicCapacity(ap, 4), "dynCap==hardCap 应返回 hardCap") +} + +func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_NormalPath(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{} + ap := &openAIWSIngressAccountPool{} + ap.dynamicCap.Store(2) + require.Equal(t, 2, pool.effectiveDynamicCapacity(ap, 8), "正常 dynCap= 2, "第二次 Acquire 应触发 dynamicCap 增长") + + lease1.Release() + lease2.Release() +} + +func TestOpenAIWSIngressContextPool_Sweeper_ShrinksDynamicCap(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1 // 1 秒 TTL 让 context 快速过期 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{ + &openAIWSCaptureConn{}, &openAIWSCaptureConn{}, &openAIWSCaptureConn{}, + }, + } + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 6002, Concurrency: 4} + bCtx := context.Background() + + // 创建两个 context + lease1, _ := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{ + Account: account, GroupID: 1, SessionHash: "s1", OwnerID: "o1", WSURL: "ws://t", + }) + lease2, _ := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{ + Account: account, GroupID: 1, SessionHash: "s2", OwnerID: "o2", WSURL: "ws://t", + }) + lease1.Release() + lease2.Release() + + pool.mu.Lock() + ap := pool.accounts[account.ID] + pool.mu.Unlock() + require.True(t, ap.dynamicCap.Load() >= 2) + + // 等待 context 过期 + time.Sleep(2 * time.Second) + + // 手动 sweep + pool.sweepExpiredIdleContexts() + + // sweep 后 dynamicCap 应缩减 + ap.mu.Lock() + ctxCount := len(ap.contexts) + ap.mu.Unlock() + dynCap := ap.dynamicCap.Load() + // 如果所有 context 都被 evict,dynamicCap 应缩到 1(min) + if ctxCount == 0 { + require.Equal(t, int32(1), dynCap, "context 全部 evict 后 dynamicCap 应缩至 1") + } else { + require.LessOrEqual(t, dynCap, int32(ctxCount), "dynamicCap 应缩至当前 context 数") + } +} + +// === Ping 额外边界测试 === + +func TestOpenAIWSIngressContextPool_PingContextUpstream_NilPoolOrContext(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})} + // nil context 不应 panic + pool.pingContextUpstream(nil) + // nil pool 不应 panic + var nilPool *openAIWSIngressContextPool + nilPool.pingContextUpstream(&openAIWSIngressContext{}) +} + +func TestOpenAIWSIngressContextPool_PingContextUpstream_SkipsDialingContext(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_dialing_1", accountID: 7001, + upstream: failConn, dialing: true, + } + pool.pingContextUpstream(ctx) + ctx.mu.Lock() + require.False(t, ctx.broken, "dialing 中的 context 不应被 Ping") + ctx.mu.Unlock() +} + +func TestOpenAIWSIngressContextPool_PingContextUpstream_SkipsNoUpstream(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + ctx := &openAIWSIngressContext{ + id: "ctx_no_upstream", accountID: 7002, upstream: nil, + } + pool.pingContextUpstream(ctx) + ctx.mu.Lock() + require.False(t, ctx.broken, "无 upstream 的 context 不应被 Ping") + ctx.mu.Unlock() +} + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_NilPoolOrAP(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})} + pool.pingIdleUpstreams(nil) + var nilPool *openAIWSIngressContextPool + nilPool.pingIdleUpstreams(&openAIWSIngressAccountPool{}) +} + +func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsNilContext(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + ap := &openAIWSIngressAccountPool{ + contexts: map[string]*openAIWSIngressContext{"nil_ctx": nil}, + bySession: make(map[string]string), + } + // 不应 panic + pool.pingIdleUpstreams(ap) +} + +func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_ZeroDelay(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_zero_delay", accountID: 8001, upstream: failConn, + } + // delay <= 0 应为 no-op + pool.scheduleDelayedPing(ctx, 0) + pool.scheduleDelayedPing(ctx, -1*time.Second) + time.Sleep(100 * time.Millisecond) + ctx.mu.Lock() + require.False(t, ctx.broken, "delay<=0 不应触发 Ping") + ctx.mu.Unlock() +} + +func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_NilParams(t *testing.T) { + t.Parallel() + pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})} + // nil context + pool.scheduleDelayedPing(nil, 5*time.Second) + // nil pool + var nilPool *openAIWSIngressContextPool + nilPool.scheduleDelayedPing(&openAIWSIngressContext{}, 5*time.Second) +} + +// === P1 并发回归:旧连接 Ping 失败不应误杀新连接 === + +func TestOpenAIWSIngressContextPool_PingFailDoesNotKillNewConn(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + // 旧连接:Ping 带 200ms 延迟后失败 + oldConn := newOpenAIWSDelayedPingFailConn(200 * time.Millisecond) + // 新连接:正常的 Ping + newConn := &openAIWSCaptureConn{} + + ctx := &openAIWSIngressContext{ + id: "ctx_race_test", + accountID: 9001, + upstream: oldConn, + upstreamConnID: "old_conn_1", + } + ctx.touchLease(time.Now(), pool.idleTTL) + + // 在后台对旧连接发起 Ping 探测 + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + pool.pingContextUpstream(ctx) + }() + + // 等待 Ping 开始执行 + <-oldConn.pingDone + + // 模拟连接重建:在 Ping 执行期间将 upstream 替换为新连接 + ctx.mu.Lock() + ctx.upstream = newConn + ctx.upstreamConnID = "new_conn_2" + ctx.broken = false + ctx.failureStreak = 0 + ctx.mu.Unlock() + + // 等待 Ping goroutine 完成 + wg.Wait() + + // 验证:新连接不应被标记为 broken + ctx.mu.Lock() + broken := ctx.broken + streak := ctx.failureStreak + upstream := ctx.upstream + connID := ctx.upstreamConnID + ctx.mu.Unlock() + + require.False(t, broken, "新连接不应被旧 Ping 失败标记为 broken") + require.Equal(t, 0, streak, "failureStreak 不应增加") + require.Equal(t, newConn, upstream, "upstream 应仍是新连接") + require.Equal(t, "new_conn_2", connID, "upstreamConnID 应仍是新连接的 ID") + require.False(t, newConn.Closed(), "新连接不应被关闭") +} + +func TestOpenAIWSIngressContextPool_PingFailKillsSameConn(t *testing.T) { + // 对照测试:connID 未变时应正常标记 broken + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_same_conn", + accountID: 9002, + upstream: failConn, + upstreamConnID: "conn_1", + } + ctx.touchLease(time.Now(), pool.idleTTL) + + pool.pingContextUpstream(ctx) + + ctx.mu.Lock() + broken := ctx.broken + streak := ctx.failureStreak + connID := ctx.upstreamConnID + ctx.mu.Unlock() + + require.True(t, broken, "同一连接 Ping 失败应标记 broken") + require.Equal(t, 1, streak, "failureStreak 应为 1") + require.Empty(t, connID, "upstreamConnID 应被清空") +} + +func TestOpenAIWSIngressContextPool_PingFailDoesNotKillBusyConn(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := newOpenAIWSDelayedPingFailConn(200 * time.Millisecond) + ctx := &openAIWSIngressContext{ + id: "ctx_busy_conn", + accountID: 9005, + upstream: failConn, + upstreamConnID: "conn_busy_1", + } + ctx.touchLease(time.Now(), pool.idleTTL) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + pool.pingContextUpstream(ctx) + }() + + <-failConn.pingDone + + ctx.mu.Lock() + ctx.ownerID = "active_owner" + ctx.mu.Unlock() + + wg.Wait() + + ctx.mu.Lock() + broken := ctx.broken + streak := ctx.failureStreak + upstream := ctx.upstream + connID := ctx.upstreamConnID + ownerID := ctx.ownerID + ctx.mu.Unlock() + + require.False(t, broken, "busy context 不应被后台 Ping 标记 broken") + require.Equal(t, 0, streak, "failureStreak 不应增加") + require.Equal(t, failConn, upstream, "busy context 的 upstream 不应被替换") + require.Equal(t, "conn_busy_1", connID, "busy context 的 connID 不应变化") + require.Equal(t, "active_owner", ownerID, "owner 应保持不变") + require.False(t, failConn.Closed(), "busy context 的连接不应被后台 Ping 关闭") +} + +// === P2a 去重:连续多次 Yield 仅触发一次延迟 Ping === + +func TestOpenAIWSIngressContextPool_ConsecutiveYieldsOnlyOnePing(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600 + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + // 使用 Ping 失败连接以便观察是否被标记 broken + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_yield_dedup", + accountID: 9003, + upstream: failConn, + upstreamConnID: "conn_dedup_1", + } + ctx.touchLease(time.Now(), pool.idleTTL) + + // 连续调用 5 次 scheduleDelayedPing + for i := 0; i < 5; i++ { + pool.scheduleDelayedPing(ctx, 100*time.Millisecond) + } + + // 验证:只有一个 pendingPingTimer(不应堆积多个 goroutine) + ctx.mu.Lock() + hasPending := ctx.pendingPingTimer != nil + ctx.mu.Unlock() + require.True(t, hasPending, "应有一个 pendingPingTimer") + + // 等待 timer 到期并执行 Ping + time.Sleep(300 * time.Millisecond) + + // Ping 失败应标记 broken(证明延迟 Ping 确实执行了) + ctx.mu.Lock() + broken := ctx.broken + pendingTimer := ctx.pendingPingTimer + ctx.mu.Unlock() + + require.True(t, broken, "延迟 Ping 应已执行并标记 broken") + require.Nil(t, pendingTimer, "Ping 执行后 pendingPingTimer 应被清理") +} + +func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_ResetExtendsDelay(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + pool := newOpenAIWSIngressContextPool(cfg) + defer pool.Close() + + failConn := &openAIWSPingFailConn{} + ctx := &openAIWSIngressContext{ + id: "ctx_reset_delay", + accountID: 9004, + upstream: failConn, + upstreamConnID: "conn_reset_1", + } + ctx.touchLease(time.Now(), pool.idleTTL) + + // 第一次调度 200ms 延迟 + pool.scheduleDelayedPing(ctx, 200*time.Millisecond) + + // 100ms 后再次调度 200ms(应 Reset timer,从此刻起再等 200ms) + time.Sleep(100 * time.Millisecond) + pool.scheduleDelayedPing(ctx, 200*time.Millisecond) + + // 150ms 后(距第一次 250ms,距 Reset 150ms)应未执行 + time.Sleep(150 * time.Millisecond) + ctx.mu.Lock() + broken := ctx.broken + ctx.mu.Unlock() + require.False(t, broken, "Reset 后 150ms 不应触发 Ping") + + // 再等 100ms(距 Reset 250ms)应已执行 + time.Sleep(100 * time.Millisecond) + ctx.mu.Lock() + broken = ctx.broken + ctx.mu.Unlock() + require.True(t, broken, "Reset 后 250ms 应已触发 Ping") +} diff --git a/backend/internal/service/openai_ws_ingress_normalizer.go b/backend/internal/service/openai_ws_ingress_normalizer.go new file mode 100644 index 000000000..74cd67374 --- /dev/null +++ b/backend/internal/service/openai_ws_ingress_normalizer.go @@ -0,0 +1,42 @@ +package service + +type openAIWSIngressPreSendNormalizeInput struct { + accountID int64 + turn int + connID string + + currentPayload []byte + currentPayloadBytes int + currentPreviousResponseID string + expectedPreviousResponse string + pendingExpectedCallIDs []string +} + +type openAIWSIngressPreSendNormalizeOutput struct { + currentPayload []byte + currentPayloadBytes int + currentPreviousResponseID string + expectedPreviousResponseID string + pendingExpectedCallIDs []string + functionCallOutputCallIDs []string + hasFunctionCallOutputCallID bool +} + +// normalizeOpenAIWSIngressPayloadBeforeSend 纯透传 + callID 提取。 +// proxy 只负责转发、认证替换、计费,所有边缘场景由 recoverIngressPrevResponseNotFound 兜底。 +func normalizeOpenAIWSIngressPayloadBeforeSend(input openAIWSIngressPreSendNormalizeInput) openAIWSIngressPreSendNormalizeOutput { + _ = input.accountID + _ = input.turn + _ = input.connID + callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(input.currentPayload) + + return openAIWSIngressPreSendNormalizeOutput{ + currentPayload: input.currentPayload, + currentPayloadBytes: input.currentPayloadBytes, + currentPreviousResponseID: input.currentPreviousResponseID, + expectedPreviousResponseID: input.expectedPreviousResponse, + pendingExpectedCallIDs: input.pendingExpectedCallIDs, + functionCallOutputCallIDs: callIDs, + hasFunctionCallOutputCallID: len(callIDs) > 0, + } +} diff --git a/backend/internal/service/openai_ws_ingress_normalizer_test.go b/backend/internal/service/openai_ws_ingress_normalizer_test.go new file mode 100644 index 000000000..ff68f4b5b --- /dev/null +++ b/backend/internal/service/openai_ws_ingress_normalizer_test.go @@ -0,0 +1,193 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// === 纯透传 normalizer 行为测试 === + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_BasicPassthrough(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_prev", + "input":[{"type":"input_text","text":"hello"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 1, + turn: 2, + connID: "conn_1", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "resp_prev", + expectedPreviousResponse: "resp_expected", + pendingExpectedCallIDs: []string{"call_1"}, + }) + + require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传") + require.Equal(t, len(payload), out.currentPayloadBytes) + require.Equal(t, "resp_prev", out.currentPreviousResponseID, "currentPreviousResponseID 应原样透传") + require.Equal(t, "resp_expected", out.expectedPreviousResponseID, "expectedPreviousResponseID 应原样透传") + require.Equal(t, []string{"call_1"}, out.pendingExpectedCallIDs, "pendingExpectedCallIDs 应原样透传") + require.False(t, out.hasFunctionCallOutputCallID, "无 FCO 时应为 false") + require.Empty(t, out.functionCallOutputCallIDs) +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_ExtractsCallID(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_prev", + "input":[{"type":"function_call_output","call_id":"call_abc","output":"{}"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 2, + turn: 3, + connID: "conn_2", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "resp_prev", + expectedPreviousResponse: "resp_prev", + }) + + require.True(t, out.hasFunctionCallOutputCallID, "有 FCO 时应为 true") + require.Equal(t, []string{"call_abc"}, out.functionCallOutputCallIDs, "应正确提取 call_id") +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_NoFCO(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[{"type":"input_text","text":"hello"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 3, + turn: 1, + connID: "conn_3", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "", + expectedPreviousResponse: "", + }) + + require.False(t, out.hasFunctionCallOutputCallID) + require.Empty(t, out.functionCallOutputCallIDs) +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_MultipleFCO(t *testing.T) { + t.Parallel() + + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_ok", + "input":[ + {"type":"function_call_output","call_id":"call_a","output":"{}"}, + {"type":"function_call_output","call_id":"call_b","output":"{}"}, + {"type":"function_call_output","call_id":"call_c","output":"{}"} + ] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 4, + turn: 2, + connID: "conn_4", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "resp_ok", + expectedPreviousResponse: "resp_ok", + }) + + require.True(t, out.hasFunctionCallOutputCallID) + require.ElementsMatch(t, []string{"call_a", "call_b", "call_c"}, out.functionCallOutputCallIDs, "应提取所有 call_id") +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_EmptyInput(t *testing.T) { + t.Parallel() + + payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 5, + turn: 1, + connID: "conn_5", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "", + expectedPreviousResponse: "", + }) + + require.JSONEq(t, string(payload), string(out.currentPayload), "空 input 不应 panic") + require.False(t, out.hasFunctionCallOutputCallID) + require.Empty(t, out.functionCallOutputCallIDs) +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_ESCInterruptPassthrough(t *testing.T) { + t.Parallel() + + // 场景:ESC 中断后客户端有意不传 previous_response_id,有 pendingCallIDs。 + // 透传不补 prev、不注入 aborted output。 + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "input":[{"type":"input_text","text":"new task after ESC"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 6, + turn: 5, + connID: "conn_esc", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "", + expectedPreviousResponse: "resp_prev_turn4", + pendingExpectedCallIDs: []string{"call_pending_1", "call_pending_2"}, + }) + + require.Empty(t, out.currentPreviousResponseID, "透传不应补 previous_response_id") + require.Equal(t, "resp_prev_turn4", out.expectedPreviousResponseID) + require.False(t, out.hasFunctionCallOutputCallID, "透传不应注入 function_call_output") + require.Empty(t, out.functionCallOutputCallIDs) + require.Equal(t, []string{"call_pending_1", "call_pending_2"}, out.pendingExpectedCallIDs, "pendingExpectedCallIDs 应原样传递") + require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传") +} + +func TestNormalizeOpenAIWSIngressPayloadBeforeSend_StalePrevPassthrough(t *testing.T) { + t.Parallel() + + // 场景:客户端传了过期 previous_response_id,透传不对齐。 + // 由下游 recoverIngressPrevResponseNotFound 处理。 + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "previous_response_id":"resp_stale", + "input":[{"type":"function_call_output","call_id":"call_1","output":"{}"}] + }`) + + out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{ + accountID: 7, + turn: 4, + connID: "conn_stale", + currentPayload: payload, + currentPayloadBytes: len(payload), + currentPreviousResponseID: "resp_stale", + expectedPreviousResponse: "resp_latest", + }) + + require.Equal(t, "resp_stale", out.currentPreviousResponseID, "透传不应对齐 previous_response_id") + require.Equal(t, "resp_latest", out.expectedPreviousResponseID) + require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传") + require.True(t, out.hasFunctionCallOutputCallID) + require.Equal(t, []string{"call_1"}, out.functionCallOutputCallIDs) +} diff --git a/backend/internal/service/openai_ws_passthrough_layout_test.go b/backend/internal/service/openai_ws_passthrough_layout_test.go new file mode 100644 index 000000000..0cf611faa --- /dev/null +++ b/backend/internal/service/openai_ws_passthrough_layout_test.go @@ -0,0 +1,39 @@ +package service + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSPassthroughDataPlaneLayout(t *testing.T) { + t.Parallel() + + serviceDir, err := os.Getwd() + require.NoError(t, err) + v2Dir := filepath.Join(serviceDir, "openai_ws_v2") + forwarderFile := filepath.Join(serviceDir, "openai_ws_forwarder.go") + + requiredFiles := []string{ + filepath.Join(v2Dir, "entry.go"), + filepath.Join(v2Dir, "caddy_adapter.go"), + filepath.Join(v2Dir, "passthrough_relay.go"), + } + for _, file := range requiredFiles { + info, err := os.Stat(file) + require.NoError(t, err) + require.False(t, info.IsDir()) + } + + content, err := os.ReadFile(forwarderFile) + require.NoError(t, err) + forwarder := string(content) + + // openai_ws_forwarder 允许分流入口,不承载 passthrough 数据面函数实现。 + require.Contains(t, forwarder, "proxyResponsesWebSocketV2Passthrough(") + require.NotContains(t, forwarder, "func runUpstreamToClient(") + require.NotContains(t, forwarder, "func observeUpstreamMessage(") + require.NotContains(t, forwarder, "func parseUsageAndAccumulate(") +} diff --git a/backend/internal/service/openai_ws_pool.go b/backend/internal/service/openai_ws_pool.go deleted file mode 100644 index db6a96a7d..000000000 --- a/backend/internal/service/openai_ws_pool.go +++ /dev/null @@ -1,1706 +0,0 @@ -package service - -import ( - "context" - "errors" - "fmt" - "math" - "net/http" - "sort" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "golang.org/x/sync/errgroup" -) - -const ( - openAIWSConnMaxAge = 60 * time.Minute - openAIWSConnHealthCheckIdle = 90 * time.Second - openAIWSConnHealthCheckTO = 2 * time.Second - openAIWSConnPrewarmExtraDelay = 2 * time.Second - openAIWSAcquireCleanupInterval = 3 * time.Second - openAIWSBackgroundPingInterval = 30 * time.Second - openAIWSBackgroundSweepTicker = 30 * time.Second - - openAIWSPrewarmFailureWindow = 30 * time.Second - openAIWSPrewarmFailureSuppress = 2 -) - -var ( - errOpenAIWSConnClosed = errors.New("openai ws connection closed") - errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full") - errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable") -) - -type openAIWSDialError struct { - StatusCode int - ResponseHeaders http.Header - Err error -} - -func (e *openAIWSDialError) Error() string { - if e == nil { - return "" - } - if e.StatusCode > 0 { - return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err) - } - return fmt.Sprintf("openai ws dial failed: %v", e.Err) -} - -func (e *openAIWSDialError) Unwrap() error { - if e == nil { - return nil - } - return e.Err -} - -type openAIWSAcquireRequest struct { - Account *Account - WSURL string - Headers http.Header - ProxyURL string - PreferredConnID string - // ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。 - ForceNewConn bool - // ForcePreferredConn: 强制本次只使用 PreferredConnID,禁止漂移到其它连接。 - ForcePreferredConn bool -} - -type openAIWSConnLease struct { - pool *openAIWSConnPool - accountID int64 - conn *openAIWSConn - queueWait time.Duration - connPick time.Duration - reused bool - released atomic.Bool -} - -func (l *openAIWSConnLease) activeConn() (*openAIWSConn, error) { - if l == nil || l.conn == nil { - return nil, errOpenAIWSConnClosed - } - if l.released.Load() { - return nil, errOpenAIWSConnClosed - } - return l.conn, nil -} - -func (l *openAIWSConnLease) ConnID() string { - if l == nil || l.conn == nil { - return "" - } - return l.conn.id -} - -func (l *openAIWSConnLease) QueueWaitDuration() time.Duration { - if l == nil { - return 0 - } - return l.queueWait -} - -func (l *openAIWSConnLease) ConnPickDuration() time.Duration { - if l == nil { - return 0 - } - return l.connPick -} - -func (l *openAIWSConnLease) Reused() bool { - if l == nil { - return false - } - return l.reused -} - -func (l *openAIWSConnLease) HandshakeHeader(name string) string { - if l == nil || l.conn == nil { - return "" - } - return l.conn.handshakeHeader(name) -} - -func (l *openAIWSConnLease) IsPrewarmed() bool { - if l == nil || l.conn == nil { - return false - } - return l.conn.isPrewarmed() -} - -func (l *openAIWSConnLease) MarkPrewarmed() { - if l == nil || l.conn == nil { - return - } - l.conn.markPrewarmed() -} - -func (l *openAIWSConnLease) WriteJSON(value any, timeout time.Duration) error { - conn, err := l.activeConn() - if err != nil { - return err - } - return conn.writeJSONWithTimeout(context.Background(), value, timeout) -} - -func (l *openAIWSConnLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error { - conn, err := l.activeConn() - if err != nil { - return err - } - return conn.writeJSONWithTimeout(ctx, value, timeout) -} - -func (l *openAIWSConnLease) WriteJSONContext(ctx context.Context, value any) error { - conn, err := l.activeConn() - if err != nil { - return err - } - return conn.writeJSON(value, ctx) -} - -func (l *openAIWSConnLease) ReadMessage(timeout time.Duration) ([]byte, error) { - conn, err := l.activeConn() - if err != nil { - return nil, err - } - return conn.readMessageWithTimeout(timeout) -} - -func (l *openAIWSConnLease) ReadMessageContext(ctx context.Context) ([]byte, error) { - conn, err := l.activeConn() - if err != nil { - return nil, err - } - return conn.readMessage(ctx) -} - -func (l *openAIWSConnLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) { - conn, err := l.activeConn() - if err != nil { - return nil, err - } - return conn.readMessageWithContextTimeout(ctx, timeout) -} - -func (l *openAIWSConnLease) PingWithTimeout(timeout time.Duration) error { - conn, err := l.activeConn() - if err != nil { - return err - } - return conn.pingWithTimeout(timeout) -} - -func (l *openAIWSConnLease) MarkBroken() { - if l == nil || l.pool == nil || l.conn == nil || l.released.Load() { - return - } - l.pool.evictConn(l.accountID, l.conn.id) -} - -func (l *openAIWSConnLease) Release() { - if l == nil || l.conn == nil { - return - } - if !l.released.CompareAndSwap(false, true) { - return - } - l.conn.release() -} - -type openAIWSConn struct { - id string - ws openAIWSClientConn - - handshakeHeaders http.Header - - leaseCh chan struct{} - closedCh chan struct{} - closeOnce sync.Once - - readMu sync.Mutex - writeMu sync.Mutex - - waiters atomic.Int32 - createdAtNano atomic.Int64 - lastUsedNano atomic.Int64 - prewarmed atomic.Bool -} - -func newOpenAIWSConn(id string, _ int64, ws openAIWSClientConn, handshakeHeaders http.Header) *openAIWSConn { - now := time.Now() - conn := &openAIWSConn{ - id: id, - ws: ws, - handshakeHeaders: cloneHeader(handshakeHeaders), - leaseCh: make(chan struct{}, 1), - closedCh: make(chan struct{}), - } - conn.leaseCh <- struct{}{} - conn.createdAtNano.Store(now.UnixNano()) - conn.lastUsedNano.Store(now.UnixNano()) - return conn -} - -func (c *openAIWSConn) tryAcquire() bool { - if c == nil { - return false - } - select { - case <-c.closedCh: - return false - default: - } - select { - case <-c.leaseCh: - select { - case <-c.closedCh: - c.release() - return false - default: - } - return true - default: - return false - } -} - -func (c *openAIWSConn) acquire(ctx context.Context) error { - if c == nil { - return errOpenAIWSConnClosed - } - for { - select { - case <-ctx.Done(): - return ctx.Err() - case <-c.closedCh: - return errOpenAIWSConnClosed - case <-c.leaseCh: - select { - case <-c.closedCh: - c.release() - return errOpenAIWSConnClosed - default: - } - return nil - } - } -} - -func (c *openAIWSConn) release() { - if c == nil { - return - } - select { - case c.leaseCh <- struct{}{}: - default: - } - c.touch() -} - -func (c *openAIWSConn) close() { - if c == nil { - return - } - c.closeOnce.Do(func() { - close(c.closedCh) - if c.ws != nil { - _ = c.ws.Close() - } - select { - case c.leaseCh <- struct{}{}: - default: - } - }) -} - -func (c *openAIWSConn) writeJSONWithTimeout(parent context.Context, value any, timeout time.Duration) error { - if c == nil { - return errOpenAIWSConnClosed - } - select { - case <-c.closedCh: - return errOpenAIWSConnClosed - default: - } - - writeCtx := parent - if writeCtx == nil { - writeCtx = context.Background() - } - if timeout <= 0 { - return c.writeJSON(value, writeCtx) - } - var cancel context.CancelFunc - writeCtx, cancel = context.WithTimeout(writeCtx, timeout) - defer cancel() - return c.writeJSON(value, writeCtx) -} - -func (c *openAIWSConn) writeJSON(value any, writeCtx context.Context) error { - c.writeMu.Lock() - defer c.writeMu.Unlock() - if c.ws == nil { - return errOpenAIWSConnClosed - } - if writeCtx == nil { - writeCtx = context.Background() - } - if err := c.ws.WriteJSON(writeCtx, value); err != nil { - return err - } - c.touch() - return nil -} - -func (c *openAIWSConn) readMessageWithTimeout(timeout time.Duration) ([]byte, error) { - return c.readMessageWithContextTimeout(context.Background(), timeout) -} - -func (c *openAIWSConn) readMessageWithContextTimeout(parent context.Context, timeout time.Duration) ([]byte, error) { - if c == nil { - return nil, errOpenAIWSConnClosed - } - select { - case <-c.closedCh: - return nil, errOpenAIWSConnClosed - default: - } - - if parent == nil { - parent = context.Background() - } - if timeout <= 0 { - return c.readMessage(parent) - } - readCtx, cancel := context.WithTimeout(parent, timeout) - defer cancel() - return c.readMessage(readCtx) -} - -func (c *openAIWSConn) readMessage(readCtx context.Context) ([]byte, error) { - c.readMu.Lock() - defer c.readMu.Unlock() - if c.ws == nil { - return nil, errOpenAIWSConnClosed - } - if readCtx == nil { - readCtx = context.Background() - } - payload, err := c.ws.ReadMessage(readCtx) - if err != nil { - return nil, err - } - c.touch() - return payload, nil -} - -func (c *openAIWSConn) pingWithTimeout(timeout time.Duration) error { - if c == nil { - return errOpenAIWSConnClosed - } - select { - case <-c.closedCh: - return errOpenAIWSConnClosed - default: - } - - c.writeMu.Lock() - defer c.writeMu.Unlock() - if c.ws == nil { - return errOpenAIWSConnClosed - } - if timeout <= 0 { - timeout = openAIWSConnHealthCheckTO - } - pingCtx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - if err := c.ws.Ping(pingCtx); err != nil { - return err - } - return nil -} - -func (c *openAIWSConn) touch() { - if c == nil { - return - } - c.lastUsedNano.Store(time.Now().UnixNano()) -} - -func (c *openAIWSConn) createdAt() time.Time { - if c == nil { - return time.Time{} - } - nano := c.createdAtNano.Load() - if nano <= 0 { - return time.Time{} - } - return time.Unix(0, nano) -} - -func (c *openAIWSConn) lastUsedAt() time.Time { - if c == nil { - return time.Time{} - } - nano := c.lastUsedNano.Load() - if nano <= 0 { - return time.Time{} - } - return time.Unix(0, nano) -} - -func (c *openAIWSConn) idleDuration(now time.Time) time.Duration { - if c == nil { - return 0 - } - last := c.lastUsedAt() - if last.IsZero() { - return 0 - } - return now.Sub(last) -} - -func (c *openAIWSConn) age(now time.Time) time.Duration { - if c == nil { - return 0 - } - created := c.createdAt() - if created.IsZero() { - return 0 - } - return now.Sub(created) -} - -func (c *openAIWSConn) isLeased() bool { - if c == nil { - return false - } - return len(c.leaseCh) == 0 -} - -func (c *openAIWSConn) handshakeHeader(name string) string { - if c == nil || c.handshakeHeaders == nil { - return "" - } - return strings.TrimSpace(c.handshakeHeaders.Get(strings.TrimSpace(name))) -} - -func (c *openAIWSConn) isPrewarmed() bool { - if c == nil { - return false - } - return c.prewarmed.Load() -} - -func (c *openAIWSConn) markPrewarmed() { - if c == nil { - return - } - c.prewarmed.Store(true) -} - -type openAIWSAccountPool struct { - mu sync.Mutex - conns map[string]*openAIWSConn - pinnedConns map[string]int - creating int - lastCleanupAt time.Time - lastAcquire *openAIWSAcquireRequest - prewarmActive bool - prewarmUntil time.Time - prewarmFails int - prewarmFailAt time.Time -} - -type OpenAIWSPoolMetricsSnapshot struct { - AcquireTotal int64 - AcquireReuseTotal int64 - AcquireCreateTotal int64 - AcquireQueueWaitTotal int64 - AcquireQueueWaitMsTotal int64 - ConnPickTotal int64 - ConnPickMsTotal int64 - ScaleUpTotal int64 - ScaleDownTotal int64 -} - -type openAIWSPoolMetrics struct { - acquireTotal atomic.Int64 - acquireReuseTotal atomic.Int64 - acquireCreateTotal atomic.Int64 - acquireQueueWaitTotal atomic.Int64 - acquireQueueWaitMs atomic.Int64 - connPickTotal atomic.Int64 - connPickMs atomic.Int64 - scaleUpTotal atomic.Int64 - scaleDownTotal atomic.Int64 -} - -type openAIWSConnPool struct { - cfg *config.Config - // 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。 - clientDialer openAIWSClientDialer - - accounts sync.Map // key: int64(accountID), value: *openAIWSAccountPool - seq atomic.Uint64 - - metrics openAIWSPoolMetrics - - workerStopCh chan struct{} - workerWg sync.WaitGroup - closeOnce sync.Once -} - -func newOpenAIWSConnPool(cfg *config.Config) *openAIWSConnPool { - pool := &openAIWSConnPool{ - cfg: cfg, - clientDialer: newDefaultOpenAIWSClientDialer(), - workerStopCh: make(chan struct{}), - } - pool.startBackgroundWorkers() - return pool -} - -func (p *openAIWSConnPool) SnapshotMetrics() OpenAIWSPoolMetricsSnapshot { - if p == nil { - return OpenAIWSPoolMetricsSnapshot{} - } - return OpenAIWSPoolMetricsSnapshot{ - AcquireTotal: p.metrics.acquireTotal.Load(), - AcquireReuseTotal: p.metrics.acquireReuseTotal.Load(), - AcquireCreateTotal: p.metrics.acquireCreateTotal.Load(), - AcquireQueueWaitTotal: p.metrics.acquireQueueWaitTotal.Load(), - AcquireQueueWaitMsTotal: p.metrics.acquireQueueWaitMs.Load(), - ConnPickTotal: p.metrics.connPickTotal.Load(), - ConnPickMsTotal: p.metrics.connPickMs.Load(), - ScaleUpTotal: p.metrics.scaleUpTotal.Load(), - ScaleDownTotal: p.metrics.scaleDownTotal.Load(), - } -} - -func (p *openAIWSConnPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { - if p == nil { - return OpenAIWSTransportMetricsSnapshot{} - } - if dialer, ok := p.clientDialer.(openAIWSTransportMetricsDialer); ok { - return dialer.SnapshotTransportMetrics() - } - return OpenAIWSTransportMetricsSnapshot{} -} - -func (p *openAIWSConnPool) setClientDialerForTest(dialer openAIWSClientDialer) { - if p == nil || dialer == nil { - return - } - p.clientDialer = dialer -} - -// Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。 -func (p *openAIWSConnPool) Close() { - if p == nil { - return - } - p.closeOnce.Do(func() { - if p.workerStopCh != nil { - close(p.workerStopCh) - } - p.workerWg.Wait() - // 遍历所有账户池,关闭全部空闲连接。 - p.accounts.Range(func(key, value any) bool { - ap, ok := value.(*openAIWSAccountPool) - if !ok || ap == nil { - return true - } - ap.mu.Lock() - for _, conn := range ap.conns { - if conn != nil && !conn.isLeased() { - conn.close() - } - } - ap.mu.Unlock() - return true - }) - }) -} - -func (p *openAIWSConnPool) startBackgroundWorkers() { - if p == nil || p.workerStopCh == nil { - return - } - p.workerWg.Add(2) - go func() { - defer p.workerWg.Done() - p.runBackgroundPingWorker() - }() - go func() { - defer p.workerWg.Done() - p.runBackgroundCleanupWorker() - }() -} - -type openAIWSIdlePingCandidate struct { - accountID int64 - conn *openAIWSConn -} - -func (p *openAIWSConnPool) runBackgroundPingWorker() { - if p == nil { - return - } - ticker := time.NewTicker(openAIWSBackgroundPingInterval) - defer ticker.Stop() - for { - select { - case <-ticker.C: - p.runBackgroundPingSweep() - case <-p.workerStopCh: - return - } - } -} - -func (p *openAIWSConnPool) runBackgroundPingSweep() { - if p == nil { - return - } - candidates := p.snapshotIdleConnsForPing() - var g errgroup.Group - g.SetLimit(10) - for _, item := range candidates { - item := item - if item.conn == nil || item.conn.isLeased() || item.conn.waiters.Load() > 0 { - continue - } - g.Go(func() error { - if err := item.conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { - p.evictConn(item.accountID, item.conn.id) - } - return nil - }) - } - _ = g.Wait() -} - -func (p *openAIWSConnPool) snapshotIdleConnsForPing() []openAIWSIdlePingCandidate { - if p == nil { - return nil - } - candidates := make([]openAIWSIdlePingCandidate, 0) - p.accounts.Range(func(key, value any) bool { - accountID, ok := key.(int64) - if !ok || accountID <= 0 { - return true - } - ap, ok := value.(*openAIWSAccountPool) - if !ok || ap == nil { - return true - } - ap.mu.Lock() - for _, conn := range ap.conns { - if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 { - continue - } - candidates = append(candidates, openAIWSIdlePingCandidate{ - accountID: accountID, - conn: conn, - }) - } - ap.mu.Unlock() - return true - }) - return candidates -} - -func (p *openAIWSConnPool) runBackgroundCleanupWorker() { - if p == nil { - return - } - ticker := time.NewTicker(openAIWSBackgroundSweepTicker) - defer ticker.Stop() - for { - select { - case <-ticker.C: - p.runBackgroundCleanupSweep(time.Now()) - case <-p.workerStopCh: - return - } - } -} - -func (p *openAIWSConnPool) runBackgroundCleanupSweep(now time.Time) { - if p == nil { - return - } - type cleanupResult struct { - evicted []*openAIWSConn - } - results := make([]cleanupResult, 0) - p.accounts.Range(func(_ any, value any) bool { - ap, ok := value.(*openAIWSAccountPool) - if !ok || ap == nil { - return true - } - maxConns := p.maxConnsHardCap() - ap.mu.Lock() - if ap.lastAcquire != nil && ap.lastAcquire.Account != nil { - maxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account) - } - evicted := p.cleanupAccountLocked(ap, now, maxConns) - ap.lastCleanupAt = now - ap.mu.Unlock() - if len(evicted) > 0 { - results = append(results, cleanupResult{evicted: evicted}) - } - return true - }) - for _, result := range results { - closeOpenAIWSConns(result.evicted) - } -} - -func (p *openAIWSConnPool) Acquire(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConnLease, error) { - if p != nil { - p.metrics.acquireTotal.Add(1) - } - return p.acquire(ctx, cloneOpenAIWSAcquireRequest(req), 0) -} - -func (p *openAIWSConnPool) acquire(ctx context.Context, req openAIWSAcquireRequest, retry int) (*openAIWSConnLease, error) { - if p == nil || req.Account == nil || req.Account.ID <= 0 { - return nil, errors.New("invalid ws acquire request") - } - if stringsTrim(req.WSURL) == "" { - return nil, errors.New("ws url is empty") - } - - accountID := req.Account.ID - effectiveMaxConns := p.effectiveMaxConnsByAccount(req.Account) - if effectiveMaxConns <= 0 { - return nil, errOpenAIWSConnQueueFull - } - var evicted []*openAIWSConn - ap := p.getOrCreateAccountPool(accountID) - ap.mu.Lock() - ap.lastAcquire = cloneOpenAIWSAcquireRequestPtr(&req) - now := time.Now() - if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval { - evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns) - ap.lastCleanupAt = now - } - pickStartedAt := time.Now() - allowReuse := !req.ForceNewConn - preferredConnID := stringsTrim(req.PreferredConnID) - forcePreferredConn := allowReuse && req.ForcePreferredConn - - if allowReuse { - if forcePreferredConn { - if preferredConnID == "" { - p.recordConnPickDuration(time.Since(pickStartedAt)) - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - return nil, errOpenAIWSPreferredConnUnavailable - } - preferredConn, ok := ap.conns[preferredConnID] - if !ok || preferredConn == nil { - p.recordConnPickDuration(time.Since(pickStartedAt)) - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - return nil, errOpenAIWSPreferredConnUnavailable - } - if preferredConn.tryAcquire() { - connPick := time.Since(pickStartedAt) - p.recordConnPickDuration(connPick) - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - if p.shouldHealthCheckConn(preferredConn) { - if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { - preferredConn.close() - p.evictConn(accountID, preferredConn.id) - if retry < 1 { - return p.acquire(ctx, req, retry+1) - } - return nil, err - } - } - lease := &openAIWSConnLease{ - pool: p, - accountID: accountID, - conn: preferredConn, - connPick: connPick, - reused: true, - } - p.metrics.acquireReuseTotal.Add(1) - p.ensureTargetIdleAsync(accountID) - return lease, nil - } - - connPick := time.Since(pickStartedAt) - p.recordConnPickDuration(connPick) - if int(preferredConn.waiters.Load()) >= p.queueLimitPerConn() { - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - return nil, errOpenAIWSConnQueueFull - } - preferredConn.waiters.Add(1) - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - defer preferredConn.waiters.Add(-1) - waitStart := time.Now() - p.metrics.acquireQueueWaitTotal.Add(1) - - if err := preferredConn.acquire(ctx); err != nil { - if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 { - return p.acquire(ctx, req, retry+1) - } - return nil, err - } - if p.shouldHealthCheckConn(preferredConn) { - if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { - preferredConn.release() - preferredConn.close() - p.evictConn(accountID, preferredConn.id) - if retry < 1 { - return p.acquire(ctx, req, retry+1) - } - return nil, err - } - } - - queueWait := time.Since(waitStart) - p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds()) - lease := &openAIWSConnLease{ - pool: p, - accountID: accountID, - conn: preferredConn, - queueWait: queueWait, - connPick: connPick, - reused: true, - } - p.metrics.acquireReuseTotal.Add(1) - p.ensureTargetIdleAsync(accountID) - return lease, nil - } - - if preferredConnID != "" { - if conn, ok := ap.conns[preferredConnID]; ok && conn.tryAcquire() { - connPick := time.Since(pickStartedAt) - p.recordConnPickDuration(connPick) - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - if p.shouldHealthCheckConn(conn) { - if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { - conn.close() - p.evictConn(accountID, conn.id) - if retry < 1 { - return p.acquire(ctx, req, retry+1) - } - return nil, err - } - } - lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true} - p.metrics.acquireReuseTotal.Add(1) - p.ensureTargetIdleAsync(accountID) - return lease, nil - } - } - - best := p.pickLeastBusyConnLocked(ap, "") - if best != nil && best.tryAcquire() { - connPick := time.Since(pickStartedAt) - p.recordConnPickDuration(connPick) - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - if p.shouldHealthCheckConn(best) { - if err := best.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { - best.close() - p.evictConn(accountID, best.id) - if retry < 1 { - return p.acquire(ctx, req, retry+1) - } - return nil, err - } - } - lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: best, connPick: connPick, reused: true} - p.metrics.acquireReuseTotal.Add(1) - p.ensureTargetIdleAsync(accountID) - return lease, nil - } - for _, conn := range ap.conns { - if conn == nil || conn == best { - continue - } - if conn.tryAcquire() { - connPick := time.Since(pickStartedAt) - p.recordConnPickDuration(connPick) - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - if p.shouldHealthCheckConn(conn) { - if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { - conn.close() - p.evictConn(accountID, conn.id) - if retry < 1 { - return p.acquire(ctx, req, retry+1) - } - return nil, err - } - } - lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true} - p.metrics.acquireReuseTotal.Add(1) - p.ensureTargetIdleAsync(accountID) - return lease, nil - } - } - } - - if req.ForceNewConn && len(ap.conns)+ap.creating >= effectiveMaxConns { - if idle := p.pickOldestIdleConnLocked(ap); idle != nil { - delete(ap.conns, idle.id) - evicted = append(evicted, idle) - p.metrics.scaleDownTotal.Add(1) - } - } - - if len(ap.conns)+ap.creating < effectiveMaxConns { - connPick := time.Since(pickStartedAt) - p.recordConnPickDuration(connPick) - ap.creating++ - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - - conn, dialErr := p.dialConn(ctx, req) - - ap = p.getOrCreateAccountPool(accountID) - ap.mu.Lock() - ap.creating-- - if dialErr != nil { - ap.prewarmFails++ - ap.prewarmFailAt = time.Now() - ap.mu.Unlock() - return nil, dialErr - } - ap.conns[conn.id] = conn - ap.prewarmFails = 0 - ap.prewarmFailAt = time.Time{} - ap.mu.Unlock() - p.metrics.acquireCreateTotal.Add(1) - - if !conn.tryAcquire() { - if err := conn.acquire(ctx); err != nil { - conn.close() - p.evictConn(accountID, conn.id) - return nil, err - } - } - lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick} - p.ensureTargetIdleAsync(accountID) - return lease, nil - } - - if req.ForceNewConn { - p.recordConnPickDuration(time.Since(pickStartedAt)) - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - return nil, errOpenAIWSConnQueueFull - } - - target := p.pickLeastBusyConnLocked(ap, req.PreferredConnID) - connPick := time.Since(pickStartedAt) - p.recordConnPickDuration(connPick) - if target == nil { - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - return nil, errOpenAIWSConnClosed - } - if int(target.waiters.Load()) >= p.queueLimitPerConn() { - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - return nil, errOpenAIWSConnQueueFull - } - target.waiters.Add(1) - ap.mu.Unlock() - closeOpenAIWSConns(evicted) - defer target.waiters.Add(-1) - waitStart := time.Now() - p.metrics.acquireQueueWaitTotal.Add(1) - - if err := target.acquire(ctx); err != nil { - if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 { - return p.acquire(ctx, req, retry+1) - } - return nil, err - } - if p.shouldHealthCheckConn(target) { - if err := target.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { - target.release() - target.close() - p.evictConn(accountID, target.id) - if retry < 1 { - return p.acquire(ctx, req, retry+1) - } - return nil, err - } - } - - queueWait := time.Since(waitStart) - p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds()) - lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: target, queueWait: queueWait, connPick: connPick, reused: true} - p.metrics.acquireReuseTotal.Add(1) - p.ensureTargetIdleAsync(accountID) - return lease, nil -} - -func (p *openAIWSConnPool) recordConnPickDuration(duration time.Duration) { - if p == nil { - return - } - if duration < 0 { - duration = 0 - } - p.metrics.connPickTotal.Add(1) - p.metrics.connPickMs.Add(duration.Milliseconds()) -} - -func (p *openAIWSConnPool) pickOldestIdleConnLocked(ap *openAIWSAccountPool) *openAIWSConn { - if ap == nil || len(ap.conns) == 0 { - return nil - } - var oldest *openAIWSConn - for _, conn := range ap.conns { - if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) { - continue - } - if oldest == nil || conn.lastUsedAt().Before(oldest.lastUsedAt()) { - oldest = conn - } - } - return oldest -} - -func (p *openAIWSConnPool) getOrCreateAccountPool(accountID int64) *openAIWSAccountPool { - if p == nil || accountID <= 0 { - return nil - } - if existing, ok := p.accounts.Load(accountID); ok { - if ap, typed := existing.(*openAIWSAccountPool); typed && ap != nil { - return ap - } - } - ap := &openAIWSAccountPool{ - conns: make(map[string]*openAIWSConn), - pinnedConns: make(map[string]int), - } - actual, _ := p.accounts.LoadOrStore(accountID, ap) - if typed, ok := actual.(*openAIWSAccountPool); ok && typed != nil { - return typed - } - return ap -} - -// ensureAccountPoolLocked 兼容旧调用。 -func (p *openAIWSConnPool) ensureAccountPoolLocked(accountID int64) *openAIWSAccountPool { - return p.getOrCreateAccountPool(accountID) -} - -func (p *openAIWSConnPool) getAccountPool(accountID int64) (*openAIWSAccountPool, bool) { - if p == nil || accountID <= 0 { - return nil, false - } - value, ok := p.accounts.Load(accountID) - if !ok || value == nil { - return nil, false - } - ap, typed := value.(*openAIWSAccountPool) - return ap, typed && ap != nil -} - -func (p *openAIWSConnPool) isConnPinnedLocked(ap *openAIWSAccountPool, connID string) bool { - if ap == nil || connID == "" || len(ap.pinnedConns) == 0 { - return false - } - return ap.pinnedConns[connID] > 0 -} - -func (p *openAIWSConnPool) cleanupAccountLocked(ap *openAIWSAccountPool, now time.Time, maxConns int) []*openAIWSConn { - if ap == nil { - return nil - } - maxAge := p.maxConnAge() - - evicted := make([]*openAIWSConn, 0) - for id, conn := range ap.conns { - if conn == nil { - delete(ap.conns, id) - if len(ap.pinnedConns) > 0 { - delete(ap.pinnedConns, id) - } - continue - } - select { - case <-conn.closedCh: - delete(ap.conns, id) - if len(ap.pinnedConns) > 0 { - delete(ap.pinnedConns, id) - } - evicted = append(evicted, conn) - continue - default: - } - if p.isConnPinnedLocked(ap, id) { - continue - } - if maxAge > 0 && !conn.isLeased() && conn.age(now) > maxAge { - delete(ap.conns, id) - if len(ap.pinnedConns) > 0 { - delete(ap.pinnedConns, id) - } - evicted = append(evicted, conn) - } - } - - if maxConns <= 0 { - maxConns = p.maxConnsHardCap() - } - maxIdle := p.maxIdlePerAccount() - if maxIdle < 0 || maxIdle > maxConns { - maxIdle = maxConns - } - if maxIdle >= 0 && len(ap.conns) > maxIdle { - idleConns := make([]*openAIWSConn, 0, len(ap.conns)) - for id, conn := range ap.conns { - if conn == nil { - delete(ap.conns, id) - if len(ap.pinnedConns) > 0 { - delete(ap.pinnedConns, id) - } - continue - } - // 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。 - if conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) { - continue - } - idleConns = append(idleConns, conn) - } - sort.SliceStable(idleConns, func(i, j int) bool { - return idleConns[i].lastUsedAt().Before(idleConns[j].lastUsedAt()) - }) - redundant := len(ap.conns) - maxIdle - if redundant > len(idleConns) { - redundant = len(idleConns) - } - for i := 0; i < redundant; i++ { - conn := idleConns[i] - delete(ap.conns, conn.id) - if len(ap.pinnedConns) > 0 { - delete(ap.pinnedConns, conn.id) - } - evicted = append(evicted, conn) - } - if redundant > 0 { - p.metrics.scaleDownTotal.Add(int64(redundant)) - } - } - - return evicted -} - -func (p *openAIWSConnPool) pickLeastBusyConnLocked(ap *openAIWSAccountPool, preferredConnID string) *openAIWSConn { - if ap == nil || len(ap.conns) == 0 { - return nil - } - preferredConnID = stringsTrim(preferredConnID) - if preferredConnID != "" { - if conn, ok := ap.conns[preferredConnID]; ok { - return conn - } - } - var best *openAIWSConn - var bestWaiters int32 - var bestLastUsed time.Time - for _, conn := range ap.conns { - if conn == nil { - continue - } - waiters := conn.waiters.Load() - lastUsed := conn.lastUsedAt() - if best == nil || - waiters < bestWaiters || - (waiters == bestWaiters && lastUsed.Before(bestLastUsed)) { - best = conn - bestWaiters = waiters - bestLastUsed = lastUsed - } - } - return best -} - -func accountPoolLoadLocked(ap *openAIWSAccountPool) (inflight int, waiters int) { - if ap == nil { - return 0, 0 - } - for _, conn := range ap.conns { - if conn == nil { - continue - } - if conn.isLeased() { - inflight++ - } - waiters += int(conn.waiters.Load()) - } - return inflight, waiters -} - -// AccountPoolLoad 返回指定账号连接池的并发与排队快照。 -func (p *openAIWSConnPool) AccountPoolLoad(accountID int64) (inflight int, waiters int, conns int) { - if p == nil || accountID <= 0 { - return 0, 0, 0 - } - ap, ok := p.getAccountPool(accountID) - if !ok || ap == nil { - return 0, 0, 0 - } - ap.mu.Lock() - defer ap.mu.Unlock() - inflight, waiters = accountPoolLoadLocked(ap) - return inflight, waiters, len(ap.conns) -} - -func (p *openAIWSConnPool) ensureTargetIdleAsync(accountID int64) { - if p == nil || accountID <= 0 { - return - } - - var req openAIWSAcquireRequest - need := 0 - ap, ok := p.getAccountPool(accountID) - if !ok || ap == nil { - return - } - ap.mu.Lock() - defer ap.mu.Unlock() - if ap.lastAcquire == nil { - return - } - if ap.prewarmActive { - return - } - now := time.Now() - if !ap.prewarmUntil.IsZero() && now.Before(ap.prewarmUntil) { - return - } - if p.shouldSuppressPrewarmLocked(ap, now) { - return - } - effectiveMaxConns := p.maxConnsHardCap() - if ap.lastAcquire != nil && ap.lastAcquire.Account != nil { - effectiveMaxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account) - } - target := p.targetConnCountLocked(ap, effectiveMaxConns) - current := len(ap.conns) + ap.creating - if current >= target { - return - } - need = target - current - if need <= 0 { - return - } - req = cloneOpenAIWSAcquireRequest(*ap.lastAcquire) - ap.prewarmActive = true - if cooldown := p.prewarmCooldown(); cooldown > 0 { - ap.prewarmUntil = now.Add(cooldown) - } - ap.creating += need - p.metrics.scaleUpTotal.Add(int64(need)) - - go p.prewarmConns(accountID, req, need) -} - -func (p *openAIWSConnPool) targetConnCountLocked(ap *openAIWSAccountPool, maxConns int) int { - if ap == nil { - return 0 - } - - if maxConns <= 0 { - return 0 - } - - minIdle := p.minIdlePerAccount() - if minIdle < 0 { - minIdle = 0 - } - if minIdle > maxConns { - minIdle = maxConns - } - - inflight, waiters := accountPoolLoadLocked(ap) - utilization := p.targetUtilization() - demand := inflight + waiters - if demand <= 0 { - return minIdle - } - - target := 1 - if demand > 1 { - target = int(math.Ceil(float64(demand) / utilization)) - } - if waiters > 0 && target < len(ap.conns)+1 { - target = len(ap.conns) + 1 - } - if target < minIdle { - target = minIdle - } - if target > maxConns { - target = maxConns - } - return target -} - -func (p *openAIWSConnPool) prewarmConns(accountID int64, req openAIWSAcquireRequest, total int) { - defer func() { - if ap, ok := p.getAccountPool(accountID); ok && ap != nil { - ap.mu.Lock() - ap.prewarmActive = false - ap.mu.Unlock() - } - }() - - for i := 0; i < total; i++ { - ctx, cancel := context.WithTimeout(context.Background(), p.dialTimeout()+openAIWSConnPrewarmExtraDelay) - conn, err := p.dialConn(ctx, req) - cancel() - - ap, ok := p.getAccountPool(accountID) - if !ok || ap == nil { - if conn != nil { - conn.close() - } - return - } - ap.mu.Lock() - if ap.creating > 0 { - ap.creating-- - } - if err != nil { - ap.prewarmFails++ - ap.prewarmFailAt = time.Now() - ap.mu.Unlock() - continue - } - if len(ap.conns) >= p.effectiveMaxConnsByAccount(req.Account) { - ap.mu.Unlock() - conn.close() - continue - } - ap.conns[conn.id] = conn - ap.prewarmFails = 0 - ap.prewarmFailAt = time.Time{} - ap.mu.Unlock() - } -} - -func (p *openAIWSConnPool) evictConn(accountID int64, connID string) { - if p == nil || accountID <= 0 || stringsTrim(connID) == "" { - return - } - var conn *openAIWSConn - ap, ok := p.getAccountPool(accountID) - if ok && ap != nil { - ap.mu.Lock() - if c, exists := ap.conns[connID]; exists { - conn = c - delete(ap.conns, connID) - if len(ap.pinnedConns) > 0 { - delete(ap.pinnedConns, connID) - } - } - ap.mu.Unlock() - } - if conn != nil { - conn.close() - } -} - -func (p *openAIWSConnPool) PinConn(accountID int64, connID string) bool { - if p == nil || accountID <= 0 { - return false - } - connID = stringsTrim(connID) - if connID == "" { - return false - } - ap, ok := p.getAccountPool(accountID) - if !ok || ap == nil { - return false - } - ap.mu.Lock() - defer ap.mu.Unlock() - if _, exists := ap.conns[connID]; !exists { - return false - } - if ap.pinnedConns == nil { - ap.pinnedConns = make(map[string]int) - } - ap.pinnedConns[connID]++ - return true -} - -func (p *openAIWSConnPool) UnpinConn(accountID int64, connID string) { - if p == nil || accountID <= 0 { - return - } - connID = stringsTrim(connID) - if connID == "" { - return - } - ap, ok := p.getAccountPool(accountID) - if !ok || ap == nil { - return - } - ap.mu.Lock() - defer ap.mu.Unlock() - if len(ap.pinnedConns) == 0 { - return - } - count := ap.pinnedConns[connID] - if count <= 1 { - delete(ap.pinnedConns, connID) - return - } - ap.pinnedConns[connID] = count - 1 -} - -func (p *openAIWSConnPool) dialConn(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConn, error) { - if p == nil || p.clientDialer == nil { - return nil, errors.New("openai ws client dialer is nil") - } - conn, status, handshakeHeaders, err := p.clientDialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL) - if err != nil { - return nil, &openAIWSDialError{ - StatusCode: status, - ResponseHeaders: cloneHeader(handshakeHeaders), - Err: err, - } - } - if conn == nil { - return nil, &openAIWSDialError{ - StatusCode: status, - ResponseHeaders: cloneHeader(handshakeHeaders), - Err: errors.New("openai ws dialer returned nil connection"), - } - } - id := p.nextConnID(req.Account.ID) - return newOpenAIWSConn(id, req.Account.ID, conn, handshakeHeaders), nil -} - -func (p *openAIWSConnPool) nextConnID(accountID int64) string { - seq := p.seq.Add(1) - buf := make([]byte, 0, 32) - buf = append(buf, "oa_ws_"...) - buf = strconv.AppendInt(buf, accountID, 10) - buf = append(buf, '_') - buf = strconv.AppendUint(buf, seq, 10) - return string(buf) -} - -func (p *openAIWSConnPool) shouldHealthCheckConn(conn *openAIWSConn) bool { - if conn == nil { - return false - } - return conn.idleDuration(time.Now()) >= openAIWSConnHealthCheckIdle -} - -func (p *openAIWSConnPool) maxConnsHardCap() int { - if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 { - return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount - } - return 8 -} - -func (p *openAIWSConnPool) dynamicMaxConnsEnabled() bool { - if p != nil && p.cfg != nil { - return p.cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled - } - return false -} - -func (p *openAIWSConnPool) modeRouterV2Enabled() bool { - if p != nil && p.cfg != nil { - return p.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled - } - return false -} - -func (p *openAIWSConnPool) maxConnsFactorByAccount(account *Account) float64 { - if p == nil || p.cfg == nil || account == nil { - return 1.0 - } - switch account.Type { - case AccountTypeOAuth: - if p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor > 0 { - return p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor - } - case AccountTypeAPIKey: - if p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor > 0 { - return p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor - } - } - return 1.0 -} - -func (p *openAIWSConnPool) effectiveMaxConnsByAccount(account *Account) int { - hardCap := p.maxConnsHardCap() - if hardCap <= 0 { - return 0 - } - if p.modeRouterV2Enabled() { - if account == nil { - return hardCap - } - if account.Concurrency <= 0 { - return 0 - } - return account.Concurrency - } - if account == nil || !p.dynamicMaxConnsEnabled() { - return hardCap - } - if account.Concurrency <= 0 { - // 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。 - return hardCap - } - factor := p.maxConnsFactorByAccount(account) - if factor <= 0 { - factor = 1.0 - } - effective := int(math.Ceil(float64(account.Concurrency) * factor)) - if effective < 1 { - effective = 1 - } - if effective > hardCap { - effective = hardCap - } - return effective -} - -func (p *openAIWSConnPool) minIdlePerAccount() int { - if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MinIdlePerAccount >= 0 { - return p.cfg.Gateway.OpenAIWS.MinIdlePerAccount - } - return 0 -} - -func (p *openAIWSConnPool) maxIdlePerAccount() int { - if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount >= 0 { - return p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount - } - return 4 -} - -func (p *openAIWSConnPool) maxConnAge() time.Duration { - return openAIWSConnMaxAge -} - -func (p *openAIWSConnPool) queueLimitPerConn() int { - if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.QueueLimitPerConn > 0 { - return p.cfg.Gateway.OpenAIWS.QueueLimitPerConn - } - return 256 -} - -func (p *openAIWSConnPool) targetUtilization() float64 { - if p != nil && p.cfg != nil { - ratio := p.cfg.Gateway.OpenAIWS.PoolTargetUtilization - if ratio > 0 && ratio <= 1 { - return ratio - } - } - return 0.7 -} - -func (p *openAIWSConnPool) prewarmCooldown() time.Duration { - if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS > 0 { - return time.Duration(p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS) * time.Millisecond - } - return 0 -} - -func (p *openAIWSConnPool) shouldSuppressPrewarmLocked(ap *openAIWSAccountPool, now time.Time) bool { - if ap == nil { - return true - } - if ap.prewarmFails <= 0 { - return false - } - if ap.prewarmFailAt.IsZero() { - ap.prewarmFails = 0 - return false - } - if now.Sub(ap.prewarmFailAt) > openAIWSPrewarmFailureWindow { - ap.prewarmFails = 0 - ap.prewarmFailAt = time.Time{} - return false - } - return ap.prewarmFails >= openAIWSPrewarmFailureSuppress -} - -func (p *openAIWSConnPool) dialTimeout() time.Duration { - if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { - return time.Duration(p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second - } - return 10 * time.Second -} - -func cloneOpenAIWSAcquireRequest(req openAIWSAcquireRequest) openAIWSAcquireRequest { - copied := req - copied.Headers = cloneHeader(req.Headers) - copied.WSURL = stringsTrim(req.WSURL) - copied.ProxyURL = stringsTrim(req.ProxyURL) - copied.PreferredConnID = stringsTrim(req.PreferredConnID) - return copied -} - -func cloneOpenAIWSAcquireRequestPtr(req *openAIWSAcquireRequest) *openAIWSAcquireRequest { - if req == nil { - return nil - } - copied := cloneOpenAIWSAcquireRequest(*req) - return &copied -} - -func cloneHeader(src http.Header) http.Header { - if src == nil { - return nil - } - dst := make(http.Header, len(src)) - for k, vals := range src { - if len(vals) == 0 { - dst[k] = nil - continue - } - copied := make([]string, len(vals)) - copy(copied, vals) - dst[k] = copied - } - return dst -} - -func closeOpenAIWSConns(conns []*openAIWSConn) { - if len(conns) == 0 { - return - } - for _, conn := range conns { - if conn == nil { - continue - } - conn.close() - } -} - -func stringsTrim(value string) string { - return strings.TrimSpace(value) -} diff --git a/backend/internal/service/openai_ws_pool_benchmark_test.go b/backend/internal/service/openai_ws_pool_benchmark_test.go deleted file mode 100644 index bff74b626..000000000 --- a/backend/internal/service/openai_ws_pool_benchmark_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package service - -import ( - "context" - "errors" - "testing" - - "github.com/Wei-Shaw/sub2api/internal/config" -) - -func BenchmarkOpenAIWSPoolAcquire(b *testing.B) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 - - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(&openAIWSCountingDialer{}) - - account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - req := openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - } - ctx := context.Background() - - lease, err := pool.Acquire(ctx, req) - if err != nil { - b.Fatalf("warm acquire failed: %v", err) - } - lease.Release() - - b.ReportAllocs() - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - var ( - got *openAIWSConnLease - acquireErr error - ) - for retry := 0; retry < 3; retry++ { - got, acquireErr = pool.Acquire(ctx, req) - if acquireErr == nil { - break - } - if !errors.Is(acquireErr, errOpenAIWSConnClosed) { - break - } - } - if acquireErr != nil { - b.Fatalf("acquire failed: %v", acquireErr) - } - got.Release() - } - }) -} diff --git a/backend/internal/service/openai_ws_pool_test.go b/backend/internal/service/openai_ws_pool_test.go deleted file mode 100644 index b2683ee04..000000000 --- a/backend/internal/service/openai_ws_pool_test.go +++ /dev/null @@ -1,1709 +0,0 @@ -package service - -import ( - "context" - "errors" - "net/http" - "strings" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/Wei-Shaw/sub2api/internal/config" - "github.com/stretchr/testify/require" -) - -func TestOpenAIWSConnPool_CleanupStaleAndTrimIdle(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 - pool := newOpenAIWSConnPool(cfg) - - accountID := int64(10) - ap := pool.getOrCreateAccountPool(accountID) - - stale := newOpenAIWSConn("stale", accountID, nil, nil) - stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) - stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) - - idleOld := newOpenAIWSConn("idle_old", accountID, nil, nil) - idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano()) - - idleNew := newOpenAIWSConn("idle_new", accountID, nil, nil) - idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano()) - - ap.conns[stale.id] = stale - ap.conns[idleOld.id] = idleOld - ap.conns[idleNew.id] = idleNew - - evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) - closeOpenAIWSConns(evicted) - - require.Nil(t, ap.conns["stale"], "stale connection should be rotated") - require.Nil(t, ap.conns["idle_old"], "old idle should be trimmed by max_idle") - require.NotNil(t, ap.conns["idle_new"], "newer idle should be kept") -} - -func TestOpenAIWSConnPool_NextConnIDFormat(t *testing.T) { - pool := newOpenAIWSConnPool(&config.Config{}) - id1 := pool.nextConnID(42) - id2 := pool.nextConnID(42) - - require.True(t, strings.HasPrefix(id1, "oa_ws_42_")) - require.True(t, strings.HasPrefix(id2, "oa_ws_42_")) - require.NotEqual(t, id1, id2) - require.Equal(t, "oa_ws_42_1", id1) - require.Equal(t, "oa_ws_42_2", id2) -} - -func TestOpenAIWSConnPool_AcquireCleanupInterval(t *testing.T) { - require.Equal(t, 3*time.Second, openAIWSAcquireCleanupInterval) - require.Less(t, openAIWSAcquireCleanupInterval, openAIWSBackgroundSweepTicker) -} - -func TestOpenAIWSConnLease_WriteJSONAndGuards(t *testing.T) { - conn := newOpenAIWSConn("lease_write", 1, &openAIWSFakeConn{}, nil) - lease := &openAIWSConnLease{conn: conn} - require.NoError(t, lease.WriteJSON(map[string]any{"type": "response.create"}, 0)) - - var nilLease *openAIWSConnLease - err := nilLease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second) - require.ErrorIs(t, err, errOpenAIWSConnClosed) - - err = (&openAIWSConnLease{}).WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second) - require.ErrorIs(t, err, errOpenAIWSConnClosed) -} - -func TestOpenAIWSConn_WriteJSONWithTimeout_NilParentContextUsesBackground(t *testing.T) { - probe := &openAIWSContextProbeConn{} - conn := newOpenAIWSConn("ctx_probe", 1, probe, nil) - require.NoError(t, conn.writeJSONWithTimeout(context.Background(), map[string]any{"type": "response.create"}, 0)) - require.NotNil(t, probe.lastWriteCtx) -} - -func TestOpenAIWSConnPool_TargetConnCountAdaptive(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 6 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.5 - - pool := newOpenAIWSConnPool(cfg) - ap := pool.getOrCreateAccountPool(88) - - conn1 := newOpenAIWSConn("c1", 88, nil, nil) - conn2 := newOpenAIWSConn("c2", 88, nil, nil) - require.True(t, conn1.tryAcquire()) - require.True(t, conn2.tryAcquire()) - conn1.waiters.Store(1) - conn2.waiters.Store(1) - - ap.conns[conn1.id] = conn1 - ap.conns[conn2.id] = conn2 - - target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) - require.Equal(t, 6, target, "应按 inflight+waiters 与 target_utilization 自适应扩容到上限") - - conn1.release() - conn2.release() - conn1.waiters.Store(0) - conn2.waiters.Store(0) - target = pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) - require.Equal(t, 1, target, "低负载时应缩回到最小空闲连接") -} - -func TestOpenAIWSConnPool_TargetConnCountMinIdleZero(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 - - pool := newOpenAIWSConnPool(cfg) - ap := pool.getOrCreateAccountPool(66) - - target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) - require.Equal(t, 0, target, "min_idle=0 且无负载时应允许缩容到 0") -} - -func TestOpenAIWSConnPool_EnsureTargetIdleAsync(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2 - cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 - - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(&openAIWSFakeDialer{}) - - accountID := int64(77) - account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - ap := pool.getOrCreateAccountPool(accountID) - ap.mu.Lock() - ap.lastAcquire = &openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - } - ap.mu.Unlock() - - pool.ensureTargetIdleAsync(accountID) - - require.Eventually(t, func() bool { - ap, ok := pool.getAccountPool(accountID) - if !ok || ap == nil { - return false - } - ap.mu.Lock() - defer ap.mu.Unlock() - return len(ap.conns) >= 2 - }, 2*time.Second, 20*time.Millisecond) - - metrics := pool.SnapshotMetrics() - require.GreaterOrEqual(t, metrics.ScaleUpTotal, int64(2)) -} - -func TestOpenAIWSConnPool_EnsureTargetIdleAsyncCooldown(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2 - cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 - cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 500 - - pool := newOpenAIWSConnPool(cfg) - dialer := &openAIWSCountingDialer{} - pool.setClientDialerForTest(dialer) - - accountID := int64(178) - account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - ap := pool.getOrCreateAccountPool(accountID) - ap.mu.Lock() - ap.lastAcquire = &openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - } - ap.mu.Unlock() - - pool.ensureTargetIdleAsync(accountID) - require.Eventually(t, func() bool { - ap, ok := pool.getAccountPool(accountID) - if !ok || ap == nil { - return false - } - ap.mu.Lock() - defer ap.mu.Unlock() - return len(ap.conns) >= 2 && !ap.prewarmActive - }, 2*time.Second, 20*time.Millisecond) - firstDialCount := dialer.DialCount() - require.GreaterOrEqual(t, firstDialCount, 2) - - // 人工制造缺口触发新一轮预热需求。 - ap, ok := pool.getAccountPool(accountID) - require.True(t, ok) - require.NotNil(t, ap) - ap.mu.Lock() - for id := range ap.conns { - delete(ap.conns, id) - break - } - ap.mu.Unlock() - - pool.ensureTargetIdleAsync(accountID) - time.Sleep(120 * time.Millisecond) - require.Equal(t, firstDialCount, dialer.DialCount(), "cooldown 窗口内不应再次触发预热") - - time.Sleep(450 * time.Millisecond) - pool.ensureTargetIdleAsync(accountID) - require.Eventually(t, func() bool { - return dialer.DialCount() > firstDialCount - }, 2*time.Second, 20*time.Millisecond) -} - -func TestOpenAIWSConnPool_EnsureTargetIdleAsyncFailureSuppress(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 - cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 - cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 0 - - pool := newOpenAIWSConnPool(cfg) - dialer := &openAIWSAlwaysFailDialer{} - pool.setClientDialerForTest(dialer) - - accountID := int64(279) - account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - ap := pool.getOrCreateAccountPool(accountID) - ap.mu.Lock() - ap.lastAcquire = &openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - } - ap.mu.Unlock() - - pool.ensureTargetIdleAsync(accountID) - require.Eventually(t, func() bool { - ap, ok := pool.getAccountPool(accountID) - if !ok || ap == nil { - return false - } - ap.mu.Lock() - defer ap.mu.Unlock() - return !ap.prewarmActive - }, 2*time.Second, 20*time.Millisecond) - - pool.ensureTargetIdleAsync(accountID) - require.Eventually(t, func() bool { - ap, ok := pool.getAccountPool(accountID) - if !ok || ap == nil { - return false - } - ap.mu.Lock() - defer ap.mu.Unlock() - return !ap.prewarmActive - }, 2*time.Second, 20*time.Millisecond) - require.Equal(t, 2, dialer.DialCount()) - - // 连续失败达到阈值后,新的预热触发应被抑制,不再继续拨号。 - pool.ensureTargetIdleAsync(accountID) - time.Sleep(120 * time.Millisecond) - require.Equal(t, 2, dialer.DialCount()) -} - -func TestOpenAIWSConnPool_AcquireQueueWaitMetrics(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4 - - pool := newOpenAIWSConnPool(cfg) - accountID := int64(99) - account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - conn := newOpenAIWSConn("busy", accountID, &openAIWSFakeConn{}, nil) - require.True(t, conn.tryAcquire()) // 占用连接,触发后续排队 - - ap := pool.ensureAccountPoolLocked(accountID) - ap.mu.Lock() - ap.conns[conn.id] = conn - ap.lastAcquire = &openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - } - ap.mu.Unlock() - - go func() { - time.Sleep(60 * time.Millisecond) - conn.release() - }() - - lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - }) - require.NoError(t, err) - require.NotNil(t, lease) - require.True(t, lease.Reused()) - require.GreaterOrEqual(t, lease.QueueWaitDuration(), 50*time.Millisecond) - lease.Release() - - metrics := pool.SnapshotMetrics() - require.GreaterOrEqual(t, metrics.AcquireQueueWaitTotal, int64(1)) - require.Greater(t, metrics.AcquireQueueWaitMsTotal, int64(0)) - require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1)) -} - -func TestOpenAIWSConnPool_ForceNewConnSkipsReuse(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 - - pool := newOpenAIWSConnPool(cfg) - dialer := &openAIWSCountingDialer{} - pool.setClientDialerForTest(dialer) - - account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - - lease1, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - }) - require.NoError(t, err) - require.NotNil(t, lease1) - lease1.Release() - - lease2, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - ForceNewConn: true, - }) - require.NoError(t, err) - require.NotNil(t, lease2) - lease2.Release() - - require.Equal(t, 2, dialer.DialCount(), "ForceNewConn=true 时应跳过空闲连接复用并新建连接") -} - -func TestOpenAIWSConnPool_AcquireForcePreferredConnUnavailable(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 - - pool := newOpenAIWSConnPool(cfg) - account := &Account{ID: 124, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - ap := pool.getOrCreateAccountPool(account.ID) - otherConn := newOpenAIWSConn("other_conn", account.ID, &openAIWSFakeConn{}, nil) - ap.mu.Lock() - ap.conns[otherConn.id] = otherConn - ap.mu.Unlock() - - _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - ForcePreferredConn: true, - }) - require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable) - - _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - PreferredConnID: "missing_conn", - ForcePreferredConn: true, - }) - require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable) -} - -func TestOpenAIWSConnPool_AcquireForcePreferredConnQueuesOnPreferredOnly(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4 - - pool := newOpenAIWSConnPool(cfg) - account := &Account{ID: 125, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - ap := pool.getOrCreateAccountPool(account.ID) - preferredConn := newOpenAIWSConn("preferred_conn", account.ID, &openAIWSFakeConn{}, nil) - otherConn := newOpenAIWSConn("other_conn_idle", account.ID, &openAIWSFakeConn{}, nil) - require.True(t, preferredConn.tryAcquire(), "先占用 preferred 连接,触发排队获取") - ap.mu.Lock() - ap.conns[preferredConn.id] = preferredConn - ap.conns[otherConn.id] = otherConn - ap.lastCleanupAt = time.Now() - ap.mu.Unlock() - - go func() { - time.Sleep(60 * time.Millisecond) - preferredConn.release() - }() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - defer cancel() - lease, err := pool.Acquire(ctx, openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - PreferredConnID: preferredConn.id, - ForcePreferredConn: true, - }) - require.NoError(t, err) - require.NotNil(t, lease) - require.Equal(t, preferredConn.id, lease.ConnID(), "严格模式应只等待并复用 preferred 连接,不可漂移") - require.GreaterOrEqual(t, lease.QueueWaitDuration(), 40*time.Millisecond) - lease.Release() - require.True(t, otherConn.tryAcquire(), "other 连接不应被严格模式抢占") - otherConn.release() -} - -func TestOpenAIWSConnPool_AcquireForcePreferredConnDirectAndQueueFull(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1 - - pool := newOpenAIWSConnPool(cfg) - account := &Account{ID: 127, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - ap := pool.getOrCreateAccountPool(account.ID) - preferredConn := newOpenAIWSConn("preferred_conn_direct", account.ID, &openAIWSFakeConn{}, nil) - otherConn := newOpenAIWSConn("other_conn_direct", account.ID, &openAIWSFakeConn{}, nil) - ap.mu.Lock() - ap.conns[preferredConn.id] = preferredConn - ap.conns[otherConn.id] = otherConn - ap.lastCleanupAt = time.Now() - ap.mu.Unlock() - - lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - PreferredConnID: preferredConn.id, - ForcePreferredConn: true, - }) - require.NoError(t, err) - require.Equal(t, preferredConn.id, lease.ConnID(), "preferred 空闲时应直接命中") - lease.Release() - - require.True(t, preferredConn.tryAcquire()) - preferredConn.waiters.Store(1) - _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - PreferredConnID: preferredConn.id, - ForcePreferredConn: true, - }) - require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "严格模式下队列满应直接失败,不得漂移") - preferredConn.waiters.Store(0) - preferredConn.release() -} - -func TestOpenAIWSConnPool_CleanupSkipsPinnedConn(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0 - - pool := newOpenAIWSConnPool(cfg) - accountID := int64(126) - ap := pool.getOrCreateAccountPool(accountID) - pinnedConn := newOpenAIWSConn("pinned_conn", accountID, &openAIWSFakeConn{}, nil) - idleConn := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil) - ap.mu.Lock() - ap.conns[pinnedConn.id] = pinnedConn - ap.conns[idleConn.id] = idleConn - ap.mu.Unlock() - - require.True(t, pool.PinConn(accountID, pinnedConn.id)) - evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) - closeOpenAIWSConns(evicted) - - ap.mu.Lock() - _, pinnedExists := ap.conns[pinnedConn.id] - _, idleExists := ap.conns[idleConn.id] - ap.mu.Unlock() - require.True(t, pinnedExists, "被 active ingress 绑定的连接不应被 cleanup 回收") - require.False(t, idleExists, "非绑定的空闲连接应被回收") - - pool.UnpinConn(accountID, pinnedConn.id) - evicted = pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) - closeOpenAIWSConns(evicted) - ap.mu.Lock() - _, pinnedExists = ap.conns[pinnedConn.id] - ap.mu.Unlock() - require.False(t, pinnedExists, "解绑后连接应可被正常回收") -} - -func TestOpenAIWSConnPool_PinUnpinConnBranches(t *testing.T) { - var nilPool *openAIWSConnPool - require.False(t, nilPool.PinConn(1, "x")) - nilPool.UnpinConn(1, "x") - - cfg := &config.Config{} - pool := newOpenAIWSConnPool(cfg) - accountID := int64(128) - ap := &openAIWSAccountPool{ - conns: map[string]*openAIWSConn{}, - } - pool.accounts.Store(accountID, ap) - - require.False(t, pool.PinConn(0, "x")) - require.False(t, pool.PinConn(999, "x")) - require.False(t, pool.PinConn(accountID, "")) - require.False(t, pool.PinConn(accountID, "missing")) - - conn := newOpenAIWSConn("pin_refcount", accountID, &openAIWSFakeConn{}, nil) - ap.mu.Lock() - ap.conns[conn.id] = conn - ap.mu.Unlock() - require.True(t, pool.PinConn(accountID, conn.id)) - require.True(t, pool.PinConn(accountID, conn.id)) - - ap.mu.Lock() - require.Equal(t, 2, ap.pinnedConns[conn.id]) - ap.mu.Unlock() - - pool.UnpinConn(accountID, conn.id) - ap.mu.Lock() - require.Equal(t, 1, ap.pinnedConns[conn.id]) - ap.mu.Unlock() - - pool.UnpinConn(accountID, conn.id) - ap.mu.Lock() - _, exists := ap.pinnedConns[conn.id] - ap.mu.Unlock() - require.False(t, exists) - - pool.UnpinConn(accountID, conn.id) - pool.UnpinConn(accountID, "") - pool.UnpinConn(0, conn.id) - pool.UnpinConn(999, conn.id) -} - -func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 - cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true - cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0 - cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6 - - pool := newOpenAIWSConnPool(cfg) - - oauthHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 10} - require.Equal(t, 8, pool.effectiveMaxConnsByAccount(oauthHigh), "应受全局硬上限约束") - - oauthLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 3} - require.Equal(t, 3, pool.effectiveMaxConnsByAccount(oauthLow)) - - apiKeyHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 10} - require.Equal(t, 6, pool.effectiveMaxConnsByAccount(apiKeyHigh), "API Key 应按系数缩放") - - apiKeyLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1} - require.Equal(t, 1, pool.effectiveMaxConnsByAccount(apiKeyLow), "最小值应保持为 1") - - unlimited := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0} - require.Equal(t, 8, pool.effectiveMaxConnsByAccount(unlimited), "无限并发应回退到全局硬上限") - - require.Equal(t, 8, pool.effectiveMaxConnsByAccount(nil), "缺少账号上下文应回退到全局硬上限") -} - -func TestOpenAIWSConnPool_EffectiveMaxConnsDisabledFallbackHardCap(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 - cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false - cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0 - cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 1.0 - - pool := newOpenAIWSConnPool(cfg) - account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 2} - require.Equal(t, 8, pool.effectiveMaxConnsByAccount(account), "关闭动态模式后应保持旧行为") -} - -func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount_ModeRouterV2UsesAccountConcurrency(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 - cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true - cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0.3 - cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6 - - pool := newOpenAIWSConnPool(cfg) - - high := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 20} - require.Equal(t, 20, pool.effectiveMaxConnsByAccount(high), "v2 路径应直接使用账号并发数作为池上限") - - nonPositive := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 0} - require.Equal(t, 0, pool.effectiveMaxConnsByAccount(nonPositive), "并发数<=0 时应不可调度") -} - -func TestOpenAIWSConnPool_AcquireRejectsWhenEffectiveMaxConnsIsZero(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 - pool := newOpenAIWSConnPool(cfg) - - account := &Account{ID: 901, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0} - _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - }) - require.ErrorIs(t, err, errOpenAIWSConnQueueFull) -} - -func TestOpenAIWSConnLease_ReadMessageWithContextTimeout_PerRead(t *testing.T) { - conn := newOpenAIWSConn("timeout", 1, &openAIWSBlockingConn{readDelay: 80 * time.Millisecond}, nil) - lease := &openAIWSConnLease{conn: conn} - - _, err := lease.ReadMessageWithContextTimeout(context.Background(), 20*time.Millisecond) - require.Error(t, err) - require.ErrorIs(t, err, context.DeadlineExceeded) - - payload, err := lease.ReadMessageWithContextTimeout(context.Background(), 150*time.Millisecond) - require.NoError(t, err) - require.Contains(t, string(payload), "response.completed") - - parentCtx, cancel := context.WithCancel(context.Background()) - cancel() - _, err = lease.ReadMessageWithContextTimeout(parentCtx, 150*time.Millisecond) - require.Error(t, err) - require.ErrorIs(t, err, context.Canceled) -} - -func TestOpenAIWSConnLease_WriteJSONWithContextTimeout_RespectsParentContext(t *testing.T) { - conn := newOpenAIWSConn("write_timeout_ctx", 1, &openAIWSWriteBlockingConn{}, nil) - lease := &openAIWSConnLease{conn: conn} - - parentCtx, cancel := context.WithCancel(context.Background()) - go func() { - time.Sleep(20 * time.Millisecond) - cancel() - }() - - start := time.Now() - err := lease.WriteJSONWithContextTimeout(parentCtx, map[string]any{"type": "response.create"}, 2*time.Minute) - elapsed := time.Since(start) - - require.Error(t, err) - require.ErrorIs(t, err, context.Canceled) - require.Less(t, elapsed, 200*time.Millisecond) -} - -func TestOpenAIWSConnLease_PingWithTimeout(t *testing.T) { - conn := newOpenAIWSConn("ping_ok", 1, &openAIWSFakeConn{}, nil) - lease := &openAIWSConnLease{conn: conn} - require.NoError(t, lease.PingWithTimeout(50*time.Millisecond)) - - var nilLease *openAIWSConnLease - err := nilLease.PingWithTimeout(50 * time.Millisecond) - require.ErrorIs(t, err, errOpenAIWSConnClosed) -} - -func TestOpenAIWSConn_ReadAndWriteCanProceedConcurrently(t *testing.T) { - conn := newOpenAIWSConn("full_duplex", 1, &openAIWSBlockingConn{readDelay: 120 * time.Millisecond}, nil) - - readDone := make(chan error, 1) - go func() { - _, err := conn.readMessageWithContextTimeout(context.Background(), 200*time.Millisecond) - readDone <- err - }() - - // 让读取先占用 readMu。 - time.Sleep(20 * time.Millisecond) - - start := time.Now() - err := conn.pingWithTimeout(50 * time.Millisecond) - elapsed := time.Since(start) - - require.NoError(t, err) - require.Less(t, elapsed, 80*time.Millisecond, "写路径不应被读锁长期阻塞") - require.NoError(t, <-readDone) -} - -func TestOpenAIWSConnPool_BackgroundPingSweep_EvictsDeadIdleConn(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - pool := newOpenAIWSConnPool(cfg) - - accountID := int64(301) - ap := pool.getOrCreateAccountPool(accountID) - conn := newOpenAIWSConn("dead_idle", accountID, &openAIWSPingFailConn{}, nil) - ap.mu.Lock() - ap.conns[conn.id] = conn - ap.mu.Unlock() - - pool.runBackgroundPingSweep() - - ap.mu.Lock() - _, exists := ap.conns[conn.id] - ap.mu.Unlock() - require.False(t, exists, "后台 ping 失败的空闲连接应被回收") -} - -func TestOpenAIWSConnPool_BackgroundCleanupSweep_WithoutAcquire(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 - pool := newOpenAIWSConnPool(cfg) - - accountID := int64(302) - ap := pool.getOrCreateAccountPool(accountID) - stale := newOpenAIWSConn("stale_bg", accountID, &openAIWSFakeConn{}, nil) - stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) - stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) - ap.mu.Lock() - ap.conns[stale.id] = stale - ap.mu.Unlock() - - pool.runBackgroundCleanupSweep(time.Now()) - - ap.mu.Lock() - _, exists := ap.conns[stale.id] - ap.mu.Unlock() - require.False(t, exists, "后台清理应在无新 acquire 时也回收过期连接") -} - -func TestOpenAIWSConnPool_BackgroundWorkerGuardBranches(t *testing.T) { - var nilPool *openAIWSConnPool - require.NotPanics(t, func() { - nilPool.startBackgroundWorkers() - nilPool.runBackgroundPingWorker() - nilPool.runBackgroundPingSweep() - _ = nilPool.snapshotIdleConnsForPing() - nilPool.runBackgroundCleanupWorker() - nilPool.runBackgroundCleanupSweep(time.Now()) - }) - - poolNoStop := &openAIWSConnPool{} - require.NotPanics(t, func() { - poolNoStop.startBackgroundWorkers() - }) - - poolStopPing := &openAIWSConnPool{workerStopCh: make(chan struct{})} - pingDone := make(chan struct{}) - go func() { - poolStopPing.runBackgroundPingWorker() - close(pingDone) - }() - close(poolStopPing.workerStopCh) - select { - case <-pingDone: - case <-time.After(500 * time.Millisecond): - t.Fatal("runBackgroundPingWorker 未在 stop 信号后退出") - } - - poolStopCleanup := &openAIWSConnPool{workerStopCh: make(chan struct{})} - cleanupDone := make(chan struct{}) - go func() { - poolStopCleanup.runBackgroundCleanupWorker() - close(cleanupDone) - }() - close(poolStopCleanup.workerStopCh) - select { - case <-cleanupDone: - case <-time.After(500 * time.Millisecond): - t.Fatal("runBackgroundCleanupWorker 未在 stop 信号后退出") - } -} - -func TestOpenAIWSConnPool_SnapshotIdleConnsForPing_SkipsInvalidEntries(t *testing.T) { - pool := &openAIWSConnPool{} - pool.accounts.Store("invalid-key", &openAIWSAccountPool{}) - pool.accounts.Store(int64(123), "invalid-value") - - accountID := int64(123) - ap := &openAIWSAccountPool{ - conns: make(map[string]*openAIWSConn), - } - ap.conns["nil_conn"] = nil - - leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil) - require.True(t, leased.tryAcquire()) - ap.conns[leased.id] = leased - - waiting := newOpenAIWSConn("waiting", accountID, &openAIWSFakeConn{}, nil) - waiting.waiters.Store(1) - ap.conns[waiting.id] = waiting - - idle := newOpenAIWSConn("idle", accountID, &openAIWSFakeConn{}, nil) - ap.conns[idle.id] = idle - - pool.accounts.Store(accountID, ap) - candidates := pool.snapshotIdleConnsForPing() - require.Len(t, candidates, 1) - require.Equal(t, idle.id, candidates[0].conn.id) -} - -func TestOpenAIWSConnPool_RunBackgroundCleanupSweep_SkipsInvalidAndUsesAccountCap(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 - cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true - - pool := &openAIWSConnPool{cfg: cfg} - pool.accounts.Store("bad-key", "bad-value") - - accountID := int64(2026) - ap := &openAIWSAccountPool{ - conns: make(map[string]*openAIWSConn), - } - ap.conns["nil_conn"] = nil - stale := newOpenAIWSConn("stale_bg_cleanup", accountID, &openAIWSFakeConn{}, nil) - stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) - stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) - ap.conns[stale.id] = stale - ap.lastAcquire = &openAIWSAcquireRequest{ - Account: &Account{ - ID: accountID, - Platform: PlatformOpenAI, - Type: AccountTypeAPIKey, - Concurrency: 1, - }, - } - pool.accounts.Store(accountID, ap) - - now := time.Now() - require.NotPanics(t, func() { - pool.runBackgroundCleanupSweep(now) - }) - - ap.mu.Lock() - _, nilConnExists := ap.conns["nil_conn"] - _, exists := ap.conns[stale.id] - lastCleanupAt := ap.lastCleanupAt - ap.mu.Unlock() - - require.False(t, nilConnExists, "后台清理应移除无效 nil 连接条目") - require.False(t, exists, "后台清理应清理过期连接") - require.Equal(t, now, lastCleanupAt) -} - -func TestOpenAIWSConnPool_QueueLimitPerConn_DefaultAndConfigured(t *testing.T) { - var nilPool *openAIWSConnPool - require.Equal(t, 256, nilPool.queueLimitPerConn()) - - pool := &openAIWSConnPool{cfg: &config.Config{}} - require.Equal(t, 256, pool.queueLimitPerConn()) - - pool.cfg.Gateway.OpenAIWS.QueueLimitPerConn = 9 - require.Equal(t, 9, pool.queueLimitPerConn()) -} - -func TestOpenAIWSConnPool_Close(t *testing.T) { - cfg := &config.Config{} - pool := newOpenAIWSConnPool(cfg) - - // Close 应该可以安全调用 - pool.Close() - - // workerStopCh 应已关闭 - select { - case <-pool.workerStopCh: - // 预期:channel 已关闭 - default: - t.Fatal("Close 后 workerStopCh 应已关闭") - } - - // 多次调用 Close 不应 panic - pool.Close() - - // nil pool 调用 Close 不应 panic - var nilPool *openAIWSConnPool - nilPool.Close() -} - -func TestOpenAIWSDialError_ErrorAndUnwrap(t *testing.T) { - baseErr := errors.New("boom") - dialErr := &openAIWSDialError{StatusCode: 502, Err: baseErr} - require.Contains(t, dialErr.Error(), "status=502") - require.ErrorIs(t, dialErr.Unwrap(), baseErr) - - noStatus := &openAIWSDialError{Err: baseErr} - require.Contains(t, noStatus.Error(), "boom") - - var nilDialErr *openAIWSDialError - require.Equal(t, "", nilDialErr.Error()) - require.NoError(t, nilDialErr.Unwrap()) -} - -func TestOpenAIWSConnLease_ReadWriteHelpersAndConnStats(t *testing.T) { - conn := newOpenAIWSConn("helper_conn", 1, &openAIWSFakeConn{}, http.Header{ - "X-Test": []string{" value "}, - }) - lease := &openAIWSConnLease{conn: conn} - - require.NoError(t, lease.WriteJSONContext(context.Background(), map[string]any{"type": "response.create"})) - payload, err := lease.ReadMessage(100 * time.Millisecond) - require.NoError(t, err) - require.Contains(t, string(payload), "response.completed") - - payload, err = lease.ReadMessageContext(context.Background()) - require.NoError(t, err) - require.Contains(t, string(payload), "response.completed") - - payload, err = conn.readMessageWithTimeout(100 * time.Millisecond) - require.NoError(t, err) - require.Contains(t, string(payload), "response.completed") - - require.Equal(t, "value", conn.handshakeHeader(" X-Test ")) - require.NotZero(t, conn.createdAt()) - require.NotZero(t, conn.lastUsedAt()) - require.GreaterOrEqual(t, conn.age(time.Now()), time.Duration(0)) - require.GreaterOrEqual(t, conn.idleDuration(time.Now()), time.Duration(0)) - require.False(t, conn.isLeased()) - - // 覆盖空上下文路径 - _, err = conn.readMessage(context.Background()) - require.NoError(t, err) - - // 覆盖 nil 保护分支 - var nilConn *openAIWSConn - require.ErrorIs(t, nilConn.writeJSONWithTimeout(context.Background(), map[string]any{}, time.Second), errOpenAIWSConnClosed) - _, err = nilConn.readMessageWithTimeout(10 * time.Millisecond) - require.ErrorIs(t, err, errOpenAIWSConnClosed) - _, err = nilConn.readMessageWithContextTimeout(context.Background(), 10*time.Millisecond) - require.ErrorIs(t, err, errOpenAIWSConnClosed) -} - -func TestOpenAIWSConnPool_PickOldestIdleAndAccountPoolLoad(t *testing.T) { - pool := &openAIWSConnPool{} - accountID := int64(404) - ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}} - - idleOld := newOpenAIWSConn("idle_old", accountID, &openAIWSFakeConn{}, nil) - idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano()) - idleNew := newOpenAIWSConn("idle_new", accountID, &openAIWSFakeConn{}, nil) - idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano()) - leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil) - require.True(t, leased.tryAcquire()) - leased.waiters.Store(2) - - ap.conns[idleOld.id] = idleOld - ap.conns[idleNew.id] = idleNew - ap.conns[leased.id] = leased - - oldest := pool.pickOldestIdleConnLocked(ap) - require.NotNil(t, oldest) - require.Equal(t, idleOld.id, oldest.id) - - inflight, waiters := accountPoolLoadLocked(ap) - require.Equal(t, 1, inflight) - require.Equal(t, 2, waiters) - - pool.accounts.Store(accountID, ap) - loadInflight, loadWaiters, conns := pool.AccountPoolLoad(accountID) - require.Equal(t, 1, loadInflight) - require.Equal(t, 2, loadWaiters) - require.Equal(t, 3, conns) - - zeroInflight, zeroWaiters, zeroConns := pool.AccountPoolLoad(0) - require.Equal(t, 0, zeroInflight) - require.Equal(t, 0, zeroWaiters) - require.Equal(t, 0, zeroConns) -} - -func TestOpenAIWSConnPool_Close_WaitsWorkerGroupAndNilStopChannel(t *testing.T) { - pool := &openAIWSConnPool{} - release := make(chan struct{}) - pool.workerWg.Add(1) - go func() { - defer pool.workerWg.Done() - <-release - }() - - closed := make(chan struct{}) - go func() { - pool.Close() - close(closed) - }() - - select { - case <-closed: - t.Fatal("Close 不应在 WaitGroup 未完成时提前返回") - case <-time.After(30 * time.Millisecond): - } - - close(release) - select { - case <-closed: - case <-time.After(time.Second): - t.Fatal("Close 未等待 workerWg 完成") - } -} - -func TestOpenAIWSConnPool_Close_ClosesOnlyIdleConnections(t *testing.T) { - pool := &openAIWSConnPool{ - workerStopCh: make(chan struct{}), - } - - accountID := int64(606) - ap := &openAIWSAccountPool{ - conns: map[string]*openAIWSConn{}, - } - idle := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil) - leased := newOpenAIWSConn("leased_conn", accountID, &openAIWSFakeConn{}, nil) - require.True(t, leased.tryAcquire()) - - ap.conns[idle.id] = idle - ap.conns[leased.id] = leased - pool.accounts.Store(accountID, ap) - pool.accounts.Store("invalid-key", "invalid-value") - - pool.Close() - - select { - case <-idle.closedCh: - // idle should be closed - default: - t.Fatal("空闲连接应在 Close 时被关闭") - } - - select { - case <-leased.closedCh: - t.Fatal("已租赁连接不应在 Close 时被关闭") - default: - } - - leased.release() - pool.Close() -} - -func TestOpenAIWSConnPool_RunBackgroundPingSweep_ConcurrencyLimit(t *testing.T) { - cfg := &config.Config{} - pool := newOpenAIWSConnPool(cfg) - accountID := int64(505) - ap := pool.getOrCreateAccountPool(accountID) - - var current atomic.Int32 - var maxConcurrent atomic.Int32 - release := make(chan struct{}) - for i := 0; i < 25; i++ { - conn := newOpenAIWSConn(pool.nextConnID(accountID), accountID, &openAIWSPingBlockingConn{ - current: ¤t, - maxConcurrent: &maxConcurrent, - release: release, - }, nil) - ap.mu.Lock() - ap.conns[conn.id] = conn - ap.mu.Unlock() - } - - done := make(chan struct{}) - go func() { - pool.runBackgroundPingSweep() - close(done) - }() - - require.Eventually(t, func() bool { - return maxConcurrent.Load() >= 10 - }, time.Second, 10*time.Millisecond) - - close(release) - select { - case <-done: - case <-time.After(2 * time.Second): - t.Fatal("runBackgroundPingSweep 未在释放后完成") - } - - require.LessOrEqual(t, maxConcurrent.Load(), int32(10)) -} - -func TestOpenAIWSConnLease_BasicGetterBranches(t *testing.T) { - var nilLease *openAIWSConnLease - require.Equal(t, "", nilLease.ConnID()) - require.Equal(t, time.Duration(0), nilLease.QueueWaitDuration()) - require.Equal(t, time.Duration(0), nilLease.ConnPickDuration()) - require.False(t, nilLease.Reused()) - require.Equal(t, "", nilLease.HandshakeHeader("x-test")) - require.False(t, nilLease.IsPrewarmed()) - nilLease.MarkPrewarmed() - nilLease.Release() - - conn := newOpenAIWSConn("getter_conn", 1, &openAIWSFakeConn{}, http.Header{"X-Test": []string{"ok"}}) - lease := &openAIWSConnLease{ - conn: conn, - queueWait: 3 * time.Millisecond, - connPick: 4 * time.Millisecond, - reused: true, - } - require.Equal(t, "getter_conn", lease.ConnID()) - require.Equal(t, 3*time.Millisecond, lease.QueueWaitDuration()) - require.Equal(t, 4*time.Millisecond, lease.ConnPickDuration()) - require.True(t, lease.Reused()) - require.Equal(t, "ok", lease.HandshakeHeader("x-test")) - require.False(t, lease.IsPrewarmed()) - lease.MarkPrewarmed() - require.True(t, lease.IsPrewarmed()) - lease.Release() -} - -func TestOpenAIWSConnPool_UtilityBranches(t *testing.T) { - var nilPool *openAIWSConnPool - require.Equal(t, OpenAIWSPoolMetricsSnapshot{}, nilPool.SnapshotMetrics()) - require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, nilPool.SnapshotTransportMetrics()) - - pool := &openAIWSConnPool{cfg: &config.Config{}} - pool.metrics.acquireTotal.Store(7) - pool.metrics.acquireReuseTotal.Store(3) - metrics := pool.SnapshotMetrics() - require.Equal(t, int64(7), metrics.AcquireTotal) - require.Equal(t, int64(3), metrics.AcquireReuseTotal) - - // 非 transport metrics dialer 路径 - pool.clientDialer = &openAIWSFakeDialer{} - require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, pool.SnapshotTransportMetrics()) - pool.setClientDialerForTest(nil) - require.NotNil(t, pool.clientDialer) - - require.Equal(t, 8, nilPool.maxConnsHardCap()) - require.False(t, nilPool.dynamicMaxConnsEnabled()) - require.Equal(t, 1.0, nilPool.maxConnsFactorByAccount(nil)) - require.Equal(t, 0, nilPool.minIdlePerAccount()) - require.Equal(t, 4, nilPool.maxIdlePerAccount()) - require.Equal(t, 256, nilPool.queueLimitPerConn()) - require.Equal(t, 0.7, nilPool.targetUtilization()) - require.Equal(t, time.Duration(0), nilPool.prewarmCooldown()) - require.Equal(t, 10*time.Second, nilPool.dialTimeout()) - - // shouldSuppressPrewarmLocked 覆盖 3 条分支 - now := time.Now() - apNilFail := &openAIWSAccountPool{prewarmFails: 1} - require.False(t, pool.shouldSuppressPrewarmLocked(apNilFail, now)) - apZeroTime := &openAIWSAccountPool{prewarmFails: 2} - require.False(t, pool.shouldSuppressPrewarmLocked(apZeroTime, now)) - require.Equal(t, 0, apZeroTime.prewarmFails) - apOldFail := &openAIWSAccountPool{prewarmFails: 2, prewarmFailAt: now.Add(-openAIWSPrewarmFailureWindow - time.Second)} - require.False(t, pool.shouldSuppressPrewarmLocked(apOldFail, now)) - apRecentFail := &openAIWSAccountPool{prewarmFails: openAIWSPrewarmFailureSuppress, prewarmFailAt: now} - require.True(t, pool.shouldSuppressPrewarmLocked(apRecentFail, now)) - - // recordConnPickDuration 的保护分支 - nilPool.recordConnPickDuration(10 * time.Millisecond) - pool.recordConnPickDuration(-10 * time.Millisecond) - require.Equal(t, int64(1), pool.metrics.connPickTotal.Load()) - - // account pool 读写分支 - require.Nil(t, nilPool.getOrCreateAccountPool(1)) - require.Nil(t, pool.getOrCreateAccountPool(0)) - pool.accounts.Store(int64(7), "invalid") - ap := pool.getOrCreateAccountPool(7) - require.NotNil(t, ap) - _, ok := pool.getAccountPool(0) - require.False(t, ok) - _, ok = pool.getAccountPool(12345) - require.False(t, ok) - pool.accounts.Store(int64(8), "bad-type") - _, ok = pool.getAccountPool(8) - require.False(t, ok) - - // health check 条件 - require.False(t, pool.shouldHealthCheckConn(nil)) - conn := newOpenAIWSConn("health", 1, &openAIWSFakeConn{}, nil) - conn.lastUsedNano.Store(time.Now().Add(-openAIWSConnHealthCheckIdle - time.Second).UnixNano()) - require.True(t, pool.shouldHealthCheckConn(conn)) -} - -func TestOpenAIWSConn_LeaseAndTimeHelpers_NilAndClosedBranches(t *testing.T) { - var nilConn *openAIWSConn - nilConn.touch() - require.Equal(t, time.Time{}, nilConn.createdAt()) - require.Equal(t, time.Time{}, nilConn.lastUsedAt()) - require.Equal(t, time.Duration(0), nilConn.idleDuration(time.Now())) - require.Equal(t, time.Duration(0), nilConn.age(time.Now())) - require.False(t, nilConn.isLeased()) - require.False(t, nilConn.isPrewarmed()) - nilConn.markPrewarmed() - - conn := newOpenAIWSConn("lease_state", 1, &openAIWSFakeConn{}, nil) - require.True(t, conn.tryAcquire()) - require.True(t, conn.isLeased()) - conn.release() - require.False(t, conn.isLeased()) - conn.close() - require.False(t, conn.tryAcquire()) - - ctx, cancel := context.WithCancel(context.Background()) - cancel() - err := conn.acquire(ctx) - require.Error(t, err) -} - -func TestOpenAIWSConnLease_ReadWriteNilConnBranches(t *testing.T) { - lease := &openAIWSConnLease{} - require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) - require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed) - _, err := lease.ReadMessage(10 * time.Millisecond) - require.ErrorIs(t, err, errOpenAIWSConnClosed) - _, err = lease.ReadMessageContext(context.Background()) - require.ErrorIs(t, err, errOpenAIWSConnClosed) - _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond) - require.ErrorIs(t, err, errOpenAIWSConnClosed) -} - -func TestOpenAIWSConnLease_ReleasedLeaseGuards(t *testing.T) { - conn := newOpenAIWSConn("released_guard", 1, &openAIWSFakeConn{}, nil) - lease := &openAIWSConnLease{conn: conn} - - require.NoError(t, lease.PingWithTimeout(50*time.Millisecond)) - - lease.Release() - lease.Release() // idempotent - - require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) - require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed) - require.ErrorIs(t, lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) - - _, err := lease.ReadMessage(10 * time.Millisecond) - require.ErrorIs(t, err, errOpenAIWSConnClosed) - _, err = lease.ReadMessageContext(context.Background()) - require.ErrorIs(t, err, errOpenAIWSConnClosed) - _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond) - require.ErrorIs(t, err, errOpenAIWSConnClosed) - - require.ErrorIs(t, lease.PingWithTimeout(50*time.Millisecond), errOpenAIWSConnClosed) -} - -func TestOpenAIWSConnLease_MarkBrokenAfterRelease_NoEviction(t *testing.T) { - conn := newOpenAIWSConn("released_markbroken", 7, &openAIWSFakeConn{}, nil) - ap := &openAIWSAccountPool{ - conns: map[string]*openAIWSConn{ - conn.id: conn, - }, - } - pool := &openAIWSConnPool{} - pool.accounts.Store(int64(7), ap) - - lease := &openAIWSConnLease{ - pool: pool, - accountID: 7, - conn: conn, - } - - lease.Release() - lease.MarkBroken() - - ap.mu.Lock() - _, exists := ap.conns[conn.id] - ap.mu.Unlock() - require.True(t, exists, "released lease should not evict active pool connection") -} - -func TestOpenAIWSConn_AdditionalGuardBranches(t *testing.T) { - var nilConn *openAIWSConn - require.False(t, nilConn.tryAcquire()) - require.ErrorIs(t, nilConn.acquire(context.Background()), errOpenAIWSConnClosed) - nilConn.release() - nilConn.close() - require.Equal(t, "", nilConn.handshakeHeader("x-test")) - - connBusy := newOpenAIWSConn("busy_ctx", 1, &openAIWSFakeConn{}, nil) - require.True(t, connBusy.tryAcquire()) - ctx, cancel := context.WithCancel(context.Background()) - cancel() - require.ErrorIs(t, connBusy.acquire(ctx), context.Canceled) - connBusy.release() - - connClosed := newOpenAIWSConn("closed_guard", 1, &openAIWSFakeConn{}, nil) - connClosed.close() - require.ErrorIs( - t, - connClosed.writeJSONWithTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), - errOpenAIWSConnClosed, - ) - _, err := connClosed.readMessageWithContextTimeout(context.Background(), time.Second) - require.ErrorIs(t, err, errOpenAIWSConnClosed) - require.ErrorIs(t, connClosed.pingWithTimeout(time.Second), errOpenAIWSConnClosed) - - connNoWS := newOpenAIWSConn("no_ws", 1, nil, nil) - require.ErrorIs(t, connNoWS.writeJSON(map[string]any{"k": "v"}, context.Background()), errOpenAIWSConnClosed) - _, err = connNoWS.readMessage(context.Background()) - require.ErrorIs(t, err, errOpenAIWSConnClosed) - require.ErrorIs(t, connNoWS.pingWithTimeout(time.Second), errOpenAIWSConnClosed) - require.Equal(t, "", connNoWS.handshakeHeader("x-test")) - - connOK := newOpenAIWSConn("ok", 1, &openAIWSFakeConn{}, nil) - require.NoError(t, connOK.writeJSON(map[string]any{"k": "v"}, nil)) - _, err = connOK.readMessageWithContextTimeout(context.Background(), 0) - require.NoError(t, err) - require.NoError(t, connOK.pingWithTimeout(0)) - - connZero := newOpenAIWSConn("zero_ts", 1, &openAIWSFakeConn{}, nil) - connZero.createdAtNano.Store(0) - connZero.lastUsedNano.Store(0) - require.True(t, connZero.createdAt().IsZero()) - require.True(t, connZero.lastUsedAt().IsZero()) - require.Equal(t, time.Duration(0), connZero.idleDuration(time.Now())) - require.Equal(t, time.Duration(0), connZero.age(time.Now())) - - require.Nil(t, cloneOpenAIWSAcquireRequestPtr(nil)) - copied := cloneHeader(http.Header{ - "X-Empty": []string{}, - "X-Test": []string{"v1"}, - }) - require.Contains(t, copied, "X-Empty") - require.Nil(t, copied["X-Empty"]) - require.Equal(t, "v1", copied.Get("X-Test")) - - closeOpenAIWSConns([]*openAIWSConn{nil, connOK}) -} - -func TestOpenAIWSConnLease_MarkBrokenEvictsConn(t *testing.T) { - pool := newOpenAIWSConnPool(&config.Config{}) - accountID := int64(5001) - conn := newOpenAIWSConn("broken_me", accountID, &openAIWSFakeConn{}, nil) - ap := pool.getOrCreateAccountPool(accountID) - ap.mu.Lock() - ap.conns[conn.id] = conn - ap.mu.Unlock() - - lease := &openAIWSConnLease{ - pool: pool, - accountID: accountID, - conn: conn, - } - lease.MarkBroken() - - ap.mu.Lock() - _, exists := ap.conns[conn.id] - ap.mu.Unlock() - require.False(t, exists) - require.False(t, conn.tryAcquire(), "被标记为 broken 的连接应被关闭") -} - -func TestOpenAIWSConnPool_TargetConnCountAndPrewarmBranches(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - pool := newOpenAIWSConnPool(cfg) - - require.Equal(t, 0, pool.targetConnCountLocked(nil, 1)) - ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}} - require.Equal(t, 0, pool.targetConnCountLocked(ap, 0)) - - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 3 - require.Equal(t, 1, pool.targetConnCountLocked(ap, 1), "minIdle 应被 maxConns 截断") - - // 覆盖 waiters>0 且 target 需要至少 len(conns)+1 的分支 - cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 - cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.9 - busy := newOpenAIWSConn("busy_target", 2, &openAIWSFakeConn{}, nil) - require.True(t, busy.tryAcquire()) - busy.waiters.Store(1) - ap.conns[busy.id] = busy - target := pool.targetConnCountLocked(ap, 4) - require.GreaterOrEqual(t, target, len(ap.conns)+1) - - // prewarm: account pool 缺失时,拨号后的连接应被关闭并提前返回 - req := openAIWSAcquireRequest{ - Account: &Account{ID: 999, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}, - WSURL: "wss://example.com/v1/responses", - } - pool.prewarmConns(999, req, 1) - - // prewarm: 拨号失败分支(prewarmFails 累加) - accountID := int64(1000) - failPool := newOpenAIWSConnPool(cfg) - failPool.setClientDialerForTest(&openAIWSAlwaysFailDialer{}) - apFail := failPool.getOrCreateAccountPool(accountID) - apFail.mu.Lock() - apFail.creating = 1 - apFail.mu.Unlock() - req.Account.ID = accountID - failPool.prewarmConns(accountID, req, 1) - apFail.mu.Lock() - require.GreaterOrEqual(t, apFail.prewarmFails, 1) - apFail.mu.Unlock() -} - -func TestOpenAIWSConnPool_Acquire_ErrorBranches(t *testing.T) { - var nilPool *openAIWSConnPool - _, err := nilPool.Acquire(context.Background(), openAIWSAcquireRequest{}) - require.Error(t, err) - - pool := newOpenAIWSConnPool(&config.Config{}) - _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: &Account{ID: 1}, - WSURL: " ", - }) - require.Error(t, err) - require.Contains(t, err.Error(), "ws url is empty") - - // target=nil 分支:池满且仅有 nil 连接 - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 - cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1 - fullPool := newOpenAIWSConnPool(cfg) - account := &Account{ID: 2001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - ap := fullPool.getOrCreateAccountPool(account.ID) - ap.mu.Lock() - ap.conns["nil"] = nil - ap.lastCleanupAt = time.Now() - ap.mu.Unlock() - _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - }) - require.ErrorIs(t, err, errOpenAIWSConnClosed) - - // queue full 分支:waiters 达上限 - account2 := &Account{ID: 2002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - ap2 := fullPool.getOrCreateAccountPool(account2.ID) - conn := newOpenAIWSConn("queue_full", account2.ID, &openAIWSFakeConn{}, nil) - require.True(t, conn.tryAcquire()) - conn.waiters.Store(1) - ap2.mu.Lock() - ap2.conns[conn.id] = conn - ap2.lastCleanupAt = time.Now() - ap2.mu.Unlock() - _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: account2, - WSURL: "wss://example.com/v1/responses", - }) - require.ErrorIs(t, err, errOpenAIWSConnQueueFull) -} - -type openAIWSFakeDialer struct{} - -func (d *openAIWSFakeDialer) Dial( - ctx context.Context, - wsURL string, - headers http.Header, - proxyURL string, -) (openAIWSClientConn, int, http.Header, error) { - _ = ctx - _ = wsURL - _ = headers - _ = proxyURL - return &openAIWSFakeConn{}, 0, nil, nil -} - -type openAIWSCountingDialer struct { - mu sync.Mutex - dialCount int -} - -type openAIWSAlwaysFailDialer struct { - mu sync.Mutex - dialCount int -} - -type openAIWSPingBlockingConn struct { - current *atomic.Int32 - maxConcurrent *atomic.Int32 - release <-chan struct{} -} - -func (c *openAIWSPingBlockingConn) WriteJSON(context.Context, any) error { - return nil -} - -func (c *openAIWSPingBlockingConn) ReadMessage(context.Context) ([]byte, error) { - return []byte(`{"type":"response.completed","response":{"id":"resp_blocking_ping"}}`), nil -} - -func (c *openAIWSPingBlockingConn) Ping(ctx context.Context) error { - if c.current == nil || c.maxConcurrent == nil { - return nil - } - - now := c.current.Add(1) - for { - prev := c.maxConcurrent.Load() - if now <= prev || c.maxConcurrent.CompareAndSwap(prev, now) { - break - } - } - defer c.current.Add(-1) - - select { - case <-ctx.Done(): - return ctx.Err() - case <-c.release: - return nil - } -} - -func (c *openAIWSPingBlockingConn) Close() error { - return nil -} - -func (d *openAIWSCountingDialer) Dial( - ctx context.Context, - wsURL string, - headers http.Header, - proxyURL string, -) (openAIWSClientConn, int, http.Header, error) { - _ = ctx - _ = wsURL - _ = headers - _ = proxyURL - d.mu.Lock() - d.dialCount++ - d.mu.Unlock() - return &openAIWSFakeConn{}, 0, nil, nil -} - -func (d *openAIWSCountingDialer) DialCount() int { - d.mu.Lock() - defer d.mu.Unlock() - return d.dialCount -} - -func (d *openAIWSAlwaysFailDialer) Dial( - ctx context.Context, - wsURL string, - headers http.Header, - proxyURL string, -) (openAIWSClientConn, int, http.Header, error) { - _ = ctx - _ = wsURL - _ = headers - _ = proxyURL - d.mu.Lock() - d.dialCount++ - d.mu.Unlock() - return nil, 503, nil, errors.New("dial failed") -} - -func (d *openAIWSAlwaysFailDialer) DialCount() int { - d.mu.Lock() - defer d.mu.Unlock() - return d.dialCount -} - -type openAIWSFakeConn struct { - mu sync.Mutex - closed bool - payload [][]byte -} - -func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error { - _ = ctx - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return errors.New("closed") - } - c.payload = append(c.payload, []byte("ok")) - _ = value - return nil -} - -func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) { - _ = ctx - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return nil, errors.New("closed") - } - return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil -} - -func (c *openAIWSFakeConn) Ping(ctx context.Context) error { - _ = ctx - return nil -} - -func (c *openAIWSFakeConn) Close() error { - c.mu.Lock() - defer c.mu.Unlock() - c.closed = true - return nil -} - -type openAIWSBlockingConn struct { - readDelay time.Duration -} - -func (c *openAIWSBlockingConn) WriteJSON(ctx context.Context, value any) error { - _ = ctx - _ = value - return nil -} - -func (c *openAIWSBlockingConn) ReadMessage(ctx context.Context) ([]byte, error) { - delay := c.readDelay - if delay <= 0 { - delay = 10 * time.Millisecond - } - timer := time.NewTimer(delay) - defer timer.Stop() - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-timer.C: - return []byte(`{"type":"response.completed","response":{"id":"resp_blocking"}}`), nil - } -} - -func (c *openAIWSBlockingConn) Ping(ctx context.Context) error { - _ = ctx - return nil -} - -func (c *openAIWSBlockingConn) Close() error { - return nil -} - -type openAIWSWriteBlockingConn struct{} - -func (c *openAIWSWriteBlockingConn) WriteJSON(ctx context.Context, _ any) error { - <-ctx.Done() - return ctx.Err() -} - -func (c *openAIWSWriteBlockingConn) ReadMessage(context.Context) ([]byte, error) { - return []byte(`{"type":"response.completed","response":{"id":"resp_write_block"}}`), nil -} - -func (c *openAIWSWriteBlockingConn) Ping(context.Context) error { - return nil -} - -func (c *openAIWSWriteBlockingConn) Close() error { - return nil -} - -type openAIWSPingFailConn struct{} - -func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error { - return nil -} - -func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) { - return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil -} - -func (c *openAIWSPingFailConn) Ping(context.Context) error { - return errors.New("ping failed") -} - -func (c *openAIWSPingFailConn) Close() error { - return nil -} - -type openAIWSContextProbeConn struct { - lastWriteCtx context.Context -} - -func (c *openAIWSContextProbeConn) WriteJSON(ctx context.Context, _ any) error { - c.lastWriteCtx = ctx - return nil -} - -func (c *openAIWSContextProbeConn) ReadMessage(context.Context) ([]byte, error) { - return []byte(`{"type":"response.completed","response":{"id":"resp_ctx_probe"}}`), nil -} - -func (c *openAIWSContextProbeConn) Ping(context.Context) error { - return nil -} - -func (c *openAIWSContextProbeConn) Close() error { - return nil -} - -type openAIWSNilConnDialer struct{} - -func (d *openAIWSNilConnDialer) Dial( - ctx context.Context, - wsURL string, - headers http.Header, - proxyURL string, -) (openAIWSClientConn, int, http.Header, error) { - _ = ctx - _ = wsURL - _ = headers - _ = proxyURL - return nil, 200, nil, nil -} - -func TestOpenAIWSConnPool_DialConnNilConnection(t *testing.T) { - cfg := &config.Config{} - cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 - cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 - - pool := newOpenAIWSConnPool(cfg) - pool.setClientDialerForTest(&openAIWSNilConnDialer{}) - account := &Account{ID: 91, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} - - _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ - Account: account, - WSURL: "wss://example.com/v1/responses", - }) - require.Error(t, err) - require.Contains(t, err.Error(), "nil connection") -} - -func TestOpenAIWSConnPool_SnapshotTransportMetrics(t *testing.T) { - cfg := &config.Config{} - pool := newOpenAIWSConnPool(cfg) - - dialer, ok := pool.clientDialer.(*coderOpenAIWSClientDialer) - require.True(t, ok) - - _, err := dialer.proxyHTTPClient("http://127.0.0.1:28080") - require.NoError(t, err) - _, err = dialer.proxyHTTPClient("http://127.0.0.1:28080") - require.NoError(t, err) - _, err = dialer.proxyHTTPClient("http://127.0.0.1:28081") - require.NoError(t, err) - - snapshot := pool.SnapshotTransportMetrics() - require.Equal(t, int64(1), snapshot.ProxyClientCacheHits) - require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses) - require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001) -} diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index df4d4871c..69aaeaec0 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -30,6 +30,7 @@ func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ @@ -201,7 +202,9 @@ func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t * func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) { gin.SetMode(gin.TestMode) + var wsAttempts atomic.Int32 ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) w.WriteHeader(http.StatusUpgradeRequired) _, _ = w.Write([]byte(`upgrade required`)) })) @@ -211,6 +214,7 @@ func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) { c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ @@ -260,6 +264,7 @@ func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) { require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP") require.Equal(t, http.StatusUpgradeRequired, rec.Code) require.Contains(t, rec.Body.String(), "426") + require.Equal(t, int32(1), wsAttempts.Load(), "426 upgrade_required 应快速失败,不应进行 WS 重试") } func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) { @@ -273,6 +278,7 @@ func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) { c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ @@ -332,6 +338,7 @@ func TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled(t *testing. c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ @@ -419,6 +426,7 @@ func TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenRetu c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) c.String(http.StatusAccepted, "already-written") upstream := &httpUpstreamRecorder{ @@ -508,6 +516,7 @@ func TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP(t *testin c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ @@ -590,6 +599,7 @@ func TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP(t *test c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ @@ -672,6 +682,7 @@ func TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP(t *tes c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ @@ -759,6 +770,7 @@ func TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbac c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ @@ -866,6 +878,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDrop c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ @@ -966,6 +979,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryF c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ @@ -1064,6 +1078,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryW c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ @@ -1161,6 +1176,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOn c, _ := gin.CreateTestContext(rec) c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportWS) upstream := &httpUpstreamRecorder{ resp: &http.Response{ diff --git a/backend/internal/service/openai_ws_protocol_resolver.go b/backend/internal/service/openai_ws_protocol_resolver.go index 368643bea..d4fcb472c 100644 --- a/backend/internal/service/openai_ws_protocol_resolver.go +++ b/backend/internal/service/openai_ws_protocol_resolver.go @@ -69,7 +69,7 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt switch mode { case OpenAIWSIngressModeOff: return openAIWSHTTPDecision("account_mode_off") - case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough: // continue default: return openAIWSHTTPDecision("account_mode_off") diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go index 5be76e28f..da23f0eb1 100644 --- a/backend/internal/service/openai_ws_protocol_resolver_test.go +++ b/backend/internal/service/openai_ws_protocol_resolver_test.go @@ -143,21 +143,49 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { cfg.Gateway.OpenAIWS.APIKeyEnabled = true cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true - cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeOff - account := &Account{ - Platform: PlatformOpenAI, - Type: AccountTypeOAuth, - Concurrency: 1, - Extra: map[string]any{ - "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, - }, - } - - t.Run("dedicated mode routes to ws v2", func(t *testing.T) { + t.Run("dedicated mode maps to ctx_pool and routes to ws v2", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account) + // dedicated is now mapped to ctx_pool for backward compatibility + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) + }) + + t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) { + ctxPoolAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(ctxPoolAccount) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) + }) + + t.Run("passthrough mode routes to ws v2", func(t *testing.T) { + passthroughAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount) require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) - require.Equal(t, "ws_v2_mode_dedicated", decision.Reason) + require.Equal(t, "ws_v2_mode_passthrough", decision.Reason) }) t.Run("off mode routes to http", func(t *testing.T) { @@ -174,7 +202,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { require.Equal(t, "account_mode_off", decision.Reason) }) - t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) { + t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) { legacyAccount := &Account{ Platform: PlatformOpenAI, Type: AccountTypeAPIKey, @@ -185,7 +213,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { } decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount) require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) - require.Equal(t, "ws_v2_mode_shared", decision.Reason) + require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason) }) t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) { @@ -193,7 +221,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{ - "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared, + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool, }, } decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency) diff --git a/backend/internal/service/openai_ws_recovery.go b/backend/internal/service/openai_ws_recovery.go new file mode 100644 index 000000000..c8528de0f --- /dev/null +++ b/backend/internal/service/openai_ws_recovery.go @@ -0,0 +1,760 @@ +package service + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/http" + "strings" + "time" + + coderws "github.com/coder/websocket" +) + +// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。 +type openAIWSFallbackError struct { + Reason string + Err error +} + +func (e *openAIWSFallbackError) Error() string { + if e == nil { + return "" + } + if e.Err == nil { + return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason)) + } + return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err) +} + +func (e *openAIWSFallbackError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +func wrapOpenAIWSFallback(reason string, err error) error { + return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err} +} + +// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。 +type OpenAIWSClientCloseError struct { + statusCode coderws.StatusCode + reason string + err error +} + +type openAIWSIngressTurnError struct { + stage string + cause error + wroteDownstream bool + partialResult *OpenAIForwardResult +} + +type openAIWSIngressUpstreamLease interface { + ConnID() string + QueueWaitDuration() time.Duration + ConnPickDuration() time.Duration + Reused() bool + ScheduleLayer() string + StickinessLevel() string + MigrationUsed() bool + HandshakeHeader(name string) string + IsPrewarmed() bool + MarkPrewarmed() + WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error + ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) + PingWithTimeout(timeout time.Duration) error + MarkBroken() + Yield() + Release() +} + +func (e *openAIWSIngressTurnError) Error() string { + if e == nil { + return "" + } + if e.cause == nil { + return strings.TrimSpace(e.stage) + } + return e.cause.Error() +} + +func (e *openAIWSIngressTurnError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + +func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error { + return wrapOpenAIWSIngressTurnErrorWithPartial(stage, cause, wroteDownstream, nil) +} + +func cloneOpenAIForwardResult(result *OpenAIForwardResult) *OpenAIForwardResult { + if result == nil { + return nil + } + cloned := *result + if result.PendingFunctionCallIDs != nil { + cloned.PendingFunctionCallIDs = make([]string, len(result.PendingFunctionCallIDs)) + copy(cloned.PendingFunctionCallIDs, result.PendingFunctionCallIDs) + } + return &cloned +} + +func wrapOpenAIWSIngressTurnErrorWithPartial(stage string, cause error, wroteDownstream bool, partialResult *OpenAIForwardResult) error { + if cause == nil { + return nil + } + return &openAIWSIngressTurnError{ + stage: strings.TrimSpace(stage), + cause: cause, + wroteDownstream: wroteDownstream, + partialResult: cloneOpenAIForwardResult(partialResult), + } +} + +// OpenAIWSIngressTurnPartialResult returns usage-bearing partial turn result +// when WS ingress turn aborts after receiving upstream events. +func OpenAIWSIngressTurnPartialResult(err error) (*OpenAIForwardResult, bool) { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil || turnErr.partialResult == nil { + return nil, false + } + return cloneOpenAIForwardResult(turnErr.partialResult), true +} + +func isOpenAIWSIngressTurnRetryable(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) { + return false + } + if turnErr.wroteDownstream { + return false + } + switch turnErr.stage { + case "write_upstream", "read_upstream": + return true + default: + return false + } +} + +func openAIWSIngressTurnRetryReason(err error) string { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return "unknown" + } + if turnErr.stage == "" { + return "unknown" + } + return turnErr.stage +} + +func isOpenAIWSIngressPreviousResponseNotFound(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound { + return false + } + return !turnErr.wroteDownstream +} + +func isOpenAIWSIngressToolOutputNotFound(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if strings.TrimSpace(turnErr.stage) != openAIWSIngressStageToolOutputNotFound { + return false + } + return !turnErr.wroteDownstream +} + +// openAIWSIngressTurnWroteDownstream 返回本次 turn 是否已向客户端写入过数据。 +// 用于 ContinueTurn abort 时判断是否需要补发 error 事件。 +func openAIWSIngressTurnWroteDownstream(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + return turnErr.wroteDownstream +} + +func isOpenAIWSIngressUpstreamErrorEvent(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + return strings.TrimSpace(turnErr.stage) == "upstream_error_event" +} + +func isOpenAIWSContinuationUnavailableCloseError(err error) bool { + var closeErr *OpenAIWSClientCloseError + if !errors.As(err, &closeErr) || closeErr == nil { + return false + } + if closeErr.StatusCode() != coderws.StatusPolicyViolation { + return false + } + return strings.Contains(closeErr.Reason(), openAIWSContinuationUnavailableReason) +} + +// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。 +func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error { + return &OpenAIWSClientCloseError{ + statusCode: statusCode, + reason: strings.TrimSpace(reason), + err: err, + } +} + +func (e *OpenAIWSClientCloseError) Error() string { + if e == nil { + return "" + } + if e.err == nil { + return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason)) + } + return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err) +} + +func (e *OpenAIWSClientCloseError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode { + if e == nil { + return coderws.StatusInternalError + } + return e.statusCode +} + +func (e *OpenAIWSClientCloseError) Reason() string { + if e == nil { + return "" + } + return strings.TrimSpace(e.reason) +} + +func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) { + if err == nil { + return "-", "-" + } + statusCode := coderws.CloseStatus(err) + if statusCode == -1 { + return "-", "-" + } + closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) + closeReason := "-" + var closeErr coderws.CloseError + if errors.As(err, &closeErr) { + reasonText := strings.TrimSpace(closeErr.Reason) + if reasonText != "" { + closeReason = normalizeOpenAIWSLogValue(reasonText) + } + } + return normalizeOpenAIWSLogValue(closeStatus), closeReason +} + +func unwrapOpenAIWSDialBaseError(err error) error { + if err == nil { + return nil + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil { + return dialErr.Err + } + return err +} + +func openAIWSDialRespHeaderForLog(err error, key string) string { + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil { + return "-" + } + return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen) +} + +func classifyOpenAIWSDialError(err error) string { + if err == nil { + return "-" + } + baseErr := unwrapOpenAIWSDialBaseError(err) + if baseErr == nil { + return "-" + } + if errors.Is(baseErr, context.DeadlineExceeded) { + return "ctx_deadline_exceeded" + } + if errors.Is(baseErr, context.Canceled) { + return "ctx_canceled" + } + var netErr net.Error + if errors.As(baseErr, &netErr) && netErr.Timeout() { + return "net_timeout" + } + if status := coderws.CloseStatus(baseErr); status != -1 { + return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status))) + } + message := strings.ToLower(strings.TrimSpace(baseErr.Error())) + switch { + case strings.Contains(message, "handshake not finished"): + return "handshake_not_finished" + case strings.Contains(message, "bad handshake"): + return "bad_handshake" + case strings.Contains(message, "connection refused"): + return "connection_refused" + case strings.Contains(message, "no such host"): + return "dns_not_found" + case strings.Contains(message, "tls"): + return "tls_error" + case strings.Contains(message, "i/o timeout"): + return "io_timeout" + case strings.Contains(message, "context deadline exceeded"): + return "ctx_deadline_exceeded" + default: + return "dial_error" + } +} + +func summarizeOpenAIWSDialError(err error) ( + statusCode int, + dialClass string, + closeStatus string, + closeReason string, + respServer string, + respVia string, + respCFRay string, + respRequestID string, +) { + dialClass = "-" + closeStatus = "-" + closeReason = "-" + respServer = "-" + respVia = "-" + respCFRay = "-" + respRequestID = "-" + if err == nil { + return + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil { + statusCode = dialErr.StatusCode + respServer = openAIWSDialRespHeaderForLog(err, "server") + respVia = openAIWSDialRespHeaderForLog(err, "via") + respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray") + respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id") + } + dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err)) + closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err)) + return +} + +func isOpenAIWSClientDisconnectError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return true + } + switch coderws.CloseStatus(err) { + case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if message == "" { + return false + } + return strings.Contains(message, "failed to read frame header: eof") || + strings.Contains(message, "unexpected eof") || + strings.Contains(message, "use of closed network connection") || + strings.Contains(message, "connection reset by peer") || + strings.Contains(message, "broken pipe") +} + +func classifyOpenAIWSIngressReadErrorClass(err error) string { + if err == nil { + return "unknown" + } + if errors.Is(err, context.Canceled) { + return "context_canceled" + } + if errors.Is(err, context.DeadlineExceeded) { + return "deadline_exceeded" + } + switch coderws.CloseStatus(err) { + case coderws.StatusServiceRestart: + return "service_restart" + case coderws.StatusTryAgainLater: + return "try_again_later" + } + if isOpenAIWSClientDisconnectError(err) { + return "upstream_closed" + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return "eof" + } + return "unknown" +} + +func isOpenAIWSStreamWriteDisconnectError(err error, reqCtx context.Context) bool { + if err == nil { + return false + } + if reqCtx != nil && reqCtx.Err() != nil { + return true + } + return isOpenAIWSClientDisconnectError(err) +} + +func openAIWSIngressResolveDrainReadTimeout( + baseTimeout time.Duration, + disconnectDeadline time.Time, + now time.Time, +) (time.Duration, bool) { + if disconnectDeadline.IsZero() { + return baseTimeout, false + } + remaining := disconnectDeadline.Sub(now) + if remaining <= 0 { + return 0, true + } + if baseTimeout <= 0 || remaining < baseTimeout { + return remaining, false + } + return baseTimeout, false +} + +func openAIWSIngressClientDisconnectedDrainTimeoutError(timeout time.Duration) error { + if timeout <= 0 { + timeout = openAIWSIngressClientDisconnectDrainTimeout + } + return fmt.Errorf("client disconnected before upstream terminal event (drain timeout=%s): %w", timeout, context.Canceled) +} + +func openAIWSIngressPumpClosedTurnError( + clientDisconnected bool, + wroteDownstream bool, + partialResult *OpenAIForwardResult, +) error { + if clientDisconnected { + return wrapOpenAIWSIngressTurnErrorWithPartial( + "client_disconnected_drain_timeout", + openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout), + wroteDownstream, + partialResult, + ) + } + return wrapOpenAIWSIngressTurnErrorWithPartial( + "read_upstream", + errors.New("upstream event pump closed unexpectedly"), + wroteDownstream, + partialResult, + ) +} + +func shouldFlushOpenAIWSBufferedEventsOnError(reqStream bool, wroteDownstream bool, clientDisconnected bool) bool { + return reqStream && wroteDownstream && !clientDisconnected +} + +// errOpenAIWSClientPreempted 表示客户端在当前 turn 尚未完成时发送了新的 response.create 请求。 +var errOpenAIWSClientPreempted = errors.New("client preempted current turn with new request") + +var errOpenAIWSAdvanceClientReadUnavailable = errors.New("client reader channels unavailable") + +func openAIWSAdvanceConsumePendingClientReadErr(pendingErr *error) error { + if pendingErr == nil || *pendingErr == nil { + return nil + } + readErr := *pendingErr + *pendingErr = nil + return fmt.Errorf("read client websocket request: %w", readErr) +} + +func openAIWSAdvanceClientReadUnavailable(clientMsgCh <-chan []byte, clientReadErrCh <-chan error) bool { + return clientMsgCh == nil && clientReadErrCh == nil +} + +// isOpenAIWSUpstreamRestartCloseError 检测上游是否因服务重启/维护关闭了连接。 +// 1012=ServiceRestart, 1013=TryAgainLater,都是临时性上游维护,proxy 应视为可恢复错误。 +func isOpenAIWSUpstreamRestartCloseError(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if turnErr.stage != "read_upstream" { + return false + } + status := coderws.CloseStatus(turnErr.cause) + return status == 1012 || status == 1013 // ServiceRestart, TryAgainLater +} + +func classifyOpenAIWSIngressTurnAbortReason(err error) (openAIWSIngressTurnAbortReason, bool) { + if err == nil { + return openAIWSIngressTurnAbortReasonUnknown, false + } + if isOpenAIWSIngressPreviousResponseNotFound(err) { + return openAIWSIngressTurnAbortReasonPreviousResponse, true + } + if isOpenAIWSIngressToolOutputNotFound(err) { + return openAIWSIngressTurnAbortReasonToolOutput, true + } + if isOpenAIWSIngressUpstreamErrorEvent(err) { + return openAIWSIngressTurnAbortReasonUpstreamError, true + } + if isOpenAIWSContinuationUnavailableCloseError(err) { + return openAIWSIngressTurnAbortReasonContinuationUnavailable, true + } + if errors.Is(err, errOpenAIWSClientPreempted) { + return openAIWSIngressTurnAbortReasonClientPreempted, true + } + if errors.Is(err, context.Canceled) { + return openAIWSIngressTurnAbortReasonContextCanceled, true + } + if errors.Is(err, context.DeadlineExceeded) { + return openAIWSIngressTurnAbortReasonContextDeadline, false + } + if isOpenAIWSClientDisconnectError(err) { + return openAIWSIngressTurnAbortReasonClientClosed, true + } + // 上游 ServiceRestart/TryAgainLater:必须在 stage-based 分类之前检测, + // 否则会被 "read_upstream" 分支兜底为 FailRequest。 + if isOpenAIWSUpstreamRestartCloseError(err) { + return openAIWSIngressTurnAbortReasonUpstreamRestart, true + } + + var turnErr *openAIWSIngressTurnError + if errors.As(err, &turnErr) && turnErr != nil { + switch strings.TrimSpace(turnErr.stage) { + case "idle_timeout": + return openAIWSIngressTurnAbortReasonContextDeadline, false + case "write_upstream": + return openAIWSIngressTurnAbortReasonWriteUpstream, false + case "read_upstream": + return openAIWSIngressTurnAbortReasonReadUpstream, false + case "write_client": + return openAIWSIngressTurnAbortReasonWriteClient, false + } + } + return openAIWSIngressTurnAbortReasonUnknown, false +} + +func openAIWSIngressTurnAbortDispositionForReason(reason openAIWSIngressTurnAbortReason) openAIWSIngressTurnAbortDisposition { + switch reason { + case openAIWSIngressTurnAbortReasonPreviousResponse, + openAIWSIngressTurnAbortReasonToolOutput, + openAIWSIngressTurnAbortReasonUpstreamError, + openAIWSIngressTurnAbortReasonClientPreempted, + openAIWSIngressTurnAbortReasonUpstreamRestart: + return openAIWSIngressTurnAbortDispositionContinueTurn + case openAIWSIngressTurnAbortReasonContextCanceled, + openAIWSIngressTurnAbortReasonClientClosed: + return openAIWSIngressTurnAbortDispositionCloseGracefully + default: + return openAIWSIngressTurnAbortDispositionFailRequest + } +} + +func classifyOpenAIWSReadFallbackReason(err error) string { + if err == nil { + return "read_event" + } + switch coderws.CloseStatus(err) { + case coderws.StatusServiceRestart: + return "service_restart" + case coderws.StatusTryAgainLater: + return "try_again_later" + case coderws.StatusPolicyViolation: + return "policy_violation" + case coderws.StatusMessageTooBig: + return "message_too_big" + default: + return "read_event" + } +} + +func classifyOpenAIWSAcquireError(err error) string { + if err == nil { + return "acquire_conn" + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) { + switch dialErr.StatusCode { + case 426: + return "upgrade_required" + case 401, 403: + return "auth_failed" + case 429: + return "upstream_rate_limited" + } + if dialErr.StatusCode >= 500 { + return "upstream_5xx" + } + return "dial_failed" + } + if errors.Is(err, errOpenAIWSConnQueueFull) { + return "conn_queue_full" + } + if errors.Is(err, errOpenAIWSPreferredConnUnavailable) { + return "preferred_conn_unavailable" + } + if errors.Is(err, context.DeadlineExceeded) { + return "acquire_timeout" + } + return "acquire_conn" +} + +func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + switch code { + case "upgrade_required": + return "upgrade_required", true + case "websocket_not_supported", "websocket_unsupported": + return "ws_unsupported", true + case "websocket_connection_limit_reached": + return "ws_connection_limit_reached", true + case "previous_response_not_found": + return "previous_response_not_found", true + } + if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { + return "upgrade_required", true + } + if strings.Contains(errType, "upgrade") { + return "upgrade_required", true + } + if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") { + return "ws_unsupported", true + } + if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { + return "ws_connection_limit_reached", true + } + if strings.Contains(msg, "previous_response_not_found") || + (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { + return "previous_response_not_found", true + } + // "No tool output found for function call " / "No tool call found for function call output..." + // 表示 previous_response_id 指向的 response 包含未完成的 function_call(例如用户在 Codex CLI + // 按 ESC 取消 function_call 后重新发送消息)。此时 previous_response_id 本身就是问题,需要移除后重放。 + if strings.Contains(msg, "no tool output found") || + strings.Contains(msg, "no tool call found for function call output") || + (strings.Contains(msg, "no tool call found") && strings.Contains(msg, "function call output")) { + return openAIWSIngressStageToolOutputNotFound, true + } + if strings.Contains(msg, "without its required following item") || + strings.Contains(msg, "without its required preceding item") { + return openAIWSIngressStageToolOutputNotFound, true + } + if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") { + return "upstream_error_event", true + } + return "event_error", false +} + +func classifyOpenAIWSErrorEvent(message []byte) (string, bool) { + if len(message) == 0 { + return "event_error", false + } + return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + switch { + case strings.Contains(errType, "invalid_request"), + strings.Contains(code, "invalid_request"), + strings.Contains(code, "bad_request"), + code == "previous_response_not_found": + return http.StatusBadRequest + case strings.Contains(errType, "authentication"), + strings.Contains(code, "invalid_api_key"), + strings.Contains(code, "unauthorized"): + return http.StatusUnauthorized + case strings.Contains(errType, "permission"), + strings.Contains(code, "forbidden"): + return http.StatusForbidden + case strings.Contains(errType, "rate_limit"), + strings.Contains(code, "rate_limit"), + strings.Contains(code, "insufficient_quota"): + return http.StatusTooManyRequests + default: + return http.StatusBadGateway + } +} + +func openAIWSErrorHTTPStatus(message []byte) int { + if len(message) == 0 { + return http.StatusBadGateway + } + codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message) + return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) +} + +func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration { + if s == nil || s.cfg == nil { + return 30 * time.Second + } + seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + +func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool { + if s == nil || accountID <= 0 { + return false + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return false + } + rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID) + if !ok || rawUntil == nil { + return false + } + until, ok := rawUntil.(time.Time) + if !ok || until.IsZero() { + s.openaiWSFallbackUntil.Delete(accountID) + return false + } + if time.Now().Before(until) { + return true + } + s.openaiWSFallbackUntil.Delete(accountID) + return false +} + +func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) { + if s == nil || accountID <= 0 { + return + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return + } + s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown)) +} + +func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) { + if s == nil || accountID <= 0 { + return + } + s.openaiWSFallbackUntil.Delete(accountID) +} diff --git a/backend/internal/service/openai_ws_state_store.go b/backend/internal/service/openai_ws_state_store.go index b606baa1a..cff5ab6cd 100644 --- a/backend/internal/service/openai_ws_state_store.go +++ b/backend/internal/service/openai_ws_state_store.go @@ -4,11 +4,13 @@ import ( "context" "crypto/sha256" "encoding/hex" - "fmt" + "strconv" "strings" "sync" "sync/atomic" "time" + + "github.com/cespare/xxhash/v2" ) const ( @@ -17,6 +19,7 @@ const ( openAIWSStateStoreCleanupMaxPerMap = 512 openAIWSStateStoreMaxEntriesPerMap = 65536 openAIWSStateStoreRedisTimeout = 3 * time.Second + openAIWSStateStoreHotCacheTTL = time.Minute ) type openAIWSAccountBinding struct { @@ -29,6 +32,11 @@ type openAIWSConnBinding struct { expiresAt time.Time } +type openAIWSResponsePendingToolCallsBinding struct { + callIDs []string + expiresAt time.Time +} + type openAIWSTurnStateBinding struct { turnState string expiresAt time.Time @@ -39,6 +47,23 @@ type openAIWSSessionConnBinding struct { expiresAt time.Time } +type openAIWSSessionLastResponseBinding struct { + responseID string + expiresAt time.Time +} + +type openAIWSStateStoreSessionLastResponseCache interface { + SetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash, responseID string, ttl time.Duration) error + GetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) (string, error) + DeleteOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) error +} + +type openAIWSStateStoreResponsePendingToolCallsCache interface { + SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error + GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error) + DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error +} + // OpenAIWSStateStore 管理 WSv2 的粘连状态。 // - response_id -> account_id 用于续链路由 // - response_id -> conn_id 用于连接内上下文复用 @@ -53,44 +78,111 @@ type OpenAIWSStateStore interface { BindResponseConn(responseID, connID string, ttl time.Duration) GetResponseConn(responseID string) (string, bool) DeleteResponseConn(responseID string) + BindResponsePendingToolCalls(groupID int64, responseID string, callIDs []string, ttl time.Duration) + GetResponsePendingToolCalls(groupID int64, responseID string) ([]string, bool) + DeleteResponsePendingToolCalls(groupID int64, responseID string) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) DeleteSessionTurnState(groupID int64, sessionHash string) + BindSessionLastResponseID(groupID int64, sessionHash, responseID string, ttl time.Duration) + GetSessionLastResponseID(groupID int64, sessionHash string) (string, bool) + DeleteSessionLastResponseID(groupID int64, sessionHash string) + BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) GetSessionConn(groupID int64, sessionHash string) (string, bool) DeleteSessionConn(groupID int64, sessionHash string) } +const openAIWSStateStoreConnShards = 16 + +type openAIWSConnBindingShard struct { + mu sync.RWMutex + m map[string]openAIWSConnBinding +} + type defaultOpenAIWSStateStore struct { cache GatewayCache - responseToAccountMu sync.RWMutex - responseToAccount map[string]openAIWSAccountBinding - responseToConnMu sync.RWMutex - responseToConn map[string]openAIWSConnBinding - sessionToTurnStateMu sync.RWMutex - sessionToTurnState map[string]openAIWSTurnStateBinding - sessionToConnMu sync.RWMutex - sessionToConn map[string]openAIWSSessionConnBinding + responseToAccountMu sync.RWMutex + responseToAccount map[string]openAIWSAccountBinding + responseToConnShards [openAIWSStateStoreConnShards]openAIWSConnBindingShard + responsePendingToolMu sync.RWMutex + responsePendingTool map[string]openAIWSResponsePendingToolCallsBinding + sessionToTurnStateMu sync.RWMutex + sessionToTurnState map[string]openAIWSTurnStateBinding + sessionToLastRespMu sync.RWMutex + sessionToLastResp map[string]openAIWSSessionLastResponseBinding + sessionToConnMu sync.RWMutex + sessionToConn map[string]openAIWSSessionConnBinding lastCleanupUnixNano atomic.Int64 + stopCh chan struct{} + stopOnce sync.Once + workerWg sync.WaitGroup +} + +func (s *defaultOpenAIWSStateStore) connShard(key string) *openAIWSConnBindingShard { + h := xxhash.Sum64String(key) + return &s.responseToConnShards[h%openAIWSStateStoreConnShards] } // NewOpenAIWSStateStore 创建默认 WS 状态存储。 func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore { + return newOpenAIWSStateStore(cache, openAIWSStateStoreCleanupInterval) +} + +func newOpenAIWSStateStore(cache GatewayCache, cleanupInterval time.Duration) *defaultOpenAIWSStateStore { store := &defaultOpenAIWSStateStore{ - cache: cache, - responseToAccount: make(map[string]openAIWSAccountBinding, 256), - responseToConn: make(map[string]openAIWSConnBinding, 256), - sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256), - sessionToConn: make(map[string]openAIWSSessionConnBinding, 256), + cache: cache, + responseToAccount: make(map[string]openAIWSAccountBinding, 256), + responsePendingTool: make(map[string]openAIWSResponsePendingToolCallsBinding, 256), + sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256), + sessionToLastResp: make(map[string]openAIWSSessionLastResponseBinding, 256), + sessionToConn: make(map[string]openAIWSSessionConnBinding, 256), + stopCh: make(chan struct{}), + } + for i := range store.responseToConnShards { + store.responseToConnShards[i].m = make(map[string]openAIWSConnBinding, 16) } store.lastCleanupUnixNano.Store(time.Now().UnixNano()) + store.startCleanupWorker(cleanupInterval) return store } +func (s *defaultOpenAIWSStateStore) startCleanupWorker(interval time.Duration) { + if s == nil || interval <= 0 { + return + } + s.workerWg.Add(1) + go func() { + defer s.workerWg.Done() + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-s.stopCh: + return + case <-ticker.C: + s.maybeCleanup() + } + } + }() +} + +func (s *defaultOpenAIWSStateStore) Close() { + if s == nil { + return + } + s.stopOnce.Do(func() { + if s.stopCh != nil { + close(s.stopCh) + } + }) + s.workerWg.Wait() +} + func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error { id := normalizeOpenAIWSResponseID(responseID) if id == "" || accountID <= 0 { @@ -99,19 +191,31 @@ func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, gro ttl = normalizeOpenAIWSTTL(ttl) s.maybeCleanup() - expiresAt := time.Now().Add(ttl) + var redisErr error + if s.cache != nil { + cacheKey := openAIWSResponseAccountCacheKey(id) + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + redisErr = s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl) + cancel() + if redisErr != nil { + logOpenAIWSModeInfo( + "state_store_bind_response_account_redis_fail group_id=%d response_id=%s account_id=%d cause=%s", + groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), accountID, truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen), + ) + } + } + + // 无论 Redis 是否写成功,都写入本地缓存作为降级保障。 + localTTL := openAIWSStateStoreLocalHotTTL(ttl) s.responseToAccountMu.Lock() ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap) - s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt} + s.responseToAccount[id] = openAIWSAccountBinding{ + accountID: accountID, + expiresAt: time.Now().Add(localTTL), + } s.responseToAccountMu.Unlock() - if s.cache == nil { - return nil - } - cacheKey := openAIWSResponseAccountCacheKey(id) - cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) - defer cancel() - return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl) + return redisErr } func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) { @@ -119,7 +223,6 @@ func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, grou if id == "" { return 0, nil } - s.maybeCleanup() now := time.Now() s.responseToAccountMu.RLock() @@ -138,13 +241,33 @@ func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, grou cacheKey := openAIWSResponseAccountCacheKey(id) cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) - defer cancel() accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey) - if err != nil || accountID <= 0 { + cancel() + if err == nil && accountID > 0 { + return accountID, nil + } + + // Compatibility fallback for pre-v2 cache keys. + legacyCacheKey := openAIWSResponseAccountLegacyCacheKey(id) + legacyCtx, legacyCancel := withOpenAIWSStateStoreRedisTimeout(ctx) + legacyAccountID, legacyErr := s.cache.GetSessionAccountID(legacyCtx, groupID, legacyCacheKey) + legacyCancel() + if legacyErr != nil || legacyAccountID <= 0 { // 缓存读取失败不阻断主流程,按未命中降级。 return 0, nil } - return accountID, nil + + logOpenAIWSModeInfo( + "state_store_get_response_account_legacy_fallback group_id=%d response_id=%s account_id=%d", + groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), legacyAccountID, + ) + + // Best effort: backfill v2 key so subsequent reads avoid legacy fallback. + backfillCtx, backfillCancel := withOpenAIWSStateStoreRedisTimeout(ctx) + _ = s.cache.SetSessionAccountID(backfillCtx, groupID, cacheKey, legacyAccountID, openAIWSStateStoreHotCacheTTL) + backfillCancel() + + return legacyAccountID, nil } func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error { @@ -161,7 +284,15 @@ func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, g } cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) defer cancel() - return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id)) + primaryKey := openAIWSResponseAccountCacheKey(id) + if err := s.cache.DeleteSessionAccountID(cacheCtx, groupID, primaryKey); err != nil { + return err + } + legacyKey := openAIWSResponseAccountLegacyCacheKey(id) + if legacyKey == "" || legacyKey == primaryKey { + return nil + } + return s.cache.DeleteSessionAccountID(cacheCtx, groupID, legacyKey) } func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) { @@ -173,13 +304,14 @@ func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl = normalizeOpenAIWSTTL(ttl) s.maybeCleanup() - s.responseToConnMu.Lock() - ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap) - s.responseToConn[id] = openAIWSConnBinding{ + shard := s.connShard(id) + shard.mu.Lock() + ensureBindingCapacity(shard.m, id, openAIWSStateStoreMaxEntriesPerMap/openAIWSStateStoreConnShards) + shard.m[id] = openAIWSConnBinding{ connID: conn, expiresAt: time.Now().Add(ttl), } - s.responseToConnMu.Unlock() + shard.mu.Unlock() } func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) { @@ -187,13 +319,13 @@ func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, if id == "" { return "", false } - s.maybeCleanup() now := time.Now() - s.responseToConnMu.RLock() - binding, ok := s.responseToConn[id] - s.responseToConnMu.RUnlock() - if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" { + shard := s.connShard(id) + shard.mu.RLock() + binding, ok := shard.m[id] + shard.mu.RUnlock() + if !ok || now.After(binding.expiresAt) || binding.connID == "" { return "", false } return binding.connID, true @@ -204,9 +336,115 @@ func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) { if id == "" { return } - s.responseToConnMu.Lock() - delete(s.responseToConn, id) - s.responseToConnMu.Unlock() + shard := s.connShard(id) + shard.mu.Lock() + delete(shard.m, id) + shard.mu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) BindResponsePendingToolCalls(groupID int64, responseID string, callIDs []string, ttl time.Duration) { + id := normalizeOpenAIWSResponseID(responseID) + normalizedCallIDs := normalizeOpenAIWSPendingToolCallIDs(callIDs) + if id == "" || len(normalizedCallIDs) == 0 { + return + } + key := openAIWSResponsePendingToolCallsBindingKey(groupID, id) + if key == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.responsePendingToolMu.Lock() + ensureBindingCapacity(s.responsePendingTool, key, openAIWSStateStoreMaxEntriesPerMap) + s.responsePendingTool[key] = openAIWSResponsePendingToolCallsBinding{ + callIDs: append([]string(nil), normalizedCallIDs...), + expiresAt: time.Now().Add(ttl), + } + s.responsePendingToolMu.Unlock() + + if cache := s.responsePendingToolCallsCache(); cache != nil { + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + if redisErr := cache.SetOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id, normalizedCallIDs, ttl); redisErr != nil { + logOpenAIWSModeInfo( + "state_store_bind_response_pending_tool_calls_redis_fail group_id=%d response_id=%s call_count=%d cause=%s", + groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), len(normalizedCallIDs), truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen), + ) + } + } +} + +func (s *defaultOpenAIWSStateStore) GetResponsePendingToolCalls(groupID int64, responseID string) ([]string, bool) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return nil, false + } + key := openAIWSResponsePendingToolCallsBindingKey(groupID, id) + if key == "" { + return nil, false + } + + now := time.Now() + s.responsePendingToolMu.RLock() + binding, ok := s.responsePendingTool[key] + s.responsePendingToolMu.RUnlock() + if !ok || now.After(binding.expiresAt) || len(binding.callIDs) == 0 { + cache := s.responsePendingToolCallsCache() + if cache == nil { + return nil, false + } + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + callIDs, err := cache.GetOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id) + normalizedCallIDs := normalizeOpenAIWSPendingToolCallIDs(callIDs) + if err != nil || len(normalizedCallIDs) == 0 { + if err != nil { + logOpenAIWSModeInfo( + "state_store_get_response_pending_tool_calls_redis_fail group_id=%d response_id=%s cause=%s", + groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + } + return nil, false + } + + logOpenAIWSModeInfo( + "state_store_get_response_pending_tool_calls_redis_hit group_id=%d response_id=%s call_count=%d", + groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), len(normalizedCallIDs), + ) + + // Redis 命中后回填本地热缓存,降低后续访问开销。 + s.responsePendingToolMu.Lock() + ensureBindingCapacity(s.responsePendingTool, key, openAIWSStateStoreMaxEntriesPerMap) + s.responsePendingTool[key] = openAIWSResponsePendingToolCallsBinding{ + callIDs: append([]string(nil), normalizedCallIDs...), + expiresAt: time.Now().Add(openAIWSStateStoreHotCacheTTL), + } + s.responsePendingToolMu.Unlock() + return normalizedCallIDs, true + } + // binding.callIDs was already copied at bind time; return directly (callers are read-only). + return binding.callIDs, true +} + +func (s *defaultOpenAIWSStateStore) DeleteResponsePendingToolCalls(groupID int64, responseID string) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return + } + key := openAIWSResponsePendingToolCallsBindingKey(groupID, id) + if key == "" { + return + } + s.responsePendingToolMu.Lock() + delete(s.responsePendingTool, key) + s.responsePendingToolMu.Unlock() + + if cache := s.responsePendingToolCallsCache(); cache != nil { + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + _ = cache.DeleteOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id) + } } func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) { @@ -232,7 +470,6 @@ func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHa if key == "" { return "", false } - s.maybeCleanup() now := time.Now() s.sessionToTurnStateMu.RLock() @@ -254,6 +491,99 @@ func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessio s.sessionToTurnStateMu.Unlock() } +func (s *defaultOpenAIWSStateStore) BindSessionLastResponseID(groupID int64, sessionHash, responseID string, ttl time.Duration) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + id := normalizeOpenAIWSResponseID(responseID) + if key == "" || id == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.sessionToLastRespMu.Lock() + ensureBindingCapacity(s.sessionToLastResp, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToLastResp[key] = openAIWSSessionLastResponseBinding{ + responseID: id, + expiresAt: time.Now().Add(ttl), + } + s.sessionToLastRespMu.Unlock() + + if cache := s.sessionLastResponseCache(); cache != nil { + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + if redisErr := cache.SetOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash), id, ttl); redisErr != nil { + logOpenAIWSModeInfo( + "state_store_bind_session_last_response_redis_fail group_id=%d session_hash=%s response_id=%s cause=%s", + groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen), + ) + } + } +} + +func (s *defaultOpenAIWSStateStore) GetSessionLastResponseID(groupID int64, sessionHash string) (string, bool) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return "", false + } + + now := time.Now() + s.sessionToLastRespMu.RLock() + binding, ok := s.sessionToLastResp[key] + s.sessionToLastRespMu.RUnlock() + if ok && now.Before(binding.expiresAt) && strings.TrimSpace(binding.responseID) != "" { + return binding.responseID, true + } + + cache := s.sessionLastResponseCache() + if cache == nil { + return "", false + } + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + responseID, err := cache.GetOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash)) + responseID = normalizeOpenAIWSResponseID(responseID) + if err != nil || responseID == "" { + if err != nil { + logOpenAIWSModeInfo( + "state_store_get_session_last_response_redis_fail group_id=%d session_hash=%s cause=%s", + groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + } + return "", false + } + + logOpenAIWSModeInfo( + "state_store_get_session_last_response_redis_hit group_id=%d session_hash=%s response_id=%s", + groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + ) + + // Redis 命中后回填本地热缓存,降低后续访问开销。 + s.sessionToLastRespMu.Lock() + ensureBindingCapacity(s.sessionToLastResp, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToLastResp[key] = openAIWSSessionLastResponseBinding{ + responseID: responseID, + expiresAt: time.Now().Add(openAIWSStateStoreHotCacheTTL), + } + s.sessionToLastRespMu.Unlock() + return responseID, true +} + +func (s *defaultOpenAIWSStateStore) DeleteSessionLastResponseID(groupID int64, sessionHash string) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return + } + s.sessionToLastRespMu.Lock() + delete(s.sessionToLastResp, key) + s.sessionToLastRespMu.Unlock() + + if cache := s.sessionLastResponseCache(); cache != nil { + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + _ = cache.DeleteOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash)) + } +} + func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) { key := openAIWSSessionTurnStateKey(groupID, sessionHash) conn := strings.TrimSpace(connID) @@ -277,7 +607,6 @@ func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash st if key == "" { return "", false } - s.maybeCleanup() now := time.Now() s.sessionToConnMu.RLock() @@ -317,14 +646,29 @@ func (s *defaultOpenAIWSStateStore) maybeCleanup() { cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap) s.responseToAccountMu.Unlock() - s.responseToConnMu.Lock() - cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap) - s.responseToConnMu.Unlock() + perShardLimit := openAIWSStateStoreCleanupMaxPerMap / openAIWSStateStoreConnShards + if perShardLimit < 32 { + perShardLimit = 32 + } + for i := range s.responseToConnShards { + shard := &s.responseToConnShards[i] + shard.mu.Lock() + cleanupExpiredConnBindings(shard.m, now, perShardLimit) + shard.mu.Unlock() + } + + s.responsePendingToolMu.Lock() + cleanupExpiredResponsePendingToolCallsBindings(s.responsePendingTool, now, openAIWSStateStoreCleanupMaxPerMap) + s.responsePendingToolMu.Unlock() s.sessionToTurnStateMu.Lock() cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap) s.sessionToTurnStateMu.Unlock() + s.sessionToLastRespMu.Lock() + cleanupExpiredSessionLastResponseBindings(s.sessionToLastResp, now, openAIWSStateStoreCleanupMaxPerMap) + s.sessionToLastRespMu.Unlock() + s.sessionToConnMu.Lock() cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap) s.sessionToConnMu.Unlock() @@ -362,6 +706,22 @@ func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now tim } } +func cleanupExpiredResponsePendingToolCallsBindings(bindings map[string]openAIWSResponsePendingToolCallsBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) { if len(bindings) == 0 || maxScan <= 0 { return @@ -378,6 +738,22 @@ func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBindin } } +func cleanupExpiredSessionLastResponseBindings(bindings map[string]openAIWSSessionLastResponseBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) { if len(bindings) == 0 || maxScan <= 0 { return @@ -394,17 +770,56 @@ func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBi } } -func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) { +type expiringBinding interface { + getExpiresAt() time.Time +} + +func (b openAIWSAccountBinding) getExpiresAt() time.Time { return b.expiresAt } +func (b openAIWSConnBinding) getExpiresAt() time.Time { return b.expiresAt } +func (b openAIWSResponsePendingToolCallsBinding) getExpiresAt() time.Time { return b.expiresAt } +func (b openAIWSTurnStateBinding) getExpiresAt() time.Time { return b.expiresAt } +func (b openAIWSSessionConnBinding) getExpiresAt() time.Time { return b.expiresAt } +func (b openAIWSSessionLastResponseBinding) getExpiresAt() time.Time { return b.expiresAt } + +func ensureBindingCapacity[T expiringBinding](bindings map[string]T, incomingKey string, maxEntries int) { if len(bindings) < maxEntries || maxEntries <= 0 { return } if _, exists := bindings[incomingKey]; exists { return } - // 固定上限保护:淘汰任意一项,优先保证内存有界。 - for key := range bindings { - delete(bindings, key) - return + // 优先驱逐已过期条目;若不存在过期项,则按 expiresAt 最早驱逐,避免随机删除活跃绑定。 + now := time.Now() + for key, val := range bindings { + if !val.getExpiresAt().IsZero() && now.After(val.getExpiresAt()) { + delete(bindings, key) + return + } + } + var ( + evictKey string + evictExpireAt time.Time + hasCandidate bool + ) + for key, val := range bindings { + expiresAt := val.getExpiresAt() + if !hasCandidate { + evictKey = key + evictExpireAt = expiresAt + hasCandidate = true + continue + } + switch { + case expiresAt.IsZero() && !evictExpireAt.IsZero(): + evictKey = key + evictExpireAt = expiresAt + case !expiresAt.IsZero() && !evictExpireAt.IsZero() && expiresAt.Before(evictExpireAt): + evictKey = key + evictExpireAt = expiresAt + } + } + if hasCandidate { + delete(bindings, evictKey) } } @@ -412,7 +827,46 @@ func normalizeOpenAIWSResponseID(responseID string) string { return strings.TrimSpace(responseID) } +func normalizeOpenAIWSPendingToolCallIDs(callIDs []string) []string { + if len(callIDs) == 0 { + return nil + } + seen := make(map[string]struct{}, len(callIDs)) + normalized := make([]string, 0, len(callIDs)) + for _, callID := range callIDs { + id := strings.TrimSpace(callID) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + normalized = append(normalized, id) + } + return normalized +} + +func openAIWSResponsePendingToolCallsBindingKey(groupID int64, responseID string) string { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return "" + } + return strconv.FormatInt(groupID, 10) + ":" + id +} + func openAIWSResponseAccountCacheKey(responseID string) string { + h := xxhash.Sum64String(responseID) + // Pad to 16 hex chars for consistent key length. + hex := strconv.FormatUint(h, 16) + const pad = "0000000000000000" + if len(hex) < 16 { + hex = pad[:16-len(hex)] + hex + } + return openAIWSResponseAccountCachePrefix + "v2:" + hex +} + +func openAIWSResponseAccountLegacyCacheKey(responseID string) string { sum := sha256.Sum256([]byte(responseID)) return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:]) } @@ -424,12 +878,42 @@ func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration { return ttl } +func openAIWSStateStoreLocalHotTTL(ttl time.Duration) time.Duration { + ttl = normalizeOpenAIWSTTL(ttl) + if ttl > openAIWSStateStoreHotCacheTTL { + return openAIWSStateStoreHotCacheTTL + } + return ttl +} + +func (s *defaultOpenAIWSStateStore) sessionLastResponseCache() openAIWSStateStoreSessionLastResponseCache { + if s == nil || s.cache == nil { + return nil + } + cache, ok := s.cache.(openAIWSStateStoreSessionLastResponseCache) + if !ok { + return nil + } + return cache +} + +func (s *defaultOpenAIWSStateStore) responsePendingToolCallsCache() openAIWSStateStoreResponsePendingToolCallsCache { + if s == nil || s.cache == nil { + return nil + } + cache, ok := s.cache.(openAIWSStateStoreResponsePendingToolCallsCache) + if !ok { + return nil + } + return cache +} + func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string { hash := strings.TrimSpace(sessionHash) if hash == "" { return "" } - return fmt.Sprintf("%d:%s", groupID, hash) + return strconv.FormatInt(groupID, 10) + ":" + hash } func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) { diff --git a/backend/internal/service/openai_ws_state_store_test.go b/backend/internal/service/openai_ws_state_store_test.go index 235d42331..f11a05ec9 100644 --- a/backend/internal/service/openai_ws_state_store_test.go +++ b/backend/internal/service/openai_ws_state_store_test.go @@ -41,6 +41,27 @@ func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) { require.False(t, ok) } +func TestOpenAIWSStateStore_ResponsePendingToolCallsTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + groupID := int64(9) + store.BindResponsePendingToolCalls(groupID, "resp_pending_tool_1", []string{"call_1", "call_2", "call_1", " "}, 30*time.Millisecond) + + callIDs, ok := store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_1") + require.True(t, ok) + require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs) + _, ok = store.GetResponsePendingToolCalls(groupID+1, "resp_pending_tool_1") + require.False(t, ok, "pending tool calls should be group-isolated") + + store.DeleteResponsePendingToolCalls(groupID, "resp_pending_tool_1") + _, ok = store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_1") + require.False(t, ok) + + store.BindResponsePendingToolCalls(groupID, "resp_pending_tool_2", []string{"call_3"}, 30*time.Millisecond) + time.Sleep(60 * time.Millisecond) + _, ok = store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_2") + require.False(t, ok) +} + func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) { store := NewOpenAIWSStateStore(nil) store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond) @@ -75,6 +96,178 @@ func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) { require.False(t, ok) } +func TestOpenAIWSStateStore_SessionLastResponseIDTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindSessionLastResponseID(9, "session_hash_resp_1", "resp_1", 30*time.Millisecond) + + responseID, ok := store.GetSessionLastResponseID(9, "session_hash_resp_1") + require.True(t, ok) + require.Equal(t, "resp_1", responseID) + + // group 隔离 + _, ok = store.GetSessionLastResponseID(10, "session_hash_resp_1") + require.False(t, ok) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetSessionLastResponseID(9, "session_hash_resp_1") + require.False(t, ok) +} + +type openAIWSSessionLastResponseProbeCache struct { + sessionData map[string]string + setCalled bool + getCalled bool + delCalled bool +} + +func (c *openAIWSSessionLastResponseProbeCache) GetSessionAccountID(context.Context, int64, string) (int64, error) { + return 0, nil +} + +func (c *openAIWSSessionLastResponseProbeCache) SetSessionAccountID(context.Context, int64, string, int64, time.Duration) error { + return nil +} + +func (c *openAIWSSessionLastResponseProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error { + return nil +} + +func (c *openAIWSSessionLastResponseProbeCache) DeleteSessionAccountID(context.Context, int64, string) error { + return nil +} + +func (c *openAIWSSessionLastResponseProbeCache) SetOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash, responseID string, _ time.Duration) error { + if c.sessionData == nil { + c.sessionData = make(map[string]string) + } + c.setCalled = true + c.sessionData[fmt.Sprintf("%d:%s", groupID, sessionHash)] = responseID + return nil +} + +func (c *openAIWSSessionLastResponseProbeCache) GetOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash string) (string, error) { + c.getCalled = true + return c.sessionData[fmt.Sprintf("%d:%s", groupID, sessionHash)], nil +} + +func (c *openAIWSSessionLastResponseProbeCache) DeleteOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash string) error { + c.delCalled = true + delete(c.sessionData, fmt.Sprintf("%d:%s", groupID, sessionHash)) + return nil +} + +func TestOpenAIWSStateStore_SessionLastResponseID_UsesOptionalCacheFallback(t *testing.T) { + probe := &openAIWSSessionLastResponseProbeCache{sessionData: make(map[string]string)} + raw := NewOpenAIWSStateStore(probe) + store, ok := raw.(*defaultOpenAIWSStateStore) + require.True(t, ok) + + groupID := int64(9) + sessionHash := "session_hash_resp_cache_1" + responseID := "resp_cache_1" + store.BindSessionLastResponseID(groupID, sessionHash, responseID, time.Minute) + require.True(t, probe.setCalled, "绑定 session last_response_id 时应写入可选缓存") + + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + store.sessionToLastRespMu.Lock() + delete(store.sessionToLastResp, key) + store.sessionToLastRespMu.Unlock() + + gotResponseID, found := store.GetSessionLastResponseID(groupID, sessionHash) + require.True(t, found, "本地缓存缺失时应降级读取可选缓存") + require.Equal(t, responseID, gotResponseID) + require.True(t, probe.getCalled) + + store.DeleteSessionLastResponseID(groupID, sessionHash) + require.True(t, probe.delCalled, "删除 session last_response_id 时应同步删除可选缓存") + _, found = store.GetSessionLastResponseID(groupID, sessionHash) + require.False(t, found) +} + +type openAIWSResponsePendingToolCallsProbeCache struct { + pendingData map[string][]string + setCalls int + getCalls int + delCalls int +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) GetSessionAccountID(context.Context, int64, string) (int64, error) { + return 0, nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) SetSessionAccountID(context.Context, int64, string, int64, time.Duration) error { + return nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error { + return nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) DeleteSessionAccountID(context.Context, int64, string) error { + return nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) SetOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string, callIDs []string, _ time.Duration) error { + if c.pendingData == nil { + c.pendingData = make(map[string][]string) + } + key := fmt.Sprintf("%d:%s", groupID, responseID) + normalized := normalizeOpenAIWSPendingToolCallIDs(callIDs) + if len(normalized) == 0 { + delete(c.pendingData, key) + } else { + c.pendingData[key] = append([]string(nil), normalized...) + } + c.setCalls++ + return nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) GetOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string) ([]string, error) { + c.getCalls++ + callIDs := c.pendingData[fmt.Sprintf("%d:%s", groupID, responseID)] + return append([]string(nil), callIDs...), nil +} + +func (c *openAIWSResponsePendingToolCallsProbeCache) DeleteOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string) error { + c.delCalls++ + delete(c.pendingData, fmt.Sprintf("%d:%s", groupID, responseID)) + return nil +} + +func TestOpenAIWSStateStore_ResponsePendingToolCalls_UsesOptionalCacheFallback(t *testing.T) { + probe := &openAIWSResponsePendingToolCallsProbeCache{pendingData: make(map[string][]string)} + raw := NewOpenAIWSStateStore(probe) + store, ok := raw.(*defaultOpenAIWSStateStore) + require.True(t, ok) + + groupID := int64(11) + responseID := "resp_pending_tool_cache_1" + store.BindResponsePendingToolCalls(groupID, responseID, []string{"call_1", "call_2", "call_1"}, time.Minute) + require.Equal(t, 1, probe.setCalls, "绑定 pending_tool_calls 时应写入可选缓存") + + store.responsePendingToolMu.Lock() + delete(store.responsePendingTool, openAIWSResponsePendingToolCallsBindingKey(groupID, responseID)) + store.responsePendingToolMu.Unlock() + + callIDs, found := store.GetResponsePendingToolCalls(groupID, responseID) + require.True(t, found, "本地缓存缺失时应降级读取可选缓存") + require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs) + require.Equal(t, 1, probe.getCalls) + + // 回填后再次读取应命中本地缓存,不再触发 Redis 回源。 + callIDs, found = store.GetResponsePendingToolCalls(groupID, responseID) + require.True(t, found) + require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs) + require.Equal(t, 1, probe.getCalls) + _, found = store.GetResponsePendingToolCalls(groupID+1, responseID) + require.False(t, found, "optional cache fallback should remain group-isolated") + + store.DeleteResponsePendingToolCalls(groupID, responseID) + require.Equal(t, 1, probe.delCalls, "删除 pending_tool_calls 时应同步删除可选缓存") + _, found = store.GetResponsePendingToolCalls(groupID, responseID) + require.False(t, found) +} + func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) { cache := &stubGatewayCache{sessionBindings: map[string]int64{}} store := NewOpenAIWSStateStore(cache) @@ -94,6 +287,42 @@ func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing. require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射") } +func TestOpenAIWSStateStore_GetResponseAccount_LegacyKeyFallback(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(18) + responseID := "resp_cache_legacy_fallback" + + legacyKey := openAIWSResponseAccountLegacyCacheKey(responseID) + v2Key := openAIWSResponseAccountCacheKey(responseID) + cache.sessionBindings[legacyKey] = 601 + + accountID, err := store.GetResponseAccount(ctx, groupID, responseID) + require.NoError(t, err) + require.Equal(t, int64(601), accountID, "应支持 legacy cache key 回读") + require.Equal(t, int64(601), cache.sessionBindings[v2Key], "legacy 回读后应回填 v2 cache key") +} + +func TestOpenAIWSStateStore_DeleteResponseAccount_DeletesLegacyAndV2Keys(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(19) + responseID := "resp_cache_delete_both_keys" + + legacyKey := openAIWSResponseAccountLegacyCacheKey(responseID) + v2Key := openAIWSResponseAccountCacheKey(responseID) + cache.sessionBindings[legacyKey] = 701 + cache.sessionBindings[v2Key] = 701 + + require.NoError(t, store.DeleteResponseAccount(ctx, groupID, responseID)) + _, legacyExists := cache.sessionBindings[legacyKey] + _, v2Exists := cache.sessionBindings[v2Key] + require.False(t, legacyExists, "删除 response account 绑定时应清理 legacy key") + require.False(t, v2Exists, "删除 response account 绑定时应清理 v2 key") +} + func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) { raw := NewOpenAIWSStateStore(nil) store, ok := raw.(*defaultOpenAIWSStateStore) @@ -101,21 +330,27 @@ func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T expiredAt := time.Now().Add(-time.Minute) total := 2048 - store.responseToConnMu.Lock() for i := 0; i < total; i++ { - store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{ + key := fmt.Sprintf("resp_%d", i) + shard := store.connShard(key) + shard.mu.Lock() + shard.m[key] = openAIWSConnBinding{ connID: "conn_incremental", expiresAt: expiredAt, } + shard.mu.Unlock() } - store.responseToConnMu.Unlock() store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano()) store.maybeCleanup() - store.responseToConnMu.RLock() - remainingAfterFirst := len(store.responseToConn) - store.responseToConnMu.RUnlock() + remainingAfterFirst := 0 + for i := range store.responseToConnShards { + shard := &store.responseToConnShards[i] + shard.mu.RLock() + remainingAfterFirst += len(shard.m) + shard.mu.RUnlock() + } require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展") require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键") @@ -124,36 +359,99 @@ func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T store.maybeCleanup() } - store.responseToConnMu.RLock() - remaining := len(store.responseToConn) - store.responseToConnMu.RUnlock() + remaining := 0 + for i := range store.responseToConnShards { + shard := &store.responseToConnShards[i] + shard.mu.RLock() + remaining += len(shard.m) + shard.mu.RUnlock() + } require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键") } +func TestOpenAIWSStateStore_BackgroundCleanupRemovesExpiredWithoutNewWrites(t *testing.T) { + store := newOpenAIWSStateStore(nil, 20*time.Millisecond) + defer store.Close() + + expiredAt := time.Now().Add(-time.Minute) + store.responseToAccountMu.Lock() + for i := 0; i < 64; i++ { + key := fmt.Sprintf("bg_cleanup_resp_%d", i) + store.responseToAccount[key] = openAIWSAccountBinding{ + accountID: int64(i + 1), + expiresAt: expiredAt, + } + } + store.responseToAccountMu.Unlock() + + // Backdate cleanup watermark so the worker can run immediately on next tick. + store.lastCleanupUnixNano.Store(time.Now().Add(-time.Minute).UnixNano()) + + require.Eventually(t, func() bool { + store.responseToAccountMu.RLock() + remaining := len(store.responseToAccount) + store.responseToAccountMu.RUnlock() + return remaining == 0 + }, 600*time.Millisecond, 10*time.Millisecond, "后台 cleanup 应在无新写入时清理过期项") +} + func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) { - bindings := map[string]int{ - "a": 1, - "b": 2, + bindings := map[string]openAIWSAccountBinding{ + "a": {accountID: 1, expiresAt: time.Now().Add(time.Hour)}, + "b": {accountID: 2, expiresAt: time.Now().Add(time.Hour)}, } ensureBindingCapacity(bindings, "c", 2) - bindings["c"] = 3 + bindings["c"] = openAIWSAccountBinding{accountID: 3, expiresAt: time.Now().Add(time.Hour)} require.Len(t, bindings, 2) - require.Equal(t, 3, bindings["c"]) + require.Equal(t, int64(3), bindings["c"].accountID) } func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) { - bindings := map[string]int{ - "a": 1, - "b": 2, + bindings := map[string]openAIWSAccountBinding{ + "a": {accountID: 1, expiresAt: time.Now().Add(time.Hour)}, + "b": {accountID: 2, expiresAt: time.Now().Add(time.Hour)}, } ensureBindingCapacity(bindings, "a", 2) - bindings["a"] = 9 + bindings["a"] = openAIWSAccountBinding{accountID: 9, expiresAt: time.Now().Add(time.Hour)} + + require.Len(t, bindings, 2) + require.Equal(t, int64(9), bindings["a"].accountID) +} + +func TestEnsureBindingCapacity_PrefersExpiredEntry(t *testing.T) { + bindings := map[string]openAIWSAccountBinding{ + "expired": {accountID: 1, expiresAt: time.Now().Add(-time.Hour)}, + "active": {accountID: 2, expiresAt: time.Now().Add(time.Hour)}, + } + + ensureBindingCapacity(bindings, "c", 2) + bindings["c"] = openAIWSAccountBinding{accountID: 3, expiresAt: time.Now().Add(time.Hour)} require.Len(t, bindings, 2) - require.Equal(t, 9, bindings["a"]) + _, hasExpired := bindings["expired"] + require.False(t, hasExpired, "expired entry should have been evicted") + require.Equal(t, int64(2), bindings["active"].accountID) + require.Equal(t, int64(3), bindings["c"].accountID) +} + +func TestEnsureBindingCapacity_EvictsEarliestExpiryWhenNoExpired(t *testing.T) { + now := time.Now() + bindings := map[string]openAIWSAccountBinding{ + "soon": {accountID: 1, expiresAt: now.Add(30 * time.Second)}, + "later": {accountID: 2, expiresAt: now.Add(5 * time.Minute)}, + } + + ensureBindingCapacity(bindings, "new", 2) + bindings["new"] = openAIWSAccountBinding{accountID: 3, expiresAt: now.Add(10 * time.Minute)} + + require.Len(t, bindings, 2) + _, hasSoon := bindings["soon"] + require.False(t, hasSoon, "entry with earliest expiresAt should be evicted") + require.Equal(t, int64(2), bindings["later"].accountID) + require.Equal(t, int64(3), bindings["new"].accountID) } type openAIWSStateStoreTimeoutProbeCache struct { @@ -204,13 +502,14 @@ func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) { accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe") require.NoError(t, getErr) - require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号") + require.Equal(t, int64(11), accountID, "Redis Set 失败时本地缓存仍应保留作为降级保障") require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe")) require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文") require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文") - require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取") + // 本地缓存作为降级保障保留,Get 直接命中本地缓存不会穿透到 Redis + require.False(t, probe.getHasDeadline, "本地缓存命中时不应穿透到 Redis 读取") require.Greater(t, probe.setDeadlineDelta, 2*time.Second) require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second) require.Greater(t, probe.delDeadlineDelta, 2*time.Second) @@ -226,6 +525,26 @@ func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) { require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second) } +func TestOpenAIWSStateStore_BindResponseAccount_UsesShortLocalHotTTL(t *testing.T) { + cache := &stubGatewayCache{} + raw := NewOpenAIWSStateStore(cache) + store, ok := raw.(*defaultOpenAIWSStateStore) + require.True(t, ok) + + groupID := int64(23) + responseID := "resp_local_hot_ttl" + require.NoError(t, store.BindResponseAccount(context.Background(), groupID, responseID, 902, 24*time.Hour)) + + id := normalizeOpenAIWSResponseID(responseID) + require.NotEmpty(t, id) + store.responseToAccountMu.RLock() + binding, exists := store.responseToAccount[id] + store.responseToAccountMu.RUnlock() + require.True(t, exists) + require.Equal(t, int64(902), binding.accountID) + require.WithinDuration(t, time.Now().Add(openAIWSStateStoreHotCacheTTL), binding.expiresAt, 1500*time.Millisecond) +} + func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) { ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) defer cancel() diff --git a/backend/internal/service/openai_ws_test_helpers_test.go b/backend/internal/service/openai_ws_test_helpers_test.go new file mode 100644 index 000000000..d719c0eaf --- /dev/null +++ b/backend/internal/service/openai_ws_test_helpers_test.go @@ -0,0 +1,277 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "sync" + "time" +) + +type openAIWSQueueDialer struct { + mu sync.Mutex + conns []openAIWSClientConn + dialCount int +} + +func (d *openAIWSQueueDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + defer d.mu.Unlock() + d.dialCount++ + if len(d.conns) == 0 { + return nil, 503, nil, errors.New("no test conn") + } + conn := d.conns[0] + if len(d.conns) > 1 { + d.conns = d.conns[1:] + } + return conn, 0, nil, nil +} + +func (d *openAIWSQueueDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSCaptureConn struct { + mu sync.Mutex + readDelays []time.Duration + events [][]byte + writes []map[string]any + closed bool +} + +func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return errOpenAIWSConnClosed + } + switch payload := value.(type) { + case map[string]any: + c.writes = append(c.writes, cloneMapStringAny(payload)) + case json.RawMessage: + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err == nil { + c.writes = append(c.writes, cloneMapStringAny(parsed)) + } + case []byte: + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err == nil { + c.writes = append(c.writes, cloneMapStringAny(parsed)) + } + } + return nil +} + +func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) { + if ctx == nil { + ctx = context.Background() + } + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, errOpenAIWSConnClosed + } + if len(c.events) == 0 { + c.mu.Unlock() + return nil, io.EOF + } + delay := time.Duration(0) + if len(c.readDelays) > 0 { + delay = c.readDelays[0] + c.readDelays = c.readDelays[1:] + } + event := c.events[0] + c.events = c.events[1:] + c.mu.Unlock() + if delay > 0 { + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + } + } + return event, nil +} + +func (c *openAIWSCaptureConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSCaptureConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func (c *openAIWSCaptureConn) Closed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} + +func cloneMapStringAny(src map[string]any) map[string]any { + if src == nil { + return nil + } + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +type openAIWSAlwaysFailDialer struct { + mu sync.Mutex + dialCount int +} + +func (d *openAIWSAlwaysFailDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + d.dialCount++ + d.mu.Unlock() + return nil, 503, nil, errors.New("dial failed") +} + +func (d *openAIWSAlwaysFailDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSFakeConn struct { + mu sync.Mutex + closed bool + payload [][]byte +} + +func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + _ = value + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return errors.New("closed") + } + c.payload = append(c.payload, []byte("ok")) + return nil +} + +func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil, errors.New("closed") + } + return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil +} + +func (c *openAIWSFakeConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSFakeConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +type openAIWSPingFailConn struct{} + +func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error { + return nil +} + +func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil +} + +func (c *openAIWSPingFailConn) Ping(context.Context) error { + return errors.New("ping failed") +} + +func (c *openAIWSPingFailConn) Close() error { + return nil +} + +// openAIWSDelayedPingFailConn 是带可控延迟的 Ping 失败连接, +// 用于模拟"Ping 执行期间连接被重建"的竞态场景。 +type openAIWSDelayedPingFailConn struct { + delay time.Duration + pingDone chan struct{} // Ping 开始执行时关闭,通知测试可以进行下一步 + mu sync.Mutex + closed bool +} + +func newOpenAIWSDelayedPingFailConn(delay time.Duration) *openAIWSDelayedPingFailConn { + return &openAIWSDelayedPingFailConn{ + delay: delay, + pingDone: make(chan struct{}), + } +} + +func (c *openAIWSDelayedPingFailConn) WriteJSON(context.Context, any) error { return nil } +func (c *openAIWSDelayedPingFailConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_delayed_ping"}}`), nil +} + +func (c *openAIWSDelayedPingFailConn) Ping(ctx context.Context) error { + // 通知测试 Ping 已开始 + select { + case <-c.pingDone: + default: + close(c.pingDone) + } + // 等待延迟或上下文取消 + timer := time.NewTimer(c.delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + } + return errors.New("ping failed after delay") +} + +func (c *openAIWSDelayedPingFailConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func (c *openAIWSDelayedPingFailConn) Closed() bool { + c.mu.Lock() + defer c.mu.Unlock() + return c.closed +} diff --git a/backend/internal/service/openai_ws_turn.go b/backend/internal/service/openai_ws_turn.go new file mode 100644 index 000000000..d1854b095 --- /dev/null +++ b/backend/internal/service/openai_ws_turn.go @@ -0,0 +1,1601 @@ +package service + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "net/http" + "sort" + "strings" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 +type OpenAIWSIngressHooks struct { + BeforeTurn func(turn int) error + AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) +} + +const ( + openAIWSCtxFirstModelKey = "openai_ws_first_model" + openAIWSCtxFirstPreviousResponseIDKey = "openai_ws_first_previous_response_id" + openAIWSCtxFirstPreviousResponseIDKindKey = "openai_ws_first_previous_response_id_kind" +) + +func SetOpenAIWSFirstMessageMeta(c *gin.Context, model, previousResponseID, previousResponseIDKind string) { + if c == nil { + return + } + c.Set(openAIWSCtxFirstModelKey, strings.TrimSpace(model)) + c.Set(openAIWSCtxFirstPreviousResponseIDKey, strings.TrimSpace(previousResponseID)) + c.Set(openAIWSCtxFirstPreviousResponseIDKindKey, strings.TrimSpace(previousResponseIDKind)) +} + +func ResolveOpenAIWSFirstMessageMeta( + c *gin.Context, + firstClientMessage []byte, +) (model string, previousResponseID string, previousResponseIDKind string) { + if c != nil { + if v, ok := c.Get(openAIWSCtxFirstModelKey); ok { + if text, okText := v.(string); okText { + model = strings.TrimSpace(text) + } + } + if v, ok := c.Get(openAIWSCtxFirstPreviousResponseIDKey); ok { + if text, okText := v.(string); okText { + previousResponseID = strings.TrimSpace(text) + } + } + if v, ok := c.Get(openAIWSCtxFirstPreviousResponseIDKindKey); ok { + if text, okText := v.(string); okText { + previousResponseIDKind = strings.TrimSpace(text) + } + } + } + if model == "" || previousResponseID == "" { + values := gjson.GetManyBytes(firstClientMessage, "model", "previous_response_id") + if model == "" { + model = strings.TrimSpace(values[0].String()) + } + if previousResponseID == "" { + previousResponseID = strings.TrimSpace(values[1].String()) + } + } + if previousResponseIDKind == "" { + previousResponseIDKind = ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + } + return model, previousResponseID, previousResponseIDKind +} + +func normalizeOpenAIWSLogValue(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "-" + } + return openAIWSLogValueReplacer.Replace(trimmed) +} + +func truncateOpenAIWSLogValue(value string, maxLen int) string { + normalized := normalizeOpenAIWSLogValue(value) + if normalized == "-" || maxLen <= 0 { + return normalized + } + if len(normalized) <= maxLen { + return normalized + } + return normalized[:maxLen] + "..." +} + +func openAIWSHeaderValueForLog(headers http.Header, key string) string { + if headers == nil { + return "-" + } + return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen) +} + +func hasOpenAIWSHeader(headers http.Header, key string) bool { + if headers == nil { + return false + } + return strings.TrimSpace(headers.Get(key)) != "" +} + +type openAIWSSessionHeaderResolution struct { + SessionID string + ConversationID string + SessionSource string + ConversationSource string +} + +func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution { + resolution := openAIWSSessionHeaderResolution{ + SessionSource: "none", + ConversationSource: "none", + } + if c != nil && c.Request != nil { + if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" { + resolution.SessionID = sessionID + resolution.SessionSource = "header_session_id" + } + if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" { + resolution.ConversationID = conversationID + resolution.ConversationSource = "header_conversation_id" + if resolution.SessionID == "" { + resolution.SessionID = conversationID + resolution.SessionSource = "header_conversation_id" + } + } + } + + cacheKey := strings.TrimSpace(promptCacheKey) + if cacheKey != "" { + if resolution.SessionID == "" { + resolution.SessionID = cacheKey + resolution.SessionSource = "prompt_cache_key" + } + } + return resolution +} + +func openAIWSIngressSessionScopeFromContext(c *gin.Context) string { + if c == nil { + return "" + } + value, exists := c.Get("api_key") + if !exists || value == nil { + return "" + } + apiKey, ok := value.(*APIKey) + if !ok || apiKey == nil { + return "" + } + userID := apiKey.UserID + if userID <= 0 && apiKey.User != nil { + userID = apiKey.User.ID + } + apiKeyID := apiKey.ID + if userID <= 0 && apiKeyID <= 0 { + return "" + } + return fmt.Sprintf("u%d:k%d", userID, apiKeyID) +} + +func openAIWSApplySessionScope(sessionHash, scope string) string { + hash := strings.TrimSpace(sessionHash) + if hash == "" { + return "" + } + scope = strings.TrimSpace(scope) + if scope == "" { + return hash + } + return scope + "|" + hash +} + +func openAIWSHostPathForLogFromURL(rawURL string) (host string, path string) { + trimmed := strings.TrimSpace(rawURL) + if trimmed == "" { + return "-", "-" + } + + withoutScheme := trimmed + if schemeIdx := strings.Index(withoutScheme, "://"); schemeIdx >= 0 { + withoutScheme = withoutScheme[schemeIdx+3:] + } + withoutScheme = strings.TrimPrefix(withoutScheme, "//") + + rawHost := withoutScheme + rawPath := "" + if slashIdx := strings.IndexByte(withoutScheme, '/'); slashIdx >= 0 { + rawHost = withoutScheme[:slashIdx] + rawPath = withoutScheme[slashIdx:] + } + if queryIdx := strings.IndexByte(rawPath, '?'); queryIdx >= 0 { + rawPath = rawPath[:queryIdx] + } + + host = normalizeOpenAIWSLogValue(rawHost) + path = normalizeOpenAIWSLogValue(rawPath) + return host, path +} + +func shouldLogOpenAIWSEvent(idx int, eventType string) bool { + if idx <= openAIWSEventLogHeadLimit { + return true + } + if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 { + return true + } + if eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + return true + } + return false +} + +func shouldLogOpenAIWSBufferedEvent(idx int) bool { + if idx <= openAIWSBufferLogHeadLimit { + return true + } + if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 { + return true + } + return false +} + +func openAIWSEventMayContainModel(eventType string) bool { + switch eventType { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + trimmed := strings.TrimSpace(eventType) + if trimmed == eventType { + return false + } + switch trimmed { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + return false + } + } +} + +func openAIWSEventMayContainToolCalls(eventType string) bool { + if eventType == "" { + return false + } + if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") { + return true + } + switch eventType { + case "response.output_item.added", "response.output_item.done", "response.completed", "response.done": + return true + default: + return false + } +} + +// openAIWSEventShouldParseUsage 判断是否应解析 usage。 +// 调用方需确保 eventType 已经过 TrimSpace(如 parseOpenAIWSEventType 的返回值)。 +func openAIWSEventShouldParseUsage(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed": + return true + default: + return false + } +} + +// parseOpenAIWSEventType extracts only the event type and response ID from a WS message. +// Use this lightweight version on hot paths where the full response body is not needed. +func parseOpenAIWSEventType(message []byte) (eventType string, responseID string) { + if len(message) == 0 { + return "", "" + } + values := gjson.GetManyBytes(message, "type", "response.id", "id") + eventType = strings.TrimSpace(values[0].String()) + if id := strings.TrimSpace(values[1].String()); id != "" { + responseID = id + } else { + responseID = strings.TrimSpace(values[2].String()) + } + return eventType, responseID +} + +func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) { + if len(message) == 0 { + return "", "", gjson.Result{} + } + values := gjson.GetManyBytes(message, "type", "response.id", "id", "response") + eventType = strings.TrimSpace(values[0].String()) + if id := strings.TrimSpace(values[1].String()); id != "" { + responseID = id + } else { + responseID = strings.TrimSpace(values[2].String()) + } + return eventType, responseID, values[3] +} + +func openAIWSMessageLikelyContainsToolCalls(message []byte) bool { + if len(message) == 0 { + return false + } + return bytes.Contains(message, []byte(`"tool_calls"`)) || + bytes.Contains(message, []byte(`"tool_call"`)) || + bytes.Contains(message, []byte(`"function_call"`)) +} + +func openAIWSCollectPendingFunctionCallIDsFromJSONResult(result gjson.Result, callIDSet map[string]struct{}, depth int) { + if !result.Exists() || callIDSet == nil || depth > 8 || result.Type != gjson.JSON { + return + } + itemType := strings.TrimSpace(result.Get("type").String()) + if itemType == "function_call" || itemType == "tool_call" { + callID := strings.TrimSpace(result.Get("call_id").String()) + if callID == "" { + fallbackID := strings.TrimSpace(result.Get("id").String()) + if strings.HasPrefix(fallbackID, "call_") { + callID = fallbackID + } + } + if callID != "" { + callIDSet[callID] = struct{}{} + } + } + result.ForEach(func(_, child gjson.Result) bool { + openAIWSCollectPendingFunctionCallIDsFromJSONResult(child, callIDSet, depth+1) + return true + }) +} + +func openAIWSExtractPendingFunctionCallIDsFromEvent(message []byte) []string { + if len(message) == 0 { + return nil + } + callIDSet := make(map[string]struct{}, 4) + openAIWSCollectPendingFunctionCallIDsFromJSONResult(gjson.ParseBytes(message), callIDSet, 0) + if len(callIDSet) == 0 { + return nil + } + callIDs := make([]string, 0, len(callIDSet)) + for callID := range callIDSet { + callIDs = append(callIDs, callID) + } + sort.Strings(callIDs) + return callIDs +} + +func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) { + if usage == nil || len(message) == 0 { + return + } + values := gjson.GetManyBytes( + message, + "response.usage.input_tokens", + "response.usage.output_tokens", + "response.usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "", "", "" + } + values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message") + return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String()) +} + +func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) { + code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen) + errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen) + errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen) + return code, errType, errMessage +} + +func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "-", "-", "-" + } + return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string { + if len(payload) == 0 { + return "-" + } + type keySize struct { + Key string + Size int + } + sizes := make([]keySize, 0, len(payload)) + for key, value := range payload { + size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth) + sizes = append(sizes, keySize{Key: key, Size: size}) + } + sort.Slice(sizes, func(i, j int) bool { + if sizes[i].Size == sizes[j].Size { + return sizes[i].Key < sizes[j].Key + } + return sizes[i].Size > sizes[j].Size + }) + + if topN <= 0 || topN > len(sizes) { + topN = len(sizes) + } + parts := make([]string, 0, topN) + for idx := 0; idx < topN; idx++ { + item := sizes[idx] + parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size)) + } + return strings.Join(parts, ",") +} + +func estimateOpenAIWSPayloadValueSize(value any, depth int) int { + if depth <= 0 { + return -1 + } + switch v := value.(type) { + case nil: + return 0 + case string: + return len(v) + case []byte: + return len(v) + case bool: + return 1 + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return 8 + case float32, float64: + return 8 + case map[string]any: + if len(v) == 0 { + return 2 + } + total := 2 + count := 0 + for key, item := range v { + count++ + if count > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1) + if itemSize < 0 { + return -1 + } + total += len(key) + itemSize + 3 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + case []any: + if len(v) == 0 { + return 2 + } + total := 2 + limit := len(v) + if limit > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + for i := 0; i < limit; i++ { + itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1) + if itemSize < 0 { + return -1 + } + total += itemSize + 1 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + default: + raw, err := json.Marshal(v) + if err != nil { + return -1 + } + if len(raw) > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + return len(raw) + } +} + +func openAIWSPayloadString(payload map[string]any, key string) string { + if len(payload) == 0 { + return "" + } + raw, ok := payload[key] + if !ok { + return "" + } + switch v := raw.(type) { + case nil: + return "" + case string: + return strings.TrimSpace(v) + case []byte: + return strings.TrimSpace(string(v)) + default: + return "" + } +} + +func openAIWSPayloadStringFromRaw(payload []byte, key string) string { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return "" + } + return strings.TrimSpace(gjson.GetBytes(payload, key).String()) +} + +func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return defaultValue + } + value := gjson.GetBytes(payload, key) + if !value.Exists() { + return defaultValue + } + if value.Type != gjson.True && value.Type != gjson.False { + return defaultValue + } + return value.Bool() +} + +func openAIWSSessionHashesFromID(sessionID string) (string, string) { + return deriveOpenAISessionHashes(sessionID) +} + +func extractOpenAIWSImageURL(value any) string { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) + case map[string]any: + if raw, ok := v["url"].(string); ok { + return strings.TrimSpace(raw) + } + } + return "" +} + +func summarizeOpenAIWSInput(input any) string { + items, ok := input.([]any) + if !ok || len(items) == 0 { + return "-" + } + + itemCount := len(items) + textChars := 0 + imageDataURLs := 0 + imageDataURLChars := 0 + imageRemoteURLs := 0 + + handleContentItem := func(contentItem map[string]any) { + contentType, _ := contentItem["type"].(string) + switch strings.TrimSpace(contentType) { + case "input_text", "output_text", "text": + if text, ok := contentItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(contentItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + handleInputItem := func(inputItem map[string]any) { + if content, ok := inputItem["content"].([]any); ok { + for _, rawContent := range content { + contentItem, ok := rawContent.(map[string]any) + if !ok { + continue + } + handleContentItem(contentItem) + } + return + } + + itemType, _ := inputItem["type"].(string) + switch strings.TrimSpace(itemType) { + case "input_text", "output_text", "text": + if text, ok := inputItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(inputItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + for _, rawItem := range items { + inputItem, ok := rawItem.(map[string]any) + if !ok { + continue + } + handleInputItem(inputItem) + } + + return fmt.Sprintf( + "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d", + itemCount, + textChars, + imageDataURLs, + imageDataURLChars, + imageRemoteURLs, + ) +} + +func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return + } + if _, exists := payload[key]; !exists { + return + } + delete(payload, key) + *removed = append(*removed, key) +} + +// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段, +// 避免重试成功却改变原始请求语义。 +// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。 +func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) { + if len(payload) == 0 { + return "empty", nil + } + if attempt <= 1 { + return "full", nil + } + + removed := make([]string, 0, 2) + if attempt >= 2 { + dropOpenAIWSPayloadKey(payload, "include", &removed) + } + + if len(removed) == 0 { + return "full", nil + } + sort.Strings(removed) + return "trim_optional_fields", removed +} + +func logOpenAIWSModeInfo(format string, args ...any) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func isOpenAIWSModeDebugEnabled() bool { + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func logOpenAIWSModeDebug(format string, args ...any) { + if !isOpenAIWSModeDebugEnabled() { + return + } + logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) { + if err == nil { + return + } + logger.L().Warn( + "openai.ws_bind_response_account_failed", + zap.Int64("group_id", groupID), + zap.Int64("account_id", accountID), + zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)), + zap.Error(err), + ) +} + +func logOpenAIWSIngressTurnAbort( + accountID int64, + turn int, + connID string, + reason openAIWSIngressTurnAbortReason, + expected bool, + cause error, +) { + causeValue := "-" + if cause != nil { + causeValue = truncateOpenAIWSLogValue(cause.Error(), openAIWSLogValueMaxLen) + } + logOpenAIWSModeInfo( + "ingress_ws_turn_aborted account_id=%d turn=%d conn_id=%s reason=%s expected=%v cause=%s", + accountID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(string(reason)), + expected, + causeValue, + ) +} + +func sortedKeys(m map[string]any) []string { + if len(m) == 0 { + return nil + } + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) { + return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes) +} + +func dropPreviousResponseIDFromRawPayloadWithDeleteFn( + payload []byte, + deleteFn func([]byte, string) ([]byte, error), +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + if !gjson.GetBytes(payload, "previous_response_id").Exists() { + return payload, false, nil + } + if deleteFn == nil { + deleteFn = sjson.DeleteBytes + } + + updated := payload + for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses && + gjson.GetBytes(updated, "previous_response_id").Exists(); i++ { + next, err := deleteFn(updated, "previous_response_id") + if err != nil { + return payload, false, err + } + updated = next + } + return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil +} + +func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) { + normalizedPrevID := strings.TrimSpace(previousResponseID) + if len(payload) == 0 || normalizedPrevID == "" { + return payload, nil + } + if current := openAIWSPayloadStringFromRaw(payload, "previous_response_id"); current == normalizedPrevID { + return payload, nil + } + updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID) + if err == nil { + return updated, nil + } + + var reqBody map[string]any + if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil { + return nil, err + } + reqBody["previous_response_id"] = normalizedPrevID + rebuilt, marshalErr := json.Marshal(reqBody) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil +} + +func shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled bool, + turn int, + hasFunctionCallOutput bool, + currentPreviousResponseID string, + expectedPreviousResponseID string, +) bool { + if !storeDisabled || turn <= 0 || !hasFunctionCallOutput { + return false + } + if strings.TrimSpace(currentPreviousResponseID) != "" { + return false + } + return strings.TrimSpace(expectedPreviousResponseID) != "" +} + +func alignStoreDisabledPreviousResponseID( + payload []byte, + expectedPreviousResponseID string, +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + expected := strings.TrimSpace(expectedPreviousResponseID) + if expected == "" { + return payload, false, nil + } + current := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + if current == "" || current == expected { + return payload, false, nil + } + + // 常见路径(无重复 key)直接 set,避免先 delete 再 set 的双遍处理。 + // 仅在检测到重复 key 时走 drop+set 慢路径,确保最终语义一致。 + if bytes.Count(payload, []byte(`"previous_response_id"`)) <= 1 { + updated, setErr := setPreviousResponseIDToRawPayload(payload, expected) + if setErr != nil { + return payload, false, setErr + } + return updated, !bytes.Equal(updated, payload), nil + } + + withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + if dropErr != nil { + return payload, false, dropErr + } + if !removed { + return payload, false, nil + } + updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected) + if setErr != nil { + return payload, false, setErr + } + return updated, true, nil +} + +func cloneOpenAIWSPayloadBytes(payload []byte) []byte { + if len(payload) == 0 { + return nil + } + cloned := make([]byte, len(payload)) + copy(cloned, payload) + return cloned +} + +func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage { + if items == nil { + return nil + } + cloned := make([]json.RawMessage, 0, len(items)) + for idx := range items { + cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx]))) + } + return cloned +} + +func cloneOpenAIWSJSONRawString(raw string) []byte { + if strings.TrimSpace(raw) == "" { + return nil + } + cloned := make([]byte, len(raw)) + copy(cloned, raw) + return cloned +} + +func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return nil, errors.New("json is empty") + } + var decoded any + if err := json.Unmarshal(trimmed, &decoded); err != nil { + return nil, err + } + return json.Marshal(decoded) +} + +func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte { + normalized, err := normalizeOpenAIWSJSONForCompare(raw) + if err != nil { + return bytes.TrimSpace(raw) + } + return normalized +} + +func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) { + if len(payload) == 0 { + return nil, errors.New("payload is empty") + } + var decoded map[string]any + if err := json.Unmarshal(payload, &decoded); err != nil { + return nil, err + } + delete(decoded, "input") + delete(decoded, "previous_response_id") + return json.Marshal(decoded) +} + +func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) { + if len(payload) == 0 { + return nil, false, nil + } + inputValue := gjson.GetBytes(payload, "input") + if !inputValue.Exists() { + return nil, false, nil + } + if inputValue.Type == gjson.JSON { + raw := strings.TrimSpace(inputValue.Raw) + if strings.HasPrefix(raw, "[") { + var items []json.RawMessage + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil, true, err + } + return items, true, nil + } + return []json.RawMessage{json.RawMessage(raw)}, true, nil + } + if inputValue.Type == gjson.String { + encoded, _ := json.Marshal(inputValue.String()) + return []json.RawMessage{encoded}, true, nil + } + return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil +} + +func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) { + previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload) + if prevErr != nil { + return false, prevErr + } + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return false, currentErr + } + if !previousExists && !currentExists { + return true, nil + } + if !previousExists { + return len(currentItems) == 0, nil + } + if !currentExists { + return len(previousItems) == 0, nil + } + if len(currentItems) < len(previousItems) { + return false, nil + } + + for idx := range previousItems { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false, nil + } + } + return true, nil +} + +func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool { + if len(prefix) == 0 { + return true + } + if len(items) < len(prefix) { + return false + } + for idx := range prefix { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false + } + } + return true +} + +func limitOpenAIWSReplayInputSequenceByBytes(items []json.RawMessage, maxBytes int) []json.RawMessage { + if len(items) == 0 { + return nil + } + if maxBytes <= 0 { + return cloneOpenAIWSRawMessages(items) + } + + start := len(items) + total := 2 // "[]" + for idx := len(items) - 1; idx >= 0; idx-- { + itemBytes := len(items[idx]) + if start != len(items) { + itemBytes++ // comma + } + if total+itemBytes > maxBytes { + // Keep at least the newest item to avoid creating an empty replay input. + if start == len(items) { + start = idx + } + break + } + total += itemBytes + start = idx + } + if start < 0 || start > len(items) { + start = len(items) - 1 + } + return cloneOpenAIWSRawMessages(items[start:]) +} + +func buildOpenAIWSReplayInputSequence( + previousFullInput []json.RawMessage, + previousFullInputExists bool, + currentPayload []byte, + hasPreviousResponseID bool, +) ([]json.RawMessage, bool, error) { + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return nil, false, currentErr + } + candidate := []json.RawMessage(nil) + exists := false + if !hasPreviousResponseID { + candidate = cloneOpenAIWSRawMessages(currentItems) + exists = currentExists + if !exists { + return candidate, false, nil + } + return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), true, nil + } + if !previousFullInputExists { + candidate = cloneOpenAIWSRawMessages(currentItems) + exists = currentExists + if !exists { + return candidate, false, nil + } + return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), true, nil + } + if !currentExists || len(currentItems) == 0 { + candidate = cloneOpenAIWSRawMessages(previousFullInput) + exists = true + return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil + } + if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { + candidate = cloneOpenAIWSRawMessages(currentItems) + exists = true + return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil + } + merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems)) + merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...) + merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...) + candidate = merged + exists = true + return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil +} + +func openAIWSInputAppearsEditedFromPreviousFullInput( + previousFullInput []json.RawMessage, + previousFullInputExists bool, + currentPayload []byte, + hasPreviousResponseID bool, +) (bool, error) { + if !hasPreviousResponseID || !previousFullInputExists { + return false, nil + } + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return false, currentErr + } + if !currentExists || len(currentItems) == 0 { + return false, nil + } + if len(previousFullInput) < 2 { + // Single-item turns are ambiguous (could be a normal incremental replace), avoid false positives. + return false, nil + } + if len(currentItems) < len(previousFullInput) { + // Most delta appends only send the latest one/few items. + return false, nil + } + if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { + // Full snapshot append or unchanged snapshot. + return false, nil + } + return true, nil +} + +func setOpenAIWSPayloadInputSequence( + payload []byte, + fullInput []json.RawMessage, + fullInputExists bool, +) ([]byte, error) { + if !fullInputExists { + return payload, nil + } + // Preserve [] vs null semantics when input exists but is empty. + inputForMarshal := fullInput + if inputForMarshal == nil { + inputForMarshal = []json.RawMessage{} + } + inputRaw, marshalErr := json.Marshal(inputForMarshal) + if marshalErr != nil { + return nil, marshalErr + } + return sjson.SetRawBytes(payload, "input", inputRaw) +} + +func openAIWSNormalizeCallIDs(callIDs []string) []string { + if len(callIDs) == 0 { + return nil + } + seen := make(map[string]struct{}, len(callIDs)) + normalized := make([]string, 0, len(callIDs)) + for _, callID := range callIDs { + id := strings.TrimSpace(callID) + if id == "" { + continue + } + if _, exists := seen[id]; exists { + continue + } + seen[id] = struct{}{} + normalized = append(normalized, id) + } + sort.Strings(normalized) + return normalized +} + +func openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload []byte) []string { + if len(payload) == 0 { + return nil + } + input := gjson.GetBytes(payload, "input") + if !input.Exists() { + return nil + } + callIDSet := make(map[string]struct{}, 4) + collect := func(item gjson.Result) { + if item.Type != gjson.JSON { + return + } + if strings.TrimSpace(item.Get("type").String()) != "function_call_output" { + return + } + callID := strings.TrimSpace(item.Get("call_id").String()) + if callID == "" { + return + } + callIDSet[callID] = struct{}{} + } + if input.IsArray() { + input.ForEach(func(_, item gjson.Result) bool { + collect(item) + return true + }) + } else { + collect(input) + } + if len(callIDSet) == 0 { + return nil + } + callIDs := make([]string, 0, len(callIDSet)) + for callID := range callIDSet { + callIDs = append(callIDs, callID) + } + sort.Strings(callIDs) + return callIDs +} + +func openAIWSHasToolCallContextInPayload(payload []byte) bool { + if len(payload) == 0 { + return false + } + input := gjson.GetBytes(payload, "input") + if !input.Exists() { + return false + } + + hasContext := false + collect := func(item gjson.Result) { + if hasContext || item.Type != gjson.JSON { + return + } + itemType := strings.TrimSpace(item.Get("type").String()) + if itemType != "tool_call" && itemType != "function_call" { + return + } + if strings.TrimSpace(item.Get("call_id").String()) == "" { + return + } + hasContext = true + } + if input.IsArray() { + input.ForEach(func(_, item gjson.Result) bool { + collect(item) + return !hasContext + }) + return hasContext + } + collect(input) + return hasContext +} + +func openAIWSHasItemReferenceForAllFunctionCallOutputsInPayload(payload []byte, functionCallOutputCallIDs []string) bool { + requiredCallIDs := openAIWSNormalizeCallIDs(functionCallOutputCallIDs) + if len(payload) == 0 || len(requiredCallIDs) == 0 { + return false + } + input := gjson.GetBytes(payload, "input") + if !input.Exists() { + return false + } + + referenceIDSet := make(map[string]struct{}, len(requiredCallIDs)) + collect := func(item gjson.Result) { + if item.Type != gjson.JSON { + return + } + if strings.TrimSpace(item.Get("type").String()) != "item_reference" { + return + } + referenceID := strings.TrimSpace(item.Get("id").String()) + if referenceID == "" { + return + } + referenceIDSet[referenceID] = struct{}{} + } + if input.IsArray() { + input.ForEach(func(_, item gjson.Result) bool { + collect(item) + return true + }) + } else { + collect(input) + } + + if len(referenceIDSet) == 0 { + return false + } + for _, callID := range requiredCallIDs { + if _, ok := referenceIDSet[callID]; !ok { + return false + } + } + return true +} + +func shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID( + storeDisabled bool, + hasFunctionCallOutput bool, + previousResponseID string, + hasToolOutputContext bool, +) bool { + if !storeDisabled || !hasFunctionCallOutput { + return false + } + if strings.TrimSpace(previousResponseID) != "" { + return false + } + return !hasToolOutputContext +} + +func openAIWSFindMissingCallIDs(requiredCallIDs []string, actualCallIDs []string) []string { + required := openAIWSNormalizeCallIDs(requiredCallIDs) + if len(required) == 0 { + return nil + } + actualSet := make(map[string]struct{}, len(actualCallIDs)) + for _, callID := range actualCallIDs { + id := strings.TrimSpace(callID) + if id == "" { + continue + } + actualSet[id] = struct{}{} + } + missing := make([]string, 0, len(required)) + for _, callID := range required { + if _, ok := actualSet[callID]; ok { + continue + } + missing = append(missing, callID) + } + return missing +} + +func openAIWSInjectFunctionCallOutputItems(payload []byte, callIDs []string, outputValue string) ([]byte, int, error) { + normalizedCallIDs := openAIWSNormalizeCallIDs(callIDs) + if len(normalizedCallIDs) == 0 { + return payload, 0, nil + } + inputItems, inputExists, inputErr := openAIWSExtractNormalizedInputSequence(payload) + if inputErr != nil { + return nil, 0, inputErr + } + if !inputExists { + inputItems = []json.RawMessage{} + } + updatedInput := make([]json.RawMessage, 0, len(inputItems)+len(normalizedCallIDs)) + updatedInput = append(updatedInput, cloneOpenAIWSRawMessages(inputItems)...) + for _, callID := range normalizedCallIDs { + rawItem, marshalErr := json.Marshal(map[string]any{ + "type": "function_call_output", + "call_id": callID, + "output": outputValue, + }) + if marshalErr != nil { + return nil, 0, marshalErr + } + updatedInput = append(updatedInput, json.RawMessage(rawItem)) + } + updatedPayload, setErr := setOpenAIWSPayloadInputSequence(payload, updatedInput, true) + if setErr != nil { + return nil, 0, setErr + } + return updatedPayload, len(normalizedCallIDs), nil +} + +func shouldKeepIngressPreviousResponseID( + previousPayload []byte, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, + expectedPendingCallIDs []string, + functionCallOutputCallIDs []string, +) (bool, string, error) { + if hasFunctionCallOutput { + if len(expectedPendingCallIDs) == 0 { + return true, "has_function_call_output", nil + } + if len(openAIWSFindMissingCallIDs(expectedPendingCallIDs, functionCallOutputCallIDs)) > 0 { + return false, "function_call_output_call_id_mismatch", nil + } + return true, "function_call_output_call_id_match", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if len(previousPayload) == 0 { + return false, "missing_previous_turn_payload", nil + } + + previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload) + if previousComparableErr != nil { + return false, "non_input_compare_error", previousComparableErr + } + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +type openAIWSIngressPreviousTurnStrictState struct { + nonInputComparable []byte +} + +func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) { + if len(payload) == 0 { + return nil, nil + } + nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload) + if nonInputErr != nil { + return nil, nonInputErr + } + return &openAIWSIngressPreviousTurnStrictState{ + nonInputComparable: nonInputComparable, + }, nil +} + +func shouldKeepIngressPreviousResponseIDWithStrictState( + previousState *openAIWSIngressPreviousTurnStrictState, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, + expectedPendingCallIDs []string, + functionCallOutputCallIDs []string, +) (bool, string, error) { + if hasFunctionCallOutput { + if len(expectedPendingCallIDs) == 0 { + return true, "has_function_call_output", nil + } + if len(openAIWSFindMissingCallIDs(expectedPendingCallIDs, functionCallOutputCallIDs)) > 0 { + return false, "function_call_output_call_id_mismatch", nil + } + return true, "function_call_output_call_id_match", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if previousState == nil { + return false, "missing_previous_turn_payload", nil + } + + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousState.nonInputComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +func payloadAsJSON(payload map[string]any) string { + return string(payloadAsJSONBytes(payload)) +} + +func normalizeOpenAIWSPreferredConnID(connID string) (string, bool) { + trimmed := strings.TrimSpace(connID) + if trimmed == "" { + return "", false + } + if strings.HasPrefix(trimmed, openAIWSConnIDPrefixCtx) { + return trimmed, true + } + if strings.HasPrefix(trimmed, openAIWSConnIDPrefixLegacy) { + return trimmed, true + } + return "", false +} + +func openAIWSPreferredConnIDFromResponse(stateStore OpenAIWSStateStore, responseID string) string { + if stateStore == nil { + return "" + } + normalizedResponseID := strings.TrimSpace(responseID) + if normalizedResponseID == "" { + return "" + } + connID, ok := stateStore.GetResponseConn(normalizedResponseID) + if !ok { + return "" + } + normalizedConnID, ok := normalizeOpenAIWSPreferredConnID(connID) + if !ok { + return "" + } + return normalizedConnID +} + +func payloadAsJSONBytes(payload map[string]any) []byte { + if len(payload) == 0 { + return []byte("{}") + } + body, err := json.Marshal(payload) + if err != nil { + return []byte("{}") + } + return body +} + +func isOpenAIWSTerminalEvent(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func shouldPersistOpenAIWSLastResponseID(terminalEventType string) bool { + switch terminalEventType { + case "response.completed", "response.done": + return true + default: + return false + } +} + +func isOpenAIWSTokenEvent(eventType string) bool { + if eventType == "" { + return false + } + switch eventType { + case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": + return false + } + if strings.Contains(eventType, ".delta") { + return true + } + if strings.HasPrefix(eventType, "response.output_text") { + return true + } + if strings.HasPrefix(eventType, "response.output") { + return true + } + return eventType == "response.completed" || eventType == "response.done" +} + +func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte { + if len(message) == 0 { + return message + } + if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel { + return message + } + if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) { + return message + } + modelValues := gjson.GetManyBytes(message, "model", "response.model") + replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel + replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel + if !replaceModel && !replaceResponseModel { + return message + } + updated := message + if replaceModel { + if next, err := sjson.SetBytes(updated, "model", toModel); err == nil { + updated = next + } + } + if replaceResponseModel { + if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil { + updated = next + } + } + return updated +} + +func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) { + if usage == nil || len(body) == 0 { + return + } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func getOpenAIGroupIDFromContext(c *gin.Context) int64 { + if c == nil { + return 0 + } + value, exists := c.Get("api_key") + if !exists { + return 0 + } + apiKey, ok := value.(*APIKey) + if !ok || apiKey == nil || apiKey.GroupID == nil { + return 0 + } + return *apiKey.GroupID +} + +func openAIWSIngressFallbackSessionSeedFromContext(c *gin.Context) string { + if c == nil { + return "" + } + value, exists := c.Get("api_key") + if !exists { + return "" + } + apiKey, ok := value.(*APIKey) + if !ok || apiKey == nil { + return "" + } + gid := int64(0) + if apiKey.GroupID != nil { + gid = *apiKey.GroupID + } + userID := int64(0) + if apiKey.User != nil { + userID = apiKey.User.ID + } + return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKey.ID) +} diff --git a/backend/internal/service/openai_ws_upstream_pump_test.go b/backend/internal/service/openai_ws_upstream_pump_test.go new file mode 100644 index 000000000..26a7439dc --- /dev/null +++ b/backend/internal/service/openai_ws_upstream_pump_test.go @@ -0,0 +1,1894 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "io" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// 辅助:构造测试用上游事件 JSON +// --------------------------------------------------------------------------- + +func pumpTestEvent(eventType string) []byte { + m := map[string]any{"type": eventType} + b, _ := json.Marshal(m) + return b +} + +func pumpTestEventWithResponseID(eventType, responseID string) []byte { + m := map[string]any{"type": eventType, "response": map[string]any{"id": responseID}} + b, _ := json.Marshal(m) + return b +} + +// --------------------------------------------------------------------------- +// 辅助:模拟上游连接(支持按序返回事件、延迟、错误注入) +// --------------------------------------------------------------------------- + +type pumpTestConn struct { + mu sync.Mutex + events []pumpTestConnEvent + readCount int + closed bool + closedCh chan struct{} + ignoreCtx bool + pingErr error + writeErr error + writeCount int +} + +type pumpTestConnEvent struct { + data []byte + err error + delay time.Duration +} + +func newPumpTestConn(events ...pumpTestConnEvent) *pumpTestConn { + return &pumpTestConn{ + events: events, + closedCh: make(chan struct{}), + } +} + +func (c *pumpTestConn) WriteJSON(_ context.Context, _ any) error { + c.mu.Lock() + defer c.mu.Unlock() + c.writeCount++ + return c.writeErr +} + +func (c *pumpTestConn) ReadMessage(ctx context.Context) ([]byte, error) { + if ctx == nil { + ctx = context.Background() + } + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, errOpenAIWSConnClosed + } + if len(c.events) == 0 { + c.mu.Unlock() + if c.ignoreCtx { + <-c.closedCh + return nil, io.EOF + } + // 阻塞直到上下文取消,模拟上游无更多事件 + <-ctx.Done() + return nil, ctx.Err() + } + evt := c.events[0] + c.events = c.events[1:] + c.readCount++ + c.mu.Unlock() + + if evt.delay > 0 { + timer := time.NewTimer(evt.delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + } + } + return evt.data, evt.err +} + +func (c *pumpTestConn) Ping(_ context.Context) error { return c.pingErr } + +func (c *pumpTestConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + if !c.closed { + c.closed = true + close(c.closedCh) + } + return nil +} + +// --------------------------------------------------------------------------- +// 辅助:模拟 lease 接口(仅泵测试所需的读写方法) +// --------------------------------------------------------------------------- + +type pumpTestLease struct { + conn *pumpTestConn + broken atomic.Bool +} + +func (l *pumpTestLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) { + readCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + return l.conn.ReadMessage(readCtx) +} + +func (l *pumpTestLease) MarkBroken() { + l.broken.Store(true) + if l.conn != nil { + _ = l.conn.Close() + } +} + +func (l *pumpTestLease) IsBroken() bool { return l.broken.Load() } + +// --------------------------------------------------------------------------- +// 辅助:运行泵 goroutine 并收集所有产出的事件 +// --------------------------------------------------------------------------- + +// startPump 模拟 sendAndRelay 中的泵 goroutine,返回事件 channel 和取消函数。 +func startPump(ctx context.Context, lease *pumpTestLease, readTimeout time.Duration) (chan openAIWSUpstreamPumpEvent, context.CancelFunc) { + pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize) + pumpCtx, pumpCancel := context.WithCancel(ctx) + go func() { + defer close(pumpEventCh) + for { + msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, readTimeout) + select { + case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}: + case <-pumpCtx.Done(): + return + } + if readErr != nil { + return + } + evtType, _ := parseOpenAIWSEventType(msg) + if isOpenAIWSTerminalEvent(evtType) || evtType == "error" { + return + } + } + }() + return pumpEventCh, pumpCancel +} + +// collectAll 从 channel 读取所有事件直到关闭。 +func collectAll(ch chan openAIWSUpstreamPumpEvent) []openAIWSUpstreamPumpEvent { + var result []openAIWSUpstreamPumpEvent + for evt := range ch { + result = append(result, evt) + } + return result +} + +// --------------------------------------------------------------------------- +// 测试:openAIWSUpstreamPumpEvent 结构体 +// --------------------------------------------------------------------------- + +func TestOpenAIWSUpstreamPumpEvent_Fields(t *testing.T) { + t.Parallel() + + t.Run("message_only", func(t *testing.T) { + evt := openAIWSUpstreamPumpEvent{message: []byte("hello")} + assert.Equal(t, []byte("hello"), evt.message) + assert.NoError(t, evt.err) + }) + + t.Run("error_only", func(t *testing.T) { + evt := openAIWSUpstreamPumpEvent{err: io.EOF} + assert.Nil(t, evt.message) + assert.ErrorIs(t, evt.err, io.EOF) + }) + + t.Run("both_fields", func(t *testing.T) { + evt := openAIWSUpstreamPumpEvent{message: []byte("partial"), err: io.ErrUnexpectedEOF} + assert.Equal(t, []byte("partial"), evt.message) + assert.ErrorIs(t, evt.err, io.ErrUnexpectedEOF) + }) +} + +func TestOpenAIWSUpstreamPumpBufferSize(t *testing.T) { + t.Parallel() + assert.Equal(t, 16, openAIWSUpstreamPumpBufferSize, "缓冲大小应为 16") +} + +// --------------------------------------------------------------------------- +// 测试:泵 goroutine 正常事件流 +// --------------------------------------------------------------------------- + +func TestPump_NormalEventFlow(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + + require.Len(t, events, 4) + for _, evt := range events { + assert.NoError(t, evt.err) + assert.NotEmpty(t, evt.message) + } + // 验证最后一个是终端事件 + lastType, _ := parseOpenAIWSEventType(events[3].message) + assert.True(t, isOpenAIWSTerminalEvent(lastType)) +} + +func TestPump_TerminalEventStopsPump(t *testing.T) { + t.Parallel() + terminalTypes := []string{ + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled", + } + for _, tt := range terminalTypes { + tt := tt + t.Run(tt, func(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent(tt)}, + // 以下事件不应该被读取 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2, "终端事件 %s 后泵应停止", tt) + assert.NoError(t, events[0].err) + assert.NoError(t, events[1].err) + }) + } +} + +func TestPump_ErrorEventStopsPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("error")}, + // 不应被读取 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2, "error 事件后泵应停止") + evtType, _ := parseOpenAIWSEventType(events[1].message) + assert.Equal(t, "error", evtType) +} + +// --------------------------------------------------------------------------- +// 测试:泵 goroutine 读取错误传播 +// --------------------------------------------------------------------------- + +func TestPump_ReadErrorPropagated(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{err: io.ErrUnexpectedEOF}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2) + assert.NoError(t, events[0].err) + assert.ErrorIs(t, events[1].err, io.ErrUnexpectedEOF) +} + +func TestPump_ReadErrorOnFirstEvent(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{err: errors.New("connection refused")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 1) + assert.Error(t, events[0].err) + assert.Contains(t, events[0].err.Error(), "connection refused") +} + +func TestPump_EOFError(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{err: io.EOF}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 3) + assert.ErrorIs(t, events[2].err, io.EOF) +} + +// --------------------------------------------------------------------------- +// 测试:上下文取消终止泵 +// --------------------------------------------------------------------------- + +func TestPump_ContextCancellationStopsPump(t *testing.T) { + t.Parallel() + // 连接永远阻塞在第二次读取 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 无更多事件,ReadMessage 将阻塞直到 ctx 取消 + ) + lease := &pumpTestLease{conn: conn} + ctx, ctxCancel := context.WithCancel(context.Background()) + ch, pumpCancel := startPump(ctx, lease, 30*time.Second) + defer pumpCancel() + + // 读取第一个事件 + evt := <-ch + assert.NoError(t, evt.err) + + // 取消上下文 + ctxCancel() + + // 泵应该退出,channel 应该关闭 + events := collectAll(ch) + // 可能收到一个 context.Canceled 错误事件 + for _, e := range events { + assert.Error(t, e.err) + } +} + +func TestPump_PumpCancelStopsPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second) + + evt := <-ch + assert.NoError(t, evt.err) + + // 调用 pumpCancel 应终止泵 + pumpCancel() + + // channel 应被关闭 + events := collectAll(ch) + for _, e := range events { + assert.Error(t, e.err) + } +} + +// --------------------------------------------------------------------------- +// 测试:缓冲行为 +// --------------------------------------------------------------------------- + +func TestPump_BufferAllowsConcurrentReadWrite(t *testing.T) { + t.Parallel() + // 生成超过缓冲大小的事件,验证不会死锁 + numEvents := openAIWSUpstreamPumpBufferSize + 5 + connEvents := make([]pumpTestConnEvent, 0, numEvents) + for i := 0; i < numEvents-1; i++ { + connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}) + } + connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.completed")}) + + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, numEvents) + for _, evt := range events { + assert.NoError(t, evt.err) + } +} + +func TestPump_SlowConsumerDoesNotBlock(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + // 模拟慢消费者 + var events []openAIWSUpstreamPumpEvent + for evt := range ch { + events = append(events, evt) + time.Sleep(10 * time.Millisecond) // 慢消费 + } + require.Len(t, events, 4) +} + +// --------------------------------------------------------------------------- +// 测试:排水定时器机制 +// --------------------------------------------------------------------------- + +func TestPump_DrainTimerCancelsPump(t *testing.T) { + t.Parallel() + // 模拟:客户端断连后,排水定时器到期取消泵 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 第二次读取会阻塞(模拟上游仍在生成但还没发出事件) + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second) + + // 读取第一个事件 + evt := <-ch + assert.NoError(t, evt.err) + + // 模拟排水定时器:50ms 后取消泵(正式代码中是 5 秒) + drainTimer := time.AfterFunc(50*time.Millisecond, pumpCancel) + defer drainTimer.Stop() + + // 等待 channel 关闭 + start := time.Now() + remaining := collectAll(ch) + elapsed := time.Since(start) + + // 应在 50ms 附近退出,而非 30 秒 + assert.Less(t, elapsed, 2*time.Second, "排水定时器应在约 50ms 后终止泵") + + // 可能收到 context.Canceled 错误事件 + for _, e := range remaining { + assert.Error(t, e.err) + } +} + +func TestPump_DrainDeadlineCheckInMainLoop(t *testing.T) { + t.Parallel() + // 模拟主循环中的排水超时检查逻辑 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 加延迟模拟上游慢响应 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 80 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + clientDisconnected := false + drainDeadline := time.Time{} + var eventsBeforeDrain []openAIWSUpstreamPumpEvent + drainTriggered := false + + for evt := range ch { + // 检查排水超时 + if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + pumpCancel() + drainTriggered = true + break + } + if evt.err != nil { + break + } + eventsBeforeDrain = append(eventsBeforeDrain, evt) + + // 模拟:第一个事件后客户端断连,设置极短的排水截止时间 + if !clientDisconnected && len(eventsBeforeDrain) == 1 { + clientDisconnected = true + drainDeadline = time.Now().Add(30 * time.Millisecond) + } + } + + // 排水截止时间为 30ms,第二个事件延迟 80ms,所以应该触发排水超时 + assert.True(t, drainTriggered, "排水超时应被触发") + assert.Len(t, eventsBeforeDrain, 1, "排水前应只有 1 个事件") +} + +// --------------------------------------------------------------------------- +// 测试:与上游事件延迟的并发行为 +// --------------------------------------------------------------------------- + +func TestPump_ReadDelayDoesNotBlockPreviousEvents(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 100 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + // 第一个事件应该立即可用 + start := time.Now() + evt := <-ch + assert.NoError(t, evt.err) + assert.Less(t, time.Since(start), 50*time.Millisecond, "第一个事件应立即到达") + + events := collectAll(ch) + require.Len(t, events, 2) +} + +// --------------------------------------------------------------------------- +// 测试:空事件流 +// --------------------------------------------------------------------------- + +func TestPump_EmptyStreamContextCancel(t *testing.T) { + t.Parallel() + // 没有任何事件,连接阻塞,靠 context 取消 + conn := newPumpTestConn() // 无事件 + lease := &pumpTestLease{conn: conn} + ctx, ctxCancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer ctxCancel() + ch, pumpCancel := startPump(ctx, lease, 30*time.Second) + defer pumpCancel() + + events := collectAll(ch) + // context 取消后,泵的 select 可能选择 pumpCtx.Done() 分支直接退出(0 个事件), + // 也可能先将错误事件发送到 channel 后退出(1 个事件),两种行为都正确。 + assert.LessOrEqual(t, len(events), 1, "最多应收到 1 个事件") + for _, evt := range events { + assert.Error(t, evt.err) + } +} + +// --------------------------------------------------------------------------- +// 测试:非终端/非错误事件不终止泵 +// --------------------------------------------------------------------------- + +func TestPump_NonTerminalEventsDoNotStopPump(t *testing.T) { + t.Parallel() + nonTerminalTypes := []string{ + "response.created", + "response.in_progress", + "response.output_text.delta", + "response.content_part.added", + "response.output_item.added", + "response.reasoning_summary_text.delta", + } + connEvents := make([]pumpTestConnEvent, 0, len(nonTerminalTypes)+1) + for _, et := range nonTerminalTypes { + connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent(et)}) + } + connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.completed")}) + + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, len(nonTerminalTypes)+1, "所有非终端事件 + 终端事件都应被传递") +} + +// --------------------------------------------------------------------------- +// 测试:多次 pumpCancel 调用安全(幂等) +// --------------------------------------------------------------------------- + +func TestPump_MultipleCancelSafe(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + events := collectAll(ch) + require.Len(t, events, 1) + + // 多次调用 pumpCancel 不应 panic + assert.NotPanics(t, func() { + pumpCancel() + pumpCancel() + pumpCancel() + }) +} + +// --------------------------------------------------------------------------- +// 测试:泵与主循环集成——模拟完整的 relay 消费模式 +// --------------------------------------------------------------------------- + +func TestPump_IntegrationRelayPattern(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_abc123")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEventWithResponseID("response.completed", "resp_abc123")}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + // 模拟主循环处理 + var responseID string + eventCount := 0 + tokenEventCount := 0 + var terminalEventType string + clientWriteCount := 0 + + for evt := range ch { + if evt.err != nil { + t.Fatalf("unexpected error: %v", evt.err) + } + eventType, evtRespID := parseOpenAIWSEventType(evt.message) + if responseID == "" && evtRespID != "" { + responseID = evtRespID + } + eventCount++ + if isOpenAIWSTokenEvent(eventType) { + tokenEventCount++ + } + // 模拟写客户端 + clientWriteCount++ + + if isOpenAIWSTerminalEvent(eventType) { + terminalEventType = eventType + break + } + } + + assert.Equal(t, "resp_abc123", responseID) + assert.Equal(t, 5, eventCount) + assert.GreaterOrEqual(t, tokenEventCount, 3, "至少 3 个 delta 事件应被计为 token 事件") + assert.Equal(t, 5, clientWriteCount) + assert.Equal(t, "response.completed", terminalEventType) +} + +// --------------------------------------------------------------------------- +// 测试:泵 goroutine 在 channel 满时 + context 取消的行为 +// --------------------------------------------------------------------------- + +func TestPump_ChannelFullThenCancel(t *testing.T) { + t.Parallel() + // 生成大量事件但不消费,验证 pumpCancel 仍然能终止泵 + numEvents := openAIWSUpstreamPumpBufferSize * 3 + connEvents := make([]pumpTestConnEvent, numEvents) + for i := range connEvents { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")} + } + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + // 等待缓冲区被填满 + time.Sleep(50 * time.Millisecond) + + // 取消泵 + pumpCancel() + + // 清空 channel + events := collectAll(ch) + // 应收到 bufferSize 到 bufferSize+1 个事件(泵在 channel 满时可能阻塞在 select) + assert.LessOrEqual(t, len(events), numEvents, "不应收到超过总事件数的事件") + assert.GreaterOrEqual(t, len(events), 1, "至少应收到一些事件") +} + +// --------------------------------------------------------------------------- +// 测试:读取超时机制 +// --------------------------------------------------------------------------- + +func TestPump_ReadTimeoutTriggersError(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 第二次读取延迟超过超时 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 500 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + // 读取超时设为 50ms,远小于 500ms 延迟 + ch, cancel := startPump(context.Background(), lease, 50*time.Millisecond) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2) + assert.NoError(t, events[0].err) + assert.Error(t, events[1].err, "第二次读取应超时") +} + +// --------------------------------------------------------------------------- +// 测试:泵在 response.done 事件后停止(另一种终端事件) +// --------------------------------------------------------------------------- + +func TestPump_ResponseDoneStopsPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.done")}, + pumpTestConnEvent{data: pumpTestEvent("should_not_reach")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 3) + lastType, _ := parseOpenAIWSEventType(events[2].message) + assert.Equal(t, "response.done", lastType) +} + +// --------------------------------------------------------------------------- +// 测试:泵在读取到 error event 后不继续读取更多事件 +// --------------------------------------------------------------------------- + +func TestPump_ErrorEventStopsReading(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("error")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, // 不应被读取 + ) + // 重写以追踪读取次数 + origEvents := conn.events + conn.events = nil + var wrappedConn pumpTestConn + wrappedConn.closedCh = make(chan struct{}) + wrappedConn.events = origEvents + wrappedLease := &pumpTestLease{conn: &wrappedConn} + + ch, cancel := startPump(context.Background(), wrappedLease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 1, "error 事件后不应再读取更多事件") + evtType, _ := parseOpenAIWSEventType(events[0].message) + assert.Equal(t, "error", evtType) +} + +// --------------------------------------------------------------------------- +// 测试:验证事件顺序保持不变 +// --------------------------------------------------------------------------- + +func TestPump_EventOrderPreserved(t *testing.T) { + t.Parallel() + expectedTypes := []string{ + "response.created", + "response.in_progress", + "response.output_item.added", + "response.content_part.added", + "response.output_text.delta", + "response.output_text.delta", + "response.output_text.delta", + "response.output_text.done", + "response.content_part.done", + "response.output_item.done", + "response.completed", + } + connEvents := make([]pumpTestConnEvent, len(expectedTypes)) + for i, et := range expectedTypes { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent(et)} + } + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, len(expectedTypes)) + for i, evt := range events { + evtType, _ := parseOpenAIWSEventType(evt.message) + assert.Equal(t, expectedTypes[i], evtType, "事件 %d 类型不匹配", i) + } +} + +// --------------------------------------------------------------------------- +// 测试:无效 JSON 消息不影响泵运行 +// --------------------------------------------------------------------------- + +func TestPump_InvalidJSONDoesNotStopPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: []byte("not json")}, + pumpTestConnEvent{data: []byte("{invalid")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 3, "无效 JSON 不应终止泵") +} + +// --------------------------------------------------------------------------- +// 测试:并发安全——多个消费者不会 panic +// --------------------------------------------------------------------------- + +func TestPump_ConcurrentConsumeAndCancel(t *testing.T) { + t.Parallel() + connEvents := make([]pumpTestConnEvent, 100) + for i := range connEvents { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: time.Millisecond} + } + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + // 同时消费和取消,不应 panic + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for range ch { + // 消费 + } + }() + go func() { + defer wg.Done() + time.Sleep(20 * time.Millisecond) + pumpCancel() + }() + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // 成功 + case <-time.After(5 * time.Second): + t.Fatal("超时:并发消费和取消场景死锁") + } +} + +// --------------------------------------------------------------------------- +// 测试:排水定时器与正常终端事件的竞争 +// --------------------------------------------------------------------------- + +func TestPump_DrainTimerRaceWithTerminalEvent(t *testing.T) { + t.Parallel() + // 终端事件在排水定时器到期前到达 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 10 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + // 设置较长的排水定时器(200ms),终端事件应在 10ms 后到达 + drainTimer := time.AfterFunc(200*time.Millisecond, pumpCancel) + defer drainTimer.Stop() + + events := collectAll(ch) + // 终端事件应先到达 + require.Len(t, events, 2) + lastType, _ := parseOpenAIWSEventType(events[1].message) + assert.Equal(t, "response.completed", lastType) + assert.NoError(t, events[1].err) + + pumpCancel() // 清理 +} + +// --------------------------------------------------------------------------- +// 测试:大量事件的吞吐量(确保泵不引入异常开销) +// --------------------------------------------------------------------------- + +func TestPump_HighThroughput(t *testing.T) { + t.Parallel() + numEvents := 1000 + connEvents := make([]pumpTestConnEvent, numEvents) + for i := range connEvents { + if i == numEvents-1 { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.completed")} + } else { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")} + } + } + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + start := time.Now() + events := collectAll(ch) + elapsed := time.Since(start) + + require.Len(t, events, numEvents) + assert.Less(t, elapsed, 2*time.Second, "1000 个事件应在 2 秒内完成") + for _, evt := range events { + assert.NoError(t, evt.err) + } +} + +// --------------------------------------------------------------------------- +// 测试:空消息(零字节)不终止泵 +// --------------------------------------------------------------------------- + +func TestPump_EmptyMessageDoesNotStopPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: []byte{}}, + pumpTestConnEvent{data: nil}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 3, "空消息不应终止泵") +} + +// =========================================================================== +// 以下为消息泵模式新增代码路径的补充测试 +// =========================================================================== + +// --------------------------------------------------------------------------- +// 测试:泵 channel 关闭但无终端事件(上游异常断连) +// --------------------------------------------------------------------------- + +func TestPump_UnexpectedCloseDetectedByConsumer(t *testing.T) { + t.Parallel() + // 模拟:上游只发了非终端事件就断连(ReadMessage 返回 EOF) + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{err: io.EOF}, // 上游断连 + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + // 模拟主循环消费:检查是否收到了终端事件 + receivedTerminal := false + var lastErr error + for evt := range ch { + if evt.err != nil { + lastErr = evt.err + break + } + evtType, _ := parseOpenAIWSEventType(evt.message) + if isOpenAIWSTerminalEvent(evtType) { + receivedTerminal = true + break + } + } + // 未收到终端事件,但收到了 EOF 错误——消费者应识别为上游异常断连 + assert.False(t, receivedTerminal, "不应收到终端事件") + assert.ErrorIs(t, lastErr, io.EOF, "应收到 EOF 错误标识上游断连") +} + +func TestPump_ChannelCloseWithoutTerminalOrError(t *testing.T) { + t.Parallel() + // 极端情况:泵被外部取消(pumpCancel),channel 关闭但既无终端事件也无错误事件。 + // 模拟中间事件后泵被取消。 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + // 无更多事件,ReadMessage 将阻塞 + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second) + + // 消费前两个事件 + evt1 := <-ch + assert.NoError(t, evt1.err) + evt2 := <-ch + assert.NoError(t, evt2.err) + + // 外部取消泵 + pumpCancel() + + // for-range 应退出,模拟 "泵 channel 关闭但未收到终端事件" 场景 + receivedTerminal := false + for evt := range ch { + if evt.err == nil { + evtType, _ := parseOpenAIWSEventType(evt.message) + if isOpenAIWSTerminalEvent(evtType) { + receivedTerminal = true + } + } + } + assert.False(t, receivedTerminal, "泵被取消后不应再收到终端事件") +} + +// --------------------------------------------------------------------------- +// 测试:lease.MarkBroken 场景验证 +// --------------------------------------------------------------------------- + +func TestPump_LeaseMarkedBrokenOnUnexpectedClose(t *testing.T) { + t.Parallel() + // 模拟主循环:泵关闭但无终端事件时应标记 lease 为 broken + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{err: io.ErrUnexpectedEOF}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + receivedTerminal := false + for evt := range ch { + if evt.err != nil { + // 模拟正式代码中的错误处理路径 + lease.MarkBroken() + break + } + evtType, _ := parseOpenAIWSEventType(evt.message) + if isOpenAIWSTerminalEvent(evtType) { + receivedTerminal = true + break + } + } + // 如果 for-range 正常退出且未收到终端事件,也标记 broken + if !receivedTerminal { + lease.MarkBroken() + } + + assert.True(t, lease.IsBroken(), "上游异常断连应标记 lease 为 broken") +} + +func TestPump_LeaseNotBrokenOnNormalTerminal(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + for evt := range ch { + if evt.err != nil { + lease.MarkBroken() + break + } + } + + assert.False(t, lease.IsBroken(), "正常终端事件不应标记 lease 为 broken") +} + +// --------------------------------------------------------------------------- +// 测试:排水定时器只创建一次 +// --------------------------------------------------------------------------- + +func TestPump_DrainTimerCreatedOnlyOnce(t *testing.T) { + t.Parallel() + // 模拟多次"客户端断连"信号,验证排水定时器只创建一次 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 10 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + drainTimerCount := 0 + clientDisconnected := false + drainDeadline := time.Time{} + var drainTimer *time.Timer + + for evt := range ch { + if evt.err != nil { + break + } + // 每个事件后都"检测到客户端断连" + if !clientDisconnected { + clientDisconnected = true + } + // 排水定时器只在第一次断连时创建 + if clientDisconnected && drainDeadline.IsZero() { + drainDeadline = time.Now().Add(500 * time.Millisecond) + drainTimer = time.AfterFunc(500*time.Millisecond, pumpCancel) + drainTimerCount++ + } + } + if drainTimer != nil { + drainTimer.Stop() + } + + assert.Equal(t, 1, drainTimerCount, "排水定时器应只创建一次") +} + +// --------------------------------------------------------------------------- +// 测试:排水定时器在正常完成前被 Stop +// --------------------------------------------------------------------------- + +func TestPump_DrainTimerStoppedOnNormalCompletion(t *testing.T) { + t.Parallel() + // 终端事件在排水定时器到期前到达,验证定时器被正确停止 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + // 创建长时间排水定时器 + drainTimer := time.AfterFunc(10*time.Second, pumpCancel) + + var events []openAIWSUpstreamPumpEvent + for evt := range ch { + events = append(events, evt) + } + + // 正常完成后停止排水定时器(模拟 defer drainTimer.Stop()) + stopped := drainTimer.Stop() + pumpCancel() // 清理 + + assert.True(t, stopped, "定时器应尚未触发,Stop() 返回 true") + require.Len(t, events, 2) +} + +// --------------------------------------------------------------------------- +// 测试:排水期间读取错误处理 +// --------------------------------------------------------------------------- + +func TestPump_ReadErrorDuringDrainTreatedAsDrainTimeout(t *testing.T) { + t.Parallel() + // 新代码:客户端已断连时任何读取错误都按排水超时处理(不仅限 DeadlineExceeded) + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{err: io.ErrUnexpectedEOF, delay: 20 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + clientDisconnected := false + var drainError error + + for evt := range ch { + if !clientDisconnected { + // 第一个事件后模拟客户端断连 + clientDisconnected = true + continue + } + if evt.err != nil && clientDisconnected { + // 新代码路径:排水期间收到读取错误 + drainError = evt.err + break + } + } + + assert.Error(t, drainError, "排水期间应收到读取错误") + assert.ErrorIs(t, drainError, io.ErrUnexpectedEOF) +} + +func TestPump_ReadErrorDuringDrain_EOF(t *testing.T) { + t.Parallel() + // EOF 在排水期间等同于上游关闭 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{err: io.EOF}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + clientDisconnected := false + drainErrorCount := 0 + + for evt := range ch { + if evt.err != nil { + if clientDisconnected { + drainErrorCount++ + } + break + } + // 第一个事件后模拟客户端断连 + if !clientDisconnected { + clientDisconnected = true + } + } + + assert.Equal(t, 1, drainErrorCount, "排水期间 EOF 应被计为一次排水错误") +} + +// --------------------------------------------------------------------------- +// 测试:排水截止时间检查——在事件间隙中过期 +// --------------------------------------------------------------------------- + +func TestPump_DrainDeadlineExpiresBetweenEvents(t *testing.T) { + t.Parallel() + // 排水截止时间在两个上游事件之间到期 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 60 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 60 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + clientDisconnected := false + drainDeadline := time.Time{} + drainExpired := false + eventsProcessed := 0 + + for evt := range ch { + // 排水超时检查(在处理事件前,模拟正式代码) + if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + pumpCancel() + drainExpired = true + break + } + if evt.err != nil { + break + } + eventsProcessed++ + + // 第一个事件后断连,排水截止时间设为 30ms + if !clientDisconnected { + clientDisconnected = true + drainDeadline = time.Now().Add(30 * time.Millisecond) + } + } + + // 第二个事件延迟 60ms > 排水截止 30ms,应触发排水超时 + assert.True(t, drainExpired, "排水截止时间应在事件间隙中过期") + assert.Equal(t, 1, eventsProcessed, "过期前应只处理了 1 个事件") +} + +func TestPump_DrainDeadlineNotYetExpiredAllowsProcessing(t *testing.T) { + t.Parallel() + // 排水截止时间足够长,允许处理所有事件 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 5 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + defer pumpCancel() + + clientDisconnected := false + drainDeadline := time.Time{} + drainExpired := false + eventsProcessed := 0 + + for evt := range ch { + if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + pumpCancel() + drainExpired = true + break + } + if evt.err != nil { + break + } + eventsProcessed++ + if !clientDisconnected { + clientDisconnected = true + drainDeadline = time.Now().Add(500 * time.Millisecond) // 足够长 + } + } + + assert.False(t, drainExpired, "排水截止时间未过期,不应触发排水超时") + assert.Equal(t, 3, eventsProcessed, "所有事件都应被处理") +} + +// --------------------------------------------------------------------------- +// 测试:goroutine 清理和资源释放 +// --------------------------------------------------------------------------- + +func TestPump_DeferPumpCancelAndDrainTimerCleanup(t *testing.T) { + t.Parallel() + // 模拟正式代码的完整 defer 清理路径 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize) + pumpCtx, pumpCancel := context.WithCancel(context.Background()) + // 模拟 defer pumpCancel() + defer pumpCancel() + + go func() { + defer close(pumpEventCh) + for { + msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, 5*time.Second) + select { + case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}: + case <-pumpCtx.Done(): + return + } + if readErr != nil { + return + } + evtType, _ := parseOpenAIWSEventType(msg) + if isOpenAIWSTerminalEvent(evtType) || evtType == "error" { + return + } + } + }() + + // 模拟排水定时器 + var drainTimer *time.Timer + defer func() { + if drainTimer != nil { + drainTimer.Stop() + } + }() + drainTimer = time.AfterFunc(10*time.Second, pumpCancel) + + events := collectAll(pumpEventCh) + require.Len(t, events, 2) + + // defer 清理后不应 panic + assert.NotPanics(t, func() { + pumpCancel() + if drainTimer != nil { + drainTimer.Stop() + } + }) +} + +// --------------------------------------------------------------------------- +// 测试:连接在泵运行期间被关闭 +// --------------------------------------------------------------------------- + +func TestPump_ConnectionClosedDuringPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 后续事件阻塞 + ) + lease := &pumpTestLease{conn: conn} + // 使用较短的读超时,因为 conn.Close() 不会解除阻塞的 ReadMessage(它等待 <-ctx.Done()) + ch, pumpCancel := startPump(context.Background(), lease, 100*time.Millisecond) + defer pumpCancel() + + // 读取第一个事件 + evt := <-ch + assert.NoError(t, evt.err) + + // 关闭连接——注意:ReadMessage 仍在等待 ctx.Done(), + // 但读超时为 100ms 会触发 context.DeadlineExceeded。 + // 下次 ReadMessage 调用时会检测到 closed 状态。 + _ = conn.Close() + + // 泵应在读超时后检测到连接关闭 + events := collectAll(ch) + require.GreaterOrEqual(t, len(events), 1, "应收到错误") + // 至少有一个事件包含错误 + hasError := false + for _, e := range events { + if e.err != nil { + hasError = true + } + } + assert.True(t, hasError, "应收到连接关闭或超时错误") +} + +// --------------------------------------------------------------------------- +// 测试:大消息(KB 级别 JSON)不影响泵传递 +// --------------------------------------------------------------------------- + +func TestPump_LargeMessages(t *testing.T) { + t.Parallel() + // 构造 ~10KB 的消息 + largeContent := make([]byte, 10*1024) + for i := range largeContent { + largeContent[i] = 'x' + } + largeMsg := []byte(`{"type":"response.output_text.delta","delta":"` + string(largeContent) + `"}`) + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: largeMsg}, + pumpTestConnEvent{data: largeMsg}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 4) + // 验证大消息完整传递 + assert.Len(t, events[1].message, len(largeMsg)) + assert.Len(t, events[2].message, len(largeMsg)) +} + +// --------------------------------------------------------------------------- +// 测试:多轮泵会话(同一 lease 上依次创建多个泵) +// --------------------------------------------------------------------------- + +func TestPump_SequentialSessions(t *testing.T) { + t.Parallel() + // 第一轮 + conn1 := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease1 := &pumpTestLease{conn: conn1} + ch1, cancel1 := startPump(context.Background(), lease1, 5*time.Second) + events1 := collectAll(ch1) + cancel1() + require.Len(t, events1, 2) + + // 第二轮(新连接、新泵) + conn2 := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease2 := &pumpTestLease{conn: conn2} + ch2, cancel2 := startPump(context.Background(), lease2, 5*time.Second) + events2 := collectAll(ch2) + cancel2() + require.Len(t, events2, 3) + + // 两轮之间互不影响 + assert.False(t, lease1.IsBroken()) + assert.False(t, lease2.IsBroken()) +} + +// --------------------------------------------------------------------------- +// 测试:完整 relay 模式集成——包含客户端断连和排水 +// --------------------------------------------------------------------------- + +func TestPump_IntegrationRelayWithClientDisconnectAndDrain(t *testing.T) { + t.Parallel() + // 模拟完整场景:上游慢速响应,客户端在中途断连,排水定时器到期终止 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_drain1")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond}, + // 上游后续事件延迟大于排水超时 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 200 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 200 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + clientDisconnected := false + drainDeadline := time.Time{} + var drainTimer *time.Timer + defer func() { + if drainTimer != nil { + drainTimer.Stop() + } + }() + + eventsProcessed := 0 + drainTriggered := false + + for evt := range ch { + // 排水超时检查 + if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + pumpCancel() + drainTriggered = true + lease.MarkBroken() + break + } + if evt.err != nil { + if clientDisconnected { + // 排水期间读取错误(pumpCancel 导致 context.Canceled) + drainTriggered = true + lease.MarkBroken() + } + break + } + eventsProcessed++ + + // 模拟:第 2 个事件后客户端断连 + if eventsProcessed == 2 && !clientDisconnected { + clientDisconnected = true + drainDeadline = time.Now().Add(50 * time.Millisecond) // 50ms 排水超时 + drainTimer = time.AfterFunc(50*time.Millisecond, pumpCancel) + } + } + // for-range 退出后,如果 channel 因 pumpCancel 关闭且排水截止已过期, + // 也视为排水超时触发(泵的 select 可能选择 pumpCtx.Done() 而不发送错误事件)。 + if !drainTriggered && clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + drainTriggered = true + lease.MarkBroken() + } + + pumpCancel() // 最终清理 + + // 排水超时应触发(50ms 排水 vs 200ms 后续事件延迟) + assert.True(t, drainTriggered, "排水超时应被触发") + assert.GreaterOrEqual(t, eventsProcessed, 2, "至少应处理 2 个事件") + assert.LessOrEqual(t, eventsProcessed, 4, "不应处理所有 5 个事件") +} + +func TestPump_IntegrationRelayWithSuccessfulDrain(t *testing.T) { + t.Parallel() + // 客户端断连后上游快速完成,在排水超时前正常结束 + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_drain2")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 5 * time.Millisecond}, + pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + clientDisconnected := false + drainDeadline := time.Time{} + var drainTimer *time.Timer + defer func() { + if drainTimer != nil { + drainTimer.Stop() + } + }() + + eventsProcessed := 0 + receivedTerminal := false + drainTriggered := false + + for evt := range ch { + if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) { + pumpCancel() + drainTriggered = true + break + } + if evt.err != nil { + break + } + eventsProcessed++ + + evtType, _ := parseOpenAIWSEventType(evt.message) + if isOpenAIWSTerminalEvent(evtType) { + receivedTerminal = true + break + } + + // 第一个事件后客户端断连 + if eventsProcessed == 1 && !clientDisconnected { + clientDisconnected = true + drainDeadline = time.Now().Add(500 * time.Millisecond) // 足够长的排水超时 + drainTimer = time.AfterFunc(500*time.Millisecond, pumpCancel) + } + } + + pumpCancel() + + assert.False(t, drainTriggered, "排水超时不应触发") + assert.True(t, receivedTerminal, "应正常收到终端事件") + assert.Equal(t, 4, eventsProcessed, "所有 4 个事件都应被处理") +} + +// --------------------------------------------------------------------------- +// 测试:泵事件错误携带部分消息数据 +// --------------------------------------------------------------------------- + +func TestPump_ErrorEventWithPartialMessage(t *testing.T) { + t.Parallel() + // 模拟上游返回部分数据和错误 + partialData := []byte(`{"type":"response.output_text`) + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: partialData, err: io.ErrUnexpectedEOF}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2) + // 第二个事件同时携带 message 和 error + assert.Equal(t, partialData, events[1].message) + assert.ErrorIs(t, events[1].err, io.ErrUnexpectedEOF) +} + +// --------------------------------------------------------------------------- +// 测试:零超时读取 +// --------------------------------------------------------------------------- + +func TestPump_ZeroReadTimeout(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 第二次读取需要时间,但超时为 0 + pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond}, + ) + lease := &pumpTestLease{conn: conn} + // 使用极短超时(1 纳秒 ~ 立即超时) + ch, cancel := startPump(context.Background(), lease, time.Nanosecond) + defer cancel() + + events := collectAll(ch) + // 至少第一个事件成功读取(无延迟),第二个大概率超时 + require.GreaterOrEqual(t, len(events), 1) + // 查找是否有超时错误 + hasTimeout := false + for _, evt := range events { + if evt.err != nil && errors.Is(evt.err, context.DeadlineExceeded) { + hasTimeout = true + } + } + assert.True(t, hasTimeout, "极短超时应产生 DeadlineExceeded 错误") +} + +// --------------------------------------------------------------------------- +// 测试:并发多个排水定时器取消(防止重复调用 pumpCancel) +// --------------------------------------------------------------------------- + +func TestPump_ConcurrentDrainTimerAndExternalCancel(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + // 阻塞 + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second) + + // 读取第一个事件 + evt := <-ch + assert.NoError(t, evt.err) + + // 同时设置排水定时器和外部取消 + done := make(chan struct{}) + drainTimer := time.AfterFunc(30*time.Millisecond, pumpCancel) + defer drainTimer.Stop() + + go func() { + time.Sleep(20 * time.Millisecond) + pumpCancel() // 外部取消稍早于定时器 + close(done) + }() + + // 不应死锁或 panic + events := collectAll(ch) + <-done + + // 验证泵已终止 + for _, e := range events { + if e.err != nil { + assert.Error(t, e.err) + } + } +} + +func TestPump_DrainTimerMarkBrokenUnblocksIgnoreContextRead(t *testing.T) { + t.Parallel() + + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + ) + conn.ignoreCtx = true + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second) + defer pumpCancel() + + first := <-ch + require.NoError(t, first.err) + require.NotEmpty(t, first.message) + + done := make(chan struct{}) + drainTimer := time.AfterFunc(30*time.Millisecond, func() { + lease.MarkBroken() + pumpCancel() + }) + defer drainTimer.Stop() + + go func() { + _ = collectAll(ch) + close(done) + }() + + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("pump should stop quickly when drain timer marks lease broken") + } + + assert.True(t, lease.IsBroken(), "lease should be marked broken by drain timer") +} + +// --------------------------------------------------------------------------- +// 测试:快速连续事件(突发模式) +// --------------------------------------------------------------------------- + +func TestPump_BurstEvents(t *testing.T) { + t.Parallel() + // 50 个事件无延迟突发 + numBurst := 50 + connEvents := make([]pumpTestConnEvent, numBurst+1) + for i := 0; i < numBurst; i++ { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")} + } + connEvents[numBurst] = pumpTestConnEvent{data: pumpTestEvent("response.completed")} + + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, numBurst+1, "突发事件应全部被传递") + + // 验证所有事件无错误 + for i, evt := range events { + assert.NoError(t, evt.err, "事件 %d 不应有错误", i) + } + lastType, _ := parseOpenAIWSEventType(events[numBurst].message) + assert.True(t, isOpenAIWSTerminalEvent(lastType)) +} + +// --------------------------------------------------------------------------- +// 测试:事件类型解析边界情况 +// --------------------------------------------------------------------------- + +func TestPump_EventTypeParsingEdgeCases(t *testing.T) { + t.Parallel() + // 各种边缘 JSON 格式 + conn := newPumpTestConn( + pumpTestConnEvent{data: []byte(`{"type": " response.created "}`)}, // 带空格 + pumpTestConnEvent{data: []byte(`{"type":"response.output_text.delta"}`)}, // 无空格 + pumpTestConnEvent{data: []byte(`{"type":"","other":"field"}`)}, // 空类型 + pumpTestConnEvent{data: []byte(`{"no_type_field": true}`)}, // 无 type 字段 + pumpTestConnEvent{data: []byte(`{"type":"response.completed"}`)}, // 终端 + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 5, "所有格式的事件都应被传递") + for _, evt := range events { + assert.NoError(t, evt.err) + } +} + +// --------------------------------------------------------------------------- +// 测试:function_call_output 等非标准事件类型不终止泵 +// --------------------------------------------------------------------------- + +func TestPump_FunctionCallOutputDoesNotStopPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: pumpTestEvent("response.function_call_arguments.delta")}, + pumpTestConnEvent{data: pumpTestEvent("response.function_call_arguments.done")}, + pumpTestConnEvent{data: pumpTestEvent("response.output_item.done")}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 5, "function_call 相关事件不应终止泵") +} + +// --------------------------------------------------------------------------- +// 测试:pumpCancel 在 channel 已关闭后调用不 panic +// --------------------------------------------------------------------------- + +func TestPump_CancelAfterChannelClosed(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second) + + // 等待 channel 关闭 + events := collectAll(ch) + require.Len(t, events, 1) + + // channel 已关闭后再取消不应 panic + assert.NotPanics(t, func() { + pumpCancel() + }) + + // 再次从已关闭 channel 读取应返回零值 + evt, ok := <-ch + assert.False(t, ok, "channel 应已关闭") + assert.Nil(t, evt.message) + assert.NoError(t, evt.err) +} + +// --------------------------------------------------------------------------- +// 测试:混合事件大小(小消息和大消息交替) +// --------------------------------------------------------------------------- + +func TestPump_MixedMessageSizes(t *testing.T) { + t.Parallel() + smallMsg := pumpTestEvent("response.output_text.delta") + largeContent := make([]byte, 64*1024) // 64KB + for i := range largeContent { + largeContent[i] = 'A' + } + largeMsg := []byte(`{"type":"response.output_text.delta","delta":"` + string(largeContent) + `"}`) + + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{data: smallMsg}, + pumpTestConnEvent{data: largeMsg}, + pumpTestConnEvent{data: smallMsg}, + pumpTestConnEvent{data: largeMsg}, + pumpTestConnEvent{data: pumpTestEvent("response.completed")}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 6) + assert.Len(t, events[2].message, len(largeMsg), "大消息应完整传递") + assert.Len(t, events[4].message, len(largeMsg), "大消息应完整传递") +} + +// --------------------------------------------------------------------------- +// 测试:泵在 errOpenAIWSConnClosed 错误后停止 +// --------------------------------------------------------------------------- + +func TestPump_ConnClosedErrorStopsPump(t *testing.T) { + t.Parallel() + conn := newPumpTestConn( + pumpTestConnEvent{data: pumpTestEvent("response.created")}, + pumpTestConnEvent{err: errOpenAIWSConnClosed}, + ) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + events := collectAll(ch) + require.Len(t, events, 2) + assert.ErrorIs(t, events[1].err, errOpenAIWSConnClosed) +} + +// --------------------------------------------------------------------------- +// 测试:同时读取和写入——验证读写解耦 +// --------------------------------------------------------------------------- + +func TestPump_ReadWriteDecoupling(t *testing.T) { + t.Parallel() + // 模拟:上游事件到达时,客户端写入有延迟(通过 channel 消费延迟模拟) + numEvents := 10 + connEvents := make([]pumpTestConnEvent, numEvents) + for i := 0; i < numEvents-1; i++ { + connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")} + } + connEvents[numEvents-1] = pumpTestConnEvent{data: pumpTestEvent("response.completed")} + + conn := newPumpTestConn(connEvents...) + lease := &pumpTestLease{conn: conn} + ch, cancel := startPump(context.Background(), lease, 5*time.Second) + defer cancel() + + // 模拟慢写入:每个事件处理需要 5ms + start := time.Now() + var events []openAIWSUpstreamPumpEvent + for evt := range ch { + events = append(events, evt) + time.Sleep(5 * time.Millisecond) // 模拟写入延迟 + } + elapsed := time.Since(start) + + require.Len(t, events, numEvents) + // 如果没有并发(串行读写),总时间 >= numEvents * 5ms = 50ms + // 有缓冲并发时,上游读取可以提前完成,总时间 < 串行预估 + // 此处验证所有事件都被传递即可 + t.Logf("处理 %d 个事件耗时: %v (慢消费模式)", numEvents, elapsed) +} diff --git a/backend/internal/service/openai_ws_v2/caddy_adapter.go b/backend/internal/service/openai_ws_v2/caddy_adapter.go new file mode 100644 index 000000000..1fecc231d --- /dev/null +++ b/backend/internal/service/openai_ws_v2/caddy_adapter.go @@ -0,0 +1,24 @@ +package openai_ws_v2 + +import ( + "context" +) + +// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想: +// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。 +// +// Reference: +// - Project: caddyserver/caddy (Apache-2.0) +// - Commit: f283062d37c50627d53ca682ebae2ce219b35515 +// - Files: +// - modules/caddyhttp/reverseproxy/streaming.go +// - modules/caddyhttp/reverseproxy/reverseproxy.go +func runCaddyStyleRelay( + ctx context.Context, + clientConn FrameConn, + upstreamConn FrameConn, + firstClientMessage []byte, + options RelayOptions, +) (RelayResult, *RelayExit) { + return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options) +} diff --git a/backend/internal/service/openai_ws_v2/entry.go b/backend/internal/service/openai_ws_v2/entry.go new file mode 100644 index 000000000..176298fe9 --- /dev/null +++ b/backend/internal/service/openai_ws_v2/entry.go @@ -0,0 +1,23 @@ +package openai_ws_v2 + +import "context" + +// EntryInput 是 passthrough v2 数据面的入口参数。 +type EntryInput struct { + Ctx context.Context + ClientConn FrameConn + UpstreamConn FrameConn + FirstClientMessage []byte + Options RelayOptions +} + +// RunEntry 是 openai_ws_v2 包对外的统一入口。 +func RunEntry(input EntryInput) (RelayResult, *RelayExit) { + return runCaddyStyleRelay( + input.Ctx, + input.ClientConn, + input.UpstreamConn, + input.FirstClientMessage, + input.Options, + ) +} diff --git a/backend/internal/service/openai_ws_v2/metrics.go b/backend/internal/service/openai_ws_v2/metrics.go new file mode 100644 index 000000000..3708befdb --- /dev/null +++ b/backend/internal/service/openai_ws_v2/metrics.go @@ -0,0 +1,29 @@ +package openai_ws_v2 + +import ( + "sync/atomic" +) + +// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。 +type MetricsSnapshot struct { + SemanticMutationTotal int64 `json:"semantic_mutation_total"` + UsageParseFailureTotal int64 `json:"usage_parse_failure_total"` +} + +var ( + // passthrough 路径默认不会做语义改写,该计数通常应保持为 0(保留用于未来防御性校验)。 + passthroughSemanticMutationTotal atomic.Int64 + passthroughUsageParseFailureTotal atomic.Int64 +) + +func recordUsageParseFailure() { + passthroughUsageParseFailureTotal.Add(1) +} + +// SnapshotMetrics 返回当前 passthrough 指标快照。 +func SnapshotMetrics() MetricsSnapshot { + return MetricsSnapshot{ + SemanticMutationTotal: passthroughSemanticMutationTotal.Load(), + UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(), + } +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay.go b/backend/internal/service/openai_ws_v2/passthrough_relay.go new file mode 100644 index 000000000..d147a0e18 --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay.go @@ -0,0 +1,869 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "net" + "strconv" + "strings" + "sync/atomic" + "time" + + coderws "github.com/coder/websocket" + "github.com/tidwall/gjson" +) + +type FrameConn interface { + ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) + WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error + Close() error +} + +type Usage struct { + InputTokens int + OutputTokens int + CacheCreationInputTokens int + CacheReadInputTokens int +} + +type RelayResult struct { + RequestModel string + Usage Usage + RequestID string + TerminalEventType string + FirstTokenMs *int + Duration time.Duration + ClientToUpstreamFrames int64 + UpstreamToClientFrames int64 + DroppedDownstreamFrames int64 +} + +type RelayTurnResult struct { + RequestModel string + Usage Usage + RequestID string + TerminalEventType string + Duration time.Duration + FirstTokenMs *int +} + +type RelayExit struct { + Stage string + Err error + WroteDownstream bool +} + +type RelayOptions struct { + WriteTimeout time.Duration + IdleTimeout time.Duration + UpstreamDrainTimeout time.Duration + FirstMessageType coderws.MessageType + InitialRequestModel string + DisableWriteTimeout bool + OnUsageParseFailure func(eventType string, usageRaw string) + OnTurnComplete func(turn RelayTurnResult) + OnTrace func(event RelayTraceEvent) + Now func() time.Time +} + +type RelayTraceEvent struct { + Stage string + Direction string + MessageType string + PayloadBytes int + Graceful bool + WroteDownstream bool + Error string +} + +type relayState struct { + usage Usage + requestModel string + lastResponseID string + terminalEventType string + firstTokenMs *int + currentTurnStart time.Time + currentTurnToken *int + turnTimingByID map[string]*relayTurnTiming +} + +type relayExitSignal struct { + stage string + err error + graceful bool + wroteDownstream bool +} + +type observedUpstreamEvent struct { + terminal bool + eventType string + responseID string + usage Usage + duration time.Duration + firstToken *int +} + +type relayTurnTiming struct { + startAt time.Time + firstTokenMs *int +} + +// ErrRelayIdleTimeout indicates relay inactivity timeout in passthrough mode. +var ErrRelayIdleTimeout = errors.New("openai ws v2 passthrough relay idle timeout") + +func Relay( + ctx context.Context, + clientConn FrameConn, + upstreamConn FrameConn, + firstClientMessage []byte, + options RelayOptions, +) (RelayResult, *RelayExit) { + initialRequestModel := strings.TrimSpace(options.InitialRequestModel) + if initialRequestModel == "" { + initialRequestModel = strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String()) + } + result := RelayResult{RequestModel: initialRequestModel} + if clientConn == nil || upstreamConn == nil { + return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")} + } + if ctx == nil { + ctx = context.Background() + } + + nowFn := options.Now + if nowFn == nil { + nowFn = time.Now + } + writeTimeout := options.WriteTimeout + if !options.DisableWriteTimeout && writeTimeout <= 0 { + writeTimeout = 2 * time.Minute + } + useWriteTimeout := !options.DisableWriteTimeout && writeTimeout > 0 + drainTimeout := options.UpstreamDrainTimeout + if drainTimeout <= 0 { + drainTimeout = 1200 * time.Millisecond + } + firstMessageType := options.FirstMessageType + if firstMessageType != coderws.MessageBinary { + firstMessageType = coderws.MessageText + } + startAt := nowFn() + state := &relayState{requestModel: result.RequestModel} + onTrace := options.OnTrace + + relayCtx, relayCancel := context.WithCancel(ctx) + defer relayCancel() + + lastActivity := atomic.Int64{} + lastActivity.Store(nowFn().UnixNano()) + markActivity := func() { + lastActivity.Store(nowFn().UnixNano()) + } + + writeUpstream := func(msgType coderws.MessageType, payload []byte) error { + if !useWriteTimeout { + return upstreamConn.WriteFrame(relayCtx, msgType, payload) + } + writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout) + defer cancel() + return upstreamConn.WriteFrame(writeCtx, msgType, payload) + } + writeClient := func(msgType coderws.MessageType, payload []byte) error { + if !useWriteTimeout { + return clientConn.WriteFrame(relayCtx, msgType, payload) + } + writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout) + defer cancel() + return clientConn.WriteFrame(writeCtx, msgType, payload) + } + + clientToUpstreamFrames := &atomic.Int64{} + upstreamToClientFrames := &atomic.Int64{} + droppedDownstreamFrames := &atomic.Int64{} + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_start", + PayloadBytes: len(firstClientMessage), + MessageType: relayMessageTypeString(firstMessageType), + }) + + if err := writeUpstream(firstMessageType, firstClientMessage); err != nil { + result.Duration = nowFn().Sub(startAt) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_failed", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + Error: err.Error(), + }) + return result, &RelayExit{Stage: "write_upstream", Err: err} + } + clientToUpstreamFrames.Add(1) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_first_message_ok", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(firstMessageType), + PayloadBytes: len(firstClientMessage), + }) + markActivity() + + exitCh := make(chan relayExitSignal, 3) + dropDownstreamWrites := atomic.Bool{} + go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh) + go runUpstreamToClient( + relayCtx, + upstreamConn, + writeClient, + startAt, + nowFn, + state, + options.OnUsageParseFailure, + options.OnTurnComplete, + &dropDownstreamWrites, + upstreamToClientFrames, + droppedDownstreamFrames, + markActivity, + onTrace, + exitCh, + ) + go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh) + + firstExit := <-exitCh + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "first_exit", + Direction: relayDirectionFromStage(firstExit.stage), + Graceful: firstExit.graceful, + WroteDownstream: firstExit.wroteDownstream, + Error: relayErrorString(firstExit.err), + }) + combinedWroteDownstream := firstExit.wroteDownstream + secondExit := relayExitSignal{graceful: true} + hasSecondExit := false + + // 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。 + upstreamClosed := false + closeUpstream := func() { + if upstreamClosed { + return + } + upstreamClosed = true + _ = upstreamConn.Close() + } + if firstExit.stage == "read_client" && firstExit.graceful { + dropDownstreamWrites.Store(true) + secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout) + } else { + relayCancel() + closeUpstream() + secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond) + } + if hasSecondExit { + combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "second_exit", + Direction: relayDirectionFromStage(secondExit.stage), + Graceful: secondExit.graceful, + WroteDownstream: secondExit.wroteDownstream, + Error: relayErrorString(secondExit.err), + }) + } + + relayCancel() + closeUpstream() + + enrichResult(&result, state, nowFn().Sub(startAt)) + result.ClientToUpstreamFrames = clientToUpstreamFrames.Load() + result.UpstreamToClientFrames = upstreamToClientFrames.Load() + result.DroppedDownstreamFrames = droppedDownstreamFrames.Load() + if firstExit.stage == "read_client" && firstExit.graceful { + stage := "client_disconnected" + exitErr := firstExit.err + if hasSecondExit && !secondExit.graceful { + stage = secondExit.stage + exitErr = secondExit.err + } + if exitErr == nil { + exitErr = io.EOF + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(exitErr), + }) + return result, &RelayExit{ + Stage: stage, + Err: exitErr, + WroteDownstream: combinedWroteDownstream, + } + } + if firstExit.graceful && (!hasSecondExit || secondExit.graceful) { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_complete", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + _ = clientConn.Close() + return result, nil + } + if !firstExit.graceful { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(firstExit.stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(firstExit.err), + }) + return result, &RelayExit{ + Stage: firstExit.stage, + Err: firstExit.err, + WroteDownstream: combinedWroteDownstream, + } + } + if hasSecondExit && !secondExit.graceful { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_exit", + Direction: relayDirectionFromStage(secondExit.stage), + Graceful: false, + WroteDownstream: combinedWroteDownstream, + Error: relayErrorString(secondExit.err), + }) + return result, &RelayExit{ + Stage: secondExit.stage, + Err: secondExit.err, + WroteDownstream: combinedWroteDownstream, + } + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "relay_complete", + Graceful: true, + WroteDownstream: combinedWroteDownstream, + }) + _ = clientConn.Close() + return result, nil +} + +func runClientToUpstream( + ctx context.Context, + clientConn FrameConn, + writeUpstream func(msgType coderws.MessageType, payload []byte) error, + markActivity func(), + forwardedFrames *atomic.Int64, + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + for { + msgType, payload, err := clientConn.ReadFrame(ctx) + if err != nil { + graceful := isDisconnectError(err) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "read_client_failed", + Direction: "client_to_upstream", + Error: err.Error(), + Graceful: graceful, + }) + exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: graceful} + return + } + if err := writeUpstream(msgType, payload); err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_upstream_failed", + Direction: "client_to_upstream", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + Error: err.Error(), + }) + exitCh <- relayExitSignal{stage: "write_upstream", err: err} + return + } + if forwardedFrames != nil { + forwardedFrames.Add(1) + } + markActivity() + } +} + +func runUpstreamToClient( + ctx context.Context, + upstreamConn FrameConn, + writeClient func(msgType coderws.MessageType, payload []byte) error, + startAt time.Time, + nowFn func() time.Time, + state *relayState, + onUsageParseFailure func(eventType string, usageRaw string), + onTurnComplete func(turn RelayTurnResult), + dropDownstreamWrites *atomic.Bool, + forwardedFrames *atomic.Int64, + droppedFrames *atomic.Int64, + markActivity func(), + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + wroteDownstream := false + droppedSinceDisconnect := int64(0) + for { + msgType, payload, err := upstreamConn.ReadFrame(ctx) + if err != nil { + graceful := isDisconnectError(err) + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "read_upstream_failed", + Direction: "upstream_to_client", + Error: err.Error(), + Graceful: graceful, + WroteDownstream: wroteDownstream, + }) + exitCh <- relayExitSignal{ + stage: "read_upstream", + err: err, + graceful: graceful, + wroteDownstream: wroteDownstream, + } + return + } + observedEvent := observedUpstreamEvent{} + switch msgType { + case coderws.MessageText: + observedEvent = observeUpstreamMessage(state, payload, nowFn, onUsageParseFailure) + case coderws.MessageBinary: + // binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。 + } + emitTurnComplete(onTurnComplete, state, observedEvent) + if dropDownstreamWrites != nil && dropDownstreamWrites.Load() { + droppedSinceDisconnect++ + if droppedFrames != nil { + droppedFrames.Add(1) + } + if shouldTraceDroppedDownstreamFrame(droppedSinceDisconnect, observedEvent.terminal) { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "drop_downstream_frame", + Direction: "upstream_to_client", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + WroteDownstream: wroteDownstream, + }) + } + if observedEvent.terminal { + exitCh <- relayExitSignal{ + stage: "drain_terminal", + graceful: true, + wroteDownstream: wroteDownstream, + } + return + } + markActivity() + continue + } + if err := writeClient(msgType, payload); err != nil { + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "write_client_failed", + Direction: "upstream_to_client", + MessageType: relayMessageTypeString(msgType), + PayloadBytes: len(payload), + WroteDownstream: wroteDownstream, + Error: err.Error(), + }) + exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream} + return + } + wroteDownstream = true + if forwardedFrames != nil { + forwardedFrames.Add(1) + } + markActivity() + } +} + +func runIdleWatchdog( + ctx context.Context, + nowFn func() time.Time, + idleTimeout time.Duration, + lastActivity *atomic.Int64, + onTrace func(event RelayTraceEvent), + exitCh chan<- relayExitSignal, +) { + if idleTimeout <= 0 { + return + } + checkInterval := minDuration(idleTimeout/4, 5*time.Second) + if checkInterval < time.Second { + checkInterval = time.Second + } + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + last := time.Unix(0, lastActivity.Load()) + if nowFn().Sub(last) < idleTimeout { + continue + } + emitRelayTrace(onTrace, RelayTraceEvent{ + Stage: "idle_timeout_triggered", + Direction: "watchdog", + Error: ErrRelayIdleTimeout.Error(), + }) + exitCh <- relayExitSignal{stage: "idle_timeout", err: ErrRelayIdleTimeout} + return + } + } +} + +func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) { + if onTrace == nil { + return + } + onTrace(event) +} + +func relayMessageTypeString(msgType coderws.MessageType) string { + switch msgType { + case coderws.MessageText: + return "text" + case coderws.MessageBinary: + return "binary" + default: + return "unknown(" + strconv.Itoa(int(msgType)) + ")" + } +} + +func relayDirectionFromStage(stage string) string { + switch stage { + case "read_client", "write_upstream": + return "client_to_upstream" + case "read_upstream", "write_client", "drain_terminal": + return "upstream_to_client" + case "idle_timeout": + return "watchdog" + default: + return "" + } +} + +func relayErrorString(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func observeUpstreamMessage( + state *relayState, + message []byte, + nowFn func() time.Time, + onUsageParseFailure func(eventType string, usageRaw string), +) observedUpstreamEvent { + if state == nil || len(message) == 0 { + return observedUpstreamEvent{} + } + values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id") + eventType := strings.TrimSpace(values[0].String()) + if eventType == "" { + return observedUpstreamEvent{} + } + responseID := strings.TrimSpace(values[1].String()) + if responseID == "" { + responseID = strings.TrimSpace(values[2].String()) + } + // 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。 + if responseID == "" && isTerminalEvent(eventType) { + responseID = strings.TrimSpace(values[3].String()) + } + now := nowFn() + if state.currentTurnStart.IsZero() { + state.currentTurnStart = now + state.currentTurnToken = nil + } + + if state.currentTurnToken == nil && isTokenEvent(eventType) { + ms := int(now.Sub(state.currentTurnStart).Milliseconds()) + if ms >= 0 { + state.currentTurnToken = &ms + } + } + parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure) + observed := observedUpstreamEvent{ + eventType: eventType, + responseID: responseID, + usage: parsedUsage, + } + if responseID != "" { + turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now) + if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) { + ms := int(now.Sub(turnTiming.startAt).Milliseconds()) + if ms >= 0 { + turnTiming.firstTokenMs = &ms + } + } + } + if !isTerminalEvent(eventType) { + return observed + } + observed.terminal = true + state.terminalEventType = eventType + if responseID != "" { + state.lastResponseID = responseID + if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok { + duration := now.Sub(turnTiming.startAt) + if duration < 0 { + duration = 0 + } + observed.duration = duration + observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs) + } + } + if observed.firstToken == nil { + observed.firstToken = openAIWSRelayCloneIntPtr(state.currentTurnToken) + } + if observed.duration <= 0 && !state.currentTurnStart.IsZero() { + duration := now.Sub(state.currentTurnStart) + if duration < 0 { + duration = 0 + } + observed.duration = duration + } + state.firstTokenMs = openAIWSRelayCloneIntPtr(observed.firstToken) + state.currentTurnStart = time.Time{} + state.currentTurnToken = nil + return observed +} + +func emitTurnComplete( + onTurnComplete func(turn RelayTurnResult), + state *relayState, + observed observedUpstreamEvent, +) { + if onTurnComplete == nil || !observed.terminal { + return + } + responseID := strings.TrimSpace(observed.responseID) + if responseID == "" { + return + } + requestModel := "" + if state != nil { + requestModel = state.requestModel + } + onTurnComplete(RelayTurnResult{ + RequestModel: requestModel, + Usage: observed.usage, + RequestID: responseID, + TerminalEventType: observed.eventType, + Duration: observed.duration, + FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken), + }) +} + +func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming { + if state == nil { + return nil + } + if state.turnTimingByID == nil { + state.turnTimingByID = make(map[string]*relayTurnTiming, 8) + } + timing, ok := state.turnTimingByID[responseID] + if !ok || timing == nil || timing.startAt.IsZero() { + startAt := now + if !state.currentTurnStart.IsZero() { + startAt = state.currentTurnStart + } + timing = &relayTurnTiming{startAt: startAt} + state.turnTimingByID[responseID] = timing + return timing + } + return timing +} + +func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) { + if state == nil || state.turnTimingByID == nil { + return relayTurnTiming{}, false + } + timing, ok := state.turnTimingByID[responseID] + if !ok || timing == nil { + return relayTurnTiming{}, false + } + delete(state.turnTimingByID, responseID) + return *timing, true +} + +func openAIWSRelayCloneIntPtr(v *int) *int { + if v == nil { + return nil + } + cloned := *v + return &cloned +} + +func parseUsageAndAccumulate( + state *relayState, + message []byte, + eventType string, + onParseFailure func(eventType string, usageRaw string), +) Usage { + if state == nil || len(message) == 0 || !shouldParseUsage(eventType) { + return Usage{} + } + usageResult := gjson.GetBytes(message, "response.usage") + if !usageResult.Exists() { + return Usage{} + } + usageRaw := strings.TrimSpace(usageResult.Raw) + if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") { + recordUsageParseFailure() + if onParseFailure != nil { + onParseFailure(eventType, usageRaw) + } + return Usage{} + } + + inputResult := usageResult.Get("input_tokens") + outputResult := usageResult.Get("output_tokens") + cachedResult := usageResult.Get("input_tokens_details.cached_tokens") + + inputTokens, inputOK := parseUsageIntField(inputResult, true) + outputTokens, outputOK := parseUsageIntField(outputResult, true) + cachedTokens, cachedOK := parseUsageIntField(cachedResult, false) + if !inputOK || !outputOK || !cachedOK { + recordUsageParseFailure() + if onParseFailure != nil { + onParseFailure(eventType, usageRaw) + } + // 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。 + return Usage{} + } + parsedUsage := Usage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + CacheReadInputTokens: cachedTokens, + } + + state.usage.InputTokens += parsedUsage.InputTokens + state.usage.OutputTokens += parsedUsage.OutputTokens + state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens + return parsedUsage +} + +func parseUsageIntField(value gjson.Result, required bool) (int, bool) { + if !value.Exists() { + return 0, !required + } + if value.Type != gjson.Number { + return 0, false + } + return int(value.Int()), true +} + +func enrichResult(result *RelayResult, state *relayState, duration time.Duration) { + if result == nil { + return + } + result.Duration = duration + if state == nil { + return + } + result.RequestModel = state.requestModel + result.Usage = state.usage + result.RequestID = state.lastResponseID + result.TerminalEventType = state.terminalEventType + result.FirstTokenMs = state.firstTokenMs +} + +func isDisconnectError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return true + } + switch coderws.CloseStatus(err) { + case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if message == "" { + return false + } + return strings.Contains(message, "failed to read frame header: eof") || + strings.Contains(message, "unexpected eof") || + strings.Contains(message, "use of closed network connection") || + strings.Contains(message, "connection reset by peer") || + strings.Contains(message, "broken pipe") +} + +func isTerminalEvent(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func shouldParseUsage(eventType string) bool { + switch eventType { + case "response.completed", "response.done", "response.failed": + return true + default: + return false + } +} + +func isTokenEvent(eventType string) bool { + if eventType == "" { + return false + } + switch eventType { + case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": + return false + } + if strings.Contains(eventType, ".delta") { + return true + } + if strings.HasPrefix(eventType, "response.output_text") { + return true + } + if strings.HasPrefix(eventType, "response.output") { + return true + } + return false +} + +func minDuration(a, b time.Duration) time.Duration { + if a <= 0 { + return b + } + if b <= 0 { + return a + } + if a < b { + return a + } + return b +} + +func shouldTraceDroppedDownstreamFrame(dropCount int64, terminal bool) bool { + if terminal { + return true + } + if dropCount <= 3 { + return true + } + return dropCount%128 == 0 +} + +func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) { + if timeout <= 0 { + timeout = 200 * time.Millisecond + } + timer := time.NewTimer(timeout) + defer timer.Stop() + select { + case sig := <-exitCh: + return sig, true + case <-timer.C: + return relayExitSignal{}, false + } +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go new file mode 100644 index 000000000..d0b38442e --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go @@ -0,0 +1,504 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "net" + "sync/atomic" + "testing" + "time" + + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestRunEntry_DelegatesRelay(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + result, relayExit := RunEntry(EntryInput{ + Ctx: context.Background(), + ClientConn: clientConn, + UpstreamConn: upstreamConn, + FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`), + }) + require.Nil(t, relayExit) + require.Equal(t, "resp_entry", result.RequestID) +} + +func TestRunClientToUpstream_ErrorPaths(t *testing.T) { + t.Parallel() + + t.Run("read client eof", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn(nil, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + func() {}, + nil, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_client", sig.stage) + require.True(t, sig.graceful) + }) + + t.Run("write upstream failed", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") }, + func() {}, + nil, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "write_upstream", sig.stage) + require.False(t, sig.graceful) + }) + + t.Run("forwarded counter and trace callback", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + forwarded := &atomic.Int64{} + traces := make([]RelayTraceEvent, 0, 2) + runClientToUpstream( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + func() {}, + forwarded, + func(event RelayTraceEvent) { + traces = append(traces, event) + }, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_client", sig.stage) + require.Equal(t, int64(1), forwarded.Load()) + require.NotEmpty(t, traces) + }) +} + +func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) { + t.Parallel() + + t.Run("read upstream eof", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(false) + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn(nil, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + nil, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "read_upstream", sig.stage) + require.True(t, sig.graceful) + }) + + t.Run("write client failed", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(false) + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + {msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)}, + }, true), + func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + nil, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "write_client", sig.stage) + }) + + t.Run("drop downstream and stop on terminal", func(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + drop := &atomic.Bool{} + drop.Store(true) + dropped := &atomic.Int64{} + runUpstreamToClient( + context.Background(), + newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true), + func(_ coderws.MessageType, _ []byte) error { return nil }, + time.Now(), + time.Now, + &relayState{}, + nil, + nil, + drop, + nil, + dropped, + func() {}, + nil, + exitCh, + ) + sig := <-exitCh + require.Equal(t, "drain_terminal", sig.stage) + require.True(t, sig.graceful) + require.Equal(t, int64(1), dropped.Load()) + }) +} + +func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) { + t.Parallel() + + exitCh := make(chan relayExitSignal, 1) + lastActivity := &atomic.Int64{} + lastActivity.Store(time.Now().UnixNano()) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh) + select { + case <-exitCh: + t.Fatal("unexpected idle timeout signal") + case <-time.After(200 * time.Millisecond): + } +} + +func TestHelperFunctionsCoverage(t *testing.T) { + t.Parallel() + + require.Equal(t, "text", relayMessageTypeString(coderws.MessageText)) + require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary)) + require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(") + + require.Equal(t, "", relayErrorString(nil)) + require.Equal(t, "x", relayErrorString(errors.New("x"))) + + require.True(t, isDisconnectError(io.EOF)) + require.True(t, isDisconnectError(net.ErrClosed)) + require.True(t, isDisconnectError(context.Canceled)) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway})) + require.True(t, isDisconnectError(errors.New("broken pipe"))) + require.False(t, isDisconnectError(errors.New("unrelated"))) + + require.True(t, isTokenEvent("response.output_text.delta")) + require.True(t, isTokenEvent("response.output_audio.delta")) + require.False(t, isTokenEvent("response.completed")) + require.False(t, isTokenEvent("")) + require.False(t, isTokenEvent("response.created")) + + require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second)) + require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second)) + require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second)) + require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0)) + + ch := make(chan relayExitSignal, 1) + ch <- relayExitSignal{stage: "ok"} + sig, ok := waitRelayExit(ch, 10*time.Millisecond) + require.True(t, ok) + require.Equal(t, "ok", sig.stage) + ch <- relayExitSignal{stage: "ok2"} + sig, ok = waitRelayExit(ch, 0) + require.True(t, ok) + require.Equal(t, "ok2", sig.stage) + _, ok = waitRelayExit(ch, 10*time.Millisecond) + require.False(t, ok) + + n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true) + require.True(t, ok) + require.Equal(t, 3, n) + _, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true) + require.False(t, ok) + n, ok = parseUsageIntField(gjson.Result{}, false) + require.True(t, ok) + require.Equal(t, 0, n) + _, ok = parseUsageIntField(gjson.Result{}, true) + require.False(t, ok) +} + +func TestParseUsageAndEnrichCoverage(t *testing.T) { + t.Parallel() + + state := &relayState{} + parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil) + require.Equal(t, 0, state.usage.InputTokens) + + parseUsageAndAccumulate( + state, + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`), + "response.completed", + nil, + ) + require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage") + require.Equal(t, 0, state.usage.OutputTokens) + require.Equal(t, 0, state.usage.CacheReadInputTokens) + + parseUsageAndAccumulate( + state, + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`), + "response.completed", + nil, + ) + require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage") + require.Equal(t, 0, state.usage.OutputTokens) + require.Equal(t, 0, state.usage.CacheReadInputTokens) + + parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil) + require.Equal(t, 2, state.usage.InputTokens) + require.Equal(t, 1, state.usage.OutputTokens) + require.Equal(t, 1, state.usage.CacheReadInputTokens) + + result := &RelayResult{} + enrichResult(result, state, 5*time.Millisecond) + require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens) + require.Equal(t, 5*time.Millisecond, result.Duration) + parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil) + require.Equal(t, 2, state.usage.InputTokens) + enrichResult(nil, state, 0) +} + +func TestEmitTurnCompleteCoverage(t *testing.T) { + t.Parallel() + + // 非 terminal 事件不应触发。 + called := 0 + emitTurnComplete(func(turn RelayTurnResult) { + called++ + }, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{ + terminal: false, + eventType: "response.output_text.delta", + responseID: "resp_ignored", + usage: Usage{InputTokens: 1}, + }) + require.Equal(t, 0, called) + + // 缺少 response_id 时不应触发。 + emitTurnComplete(func(turn RelayTurnResult) { + called++ + }, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{ + terminal: true, + eventType: "response.completed", + }) + require.Equal(t, 0, called) + + // terminal 且 response_id 存在,应该触发;state=nil 时 model 为空串。 + var got RelayTurnResult + emitTurnComplete(func(turn RelayTurnResult) { + called++ + got = turn + }, nil, observedUpstreamEvent{ + terminal: true, + eventType: "response.completed", + responseID: "resp_emit", + usage: Usage{InputTokens: 2, OutputTokens: 3}, + }) + require.Equal(t, 1, called) + require.Equal(t, "resp_emit", got.RequestID) + require.Equal(t, "response.completed", got.TerminalEventType) + require.Equal(t, 2, got.Usage.InputTokens) + require.Equal(t, 3, got.Usage.OutputTokens) + require.Equal(t, "", got.RequestModel) +} + +func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) { + t.Parallel() + + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure})) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd})) + require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure})) + require.True(t, isDisconnectError(errors.New("connection reset by peer"))) + require.False(t, isDisconnectError(errors.New(" "))) +} + +func TestIsTokenEventCoverageBranches(t *testing.T) { + t.Parallel() + + require.False(t, isTokenEvent("response.in_progress")) + require.False(t, isTokenEvent("response.output_item.added")) + require.True(t, isTokenEvent("response.output_audio.delta")) + require.True(t, isTokenEvent("response.output")) + require.False(t, isTokenEvent("response.done")) + require.False(t, isTokenEvent("response.completed")) +} + +func TestRelayTurnTimingHelpersCoverage(t *testing.T) { + t.Parallel() + + now := time.Unix(100, 0) + // nil state + require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now)) + _, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil") + require.False(t, ok) + + state := &relayState{} + timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now) + require.NotNil(t, timing) + require.Equal(t, now, timing.startAt) + + // 再次获取返回同一条 timing + timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second)) + require.NotNil(t, timing2) + require.Equal(t, now, timing2.startAt) + + // 删除存在键 + deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a") + require.True(t, ok) + require.Equal(t, now, deleted.startAt) + + // 删除不存在键 + _, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a") + require.False(t, ok) +} + +func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) { + t.Parallel() + + state := &relayState{requestModel: "gpt-5"} + startAt := time.Unix(0, 0) + now := startAt + nowFn := func() time.Time { + now = now.Add(5 * time.Millisecond) + return now + } + + // 非 terminal:仅有顶层 id,不应把 event id 当成 response_id。 + observed := observeUpstreamMessage( + state, + []byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`), + nowFn, + nil, + ) + require.False(t, observed.terminal) + require.Equal(t, "", observed.responseID) + + // terminal:允许兜底用顶层 id(用于兼容少数字段变体)。 + observed = observeUpstreamMessage( + state, + []byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`), + nowFn, + nil, + ) + require.True(t, observed.terminal) + require.Equal(t, "resp_fallback", observed.responseID) +} + +type writeCtxCaptureFrameConn struct { + writeHasDeadline atomic.Bool +} + +func (c *writeCtxCaptureFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if ctx == nil { + ctx = context.Background() + } + <-ctx.Done() + return coderws.MessageText, nil, ctx.Err() +} + +func (c *writeCtxCaptureFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error { + if ctx == nil { + ctx = context.Background() + } + _, hasDeadline := ctx.Deadline() + c.writeHasDeadline.Store(hasDeadline) + return nil +} + +func (c *writeCtxCaptureFrameConn) Close() error { + return nil +} + +func TestRelay_WriteTimeoutDeadlineToggle(t *testing.T) { + t.Parallel() + + run := func(disable bool) bool { + clientConn := newPassthroughTestFrameConn(nil, true) + upstreamConn := &writeCtxCaptureFrameConn{} + _, _ = Relay(context.Background(), clientConn, upstreamConn, []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`), RelayOptions{ + WriteTimeout: 30 * time.Second, + DisableWriteTimeout: disable, + UpstreamDrainTimeout: 20 * time.Millisecond, + }) + return upstreamConn.writeHasDeadline.Load() + } + + require.False(t, run(true)) + require.True(t, run(false)) +} + +func TestRelay_InitialRequestModelOverride(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_override","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + result, relayExit := Relay(ctx, clientConn, upstreamConn, []byte(`{"type":"response.create","input":[]}`), RelayOptions{ + InitialRequestModel: "gpt-5.3-codex", + }) + require.Nil(t, relayExit) + require.Equal(t, "gpt-5.3-codex", result.RequestModel) +} + +func TestShouldTraceDroppedDownstreamFrame(t *testing.T) { + t.Parallel() + + require.True(t, shouldTraceDroppedDownstreamFrame(1, false)) + require.True(t, shouldTraceDroppedDownstreamFrame(3, false)) + require.False(t, shouldTraceDroppedDownstreamFrame(4, false)) + require.True(t, shouldTraceDroppedDownstreamFrame(128, false)) + require.True(t, shouldTraceDroppedDownstreamFrame(999, true)) +} diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go new file mode 100644 index 000000000..6503b3d0f --- /dev/null +++ b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go @@ -0,0 +1,804 @@ +package openai_ws_v2 + +import ( + "context" + "errors" + "io" + "sync" + "sync/atomic" + "testing" + "time" + + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" +) + +type passthroughTestFrame struct { + msgType coderws.MessageType + payload []byte +} + +type passthroughTestFrameConn struct { + mu sync.Mutex + writes []passthroughTestFrame + readCh chan passthroughTestFrame + once sync.Once +} + +type delayedReadFrameConn struct { + base FrameConn + firstDelay time.Duration + once sync.Once +} + +type closeSpyFrameConn struct { + closeCalls atomic.Int32 +} + +func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn { + c := &passthroughTestFrameConn{ + readCh: make(chan passthroughTestFrame, len(frames)+1), + } + for _, frame := range frames { + copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)} + c.readCh <- copied + } + if autoClose { + close(c.readCh) + } + return c +} + +func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return coderws.MessageText, nil, ctx.Err() + case frame, ok := <-c.readCh: + if !ok { + return coderws.MessageText, nil, io.EOF + } + return frame.msgType, append([]byte(nil), frame.payload...), nil + } +} + +func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + c.mu.Lock() + defer c.mu.Unlock() + c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)}) + return nil +} + +func (c *passthroughTestFrameConn) Close() error { + c.once.Do(func() { + defer func() { _ = recover() }() + close(c.readCh) + }) + return nil +} + +func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]passthroughTestFrame, len(c.writes)) + copy(out, c.writes) + return out +} + +func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.base == nil { + return coderws.MessageText, nil, io.EOF + } + c.once.Do(func() { + if c.firstDelay > 0 { + timer := time.NewTimer(c.firstDelay) + defer timer.Stop() + select { + case <-ctx.Done(): + case <-timer.C: + } + } + }) + return c.base.ReadFrame(ctx) +} + +func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.base == nil { + return io.EOF + } + return c.base.WriteFrame(ctx, msgType, payload) +} + +func (c *delayedReadFrameConn) Close() error { + if c == nil || c.base == nil { + return nil + } + return c.base.Close() +} + +func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if ctx == nil { + ctx = context.Background() + } + <-ctx.Done() + return coderws.MessageText, nil, ctx.Err() +} + +func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } +} + +func (c *closeSpyFrameConn) Close() error { + if c != nil { + c.closeCalls.Add(1) + } + return nil +} + +func (c *closeSpyFrameConn) CloseCalls() int32 { + if c == nil { + return 0 + } + return c.closeCalls.Load() +} + +func TestRelay_BasicRelayAndUsage(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, "gpt-5.3-codex", result.RequestModel) + require.Equal(t, "resp_123", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 7, result.Usage.InputTokens) + require.Equal(t, 3, result.Usage.OutputTokens) + require.Equal(t, 2, result.Usage.CacheReadInputTokens) + require.Nil(t, result.FirstTokenMs) + require.Equal(t, int64(1), result.ClientToUpstreamFrames) + require.Equal(t, int64(1), result.UpstreamToClientFrames) + require.Equal(t, int64(0), result.DroppedDownstreamFrames) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType) + require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload)) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageText, clientWrites[0].msgType) + require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload)) +} + +func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType) + require.Equal(t, firstPayload, upstreamWrites[0].payload) +} + +func TestRelay_UpstreamDisconnect(t *testing.T) { + t.Parallel() + + // 上游立即关闭(EOF),客户端不发送额外帧 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + // 上游 EOF 属于 disconnect,标记为 graceful + require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect") + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_ClientDisconnect(t *testing.T) { + t.Parallel() + + // 客户端立即关闭(EOF),上游阻塞读取直到 context 取消 + clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态") + require.Equal(t, "client_disconnected", relayExit.Stage) + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, true) + upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`), + }, + }, true) + upstreamConn := &delayedReadFrameConn{ + base: upstreamBase, + firstDelay: 80 * time.Millisecond, + } + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + UpstreamDrainTimeout: 400 * time.Millisecond, + }) + require.NotNil(t, relayExit) + require.Equal(t, "client_disconnected", relayExit.Stage) + require.Equal(t, "resp_drain", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 6, result.Usage.InputTokens) + require.Equal(t, 4, result.Usage.OutputTokens) + require.Equal(t, 1, result.Usage.CacheReadInputTokens) + require.Equal(t, int64(1), result.ClientToUpstreamFrames) + require.Equal(t, int64(0), result.UpstreamToClientFrames) + require.Equal(t, int64(1), result.DroppedDownstreamFrames) +} + +func TestRelay_IdleTimeout(t *testing.T) { + t.Parallel() + + // 客户端和上游都不发送帧,idle timeout 应触发 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // 使用快进时间来加速 idle timeout + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + // 前几次调用返回正常时间(初始化阶段),之后快进 + if callCount <= 5 { + return now + } + return now.Add(time.Hour) // 快进到超时 + } + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + }) + require.NotNil(t, relayExit, "应因 idle timeout 退出") + require.Equal(t, "idle_timeout", relayExit.Stage) + require.ErrorIs(t, relayExit.Err, ErrRelayIdleTimeout) + require.NotErrorIs(t, relayExit.Err, context.DeadlineExceeded) + require.Equal(t, "gpt-4o", result.RequestModel) +} + +func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) { + t.Parallel() + + clientConn := &closeSpyFrameConn{} + upstreamConn := &closeSpyFrameConn{} + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + if callCount <= 5 { + return now + } + return now.Add(time.Hour) + } + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + }) + require.NotNil(t, relayExit, "应因 idle timeout 退出") + require.Equal(t, "idle_timeout", relayExit.Stage) + require.ErrorIs(t, relayExit.Err, ErrRelayIdleTimeout) + require.NotErrorIs(t, relayExit.Err, context.DeadlineExceeded) + require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code") + require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1)) +} + +func TestRelay_NilConnections(t *testing.T) { + t.Parallel() + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx := context.Background() + + t.Run("nil client conn", func(t *testing.T) { + upstreamConn := newPassthroughTestFrameConn(nil, true) + _, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "relay_init", relayExit.Stage) + require.Contains(t, relayExit.Err.Error(), "nil") + }) + + t.Run("nil upstream conn", func(t *testing.T) { + clientConn := newPassthroughTestFrameConn(nil, true) + _, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "relay_init", relayExit.Stage) + require.Contains(t, relayExit.Err.Error(), "nil") + }) +} + +func TestRelay_MultipleUpstreamMessages(t *testing.T) { + t.Parallel() + + // 上游发送多个事件(delta + completed),验证多帧中继和 usage 聚合 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, "resp_multi", result.RequestID) + require.Equal(t, "response.completed", result.TerminalEventType) + require.Equal(t, 10, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) + require.NotNil(t, result.FirstTokenMs) + + // 验证所有 3 个上游帧都转发给了客户端 + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 3) +} + +func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + turns := make([]RelayTurnResult, 0, 2) + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + OnTurnComplete: func(turn RelayTurnResult) { + turns = append(turns, turn) + }, + }) + require.Nil(t, relayExit) + require.Len(t, turns, 2) + require.Equal(t, "resp_turn_1", turns[0].RequestID) + require.Equal(t, "response.completed", turns[0].TerminalEventType) + require.Equal(t, 2, turns[0].Usage.InputTokens) + require.Equal(t, 1, turns[0].Usage.OutputTokens) + require.Equal(t, "resp_turn_2", turns[1].RequestID) + require.Equal(t, "response.failed", turns[1].TerminalEventType) + require.Equal(t, 3, turns[1].Usage.InputTokens) + require.Equal(t, 4, turns[1].Usage.OutputTokens) + require.Equal(t, 5, result.Usage.InputTokens) + require.Equal(t, 5, result.Usage.OutputTokens) +} + +func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + base := time.Unix(0, 0) + var nowTick atomic.Int64 + nowFn := func() time.Time { + step := nowTick.Add(1) + return base.Add(time.Duration(step) * 5 * time.Millisecond) + } + + var turn RelayTurnResult + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + Now: nowFn, + OnTurnComplete: func(current RelayTurnResult) { + turn = current + }, + }) + require.Nil(t, relayExit) + require.Equal(t, "resp_metric", turn.RequestID) + require.Equal(t, "response.completed", turn.TerminalEventType) + require.NotNil(t, turn.FirstTokenMs) + require.GreaterOrEqual(t, *turn.FirstTokenMs, 0) + require.Greater(t, turn.Duration.Milliseconds(), int64(0)) + require.NotNil(t, result.FirstTokenMs) + require.Greater(t, result.Duration.Milliseconds(), int64(0)) +} + +func TestRelay_OnTurnComplete_FirstTokenFromTurnFallbackWhenDeltaHasNoResponseID(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.created"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":"hello"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_fallback_ttft","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + base := time.Unix(0, 0) + var nowTick atomic.Int64 + nowFn := func() time.Time { + step := nowTick.Add(1) + return base.Add(time.Duration(step) * 5 * time.Millisecond) + } + + var turn RelayTurnResult + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + Now: nowFn, + OnTurnComplete: func(current RelayTurnResult) { + turn = current + }, + }) + require.Nil(t, relayExit) + require.Equal(t, "resp_fallback_ttft", turn.RequestID) + require.NotNil(t, turn.FirstTokenMs) + require.NotNil(t, result.FirstTokenMs) + require.GreaterOrEqual(t, *turn.FirstTokenMs, 0) + require.Equal(t, *turn.FirstTokenMs, *result.FirstTokenMs) + require.Greater(t, turn.Duration.Milliseconds(), int64(0)) + require.Greater(t, turn.Duration.Milliseconds(), int64(*turn.FirstTokenMs)) +} + +func TestRelay_BinaryFramePassthrough(t *testing.T) { + t.Parallel() + + // 验证 binary frame 被透传但不进行 usage 解析 + binaryPayload := []byte{0x00, 0x01, 0x02, 0x03} + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageBinary, + payload: binaryPayload, + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + // binary frame 不解析 usage + require.Equal(t, 0, result.Usage.InputTokens) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType) + require.Equal(t, binaryPayload, clientWrites[0].payload) +} + +func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageBinary, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + require.Equal(t, 0, result.Usage.InputTokens) + require.Equal(t, "", result.RequestID) + require.Equal(t, "", result.TerminalEventType) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType) +} + +func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: errorEvent, + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.Equal(t, coderws.MessageText, clientWrites[0].msgType) + require.Equal(t, errorEvent, clientWrites[0].payload) +} + +func TestRelay_PreservesFirstMessageType(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + FirstMessageType: coderws.MessageBinary, + }) + require.Nil(t, relayExit) + + upstreamWrites := upstreamConn.Writes() + require.Len(t, upstreamWrites, 1) + require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType) + require.Equal(t, firstPayload, upstreamWrites[0].payload) +} + +func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) { + baseline := SnapshotMetrics().UsageParseFailureTotal + + // 上游发送无效 JSON(非 usage 格式),不应影响透传 + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + require.Nil(t, relayExit) + // usage 解析失败,值为 0 但不影响透传 + require.Equal(t, 0, result.Usage.InputTokens) + require.Equal(t, "response.completed", result.TerminalEventType) + + // 帧仍然被转发 + clientWrites := clientConn.Writes() + require.Len(t, clientWrites, 1) + require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1) +} + +func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) { + t.Parallel() + + // 上游连接立即关闭,首包写入失败 + upstreamConn := newPassthroughTestFrameConn(nil, true) + _ = upstreamConn.Close() + + // 覆盖 WriteFrame 使其返回错误 + errConn := &errorOnWriteFrameConn{} + clientConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + _, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{}) + require.NotNil(t, relayExit) + require.Equal(t, "write_upstream", relayExit.Stage) +} + +func TestRelay_ContextCanceled(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + + // 立即取消 context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{}) + // context 取消导致写首包失败 + require.NotNil(t, relayExit) +} + +func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, true) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + stages := make([]string, 0, 8) + var stagesMu sync.Mutex + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + OnTrace: func(event RelayTraceEvent) { + stagesMu.Lock() + stages = append(stages, event.Stage) + stagesMu.Unlock() + }, + }) + require.Nil(t, relayExit) + stagesMu.Lock() + capturedStages := append([]string(nil), stages...) + stagesMu.Unlock() + require.Contains(t, capturedStages, "relay_start") + require.Contains(t, capturedStages, "write_first_message_ok") + require.Contains(t, capturedStages, "first_exit") + require.Contains(t, capturedStages, "relay_complete") +} + +func TestRelay_TraceEvents_IdleTimeout(t *testing.T) { + t.Parallel() + + clientConn := newPassthroughTestFrameConn(nil, false) + upstreamConn := newPassthroughTestFrameConn(nil, false) + + firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + now := time.Now() + callCount := 0 + nowFn := func() time.Time { + callCount++ + if callCount <= 5 { + return now + } + return now.Add(time.Hour) + } + + stages := make([]string, 0, 8) + var stagesMu sync.Mutex + _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{ + IdleTimeout: 2 * time.Second, + Now: nowFn, + OnTrace: func(event RelayTraceEvent) { + stagesMu.Lock() + stages = append(stages, event.Stage) + stagesMu.Unlock() + }, + }) + require.NotNil(t, relayExit) + require.Equal(t, "idle_timeout", relayExit.Stage) + require.ErrorIs(t, relayExit.Err, ErrRelayIdleTimeout) + stagesMu.Lock() + capturedStages := append([]string(nil), stages...) + stagesMu.Unlock() + require.Contains(t, capturedStages, "idle_timeout_triggered") + require.Contains(t, capturedStages, "relay_exit") +} + +// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。 +type errorOnWriteFrameConn struct{} + +func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + <-ctx.Done() + return coderws.MessageText, nil, ctx.Err() +} + +func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error { + return errors.New("write failed: connection refused") +} + +func (c *errorOnWriteFrameConn) Close() error { + return nil +} diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go new file mode 100644 index 000000000..828c3692f --- /dev/null +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go @@ -0,0 +1,397 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "sync/atomic" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" +) + +type openAIWSClientFrameConn struct { + conn *coderws.Conn +} + +const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2" + +var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil) + +func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c == nil || c.conn == nil { + return coderws.MessageText, nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Read(ctx) +} + +func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Write(ctx, msgType, payload) +} + +func (c *openAIWSClientFrameConn) Close() error { + if c == nil || c.conn == nil { + return nil + } + _ = c.conn.Close(coderws.StatusNormalClosure, "") + _ = c.conn.CloseNow() + return nil +} + +func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *Account, + token string, + firstClientMessageType coderws.MessageType, + firstClientMessage []byte, + hooks *OpenAIWSIngressHooks, + wsDecision OpenAIWSProtocolDecision, +) error { + if s == nil { + return errors.New("service is nil") + } + if ctx == nil { + ctx = context.Background() + } + if clientConn == nil { + return errors.New("client websocket is nil") + } + if account == nil { + return errors.New("account is nil") + } + if strings.TrimSpace(token) == "" { + return errors.New("token is empty") + } + requestModel, requestPreviousResponseID, _ := ResolveOpenAIWSFirstMessageMeta(c, firstClientMessage) + logOpenAIWSV2Passthrough( + "relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d", + account.ID, + truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen), + openaiwsv2RelayMessageTypeName(firstClientMessageType), + len(firstClientMessage), + ) + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return fmt.Errorf("build ws url: %w", err) + } + wsHost, wsPath := openAIWSHostPathForLogFromURL(wsURL) + logOpenAIWSV2Passthrough( + "relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + + isCodexCLI := false + if c != nil { + isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) + } + if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + isCodexCLI = true + } + headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "") + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + dialer := s.getOpenAIWSPassthroughDialer() + if dialer == nil { + return errors.New("openai ws passthrough dialer is nil") + } + + dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout()) + defer cancelDial() + upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL) + if err != nil { + logOpenAIWSV2Passthrough( + "relay_dial_failed account_id=%d status_code=%d err=%s", + account.ID, + statusCode, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders) + } + defer func() { + _ = upstreamConn.Close() + }() + logOpenAIWSV2Passthrough( + "relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s", + account.ID, + statusCode, + openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"), + ) + + upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn) + if !ok { + return errors.New("openai ws passthrough upstream connection does not support frame relay") + } + + completedTurns := atomic.Int32{} + var onTrace func(event openaiwsv2.RelayTraceEvent) + if isOpenAIWSModeDebugEnabled() { + onTrace = func(event openaiwsv2.RelayTraceEvent) { + logOpenAIWSV2PassthroughDebug( + "relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s", + account.ID, + truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen), + event.PayloadBytes, + event.Graceful, + event.WroteDownstream, + truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen), + ) + } + } + relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{ + Ctx: ctx, + ClientConn: &openAIWSClientFrameConn{conn: clientConn}, + UpstreamConn: upstreamFrameConn, + FirstClientMessage: firstClientMessage, + Options: openaiwsv2.RelayOptions{ + WriteTimeout: s.openAIWSWriteTimeout(), + IdleTimeout: s.openAIWSPassthroughIdleTimeout(), + FirstMessageType: firstClientMessageType, + InitialRequestModel: requestModel, + // passthrough 链路走大帧高频转发,避免每帧创建超时 context/timer。 + // 由 relayCtx 取消 + idle watchdog 兜底释放。 + DisableWriteTimeout: true, + OnUsageParseFailure: func(eventType string, usageRaw string) { + logOpenAIWSV2Passthrough( + "usage_parse_failed event_type=%s usage_raw=%s", + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen), + ) + }, + OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) { + turnNo := int(completedTurns.Add(1)) + turnResult := &OpenAIForwardResult{ + RequestID: turn.RequestID, + Usage: OpenAIUsage{ + InputTokens: turn.Usage.InputTokens, + OutputTokens: turn.Usage.OutputTokens, + CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens, + CacheReadInputTokens: turn.Usage.CacheReadInputTokens, + }, + Model: turn.RequestModel, + Stream: true, + OpenAIWSMode: true, + WSIngressMode: OpenAIWSIngressModePassthrough, + Duration: turn.Duration, + FirstTokenMs: turn.FirstTokenMs, + TerminalEventType: turn.TerminalEventType, + } + logOpenAIWSV2Passthrough( + "relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d", + account.ID, + turnNo, + truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(turnResult.TerminalEventType, openAIWSLogValueMaxLen), + turnResult.Duration.Milliseconds(), + openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs), + turnResult.Usage.InputTokens, + turnResult.Usage.OutputTokens, + turnResult.Usage.CacheReadInputTokens, + ) + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turnNo, turnResult, nil) + } + }, + OnTrace: onTrace, + }, + }) + + result := &OpenAIForwardResult{ + RequestID: relayResult.RequestID, + Usage: OpenAIUsage{ + InputTokens: relayResult.Usage.InputTokens, + OutputTokens: relayResult.Usage.OutputTokens, + CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens, + CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens, + }, + Model: relayResult.RequestModel, + Stream: true, + OpenAIWSMode: true, + WSIngressMode: OpenAIWSIngressModePassthrough, + Duration: relayResult.Duration, + FirstTokenMs: relayResult.FirstTokenMs, + TerminalEventType: relayResult.TerminalEventType, + } + + turnCount := int(completedTurns.Load()) + if relayExit == nil { + logOpenAIWSV2Passthrough( + "relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d", + account.ID, + truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(result.TerminalEventType, openAIWSLogValueMaxLen), + result.Duration.Milliseconds(), + relayResult.ClientToUpstreamFrames, + relayResult.UpstreamToClientFrames, + relayResult.DroppedDownstreamFrames, + turnCount, + ) + // 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。 + if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(1, result, nil) + } + return nil + } + logOpenAIWSV2Passthrough( + "relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d", + account.ID, + truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen), + relayExit.WroteDownstream, + truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen), + result.Duration.Milliseconds(), + relayResult.ClientToUpstreamFrames, + relayResult.UpstreamToClientFrames, + relayResult.DroppedDownstreamFrames, + turnCount, + ) + + relayErr := relayExit.Err + if relayExit.Stage == "idle_timeout" { + relayErr = NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "client websocket idle timeout", + relayErr, + ) + } + turnErr := error(nil) + if turnCount == 0 { + turnErr = wrapOpenAIWSIngressTurnErrorWithPartial( + relayExit.Stage, + relayErr, + relayExit.WroteDownstream, + result, + ) + } else { + // 已按 turn 回调 usage 时,错误路径不再附带 partial,避免重复记账同一 request_id。 + turnErr = wrapOpenAIWSIngressTurnError( + relayExit.Stage, + relayErr, + relayExit.WroteDownstream, + ) + } + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turnCount+1, nil, turnErr) + } + return turnErr +} + +func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError( + err error, + statusCode int, + handshakeHeaders http.Header, +) error { + if err == nil { + return nil + } + wrappedErr := err + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) { + wrappedErr = &openAIWSDialError{ + StatusCode: statusCode, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: err, + } + } + + if errors.Is(err, context.Canceled) { + return err + } + if errors.Is(err, context.DeadlineExceeded) { + return NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket connect timeout", + wrappedErr, + ) + } + if statusCode == http.StatusTooManyRequests { + return NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket is busy, please retry later", + wrappedErr, + ) + } + if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream websocket authentication failed", + wrappedErr, + ) + } + if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream websocket handshake rejected", + wrappedErr, + ) + } + return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr) +} + +func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string { + switch msgType { + case coderws.MessageText: + return "text" + case coderws.MessageBinary: + return "binary" + default: + return fmt.Sprintf("unknown(%d)", msgType) + } +} + +func relayErrorText(err error) string { + if err == nil { + return "" + } + return err.Error() +} + +func openAIWSFirstTokenMsForLog(firstTokenMs *int) int { + if firstTokenMs == nil { + return -1 + } + return *firstTokenMs +} + +func logOpenAIWSV2Passthrough(format string, args ...any) { + logger.LegacyPrintf( + "service.openai_ws_v2", + "[OpenAI WS v2 passthrough] "+openaiWSV2PassthroughModeFields+" "+format, + args..., + ) +} + +func logOpenAIWSV2PassthroughDebug(format string, args ...any) { + if !isOpenAIWSModeDebugEnabled() { + return + } + logger.LegacyPrintf( + "service.openai_ws_v2", + "[debug] [OpenAI WS v2 passthrough] "+openaiWSV2PassthroughModeFields+" "+format, + args..., + ) +} diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter_test.go b/backend/internal/service/openai_ws_v2_passthrough_adapter_test.go new file mode 100644 index 000000000..3557356c0 --- /dev/null +++ b/backend/internal/service/openai_ws_v2_passthrough_adapter_test.go @@ -0,0 +1,1194 @@ +package service + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +var openAIWSModeDebugLogTestMu sync.Mutex + +func TestOpenAIWSClientFrameConn_NilGuards(t *testing.T) { + t.Parallel() + + var nilReceiver *openAIWSClientFrameConn + msgType, payload, err := nilReceiver.ReadFrame(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + require.Equal(t, coderws.MessageText, msgType) + require.Nil(t, payload) + + err = nilReceiver.WriteFrame(context.Background(), coderws.MessageText, []byte("x")) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + require.NoError(t, nilReceiver.Close()) + + empty := &openAIWSClientFrameConn{} + var nilCtx context.Context + _, _, err = empty.ReadFrame(nilCtx) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + err = empty.WriteFrame(nilCtx, coderws.MessageText, []byte("x")) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + require.NoError(t, empty.Close()) +} + +func TestOpenAIWSClientFrameConn_NilContextWithLiveConn(t *testing.T) { + gin.SetMode(gin.TestMode) + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, nil) + require.NoError(t, err) + defer func() { + _ = conn.CloseNow() + }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, payload, err := conn.Read(readCtx) + cancelRead() + if err != nil { + serverErrCh <- err + return + } + writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second) + err = conn.Write(writeCtx, msgType, payload) + cancelWrite() + serverErrCh <- err + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(wsServer.URL, "http"), + nil, + ) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + frameConn := &openAIWSClientFrameConn{conn: clientConn} + payload := []byte(`{"type":"response.create","model":"gpt-5.3-codex"}`) + var nilCtx context.Context + + require.NoError(t, frameConn.WriteFrame(nilCtx, coderws.MessageText, payload)) + msgType, gotPayload, err := frameConn.ReadFrame(nilCtx) + require.NoError(t, err) + require.Equal(t, coderws.MessageText, msgType) + require.Equal(t, payload, gotPayload) + require.NoError(t, <-serverErrCh) +} + +func TestOpenAIWSV2PassthroughHelpers(t *testing.T) { + t.Parallel() + + require.Equal(t, "text", openaiwsv2RelayMessageTypeName(coderws.MessageText)) + require.Equal(t, "binary", openaiwsv2RelayMessageTypeName(coderws.MessageBinary)) + require.Contains(t, openaiwsv2RelayMessageTypeName(coderws.MessageType(99)), "unknown(") + + require.Equal(t, "", relayErrorText(nil)) + require.Equal(t, "boom", relayErrorText(errors.New("boom"))) + require.Equal(t, -1, openAIWSFirstTokenMsForLog(nil)) + ms := 12 + require.Equal(t, 12, openAIWSFirstTokenMsForLog(&ms)) + + require.NotPanics(t, func() { + logOpenAIWSV2Passthrough("helper_test account_id=%d", 1) + }) +} + +func TestLogOpenAIWSV2PassthroughDebug_Disabled(t *testing.T) { + openAIWSModeDebugLogTestMu.Lock() + defer openAIWSModeDebugLogTestMu.Unlock() + + logger.InitBootstrap() + require.NoError(t, logger.SetLevel("info")) + require.False(t, isOpenAIWSModeDebugEnabled()) + + require.NotPanics(t, func() { + logOpenAIWSV2PassthroughDebug("helper_test_debug account_id=%d", 1) + }) +} + +func TestLogOpenAIWSV2PassthroughDebug_Enabled(t *testing.T) { + openAIWSModeDebugLogTestMu.Lock() + defer openAIWSModeDebugLogTestMu.Unlock() + + logger.InitBootstrap() + require.NoError(t, logger.SetLevel("debug")) + require.True(t, isOpenAIWSModeDebugEnabled()) + t.Cleanup(func() { + _ = logger.SetLevel("info") + }) + + require.NotPanics(t, func() { + logOpenAIWSV2PassthroughDebug("helper_test_debug_enabled account_id=%d", 1) + }) +} + +func TestProxyResponsesWebSocketV2Passthrough_InvalidInputs(t *testing.T) { + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + firstMessage := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + var nilSvc *OpenAIGatewayService + err := nilSvc.proxyResponsesWebSocketV2Passthrough( + context.Background(), + nil, + nil, + account, + "sk-test", + coderws.MessageText, + firstMessage, + nil, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "service is nil") + + cfg := buildIngressPolicyTestConfig() + svc := buildIngressPolicyTestService(cfg) + dummyClient := &coderws.Conn{} + + err = svc.proxyResponsesWebSocketV2Passthrough( + context.Background(), + nil, + nil, + account, + "sk-test", + coderws.MessageText, + firstMessage, + nil, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "client websocket is nil") + + err = svc.proxyResponsesWebSocketV2Passthrough( + context.Background(), + nil, + dummyClient, + nil, + "sk-test", + coderws.MessageText, + firstMessage, + nil, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "account is nil") + + err = svc.proxyResponsesWebSocketV2Passthrough( + context.Background(), + nil, + dummyClient, + account, + " ", + coderws.MessageText, + firstMessage, + nil, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "token is empty") +} + +func TestProxyResponsesWebSocketV2Passthrough_DialFailure(t *testing.T) { + cfg := buildIngressPolicyTestConfig() + svc := buildIngressPolicyTestService(cfg) + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + dummyClient := &coderws.Conn{} + firstMessage := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`) + + svc.openaiWSPassthroughDialer = &passthroughDialerStub{ + err: errors.New("dial failed"), + statusCode: 503, + } + err := svc.proxyResponsesWebSocketV2Passthrough( + context.Background(), + nil, + dummyClient, + account, + "sk-test", + coderws.MessageText, + firstMessage, + nil, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "openai ws passthrough dial") +} + +type passthroughDialerStub struct { + conn openAIWSClientConn + statusCode int + headers http.Header + err error + dialCount atomic.Int32 +} + +func (d *passthroughDialerStub) Dial( + _ context.Context, + _ string, + _ http.Header, + _ string, +) (openAIWSClientConn, int, http.Header, error) { + d.dialCount.Add(1) + return d.conn, d.statusCode, d.headers, d.err +} + +type passthroughUpstreamConn struct { + mu sync.Mutex + reads []struct { + msgType coderws.MessageType + payload []byte + } + readDelay time.Duration + readErr error + writes [][]byte + closed bool +} + +func (c *passthroughUpstreamConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + if c.readDelay > 0 { + timer := time.NewTimer(c.readDelay) + defer timer.Stop() + select { + case <-ctx.Done(): + return coderws.MessageText, nil, ctx.Err() + case <-timer.C: + } + } + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return coderws.MessageText, nil, io.EOF + } + if len(c.reads) == 0 { + if c.readErr != nil { + return coderws.MessageText, nil, c.readErr + } + return coderws.MessageText, nil, io.EOF + } + item := c.reads[0] + c.reads = c.reads[1:] + return item.msgType, append([]byte(nil), item.payload...), nil +} + +func (c *passthroughUpstreamConn) WriteFrame(_ context.Context, _ coderws.MessageType, payload []byte) error { + c.mu.Lock() + defer c.mu.Unlock() + c.writes = append(c.writes, append([]byte(nil), payload...)) + return nil +} + +func (c *passthroughUpstreamConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func (c *passthroughUpstreamConn) WriteJSON(context.Context, any) error { return nil } +func (c *passthroughUpstreamConn) ReadMessage(context.Context) ([]byte, error) { + return nil, io.EOF +} +func (c *passthroughUpstreamConn) Ping(context.Context) error { return nil } + +type passthroughBlockingUpstreamConn struct{} + +func (c *passthroughBlockingUpstreamConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) { + <-ctx.Done() + return coderws.MessageText, nil, ctx.Err() +} +func (c *passthroughBlockingUpstreamConn) WriteFrame(context.Context, coderws.MessageType, []byte) error { + return nil +} +func (c *passthroughBlockingUpstreamConn) Close() error { return nil } +func (c *passthroughBlockingUpstreamConn) WriteJSON(context.Context, any) error { + return nil +} +func (c *passthroughBlockingUpstreamConn) ReadMessage(context.Context) ([]byte, error) { + return nil, io.EOF +} +func (c *passthroughBlockingUpstreamConn) Ping(context.Context) error { return nil } + +func TestProxyResponsesWebSocketV2Passthrough_Success(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := &passthroughUpstreamConn{ + reads: []struct { + msgType coderws.MessageType + payload []byte + }{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_passthrough","usage":{"input_tokens":3,"output_tokens":2}}}`), + }, + }, + } + cfg := buildIngressPolicyTestConfig() + svc := buildIngressPolicyTestService(cfg) + svc.openaiWSPassthroughDialer = &passthroughDialerStub{ + conn: upstream, + statusCode: 101, + headers: http.Header{ + "X-Request-ID": []string{"req-passthrough"}, + }, + } + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + var ( + afterTurnCalled bool + afterTurnErr error + afterTurnResult *OpenAIForwardResult + ) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + afterTurnCalled = true + afterTurnErr = turnErr + afterTurnResult = result + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, nil) + require.NoError(t, err) + defer func() { + _ = conn.CloseNow() + }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, err := conn.Read(readCtx) + cancelRead() + require.NoError(t, err) + + ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ginCtx.Request = r + serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough( + r.Context(), + ginCtx, + conn, + account, + "sk-test", + msgType, + firstMessage, + hooks, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(wsServer.URL, "http"), + nil, + ) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, payload, err := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, err) + require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_passthrough","usage":{"input_tokens":3,"output_tokens":2}}}`, string(payload)) + + require.NoError(t, <-serverErrCh) + require.True(t, afterTurnCalled) + require.NoError(t, afterTurnErr) + require.NotNil(t, afterTurnResult) + require.Equal(t, "resp_passthrough", afterTurnResult.RequestID) +} + +func TestProxyResponsesWebSocketV2Passthrough_NilContext(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := &passthroughUpstreamConn{ + reads: []struct { + msgType coderws.MessageType + payload []byte + }{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_nil_ctx","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + }, + } + cfg := buildIngressPolicyTestConfig() + svc := buildIngressPolicyTestService(cfg) + svc.openaiWSPassthroughDialer = &passthroughDialerStub{ + conn: upstream, + statusCode: 101, + } + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, nil) + require.NoError(t, err) + defer func() { + _ = conn.CloseNow() + }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, err := conn.Read(readCtx) + cancelRead() + require.NoError(t, err) + + ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ginCtx.Request = r + serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough( + context.TODO(), + ginCtx, + conn, + account, + "sk-test", + msgType, + firstMessage, + nil, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(wsServer.URL, "http"), + nil, + ) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, payload, err := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, err) + require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_nil_ctx","usage":{"input_tokens":1,"output_tokens":1}}}`, string(payload)) + + require.NoError(t, <-serverErrCh) +} + +func TestProxyResponsesWebSocketV2Passthrough_ZeroTurnFallbackCallback(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := &passthroughUpstreamConn{ + reads: []struct { + msgType coderws.MessageType + payload []byte + }{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","delta":"hello"}`), + }, + }, + } + cfg := buildIngressPolicyTestConfig() + svc := buildIngressPolicyTestService(cfg) + svc.openaiWSPassthroughDialer = &passthroughDialerStub{ + conn: upstream, + statusCode: 101, + } + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + var ( + afterTurnCalled bool + afterTurnErr error + afterTurnResult *OpenAIForwardResult + ) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + afterTurnCalled = true + afterTurnErr = turnErr + afterTurnResult = result + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, nil) + require.NoError(t, err) + defer func() { + _ = conn.CloseNow() + }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, err := conn.Read(readCtx) + cancelRead() + require.NoError(t, err) + + ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ginCtx.Request = r + serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough( + r.Context(), + ginCtx, + conn, + account, + "sk-test", + msgType, + firstMessage, + hooks, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(wsServer.URL, "http"), + nil, + ) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + msgType, payload, err := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, err) + require.Equal(t, coderws.MessageText, msgType) + require.JSONEq(t, `{"type":"response.output_text.delta","delta":"hello"}`, string(payload)) + + require.NoError(t, <-serverErrCh) + require.True(t, afterTurnCalled) + require.NoError(t, afterTurnErr) + require.NotNil(t, afterTurnResult) + require.Equal(t, "", afterTurnResult.RequestID) + require.Equal(t, 0, afterTurnResult.Usage.InputTokens) + require.Equal(t, "", afterTurnResult.TerminalEventType) +} + +func TestProxyResponsesWebSocketV2Passthrough_TurnMetricsPropagated(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := &passthroughUpstreamConn{ + readDelay: 15 * time.Millisecond, + reads: []struct { + msgType coderws.MessageType + payload []byte + }{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metrics","delta":"hello"}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_metrics","usage":{"input_tokens":4,"output_tokens":2}}}`), + }, + }, + } + cfg := buildIngressPolicyTestConfig() + svc := buildIngressPolicyTestService(cfg) + svc.openaiWSPassthroughDialer = &passthroughDialerStub{ + conn: upstream, + statusCode: 101, + headers: http.Header{ + "X-Request-ID": []string{"req-passthrough"}, + }, + } + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + var ( + afterTurnCalled bool + afterTurnResult *OpenAIForwardResult + afterTurnErr error + ) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + afterTurnCalled = true + afterTurnResult = result + afterTurnErr = turnErr + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, nil) + require.NoError(t, err) + defer func() { + _ = conn.CloseNow() + }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, err := conn.Read(readCtx) + cancelRead() + require.NoError(t, err) + + ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ginCtx.Request = r + serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough( + r.Context(), + ginCtx, + conn, + account, + "sk-test", + msgType, + firstMessage, + hooks, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(wsServer.URL, "http"), + nil, + ) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)) + cancelWrite() + require.NoError(t, err) + + // 读取透传下发事件直到 terminal 到达。 + readCtx1, cancelRead1 := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx1) + cancelRead1() + require.NoError(t, err) + readCtx2, cancelRead2 := context.WithTimeout(context.Background(), 3*time.Second) + _, payload2, err := clientConn.Read(readCtx2) + cancelRead2() + require.NoError(t, err) + require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_metrics","usage":{"input_tokens":4,"output_tokens":2}}}`, string(payload2)) + + require.NoError(t, <-serverErrCh) + require.True(t, afterTurnCalled) + require.NoError(t, afterTurnErr) + require.NotNil(t, afterTurnResult) + require.NotNil(t, afterTurnResult.FirstTokenMs) + require.GreaterOrEqual(t, *afterTurnResult.FirstTokenMs, 0) + require.Greater(t, afterTurnResult.Duration.Milliseconds(), int64(0)) +} + +func TestProxyResponsesWebSocketV2Passthrough_MultiTurnAfterTurnOnTerminal(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := &passthroughUpstreamConn{ + reads: []struct { + msgType coderws.MessageType + payload []byte + }{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":3,"output_tokens":2}}}`), + }, + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":1,"output_tokens":4}}}`), + }, + }, + } + cfg := buildIngressPolicyTestConfig() + svc := buildIngressPolicyTestService(cfg) + svc.openaiWSPassthroughDialer = &passthroughDialerStub{ + conn: upstream, + statusCode: 101, + headers: http.Header{ + "X-Request-ID": []string{"req-passthrough"}, + }, + } + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + type afterTurnCall struct { + turn int + result *OpenAIForwardResult + err error + } + var ( + callsMu sync.Mutex + calls = make([]afterTurnCall, 0, 3) + ) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(turn int, result *OpenAIForwardResult, turnErr error) { + callsMu.Lock() + defer callsMu.Unlock() + calls = append(calls, afterTurnCall{ + turn: turn, + result: result, + err: turnErr, + }) + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, nil) + require.NoError(t, err) + defer func() { + _ = conn.CloseNow() + }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, err := conn.Read(readCtx) + cancelRead() + require.NoError(t, err) + + ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ginCtx.Request = r + serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough( + r.Context(), + ginCtx, + conn, + account, + "sk-test", + msgType, + firstMessage, + hooks, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(wsServer.URL, "http"), + nil, + ) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)) + cancelWrite() + require.NoError(t, err) + + readCtx1, cancelRead1 := context.WithTimeout(context.Background(), 3*time.Second) + _, payload1, err := clientConn.Read(readCtx1) + cancelRead1() + require.NoError(t, err) + require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":3,"output_tokens":2}}}`, string(payload1)) + + readCtx2, cancelRead2 := context.WithTimeout(context.Background(), 3*time.Second) + _, payload2, err := clientConn.Read(readCtx2) + cancelRead2() + require.NoError(t, err) + require.JSONEq(t, `{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":1,"output_tokens":4}}}`, string(payload2)) + + require.NoError(t, <-serverErrCh) + + callsMu.Lock() + gotCalls := append([]afterTurnCall(nil), calls...) + callsMu.Unlock() + require.Len(t, gotCalls, 2) + require.Equal(t, 1, gotCalls[0].turn) + require.NoError(t, gotCalls[0].err) + require.NotNil(t, gotCalls[0].result) + require.Equal(t, "resp_turn_1", gotCalls[0].result.RequestID) + require.Equal(t, 3, gotCalls[0].result.Usage.InputTokens) + require.Equal(t, 2, gotCalls[0].result.Usage.OutputTokens) + require.Equal(t, 2, gotCalls[1].turn) + require.NoError(t, gotCalls[1].err) + require.NotNil(t, gotCalls[1].result) + require.Equal(t, "resp_turn_2", gotCalls[1].result.RequestID) + require.Equal(t, 1, gotCalls[1].result.Usage.InputTokens) + require.Equal(t, 4, gotCalls[1].result.Usage.OutputTokens) +} + +func TestProxyResponsesWebSocketV2Passthrough_ErrorAfterTurn_NoPartialDuplicate(t *testing.T) { + gin.SetMode(gin.TestMode) + + upstream := &passthroughUpstreamConn{ + reads: []struct { + msgType coderws.MessageType + payload []byte + }{ + { + msgType: coderws.MessageText, + payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_ok","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + }, + readErr: errors.New("upstream stream broken"), + } + cfg := buildIngressPolicyTestConfig() + svc := buildIngressPolicyTestService(cfg) + svc.openaiWSPassthroughDialer = &passthroughDialerStub{ + conn: upstream, + statusCode: 101, + headers: http.Header{ + "X-Request-ID": []string{"req-passthrough"}, + }, + } + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + type afterTurnCall struct { + turn int + result *OpenAIForwardResult + err error + } + var ( + callsMu sync.Mutex + calls = make([]afterTurnCall, 0, 3) + ) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(turn int, result *OpenAIForwardResult, turnErr error) { + callsMu.Lock() + defer callsMu.Unlock() + calls = append(calls, afterTurnCall{ + turn: turn, + result: result, + err: turnErr, + }) + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, nil) + require.NoError(t, err) + defer func() { + _ = conn.CloseNow() + }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, err := conn.Read(readCtx) + cancelRead() + require.NoError(t, err) + + ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ginCtx.Request = r + serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough( + r.Context(), + ginCtx, + conn, + account, + "sk-test", + msgType, + firstMessage, + hooks, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(wsServer.URL, "http"), + nil, + ) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, payload, err := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, err) + require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_turn_ok","usage":{"input_tokens":2,"output_tokens":1}}}`, string(payload)) + + serverErr := <-serverErrCh + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "upstream stream broken") + + callsMu.Lock() + gotCalls := append([]afterTurnCall(nil), calls...) + callsMu.Unlock() + require.Len(t, gotCalls, 2) + require.Equal(t, 1, gotCalls[0].turn) + require.NoError(t, gotCalls[0].err) + require.NotNil(t, gotCalls[0].result) + require.Equal(t, "resp_turn_ok", gotCalls[0].result.RequestID) + require.Equal(t, 2, gotCalls[1].turn) + require.Error(t, gotCalls[1].err) + require.Nil(t, gotCalls[1].result) + + partial, ok := OpenAIWSIngressTurnPartialResult(gotCalls[1].err) + require.False(t, ok) + require.Nil(t, partial) +} + +func TestProxyResponsesWebSocketV2Passthrough_UpstreamWithoutFrameConn(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + svc := buildIngressPolicyTestService(cfg) + svc.openaiWSPassthroughDialer = &passthroughDialerStub{ + conn: &openAIWSFakeConn{}, + } + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, nil) + require.NoError(t, err) + defer func() { + _ = conn.CloseNow() + }() + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, err := conn.Read(readCtx) + cancelRead() + require.NoError(t, err) + + ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ginCtx.Request = r + err = svc.proxyResponsesWebSocketV2Passthrough( + r.Context(), + ginCtx, + conn, + account, + "sk-test", + msgType, + firstMessage, + nil, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "does not support frame relay") + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(wsServer.URL, "http"), + nil, + ) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)) + cancelWrite() + require.NoError(t, err) +} + +func TestProxyResponsesWebSocketV2Passthrough_BuildWSURLError(t *testing.T) { + cfg := buildIngressPolicyTestConfig() + svc := buildIngressPolicyTestService(cfg) + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + account.Credentials["base_url"] = "http://[::1" + err := svc.proxyResponsesWebSocketV2Passthrough( + context.Background(), + nil, + &coderws.Conn{}, + account, + "sk-test", + coderws.MessageText, + []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`), + nil, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + require.Error(t, err) + require.Contains(t, err.Error(), "build ws url") +} + +func TestProxyResponsesWebSocketV2Passthrough_IdleTimeoutErrorPath(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := buildIngressPolicyTestConfig() + cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 1 + svc := buildIngressPolicyTestService(cfg) + svc.openaiWSPassthroughDialer = &passthroughDialerStub{ + conn: &passthroughBlockingUpstreamConn{}, + } + account := buildIngressPolicyTestAccount(map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough, + }) + + var ( + afterTurnCalled bool + afterTurnErr error + ) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, _ *OpenAIForwardResult, turnErr error) { + afterTurnCalled = true + afterTurnErr = turnErr + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, nil) + require.NoError(t, err) + defer func() { + _ = conn.CloseNow() + }() + + readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, err := conn.Read(readCtx) + cancelRead() + require.NoError(t, err) + + ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder()) + ginCtx.Request = r + serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough( + r.Context(), + ginCtx, + conn, + account, + "sk-test", + msgType, + firstMessage, + hooks, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + ) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial( + dialCtx, + "ws"+strings.TrimPrefix(wsServer.URL, "http"), + nil, + ) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)) + cancelWrite() + require.NoError(t, err) + + serverErr := <-serverErrCh + require.Error(t, serverErr) + require.Contains(t, serverErr.Error(), "idle timeout") + require.True(t, afterTurnCalled) + require.Error(t, afterTurnErr) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, afterTurnErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Contains(t, closeErr.Reason(), "idle timeout") + require.ErrorIs(t, afterTurnErr, openaiwsv2.ErrRelayIdleTimeout) + require.NotErrorIs(t, afterTurnErr, context.DeadlineExceeded) +} + +func TestMapOpenAIWSPassthroughDialError_StatusMapping(t *testing.T) { + t.Parallel() + + svc := &OpenAIGatewayService{} + + t.Run("nil error returns nil", func(t *testing.T) { + err := svc.mapOpenAIWSPassthroughDialError(nil, 0, nil) + require.NoError(t, err) + }) + + t.Run("401 maps to policy violation", func(t *testing.T) { + err := svc.mapOpenAIWSPassthroughDialError(errors.New("unauthorized"), 401, nil) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + }) + + t.Run("403 maps to policy violation", func(t *testing.T) { + err := svc.mapOpenAIWSPassthroughDialError(errors.New("forbidden"), 403, nil) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + }) + + t.Run("429 maps to try again later", func(t *testing.T) { + err := svc.mapOpenAIWSPassthroughDialError(errors.New("rate limited"), 429, nil) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusTryAgainLater, closeErr.StatusCode()) + }) + + t.Run("4xx generic maps to policy violation", func(t *testing.T) { + err := svc.mapOpenAIWSPassthroughDialError(errors.New("bad request"), 400, nil) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + }) + + t.Run("5xx wraps as generic error without CloseError", func(t *testing.T) { + err := svc.mapOpenAIWSPassthroughDialError(errors.New("server error"), 500, nil) + require.Error(t, err) + var closeErr *OpenAIWSClientCloseError + require.False(t, errors.As(err, &closeErr), "5xx 不应封装为 CloseError") + require.Contains(t, err.Error(), "openai ws passthrough dial") + }) + + t.Run("deadline exceeded maps to try again later", func(t *testing.T) { + err := svc.mapOpenAIWSPassthroughDialError(context.DeadlineExceeded, 0, nil) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusTryAgainLater, closeErr.StatusCode()) + }) + + t.Run("context canceled returns original error", func(t *testing.T) { + err := svc.mapOpenAIWSPassthroughDialError(context.Canceled, 0, nil) + require.ErrorIs(t, err, context.Canceled) + var closeErr *OpenAIWSClientCloseError + require.False(t, errors.As(err, &closeErr), "context.Canceled 不应封装为 CloseError") + }) +} diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go index b22da7522..225477d52 100644 --- a/backend/internal/service/redeem_service.go +++ b/backend/internal/service/redeem_service.go @@ -15,11 +15,12 @@ import ( ) var ( - ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found") - ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used") - ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance") - ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later") - ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again") + ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found") + ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used") + ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance") + ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later") + ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again") + ErrBalanceCacheNotFound = errors.New("balance cache key not found") ) const ( diff --git a/backend/internal/service/setting_bulk_edit_template.go b/backend/internal/service/setting_bulk_edit_template.go new file mode 100644 index 000000000..dd28e1633 --- /dev/null +++ b/backend/internal/service/setting_bulk_edit_template.go @@ -0,0 +1,770 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "sort" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +const ( + BulkEditTemplateShareScopePrivate = "private" + BulkEditTemplateShareScopeTeam = "team" + BulkEditTemplateShareScopeGroups = "groups" +) + +var ( + ErrBulkEditTemplateNotFound = infraerrors.NotFound("BULK_EDIT_TEMPLATE_NOT_FOUND", "bulk edit template not found") + ErrBulkEditTemplateVersionNotFound = infraerrors.NotFound( + "BULK_EDIT_TEMPLATE_VERSION_NOT_FOUND", + "bulk edit template version not found", + ) + ErrBulkEditTemplateForbidden = infraerrors.Forbidden( + "BULK_EDIT_TEMPLATE_FORBIDDEN", + "no permission to modify this bulk edit template", + ) + bulkEditTemplateRandRead = rand.Read +) + +type BulkEditTemplate struct { + ID string `json:"id"` + Name string `json:"name"` + ScopePlatform string `json:"scope_platform"` + ScopeType string `json:"scope_type"` + ShareScope string `json:"share_scope"` + GroupIDs []int64 `json:"group_ids"` + State map[string]any `json:"state"` + CreatedBy int64 `json:"created_by"` + UpdatedBy int64 `json:"updated_by"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +type BulkEditTemplateQuery struct { + ScopePlatform string + ScopeType string + ScopeGroupIDs []int64 + RequesterUserID int64 +} + +type BulkEditTemplateVersion struct { + VersionID string `json:"version_id"` + ShareScope string `json:"share_scope"` + GroupIDs []int64 `json:"group_ids"` + State map[string]any `json:"state"` + UpdatedBy int64 `json:"updated_by"` + UpdatedAt int64 `json:"updated_at"` +} + +type BulkEditTemplateVersionQuery struct { + TemplateID string + ScopeGroupIDs []int64 + RequesterUserID int64 +} + +type BulkEditTemplateUpsertInput struct { + ID string + Name string + ScopePlatform string + ScopeType string + ShareScope string + GroupIDs []int64 + State map[string]any + RequesterUserID int64 +} + +type BulkEditTemplateRollbackInput struct { + TemplateID string + VersionID string + ScopeGroupIDs []int64 + RequesterUserID int64 +} + +type bulkEditTemplateLibraryStore struct { + Items []bulkEditTemplateStoreItem `json:"items"` +} + +type bulkEditTemplateVersionStoreItem struct { + VersionID string `json:"version_id"` + ShareScope string `json:"share_scope"` + GroupIDs []int64 `json:"group_ids"` + State json.RawMessage `json:"state"` + UpdatedBy int64 `json:"updated_by"` + UpdatedAt int64 `json:"updated_at"` +} + +type bulkEditTemplateStoreItem struct { + ID string `json:"id"` + Name string `json:"name"` + ScopePlatform string `json:"scope_platform"` + ScopeType string `json:"scope_type"` + ShareScope string `json:"share_scope"` + GroupIDs []int64 `json:"group_ids"` + State json.RawMessage `json:"state"` + Versions []bulkEditTemplateVersionStoreItem `json:"versions"` + CreatedBy int64 `json:"created_by"` + UpdatedBy int64 `json:"updated_by"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} + +func (s *SettingService) ListBulkEditTemplates(ctx context.Context, query BulkEditTemplateQuery) ([]BulkEditTemplate, error) { + store, err := s.loadBulkEditTemplateLibrary(ctx) + if err != nil { + return nil, err + } + + scopePlatform := strings.TrimSpace(strings.ToLower(query.ScopePlatform)) + scopeType := strings.TrimSpace(strings.ToLower(query.ScopeType)) + scopeGroupIDs := normalizeBulkEditTemplateGroupIDs(query.ScopeGroupIDs) + scopeGroupSet := make(map[int64]struct{}, len(scopeGroupIDs)) + for _, groupID := range scopeGroupIDs { + scopeGroupSet[groupID] = struct{}{} + } + + out := make([]BulkEditTemplate, 0, len(store.Items)) + for idx := range store.Items { + item := store.Items[idx] + if scopePlatform != "" && item.ScopePlatform != scopePlatform { + continue + } + if scopeType != "" && item.ScopeType != scopeType { + continue + } + if !isBulkEditTemplateVisible(item, query.RequesterUserID, scopeGroupSet) { + continue + } + out = append(out, toBulkEditTemplate(item)) + } + + sort.Slice(out, func(i, j int) bool { + if out[i].UpdatedAt == out[j].UpdatedAt { + return out[i].ID < out[j].ID + } + return out[i].UpdatedAt > out[j].UpdatedAt + }) + + return out, nil +} + +func (s *SettingService) UpsertBulkEditTemplate(ctx context.Context, input BulkEditTemplateUpsertInput) (*BulkEditTemplate, error) { + name := strings.TrimSpace(input.Name) + if name == "" { + return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template name is required") + } + if input.RequesterUserID <= 0 { + return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized") + } + + scopePlatform := strings.TrimSpace(strings.ToLower(input.ScopePlatform)) + scopeType := strings.TrimSpace(strings.ToLower(input.ScopeType)) + if scopePlatform == "" || scopeType == "" { + return nil, infraerrors.BadRequest( + "BULK_EDIT_TEMPLATE_INVALID_INPUT", + "scope_platform and scope_type are required", + ) + } + + shareScope, shareScopeErr := validateBulkEditTemplateShareScope(input.ShareScope) + if shareScopeErr != nil { + return nil, shareScopeErr + } + + groupIDs := normalizeBulkEditTemplateGroupIDs(input.GroupIDs) + if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 { + return nil, infraerrors.BadRequest( + "BULK_EDIT_TEMPLATE_INVALID_INPUT", + "group_ids is required when share_scope=groups", + ) + } + + stateRaw, err := json.Marshal(input.State) + if err != nil { + return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "invalid template state") + } + if len(stateRaw) == 0 || string(stateRaw) == "null" { + stateRaw = json.RawMessage("{}") + } + + store, err := s.loadBulkEditTemplateLibrary(ctx) + if err != nil { + return nil, err + } + + templateID := strings.TrimSpace(input.ID) + matchIndex := -1 + if templateID != "" { + for idx := range store.Items { + if store.Items[idx].ID == templateID { + matchIndex = idx + break + } + } + if matchIndex < 0 { + return nil, ErrBulkEditTemplateNotFound + } + } + + if matchIndex < 0 && templateID == "" { + for idx := range store.Items { + item := store.Items[idx] + if item.ScopePlatform != scopePlatform || item.ScopeType != scopeType { + continue + } + if !strings.EqualFold(strings.TrimSpace(item.Name), name) { + continue + } + if !canModifyBulkEditTemplate(item, input.RequesterUserID) { + continue + } + matchIndex = idx + break + } + } + + nowMS := time.Now().UnixMilli() + if matchIndex >= 0 { + item := store.Items[matchIndex] + if !canModifyBulkEditTemplate(item, input.RequesterUserID) { + return nil, ErrBulkEditTemplateForbidden + } + + previousVersion := snapshotBulkEditTemplateVersion(item) + item.Versions = append(item.Versions, previousVersion) + item.Name = name + item.ScopePlatform = scopePlatform + item.ScopeType = scopeType + item.ShareScope = shareScope + item.GroupIDs = groupIDs + item.State = cloneBulkEditTemplateStateRaw(stateRaw) + if item.CreatedBy <= 0 { + item.CreatedBy = input.RequesterUserID + } + if item.CreatedAt <= 0 { + item.CreatedAt = nowMS + } + item.UpdatedBy = input.RequesterUserID + item.UpdatedAt = nowMS + store.Items[matchIndex] = item + + if err := s.persistBulkEditTemplateLibrary(ctx, store); err != nil { + return nil, err + } + output := toBulkEditTemplate(item) + return &output, nil + } + + if templateID == "" { + templateID = generateBulkEditTemplateID() + } + + created := bulkEditTemplateStoreItem{ + ID: templateID, + Name: name, + ScopePlatform: scopePlatform, + ScopeType: scopeType, + ShareScope: shareScope, + GroupIDs: groupIDs, + State: cloneBulkEditTemplateStateRaw(stateRaw), + Versions: []bulkEditTemplateVersionStoreItem{}, + CreatedBy: input.RequesterUserID, + UpdatedBy: input.RequesterUserID, + CreatedAt: nowMS, + UpdatedAt: nowMS, + } + store.Items = append(store.Items, created) + + if err := s.persistBulkEditTemplateLibrary(ctx, store); err != nil { + return nil, err + } + + output := toBulkEditTemplate(created) + return &output, nil +} + +func (s *SettingService) DeleteBulkEditTemplate(ctx context.Context, templateID string, requesterUserID int64) error { + id := strings.TrimSpace(templateID) + if id == "" { + return infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template id is required") + } + if requesterUserID <= 0 { + return infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized") + } + + store, err := s.loadBulkEditTemplateLibrary(ctx) + if err != nil { + return err + } + + idx := -1 + for index := range store.Items { + if store.Items[index].ID == id { + idx = index + break + } + } + if idx < 0 { + return ErrBulkEditTemplateNotFound + } + + target := store.Items[idx] + if target.ShareScope == BulkEditTemplateShareScopePrivate && target.CreatedBy > 0 && target.CreatedBy != requesterUserID { + return ErrBulkEditTemplateForbidden + } + + store.Items = append(store.Items[:idx], store.Items[idx+1:]...) + return s.persistBulkEditTemplateLibrary(ctx, store) +} + +func (s *SettingService) ListBulkEditTemplateVersions( + ctx context.Context, + query BulkEditTemplateVersionQuery, +) ([]BulkEditTemplateVersion, error) { + templateID := strings.TrimSpace(query.TemplateID) + if templateID == "" { + return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template id is required") + } + if query.RequesterUserID <= 0 { + return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized") + } + + store, err := s.loadBulkEditTemplateLibrary(ctx) + if err != nil { + return nil, err + } + + scopeGroupSet := toBulkEditTemplateScopeGroupSet(query.ScopeGroupIDs) + target := findBulkEditTemplateStoreItemByID(store.Items, templateID) + if target == nil { + return nil, ErrBulkEditTemplateNotFound + } + if !isBulkEditTemplateVisible(*target, query.RequesterUserID, scopeGroupSet) { + return nil, ErrBulkEditTemplateForbidden + } + + versions := make([]BulkEditTemplateVersion, 0, len(target.Versions)) + for idx := range target.Versions { + versions = append(versions, toBulkEditTemplateVersion(target.Versions[idx])) + } + + sort.Slice(versions, func(i, j int) bool { + if versions[i].UpdatedAt == versions[j].UpdatedAt { + return versions[i].VersionID < versions[j].VersionID + } + return versions[i].UpdatedAt > versions[j].UpdatedAt + }) + + return versions, nil +} + +func (s *SettingService) RollbackBulkEditTemplate( + ctx context.Context, + input BulkEditTemplateRollbackInput, +) (*BulkEditTemplate, error) { + templateID := strings.TrimSpace(input.TemplateID) + versionID := strings.TrimSpace(input.VersionID) + if templateID == "" || versionID == "" { + return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template_id and version_id are required") + } + if input.RequesterUserID <= 0 { + return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized") + } + + store, err := s.loadBulkEditTemplateLibrary(ctx) + if err != nil { + return nil, err + } + + scopeGroupSet := toBulkEditTemplateScopeGroupSet(input.ScopeGroupIDs) + templateIndex := findBulkEditTemplateStoreItemIndexByID(store.Items, templateID) + if templateIndex < 0 { + return nil, ErrBulkEditTemplateNotFound + } + + item := store.Items[templateIndex] + if !isBulkEditTemplateVisible(item, input.RequesterUserID, scopeGroupSet) { + return nil, ErrBulkEditTemplateForbidden + } + + versionIndex := findBulkEditTemplateVersionIndexByID(item.Versions, versionID) + if versionIndex < 0 { + return nil, ErrBulkEditTemplateVersionNotFound + } + + targetVersion := item.Versions[versionIndex] + previousVersion := snapshotBulkEditTemplateVersion(item) + item.Versions = append(item.Versions, previousVersion) + item.ShareScope = targetVersion.ShareScope + item.GroupIDs = append([]int64(nil), targetVersion.GroupIDs...) + item.State = cloneBulkEditTemplateStateRaw(targetVersion.State) + item.UpdatedBy = input.RequesterUserID + item.UpdatedAt = time.Now().UnixMilli() + + store.Items[templateIndex] = item + if persistErr := s.persistBulkEditTemplateLibrary(ctx, store); persistErr != nil { + return nil, persistErr + } + + output := toBulkEditTemplate(item) + return &output, nil +} + +func (s *SettingService) loadBulkEditTemplateLibrary(ctx context.Context) (*bulkEditTemplateLibraryStore, error) { + raw, err := s.settingRepo.GetValue(ctx, SettingKeyBulkEditTemplateLibrary) + if err != nil { + if errors.Is(err, ErrSettingNotFound) { + return &bulkEditTemplateLibraryStore{}, nil + } + return nil, fmt.Errorf("get bulk edit template library: %w", err) + } + + raw = strings.TrimSpace(raw) + if raw == "" { + return &bulkEditTemplateLibraryStore{}, nil + } + + store := bulkEditTemplateLibraryStore{} + if err := json.Unmarshal([]byte(raw), &store); err != nil { + return nil, fmt.Errorf("parse bulk edit template library: %w", err) + } + + normalized := normalizeBulkEditTemplateLibraryStore(store) + return &normalized, nil +} + +func (s *SettingService) persistBulkEditTemplateLibrary(ctx context.Context, store *bulkEditTemplateLibraryStore) error { + if store == nil { + return infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template library cannot be nil") + } + + normalized := normalizeBulkEditTemplateLibraryStore(*store) + data, err := json.Marshal(normalized) + if err != nil { + return fmt.Errorf("marshal bulk edit template library: %w", err) + } + + return s.settingRepo.Set(ctx, SettingKeyBulkEditTemplateLibrary, string(data)) +} + +func validateBulkEditTemplateShareScope(scope string) (string, error) { + normalized := strings.TrimSpace(strings.ToLower(scope)) + if normalized == "" { + return BulkEditTemplateShareScopePrivate, nil + } + switch normalized { + case BulkEditTemplateShareScopePrivate, + BulkEditTemplateShareScopeTeam, + BulkEditTemplateShareScopeGroups: + return normalized, nil + default: + return "", infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "invalid share_scope") + } +} + +func normalizeBulkEditTemplateLibraryStore(store bulkEditTemplateLibraryStore) bulkEditTemplateLibraryStore { + if len(store.Items) == 0 { + return bulkEditTemplateLibraryStore{Items: []bulkEditTemplateStoreItem{}} + } + + nowMS := time.Now().UnixMilli() + items := make([]bulkEditTemplateStoreItem, 0, len(store.Items)) + seenID := make(map[string]struct{}, len(store.Items)) + + for _, raw := range store.Items { + name := strings.TrimSpace(raw.Name) + scopePlatform := strings.TrimSpace(strings.ToLower(raw.ScopePlatform)) + scopeType := strings.TrimSpace(strings.ToLower(raw.ScopeType)) + if name == "" || scopePlatform == "" || scopeType == "" { + continue + } + + shareScope := normalizeBulkEditTemplateShareScopeOrDefault(raw.ShareScope) + groupIDs := normalizeBulkEditTemplateGroupIDs(raw.GroupIDs) + if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 { + shareScope = BulkEditTemplateShareScopePrivate + } + + templateID := strings.TrimSpace(raw.ID) + if templateID == "" { + templateID = generateBulkEditTemplateID() + } + if _, exists := seenID[templateID]; exists { + continue + } + seenID[templateID] = struct{}{} + + state := raw.State + if len(state) == 0 || string(state) == "null" { + state = json.RawMessage("{}") + } + + createdAt := raw.CreatedAt + if createdAt <= 0 { + createdAt = nowMS + } + updatedAt := raw.UpdatedAt + if updatedAt <= 0 { + updatedAt = createdAt + } + + items = append(items, bulkEditTemplateStoreItem{ + ID: templateID, + Name: name, + ScopePlatform: scopePlatform, + ScopeType: scopeType, + ShareScope: shareScope, + GroupIDs: groupIDs, + State: cloneBulkEditTemplateStateRaw(state), + Versions: normalizeBulkEditTemplateVersionStoreItems(raw.Versions), + CreatedBy: raw.CreatedBy, + UpdatedBy: raw.UpdatedBy, + CreatedAt: createdAt, + UpdatedAt: updatedAt, + }) + } + + return bulkEditTemplateLibraryStore{Items: items} +} + +func toBulkEditTemplate(item bulkEditTemplateStoreItem) BulkEditTemplate { + state := map[string]any{} + if err := json.Unmarshal(item.State, &state); err != nil || state == nil { + state = map[string]any{} + } + + return BulkEditTemplate{ + ID: item.ID, + Name: item.Name, + ScopePlatform: item.ScopePlatform, + ScopeType: item.ScopeType, + ShareScope: item.ShareScope, + GroupIDs: append([]int64(nil), item.GroupIDs...), + State: state, + CreatedBy: item.CreatedBy, + UpdatedBy: item.UpdatedBy, + CreatedAt: item.CreatedAt, + UpdatedAt: item.UpdatedAt, + } +} + +func toBulkEditTemplateVersion(item bulkEditTemplateVersionStoreItem) BulkEditTemplateVersion { + state := map[string]any{} + if err := json.Unmarshal(item.State, &state); err != nil || state == nil { + state = map[string]any{} + } + + return BulkEditTemplateVersion{ + VersionID: item.VersionID, + ShareScope: item.ShareScope, + GroupIDs: append([]int64(nil), item.GroupIDs...), + State: state, + UpdatedBy: item.UpdatedBy, + UpdatedAt: item.UpdatedAt, + } +} + +func normalizeBulkEditTemplateVersionStoreItems( + rawVersions []bulkEditTemplateVersionStoreItem, +) []bulkEditTemplateVersionStoreItem { + if len(rawVersions) == 0 { + return []bulkEditTemplateVersionStoreItem{} + } + + nowMS := time.Now().UnixMilli() + seen := make(map[string]struct{}, len(rawVersions)) + out := make([]bulkEditTemplateVersionStoreItem, 0, len(rawVersions)) + for _, raw := range rawVersions { + versionID := strings.TrimSpace(raw.VersionID) + if versionID == "" { + versionID = generateBulkEditTemplateVersionID() + } + if _, exists := seen[versionID]; exists { + continue + } + seen[versionID] = struct{}{} + + shareScope := normalizeBulkEditTemplateShareScopeOrDefault(raw.ShareScope) + groupIDs := normalizeBulkEditTemplateGroupIDs(raw.GroupIDs) + if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 { + shareScope = BulkEditTemplateShareScopePrivate + } + + updatedAt := raw.UpdatedAt + if updatedAt <= 0 { + updatedAt = nowMS + } + + out = append(out, bulkEditTemplateVersionStoreItem{ + VersionID: versionID, + ShareScope: shareScope, + GroupIDs: groupIDs, + State: cloneBulkEditTemplateStateRaw(raw.State), + UpdatedBy: raw.UpdatedBy, + UpdatedAt: updatedAt, + }) + } + + sort.Slice(out, func(i, j int) bool { + if out[i].UpdatedAt == out[j].UpdatedAt { + return out[i].VersionID < out[j].VersionID + } + return out[i].UpdatedAt > out[j].UpdatedAt + }) + return out +} + +func snapshotBulkEditTemplateVersion(item bulkEditTemplateStoreItem) bulkEditTemplateVersionStoreItem { + updatedAt := item.UpdatedAt + if updatedAt <= 0 { + updatedAt = time.Now().UnixMilli() + } + return bulkEditTemplateVersionStoreItem{ + VersionID: generateBulkEditTemplateVersionID(), + ShareScope: normalizeBulkEditTemplateShareScopeOrDefault(item.ShareScope), + GroupIDs: normalizeBulkEditTemplateGroupIDs(item.GroupIDs), + State: cloneBulkEditTemplateStateRaw(item.State), + UpdatedBy: item.UpdatedBy, + UpdatedAt: updatedAt, + } +} + +func cloneBulkEditTemplateStateRaw(raw json.RawMessage) json.RawMessage { + if len(raw) == 0 || string(raw) == "null" { + return json.RawMessage("{}") + } + cloned := make(json.RawMessage, len(raw)) + copy(cloned, raw) + return cloned +} + +func toBulkEditTemplateScopeGroupSet(raw []int64) map[int64]struct{} { + groupIDs := normalizeBulkEditTemplateGroupIDs(raw) + scopeGroupSet := make(map[int64]struct{}, len(groupIDs)) + for _, groupID := range groupIDs { + scopeGroupSet[groupID] = struct{}{} + } + return scopeGroupSet +} + +func findBulkEditTemplateStoreItemByID( + items []bulkEditTemplateStoreItem, + templateID string, +) *bulkEditTemplateStoreItem { + for idx := range items { + if items[idx].ID == templateID { + return &items[idx] + } + } + return nil +} + +func findBulkEditTemplateStoreItemIndexByID(items []bulkEditTemplateStoreItem, templateID string) int { + for idx := range items { + if items[idx].ID == templateID { + return idx + } + } + return -1 +} + +func findBulkEditTemplateVersionIndexByID( + versions []bulkEditTemplateVersionStoreItem, + versionID string, +) int { + for idx := range versions { + if versions[idx].VersionID == versionID { + return idx + } + } + return -1 +} + +func canModifyBulkEditTemplate(item bulkEditTemplateStoreItem, requesterUserID int64) bool { + if requesterUserID <= 0 { + return false + } + if item.ShareScope != BulkEditTemplateShareScopePrivate { + return true + } + if item.CreatedBy <= 0 { + return true + } + return item.CreatedBy == requesterUserID +} + +func isBulkEditTemplateVisible( + item bulkEditTemplateStoreItem, + requesterUserID int64, + scopeGroupSet map[int64]struct{}, +) bool { + switch item.ShareScope { + case BulkEditTemplateShareScopeTeam: + return true + case BulkEditTemplateShareScopeGroups: + if len(scopeGroupSet) == 0 || len(item.GroupIDs) == 0 { + return false + } + for _, groupID := range item.GroupIDs { + if _, ok := scopeGroupSet[groupID]; ok { + return true + } + } + return false + default: + return requesterUserID > 0 && item.CreatedBy == requesterUserID + } +} + +func normalizeBulkEditTemplateShareScopeOrDefault(scope string) string { + normalized, err := validateBulkEditTemplateShareScope(scope) + if err != nil { + return BulkEditTemplateShareScopePrivate + } + return normalized +} + +func normalizeBulkEditTemplateGroupIDs(raw []int64) []int64 { + if len(raw) == 0 { + return []int64{} + } + + seen := make(map[int64]struct{}, len(raw)) + groupIDs := make([]int64, 0, len(raw)) + for _, groupID := range raw { + if groupID <= 0 { + continue + } + if _, exists := seen[groupID]; exists { + continue + } + seen[groupID] = struct{}{} + groupIDs = append(groupIDs, groupID) + } + sort.Slice(groupIDs, func(i, j int) bool { + return groupIDs[i] < groupIDs[j] + }) + return groupIDs +} + +func generateBulkEditTemplateID() string { + buf := make([]byte, 12) + if _, err := bulkEditTemplateRandRead(buf); err == nil { + return "btpl-" + hex.EncodeToString(buf) + } + return fmt.Sprintf("btpl-%d", time.Now().UnixNano()) +} + +func generateBulkEditTemplateVersionID() string { + buf := make([]byte, 12) + if _, err := bulkEditTemplateRandRead(buf); err == nil { + return "btplv-" + hex.EncodeToString(buf) + } + return fmt.Sprintf("btplv-%d", time.Now().UnixNano()) +} diff --git a/backend/internal/service/setting_bulk_edit_template_test.go b/backend/internal/service/setting_bulk_edit_template_test.go new file mode 100644 index 000000000..c3a0351fa --- /dev/null +++ b/backend/internal/service/setting_bulk_edit_template_test.go @@ -0,0 +1,860 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +type bulkTemplateSettingRepoStub struct { + values map[string]string +} + +func newBulkTemplateSettingRepoStub() *bulkTemplateSettingRepoStub { + return &bulkTemplateSettingRepoStub{values: map[string]string{}} +} + +func (s *bulkTemplateSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + value, err := s.GetValue(ctx, key) + if err != nil { + return nil, err + } + return &Setting{Key: key, Value: value}, nil +} + +func (s *bulkTemplateSettingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + value, ok := s.values[key] + if !ok { + return "", ErrSettingNotFound + } + return value, nil +} + +func (s *bulkTemplateSettingRepoStub) Set(ctx context.Context, key, value string) error { + s.values[key] = value + return nil +} + +func (s *bulkTemplateSettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + out := make(map[string]string, len(keys)) + for _, key := range keys { + if value, ok := s.values[key]; ok { + out[key] = value + } + } + return out, nil +} + +func (s *bulkTemplateSettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + for key, value := range settings { + s.values[key] = value + } + return nil +} + +func (s *bulkTemplateSettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + out := make(map[string]string, len(s.values)) + for key, value := range s.values { + out[key] = value + } + return out, nil +} + +func (s *bulkTemplateSettingRepoStub) Delete(ctx context.Context, key string) error { + delete(s.values, key) + return nil +} + +type bulkTemplateFailingRepoStub struct{} + +func (s *bulkTemplateFailingRepoStub) Get(ctx context.Context, key string) (*Setting, error) { + return nil, errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) GetValue(ctx context.Context, key string) (string, error) { + return "", errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) Set(ctx context.Context, key, value string) error { + return errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { + return nil, errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { + return errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) GetAll(ctx context.Context) (map[string]string, error) { + return nil, errors.New("boom") +} +func (s *bulkTemplateFailingRepoStub) Delete(ctx context.Context, key string) error { + return errors.New("boom") +} + +func TestSettingServiceBulkEditTemplate_UpsertAndPrivateVisibility(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "OpenAI OAuth Baseline", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"enableOpenAIWSMode": true}, + RequesterUserID: 11, + }) + require.NoError(t, err) + require.NotEmpty(t, created.ID) + require.Equal(t, BulkEditTemplateShareScopePrivate, created.ShareScope) + + listByOwner, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 11, + }) + require.NoError(t, err) + require.Len(t, listByOwner, 1) + require.Equal(t, created.ID, listByOwner[0].ID) + require.Equal(t, true, listByOwner[0].State["enableOpenAIWSMode"]) + + listByOther, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 22, + }) + require.NoError(t, err) + require.Len(t, listByOther, 0) +} + +func TestSettingServiceBulkEditTemplate_GroupsVisibilityByScopeGroupIDs(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Shared By Group", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeGroups, + GroupIDs: []int64{10, 20}, + State: map[string]any{"enableBaseUrl": true}, + RequesterUserID: 9, + }) + require.NoError(t, err) + + invisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ScopeGroupIDs: []int64{99}, + RequesterUserID: 8, + }) + require.NoError(t, err) + require.Len(t, invisible, 0) + + visible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ScopeGroupIDs: []int64{20, 100}, + RequesterUserID: 8, + }) + require.NoError(t, err) + require.Len(t, visible, 1) + require.Equal(t, BulkEditTemplateShareScopeGroups, visible[0].ShareScope) + require.Equal(t, []int64{10, 20}, visible[0].GroupIDs) +} + +func TestSettingServiceBulkEditTemplate_UpsertByNameReplacesSameScopeRecord(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + first, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Team Baseline", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeTeam, + State: map[string]any{"priority": 1}, + RequesterUserID: 7, + }) + require.NoError(t, err) + + second, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "team baseline", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeTeam, + State: map[string]any{"priority": 9}, + RequesterUserID: 7, + }) + require.NoError(t, err) + require.Equal(t, first.ID, second.ID) + + items, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 3, + }) + require.NoError(t, err) + require.Len(t, items, 1) + require.EqualValues(t, 9, items[0].State["priority"]) +} + +func TestSettingServiceBulkEditTemplate_DeletePermissionAndNotFound(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Private Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"status": "active"}, + RequesterUserID: 1, + }) + require.NoError(t, err) + + err = svc.DeleteBulkEditTemplate(context.Background(), created.ID, 2) + require.Error(t, err) + require.True(t, infraerrors.IsForbidden(err)) + + err = svc.DeleteBulkEditTemplate(context.Background(), created.ID, 1) + require.NoError(t, err) + + ownerList, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 1, + }) + require.NoError(t, err) + require.Len(t, ownerList, 0) + + err = svc.DeleteBulkEditTemplate(context.Background(), "missing-id", 1) + require.Error(t, err) + require.True(t, infraerrors.IsNotFound(err)) +} + +func TestSettingServiceBulkEditTemplate_ValidatesInput(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Groups No IDs", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeGroups, + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Invalid Scope", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: "invalid", + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) +} + +func TestSettingServiceBulkEditTemplate_CoversInternalHelpers(t *testing.T) { + store := normalizeBulkEditTemplateLibraryStore(bulkEditTemplateLibraryStore{ + Items: []bulkEditTemplateStoreItem{ + { + ID: "same-id", + Name: " One ", + ScopePlatform: "OPENAI", + ScopeType: "OAUTH", + ShareScope: BulkEditTemplateShareScopeGroups, + GroupIDs: []int64{}, + State: nil, + }, + { + ID: "same-id", + Name: "Duplicate ID", + ScopePlatform: "openai", + ScopeType: "oauth", + ShareScope: BulkEditTemplateShareScopeTeam, + }, + { + ID: "", + Name: "Two", + ScopePlatform: "openai", + ScopeType: "apikey", + ShareScope: "invalid", + GroupIDs: []int64{5, 5, 1}, + State: []byte(`{"ok":true}`), + }, + { + ID: "invalid-entry", + Name: "", + ScopePlatform: "openai", + ScopeType: "oauth", + }, + }, + }) + require.Len(t, store.Items, 2) + require.Equal(t, "same-id", store.Items[0].ID) + require.Equal(t, BulkEditTemplateShareScopePrivate, store.Items[0].ShareScope) + require.Equal(t, []int64{1, 5}, store.Items[1].GroupIDs) + require.NotEmpty(t, store.Items[1].ID) + + require.Equal(t, BulkEditTemplateShareScopeTeam, normalizeBulkEditTemplateShareScopeOrDefault("team")) + require.Equal(t, BulkEditTemplateShareScopePrivate, normalizeBulkEditTemplateShareScopeOrDefault("bad")) + require.Equal(t, []int64{}, normalizeBulkEditTemplateGroupIDs(nil)) + + scope, err := validateBulkEditTemplateShareScope("") + require.NoError(t, err) + require.Equal(t, BulkEditTemplateShareScopePrivate, scope) + _, err = validateBulkEditTemplateShareScope("bad") + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + require.True(t, isBulkEditTemplateVisible( + bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeTeam}, + 1, + map[int64]struct{}{}, + )) + require.False(t, isBulkEditTemplateVisible( + bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeGroups, GroupIDs: []int64{2}}, + 1, + map[int64]struct{}{}, + )) + require.True(t, isBulkEditTemplateVisible( + bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeGroups, GroupIDs: []int64{2}}, + 1, + map[int64]struct{}{2: {}}, + )) + require.True(t, isBulkEditTemplateVisible( + bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopePrivate, CreatedBy: 9}, + 9, + nil, + )) + require.False(t, isBulkEditTemplateVisible( + bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopePrivate, CreatedBy: 9}, + 1, + nil, + )) + + converted := toBulkEditTemplate(bulkEditTemplateStoreItem{ + ID: "id-1", + Name: "Demo", + ScopePlatform: "openai", + ScopeType: "oauth", + ShareScope: BulkEditTemplateShareScopePrivate, + GroupIDs: []int64{3}, + State: []byte(`invalid-json`), + }) + require.Equal(t, map[string]any{}, converted.State) +} + +func TestSettingServiceBulkEditTemplate_LoadPersistBranches(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + err := svc.persistBulkEditTemplateLibrary(context.Background(), nil) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + repo.values[SettingKeyBulkEditTemplateLibrary] = "{bad-json" + store, err := svc.loadBulkEditTemplateLibrary(context.Background()) + require.Error(t, err) + require.Nil(t, store) + + delete(repo.values, SettingKeyBulkEditTemplateLibrary) + store, err = svc.loadBulkEditTemplateLibrary(context.Background()) + require.NoError(t, err) + require.NotNil(t, store) + require.Empty(t, store.Items) +} + +func TestSettingServiceBulkEditTemplate_UpsertByMismatchedID(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Scoped Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{}, + RequesterUserID: 1, + }) + require.NoError(t, err) + require.NotEmpty(t, created.ID) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + ID: "another-id", + Name: "Scoped Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{}, + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsNotFound(err)) + + err = svc.DeleteBulkEditTemplate(context.Background(), "", 1) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) +} + +func TestSettingServiceBulkEditTemplate_PrivateTemplateIsolationAcrossUsers(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + ownerTemplate, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Private Scoped Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"priority": 1}, + RequesterUserID: 101, + }) + require.NoError(t, err) + + otherTemplate, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Private Scoped Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"priority": 9}, + RequesterUserID: 202, + }) + require.NoError(t, err) + require.NotEqual(t, ownerTemplate.ID, otherTemplate.ID, "不同用户的私有同名模板不应互相覆盖") + + ownerVisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 101, + }) + require.NoError(t, err) + require.Len(t, ownerVisible, 1) + require.Equal(t, ownerTemplate.ID, ownerVisible[0].ID) + require.EqualValues(t, 1, ownerVisible[0].State["priority"]) + + otherVisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + RequesterUserID: 202, + }) + require.NoError(t, err) + require.Len(t, otherVisible, 1) + require.Equal(t, otherTemplate.ID, otherVisible[0].ID) + require.EqualValues(t, 9, otherVisible[0].State["priority"]) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + ID: ownerTemplate.ID, + Name: ownerTemplate.Name, + ScopePlatform: ownerTemplate.ScopePlatform, + ScopeType: ownerTemplate.ScopeType, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"priority": 99}, + RequesterUserID: 202, + }) + require.Error(t, err) + require.True(t, infraerrors.IsForbidden(err), "非 owner 不允许通过 template ID 修改私有模板") +} + +func TestSettingServiceBulkEditTemplate_UpsertFailsWhenStoredLibraryCorrupted(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + repo.values[SettingKeyBulkEditTemplateLibrary] = "{bad-json" + svc := NewSettingService(repo, nil) + + _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Should Fail", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"ok": true}, + RequesterUserID: 1, + }) + require.Error(t, err) +} + +func TestSettingServiceBulkEditTemplate_ListFilteringAndSorting(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + store := bulkEditTemplateLibraryStore{ + Items: []bulkEditTemplateStoreItem{ + { + ID: "b", + Name: "Second", + ScopePlatform: "openai", + ScopeType: "oauth", + ShareScope: BulkEditTemplateShareScopeTeam, + State: []byte(`{}`), + CreatedBy: 1, + UpdatedBy: 1, + CreatedAt: 1, + UpdatedAt: 100, + }, + { + ID: "a", + Name: "First", + ScopePlatform: "openai", + ScopeType: "oauth", + ShareScope: BulkEditTemplateShareScopeTeam, + State: []byte(`{}`), + CreatedBy: 1, + UpdatedBy: 1, + CreatedAt: 1, + UpdatedAt: 100, + }, + { + ID: "skip-type", + Name: "Skip Type", + ScopePlatform: "openai", + ScopeType: "apikey", + ShareScope: BulkEditTemplateShareScopeTeam, + State: []byte(`{}`), + CreatedBy: 1, + UpdatedBy: 1, + CreatedAt: 1, + UpdatedAt: 999, + }, + { + ID: "skip-private", + Name: "Skip Private", + ScopePlatform: "openai", + ScopeType: "oauth", + ShareScope: BulkEditTemplateShareScopePrivate, + State: []byte(`{}`), + CreatedBy: 99, + UpdatedBy: 99, + CreatedAt: 1, + UpdatedAt: 1000, + }, + }, + } + require.NoError(t, svc.persistBulkEditTemplateLibrary(context.Background(), &store)) + + items, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{ + ScopePlatform: "openai", + ScopeType: "oauth", + RequesterUserID: 1, + }) + require.NoError(t, err) + require.Len(t, items, 2) + require.Equal(t, []string{"a", "b"}, []string{items[0].ID, items[1].ID}) +} + +func TestSettingServiceBulkEditTemplate_UpsertByIDAndMarshalError(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "By ID", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeTeam, + State: map[string]any{"priority": 1}, + RequesterUserID: 1, + }) + require.NoError(t, err) + + updated, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + ID: created.ID, + Name: "By ID", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeTeam, + State: map[string]any{"priority": 2}, + RequesterUserID: 1, + }) + require.NoError(t, err) + require.Equal(t, created.ID, updated.ID) + require.EqualValues(t, 2, updated.State["priority"]) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Marshal Error", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"bad": make(chan int)}, + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Unauthorized", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{}, + RequesterUserID: 0, + }) + require.Error(t, err) + require.True(t, infraerrors.IsUnauthorized(err)) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Missing Scope", + ScopePlatform: "", + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{}, + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) +} + +func TestSettingServiceBulkEditTemplate_LoadErrorFromRepository(t *testing.T) { + svc := NewSettingService(&bulkTemplateFailingRepoStub{}, nil) + store, err := svc.loadBulkEditTemplateLibrary(context.Background()) + require.Error(t, err) + require.Nil(t, store) +} + +func TestGenerateBulkEditTemplateID_Fallback(t *testing.T) { + original := bulkEditTemplateRandRead + bulkEditTemplateRandRead = func(_ []byte) (int, error) { + return 0, errors.New("rand fail") + } + defer func() { + bulkEditTemplateRandRead = original + }() + + id := generateBulkEditTemplateID() + require.NotEmpty(t, id) + require.Contains(t, id, "btpl-") + + versionID := generateBulkEditTemplateVersionID() + require.NotEmpty(t, versionID) + require.Contains(t, versionID, "btplv-") +} + +func TestSettingServiceBulkEditTemplate_VersionLifecycleAndRollback(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Versioned Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopePrivate, + State: map[string]any{"priority": 1}, + RequesterUserID: 88, + }) + require.NoError(t, err) + + updated, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + ID: created.ID, + Name: "Versioned Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeTeam, + State: map[string]any{"priority": 9}, + RequesterUserID: 88, + }) + require.NoError(t, err) + require.Equal(t, BulkEditTemplateShareScopeTeam, updated.ShareScope) + + versions, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: created.ID, + RequesterUserID: 88, + }) + require.NoError(t, err) + require.Len(t, versions, 1) + require.Equal(t, BulkEditTemplateShareScopePrivate, versions[0].ShareScope) + require.EqualValues(t, 1, versions[0].State["priority"]) + + rollbacked, err := svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{ + TemplateID: created.ID, + VersionID: versions[0].VersionID, + RequesterUserID: 88, + }) + require.NoError(t, err) + require.Equal(t, BulkEditTemplateShareScopePrivate, rollbacked.ShareScope) + require.EqualValues(t, 1, rollbacked.State["priority"]) + + versionsAfterRollback, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: created.ID, + RequesterUserID: 88, + }) + require.NoError(t, err) + require.Len(t, versionsAfterRollback, 2) +} + +func TestSettingServiceBulkEditTemplate_VersionVisibilityAndErrors(t *testing.T) { + repo := newBulkTemplateSettingRepoStub() + svc := NewSettingService(repo, nil) + + created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + Name: "Group Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeGroups, + GroupIDs: []int64{7}, + State: map[string]any{"enableBaseUrl": true}, + RequesterUserID: 1, + }) + require.NoError(t, err) + + _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{ + ID: created.ID, + Name: "Group Template", + ScopePlatform: PlatformOpenAI, + ScopeType: AccountTypeOAuth, + ShareScope: BulkEditTemplateShareScopeGroups, + GroupIDs: []int64{7, 9}, + State: map[string]any{"enableBaseUrl": false}, + RequesterUserID: 1, + }) + require.NoError(t, err) + + _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: created.ID, + ScopeGroupIDs: []int64{8}, + RequesterUserID: 2, + }) + require.Error(t, err) + require.True(t, infraerrors.IsForbidden(err)) + + visibleVersions, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: created.ID, + ScopeGroupIDs: []int64{7}, + RequesterUserID: 2, + }) + require.NoError(t, err) + require.Len(t, visibleVersions, 1) + + _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: "missing", + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsNotFound(err)) + + _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: created.ID, + RequesterUserID: 0, + }) + require.Error(t, err) + require.True(t, infraerrors.IsUnauthorized(err)) + + _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{ + TemplateID: created.ID, + VersionID: "missing-version", + ScopeGroupIDs: []int64{7}, + RequesterUserID: 2, + }) + require.Error(t, err) + require.True(t, infraerrors.IsNotFound(err)) + + _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{ + TemplateID: created.ID, + VersionID: visibleVersions[0].VersionID, + ScopeGroupIDs: []int64{8}, + RequesterUserID: 2, + }) + require.Error(t, err) + require.True(t, infraerrors.IsForbidden(err)) + + _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{ + TemplateID: " ", + RequesterUserID: 1, + }) + require.Error(t, err) + require.True(t, infraerrors.IsBadRequest(err)) + + _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{ + TemplateID: created.ID, + VersionID: visibleVersions[0].VersionID, + RequesterUserID: 0, + }) + require.Error(t, err) + require.True(t, infraerrors.IsUnauthorized(err)) +} + +func TestSettingServiceBulkEditTemplate_VersionHelpers(t *testing.T) { + normalized := normalizeBulkEditTemplateVersionStoreItems([]bulkEditTemplateVersionStoreItem{ + { + VersionID: "", + ShareScope: "groups", + GroupIDs: []int64{}, + State: nil, + UpdatedBy: 1, + UpdatedAt: 0, + }, + { + VersionID: "v-1", + ShareScope: "team", + GroupIDs: []int64{4, 4, 2}, + State: []byte(`{"ok":true}`), + UpdatedBy: 2, + UpdatedAt: 20, + }, + { + VersionID: "v-1", + ShareScope: "team", + State: []byte(`{}`), + UpdatedAt: 30, + }, + }) + require.Len(t, normalized, 2) + privateCount := 0 + teamCount := 0 + for _, item := range normalized { + if item.ShareScope == BulkEditTemplateShareScopePrivate { + privateCount++ + } + if item.ShareScope == BulkEditTemplateShareScopeTeam { + teamCount++ + require.Equal(t, []int64{2, 4}, item.GroupIDs) + } + } + require.Equal(t, 1, privateCount) + require.Equal(t, 1, teamCount) + + item := bulkEditTemplateStoreItem{ + ID: "tpl-1", + ShareScope: BulkEditTemplateShareScopeTeam, + GroupIDs: []int64{3}, + State: []byte(`{"priority":3}`), + UpdatedBy: 10, + UpdatedAt: 123, + } + version := snapshotBulkEditTemplateVersion(item) + require.NotEmpty(t, version.VersionID) + require.Equal(t, BulkEditTemplateShareScopeTeam, version.ShareScope) + require.EqualValues(t, 123, version.UpdatedAt) + + versionDTO := toBulkEditTemplateVersion(bulkEditTemplateVersionStoreItem{ + VersionID: "ver-1", + ShareScope: BulkEditTemplateShareScopePrivate, + GroupIDs: []int64{9}, + State: []byte(`invalid`), + UpdatedBy: 1, + UpdatedAt: 2, + }) + require.Equal(t, map[string]any{}, versionDTO.State) + + require.Equal(t, -1, findBulkEditTemplateVersionIndexByID(nil, "x")) + require.Equal(t, -1, findBulkEditTemplateStoreItemIndexByID(nil, "x")) + require.Nil(t, findBulkEditTemplateStoreItemByID(nil, "x")) + + scopeSet := toBulkEditTemplateScopeGroupSet([]int64{4, 4, 2, -1}) + _, has2 := scopeSet[2] + _, has4 := scopeSet[4] + require.True(t, has2) + require.True(t, has4) + + cloned := cloneBulkEditTemplateStateRaw(json.RawMessage(`{"x":1}`)) + require.Equal(t, `{"x":1}`, string(cloned)) + require.Equal(t, `{}`, string(cloneBulkEditTemplateStateRaw(nil))) +} diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f7e4fb6be..776cab85a 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -7,31 +7,22 @@ import ( "encoding/json" "errors" "fmt" - "log/slog" "net/url" + "sort" "strconv" "strings" - "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" - "golang.org/x/sync/singleflight" + "github.com/tidwall/gjson" ) var ( - ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") - ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") - ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") - ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") - ErrDefaultSubGroupInvalid = infraerrors.BadRequest( - "DEFAULT_SUBSCRIPTION_GROUP_INVALID", - "default subscription group must exist and be subscription type", - ) - ErrDefaultSubGroupDuplicate = infraerrors.BadRequest( - "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", - "default subscription group cannot be duplicated", - ) + ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") + ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") + ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") ) type SettingRepository interface { @@ -44,40 +35,13 @@ type SettingRepository interface { Delete(ctx context.Context, key string) error } -// cachedMinVersion 缓存最低 Claude Code 版本号(进程内缓存,60s TTL) -type cachedMinVersion struct { - value string // 空字符串 = 不检查 - expiresAt int64 // unix nano -} - -// minVersionCache 最低版本号进程内缓存 -var minVersionCache atomic.Value // *cachedMinVersion - -// minVersionSF 防止缓存过期时 thundering herd -var minVersionSF singleflight.Group - -// minVersionCacheTTL 缓存有效期 -const minVersionCacheTTL = 60 * time.Second - -// minVersionErrorTTL DB 错误时的短缓存,快速重试 -const minVersionErrorTTL = 5 * time.Second - -// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context -const minVersionDBTimeout = 5 * time.Second - -// DefaultSubscriptionGroupReader validates group references used by default subscriptions. -type DefaultSubscriptionGroupReader interface { - GetByID(ctx context.Context, id int64) (*Group, error) -} - // SettingService 系统设置服务 type SettingService struct { - settingRepo SettingRepository - defaultSubGroupReader DefaultSubscriptionGroupReader - cfg *config.Config - onUpdate func() // Callback when settings are updated (for cache invalidation) - onS3Update func() // Callback when Sora S3 settings are updated - version string // Application version + settingRepo SettingRepository + cfg *config.Config + onUpdate func() // Callback when settings are updated (for cache invalidation) + onS3Update func() // Callback when Sora S3 settings are updated + version string // Application version } // NewSettingService 创建系统设置服务实例 @@ -88,11 +52,6 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti } } -// SetDefaultSubscriptionGroupReader injects an optional group reader for default subscription validation. -func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscriptionGroupReader) { - s.defaultSubGroupReader = reader -} - // GetAllSettings 获取所有系统设置 func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) { settings, err := s.settingRepo.GetAll(ctx) @@ -125,7 +84,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyPurchaseSubscriptionEnabled, SettingKeyPurchaseSubscriptionURL, SettingKeySoraClientEnabled, - SettingKeyCustomMenuItems, SettingKeyLinuxDoConnectEnabled, } @@ -165,7 +123,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", - CustomMenuItems: settings[SettingKeyCustomMenuItems], LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -196,28 +153,27 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any // Return a struct that matches the frontend's expected format return &struct { - RegistrationEnabled bool `json:"registration_enabled"` - EmailVerifyEnabled bool `json:"email_verify_enabled"` - PromoCodeEnabled bool `json:"promo_code_enabled"` - PasswordResetEnabled bool `json:"password_reset_enabled"` - InvitationCodeEnabled bool `json:"invitation_code_enabled"` - TotpEnabled bool `json:"totp_enabled"` - TurnstileEnabled bool `json:"turnstile_enabled"` - TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo,omitempty"` - SiteSubtitle string `json:"site_subtitle,omitempty"` - APIBaseURL string `json:"api_base_url,omitempty"` - ContactInfo string `json:"contact_info,omitempty"` - DocURL string `json:"doc_url,omitempty"` - HomeContent string `json:"home_content,omitempty"` - HideCcsImportButton bool `json:"hide_ccs_import_button"` - PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` - PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` - SoraClientEnabled bool `json:"sora_client_enabled"` - CustomMenuItems json.RawMessage `json:"custom_menu_items"` - LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` - Version string `json:"version,omitempty"` + RegistrationEnabled bool `json:"registration_enabled"` + EmailVerifyEnabled bool `json:"email_verify_enabled"` + PromoCodeEnabled bool `json:"promo_code_enabled"` + PasswordResetEnabled bool `json:"password_reset_enabled"` + InvitationCodeEnabled bool `json:"invitation_code_enabled"` + TotpEnabled bool `json:"totp_enabled"` + TurnstileEnabled bool `json:"turnstile_enabled"` + TurnstileSiteKey string `json:"turnstile_site_key,omitempty"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo,omitempty"` + SiteSubtitle string `json:"site_subtitle,omitempty"` + APIBaseURL string `json:"api_base_url,omitempty"` + ContactInfo string `json:"contact_info,omitempty"` + DocURL string `json:"doc_url,omitempty"` + HomeContent string `json:"home_content,omitempty"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` + PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` + PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` + SoraClientEnabled bool `json:"sora_client_enabled"` + LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + Version string `json:"version,omitempty"` }{ RegistrationEnabled: settings.RegistrationEnabled, EmailVerifyEnabled: settings.EmailVerifyEnabled, @@ -238,125 +194,65 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, SoraClientEnabled: settings.SoraClientEnabled, - CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems), LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: s.version, }, nil } -// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON -// array string, returning only items with visibility != "admin". -func filterUserVisibleMenuItems(raw string) json.RawMessage { - raw = strings.TrimSpace(raw) - if raw == "" || raw == "[]" { - return json.RawMessage("[]") - } - var items []struct { - Visibility string `json:"visibility"` - } - if err := json.Unmarshal([]byte(raw), &items); err != nil { - return json.RawMessage("[]") - } - - // Parse full items to preserve all fields - var fullItems []json.RawMessage - if err := json.Unmarshal([]byte(raw), &fullItems); err != nil { - return json.RawMessage("[]") - } - - var filtered []json.RawMessage - for i, item := range items { - if item.Visibility != "admin" { - filtered = append(filtered, fullItems[i]) - } - } - if len(filtered) == 0 { - return json.RawMessage("[]") +// GetFrameSrcOrigins 提取需要注入 CSP frame-src 的外部域名来源。 +func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { + keys := []string{ + SettingKeyPurchaseSubscriptionURL, + SettingKeyHomeContent, + SettingKeyCustomMenuItems, } - result, err := json.Marshal(filtered) + settings, err := s.settingRepo.GetMultiple(ctx, keys) if err != nil { - return json.RawMessage("[]") + return nil, fmt.Errorf("get frame src settings: %w", err) } - return result -} -// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url -// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection. -func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) { - settings, err := s.GetPublicSettings(ctx) - if err != nil { - return nil, err + originsSet := make(map[string]struct{}, 8) + addOrigin := func(raw string) { + raw = strings.TrimSpace(raw) + if raw == "" { + return + } + u, parseErr := url.Parse(raw) + if parseErr != nil || u == nil { + return + } + if u.Scheme != "http" && u.Scheme != "https" { + return + } + if u.Host == "" { + return + } + originsSet[u.Scheme+"://"+u.Host] = struct{}{} } - seen := make(map[string]struct{}) - var origins []string + addOrigin(settings[SettingKeyPurchaseSubscriptionURL]) + addOrigin(settings[SettingKeyHomeContent]) - addOrigin := func(rawURL string) { - if origin := extractOriginFromURL(rawURL); origin != "" { - if _, ok := seen[origin]; !ok { - seen[origin] = struct{}{} - origins = append(origins, origin) + customMenuRaw := strings.TrimSpace(settings[SettingKeyCustomMenuItems]) + if customMenuRaw != "" { + menuItems := gjson.Parse(customMenuRaw) + if menuItems.IsArray() { + for _, item := range menuItems.Array() { + addOrigin(item.Get("url").String()) } } } - // purchase subscription URL - if settings.PurchaseSubscriptionEnabled { - addOrigin(settings.PurchaseSubscriptionURL) - } - - // all custom menu items (including admin-only, since CSP must allow all iframes) - for _, item := range parseCustomMenuItemURLs(settings.CustomMenuItems) { - addOrigin(item) + origins := make([]string, 0, len(originsSet)) + for origin := range originsSet { + origins = append(origins, origin) } - + sort.Strings(origins) return origins, nil } -// extractOriginFromURL returns the scheme+host origin from rawURL. -// Only http and https schemes are accepted. -func extractOriginFromURL(rawURL string) string { - rawURL = strings.TrimSpace(rawURL) - if rawURL == "" { - return "" - } - u, err := url.Parse(rawURL) - if err != nil || u.Host == "" { - return "" - } - if u.Scheme != "http" && u.Scheme != "https" { - return "" - } - return u.Scheme + "://" + u.Host -} - -// parseCustomMenuItemURLs extracts URLs from a raw JSON array of custom menu items. -func parseCustomMenuItemURLs(raw string) []string { - raw = strings.TrimSpace(raw) - if raw == "" || raw == "[]" { - return nil - } - var items []struct { - URL string `json:"url"` - } - if err := json.Unmarshal([]byte(raw), &items); err != nil { - return nil - } - urls := make([]string, 0, len(items)) - for _, item := range items { - if item.URL != "" { - urls = append(urls, item.URL) - } - } - return urls -} - // UpdateSettings 更新系统设置 func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error { - if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil { - return err - } - updates := make(map[string]string) // 注册设置 @@ -405,16 +301,10 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) - updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) - defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) - if err != nil { - return fmt.Errorf("marshal default subscriptions: %w", err) - } - updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON) // Model fallback configuration updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback) @@ -435,63 +325,13 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds) } - // Claude Code version check - updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion - - err = s.settingRepo.SetMultiple(ctx, updates) - if err == nil { - // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口 - minVersionSF.Forget("min_version") - minVersionCache.Store(&cachedMinVersion{ - value: settings.MinClaudeCodeVersion, - expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), - }) - if s.onUpdate != nil { - s.onUpdate() // Invalidate cache after settings update - } + err := s.settingRepo.SetMultiple(ctx, updates) + if err == nil && s.onUpdate != nil { + s.onUpdate() // Invalidate cache after settings update } return err } -func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error { - if len(items) == 0 { - return nil - } - - checked := make(map[int64]struct{}, len(items)) - for _, item := range items { - if item.GroupID <= 0 { - continue - } - if _, ok := checked[item.GroupID]; ok { - return ErrDefaultSubGroupDuplicate.WithMetadata(map[string]string{ - "group_id": strconv.FormatInt(item.GroupID, 10), - }) - } - checked[item.GroupID] = struct{}{} - if s.defaultSubGroupReader == nil { - continue - } - - group, err := s.defaultSubGroupReader.GetByID(ctx, item.GroupID) - if err != nil { - if errors.Is(err, ErrGroupNotFound) { - return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ - "group_id": strconv.FormatInt(item.GroupID, 10), - }) - } - return fmt.Errorf("get default subscription group %d: %w", item.GroupID, err) - } - if !group.IsSubscriptionType() { - return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{ - "group_id": strconv.FormatInt(item.GroupID, 10), - }) - } - } - - return nil -} - // IsRegistrationEnabled 检查是否开放注册 func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool { value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled) @@ -591,15 +431,6 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { return s.cfg.Default.UserBalance } -// GetDefaultSubscriptions 获取新用户默认订阅配置列表。 -func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting { - value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions) - if err != nil { - return nil - } - return parseDefaultSubscriptions(value) -} - // InitializeDefaultSettings 初始化默认设置 func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { // 检查是否已有设置 @@ -622,10 +453,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyPurchaseSubscriptionEnabled: "false", SettingKeyPurchaseSubscriptionURL: "", SettingKeySoraClientEnabled: "false", - SettingKeyCustomMenuItems: "[]", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), - SettingKeyDefaultSubscriptions: "[]", SettingKeySMTPPort: "587", SettingKeySMTPUseTLS: "false", // Model fallback defaults @@ -643,9 +472,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyOpsRealtimeMonitoringEnabled: "true", SettingKeyOpsQueryModeDefault: "auto", SettingKeyOpsMetricsIntervalSeconds: "60", - - // Claude Code version check (default: empty = disabled) - SettingKeyMinClaudeCodeVersion: "", } return s.settingRepo.SetMultiple(ctx, defaults) @@ -681,7 +507,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", - CustomMenuItems: settings[SettingKeyCustomMenuItems], } // 解析整数类型 @@ -703,7 +528,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } else { result.DefaultBalance = s.cfg.Default.UserBalance } - result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions]) // 敏感信息直接返回,方便测试连接时使用 result.SMTPPassword = settings[SettingKeySMTPPassword] @@ -773,9 +597,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin } } - // Claude Code version check - result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion] - return result } @@ -788,31 +609,6 @@ func isFalseSettingValue(value string) bool { } } -func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting { - raw = strings.TrimSpace(raw) - if raw == "" { - return nil - } - - var items []DefaultSubscriptionSetting - if err := json.Unmarshal([]byte(raw), &items); err != nil { - return nil - } - - normalized := make([]DefaultSubscriptionSetting, 0, len(items)) - for _, item := range items { - if item.GroupID <= 0 || item.ValidityDays <= 0 { - continue - } - if item.ValidityDays > MaxValidityDays { - item.ValidityDays = MaxValidityDays - } - normalized = append(normalized, item) - } - - return normalized -} - // getStringOrDefault 获取字符串值或默认值 func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string { if value, ok := settings[key]; ok && value != "" { @@ -1098,53 +894,6 @@ func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamT return &settings, nil } -// GetMinClaudeCodeVersion 获取最低 Claude Code 版本号要求 -// 使用进程内 atomic.Value 缓存,60 秒 TTL,热路径零锁开销 -// singleflight 防止缓存过期时 thundering herd -// 返回空字符串表示不做版本检查 -func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string { - if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok { - if time.Now().UnixNano() < cached.expiresAt { - return cached.value - } - } - // singleflight: 同一时刻只有一个 goroutine 查询 DB,其余复用结果 - result, err, _ := minVersionSF.Do("min_version", func() (any, error) { - // 二次检查,避免排队的 goroutine 重复查询 - if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok { - if time.Now().UnixNano() < cached.expiresAt { - return cached.value, nil - } - } - // 使用独立 context:断开请求取消链,避免客户端断连导致空值被长期缓存 - dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), minVersionDBTimeout) - defer cancel() - value, err := s.settingRepo.GetValue(dbCtx, SettingKeyMinClaudeCodeVersion) - if err != nil { - // fail-open: DB 错误时不阻塞请求,但记录日志并使用短 TTL 快速重试 - slog.Warn("failed to get min claude code version setting, skipping version check", "error", err) - minVersionCache.Store(&cachedMinVersion{ - value: "", - expiresAt: time.Now().Add(minVersionErrorTTL).UnixNano(), - }) - return "", nil - } - minVersionCache.Store(&cachedMinVersion{ - value: value, - expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(), - }) - return value, nil - }) - if err != nil { - return "" - } - ver, ok := result.(string) - if !ok { - return "" - } - return ver -} - // SetStreamTimeoutSettings 设置流超时处理配置 func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error { if settings == nil { diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go deleted file mode 100644 index ec64511f2..000000000 --- a/backend/internal/service/setting_service_update_test.go +++ /dev/null @@ -1,182 +0,0 @@ -//go:build unit - -package service - -import ( - "context" - "encoding/json" - "testing" - - "github.com/Wei-Shaw/sub2api/internal/config" - infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" - "github.com/stretchr/testify/require" -) - -type settingUpdateRepoStub struct { - updates map[string]string -} - -func (s *settingUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) { - panic("unexpected Get call") -} - -func (s *settingUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) { - panic("unexpected GetValue call") -} - -func (s *settingUpdateRepoStub) Set(ctx context.Context, key, value string) error { - panic("unexpected Set call") -} - -func (s *settingUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) { - panic("unexpected GetMultiple call") -} - -func (s *settingUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error { - s.updates = make(map[string]string, len(settings)) - for k, v := range settings { - s.updates[k] = v - } - return nil -} - -func (s *settingUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) { - panic("unexpected GetAll call") -} - -func (s *settingUpdateRepoStub) Delete(ctx context.Context, key string) error { - panic("unexpected Delete call") -} - -type defaultSubGroupReaderStub struct { - byID map[int64]*Group - errBy map[int64]error - calls []int64 -} - -func (s *defaultSubGroupReaderStub) GetByID(ctx context.Context, id int64) (*Group, error) { - s.calls = append(s.calls, id) - if err, ok := s.errBy[id]; ok { - return nil, err - } - if g, ok := s.byID[id]; ok { - return g, nil - } - return nil, ErrGroupNotFound -} - -func TestSettingService_UpdateSettings_DefaultSubscriptions_ValidGroup(t *testing.T) { - repo := &settingUpdateRepoStub{} - groupReader := &defaultSubGroupReaderStub{ - byID: map[int64]*Group{ - 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription}, - }, - } - svc := NewSettingService(repo, &config.Config{}) - svc.SetDefaultSubscriptionGroupReader(groupReader) - - err := svc.UpdateSettings(context.Background(), &SystemSettings{ - DefaultSubscriptions: []DefaultSubscriptionSetting{ - {GroupID: 11, ValidityDays: 30}, - }, - }) - require.NoError(t, err) - require.Equal(t, []int64{11}, groupReader.calls) - - raw, ok := repo.updates[SettingKeyDefaultSubscriptions] - require.True(t, ok) - - var got []DefaultSubscriptionSetting - require.NoError(t, json.Unmarshal([]byte(raw), &got)) - require.Equal(t, []DefaultSubscriptionSetting{ - {GroupID: 11, ValidityDays: 30}, - }, got) -} - -func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNonSubscriptionGroup(t *testing.T) { - repo := &settingUpdateRepoStub{} - groupReader := &defaultSubGroupReaderStub{ - byID: map[int64]*Group{ - 12: {ID: 12, SubscriptionType: SubscriptionTypeStandard}, - }, - } - svc := NewSettingService(repo, &config.Config{}) - svc.SetDefaultSubscriptionGroupReader(groupReader) - - err := svc.UpdateSettings(context.Background(), &SystemSettings{ - DefaultSubscriptions: []DefaultSubscriptionSetting{ - {GroupID: 12, ValidityDays: 7}, - }, - }) - require.Error(t, err) - require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err)) - require.Nil(t, repo.updates) -} - -func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNotFoundGroup(t *testing.T) { - repo := &settingUpdateRepoStub{} - groupReader := &defaultSubGroupReaderStub{ - errBy: map[int64]error{ - 13: ErrGroupNotFound, - }, - } - svc := NewSettingService(repo, &config.Config{}) - svc.SetDefaultSubscriptionGroupReader(groupReader) - - err := svc.UpdateSettings(context.Background(), &SystemSettings{ - DefaultSubscriptions: []DefaultSubscriptionSetting{ - {GroupID: 13, ValidityDays: 7}, - }, - }) - require.Error(t, err) - require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err)) - require.Equal(t, "13", infraerrors.FromError(err).Metadata["group_id"]) - require.Nil(t, repo.updates) -} - -func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroup(t *testing.T) { - repo := &settingUpdateRepoStub{} - groupReader := &defaultSubGroupReaderStub{ - byID: map[int64]*Group{ - 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription}, - }, - } - svc := NewSettingService(repo, &config.Config{}) - svc.SetDefaultSubscriptionGroupReader(groupReader) - - err := svc.UpdateSettings(context.Background(), &SystemSettings{ - DefaultSubscriptions: []DefaultSubscriptionSetting{ - {GroupID: 11, ValidityDays: 30}, - {GroupID: 11, ValidityDays: 60}, - }, - }) - require.Error(t, err) - require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err)) - require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"]) - require.Nil(t, repo.updates) -} - -func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroupWithoutGroupReader(t *testing.T) { - repo := &settingUpdateRepoStub{} - svc := NewSettingService(repo, &config.Config{}) - - err := svc.UpdateSettings(context.Background(), &SystemSettings{ - DefaultSubscriptions: []DefaultSubscriptionSetting{ - {GroupID: 11, ValidityDays: 30}, - {GroupID: 11, ValidityDays: 60}, - }, - }) - require.Error(t, err) - require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err)) - require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"]) - require.Nil(t, repo.updates) -} - -func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) { - got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`) - require.Equal(t, []DefaultSubscriptionSetting{ - {GroupID: 11, ValidityDays: 30}, - {GroupID: 11, ValidityDays: 60}, - {GroupID: 12, ValidityDays: MaxValidityDays}, - }, got) -} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 9f0de6000..6611f9901 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -42,9 +42,8 @@ type SystemSettings struct { SoraClientEnabled bool CustomMenuItems string // JSON array of custom menu items - DefaultConcurrency int - DefaultBalance float64 - DefaultSubscriptions []DefaultSubscriptionSetting + DefaultConcurrency int + DefaultBalance float64 // Model fallback configuration EnableModelFallback bool `json:"enable_model_fallback"` @@ -62,14 +61,6 @@ type SystemSettings struct { OpsRealtimeMonitoringEnabled bool OpsQueryModeDefault string OpsMetricsIntervalSeconds int - - // Claude Code version check - MinClaudeCodeVersion string -} - -type DefaultSubscriptionSetting struct { - GroupID int64 `json:"group_id"` - ValidityDays int `json:"validity_days"` } type PublicSettings struct { diff --git a/backend/internal/service/sora_generation_service_test.go b/backend/internal/service/sora_generation_service_test.go index 46f322c82..a5c0c890d 100644 --- a/backend/internal/service/sora_generation_service_test.go +++ b/backend/internal/service/sora_generation_service_test.go @@ -162,12 +162,12 @@ func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, err func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { return 0, nil } -func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil } -func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil } -func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil } func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } +func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil } +func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil } // ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ==================== diff --git a/backend/internal/service/token_refresh_parallel_test.go b/backend/internal/service/token_refresh_parallel_test.go new file mode 100644 index 000000000..c844ef934 --- /dev/null +++ b/backend/internal/service/token_refresh_parallel_test.go @@ -0,0 +1,439 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// --- 并行刷新专用 stub --- + +// concurrentTokenRefresherStub 记录并发度和调用次数 +type concurrentTokenRefresherStub struct { + canRefreshFn func(*Account) bool + needsRefreshFn func(*Account, time.Duration) bool + refreshDelay time.Duration + refreshErr error + credentials map[string]any + refreshCalls atomic.Int64 + maxConcurrent atomic.Int64 + currentActive atomic.Int64 +} + +func (r *concurrentTokenRefresherStub) CanRefresh(account *Account) bool { + if r.canRefreshFn != nil { + return r.canRefreshFn(account) + } + return true +} + +func (r *concurrentTokenRefresherStub) NeedsRefresh(account *Account, window time.Duration) bool { + if r.needsRefreshFn != nil { + return r.needsRefreshFn(account, window) + } + return true +} + +func (r *concurrentTokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) { + r.refreshCalls.Add(1) + active := r.currentActive.Add(1) + // 记录峰值并发 + for { + old := r.maxConcurrent.Load() + if active <= old || r.maxConcurrent.CompareAndSwap(old, active) { + break + } + } + if r.refreshDelay > 0 { + time.Sleep(r.refreshDelay) + } + r.currentActive.Add(-1) + if r.refreshErr != nil { + return nil, r.refreshErr + } + // 每次返回新 map,避免多 goroutine 共享同一 map 实例引发竞态 + creds := make(map[string]any, len(r.credentials)) + for k, v := range r.credentials { + creds[k] = v + } + return creds, nil +} + +// concurrentTokenRefreshAccountRepo 线程安全的 account repo stub +type concurrentTokenRefreshAccountRepo struct { + mockAccountRepoForGemini + mu sync.Mutex + updateCalls int + setErrorCalls int + activeAccounts []Account + updateErr error +} + +func (r *concurrentTokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { + r.mu.Lock() + defer r.mu.Unlock() + r.updateCalls++ + return r.updateErr +} + +func (r *concurrentTokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error { + r.mu.Lock() + defer r.mu.Unlock() + r.setErrorCalls++ + return nil +} + +func (r *concurrentTokenRefreshAccountRepo) ListActive(ctx context.Context) ([]Account, error) { + out := make([]Account, len(r.activeAccounts)) + copy(out, r.activeAccounts) + return out, nil +} + +// --- 测试用例 --- + +func TestProcessRefresh_ParallelExecution(t *testing.T) { + accounts := make([]Account, 20) + for i := range accounts { + accounts[i] = Account{ + ID: int64(100 + i), + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + } + } + + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + refreshDelay: 20 * time.Millisecond, + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + start := time.Now() + svc.processRefresh() + elapsed := time.Since(start) + + // 20 个账号每个 20ms,串行至少 400ms;并行(maxConcurrency=10)约 40-60ms + require.Equal(t, int64(20), refresher.refreshCalls.Load()) + require.Less(t, elapsed, 300*time.Millisecond, "并行刷新应显著快于串行") + require.Greater(t, refresher.maxConcurrent.Load(), int64(1), "应有多个账号并发刷新") + require.LessOrEqual(t, refresher.maxConcurrent.Load(), int64(10), "并发不应超过信号量限制") + + repo.mu.Lock() + require.Equal(t, 20, repo.updateCalls) + repo.mu.Unlock() +} + +func TestProcessRefresh_SemaphoreLimitsConcurrency(t *testing.T) { + accounts := make([]Account, 15) + for i := range accounts { + accounts[i] = Account{ + ID: int64(200 + i), + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + } + } + + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + refreshDelay: 50 * time.Millisecond, + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + svc.processRefresh() + + require.Equal(t, int64(15), refresher.refreshCalls.Load()) + require.LessOrEqual(t, refresher.maxConcurrent.Load(), int64(10), "并发不应超过 maxConcurrency=10") +} + +func TestProcessRefresh_StopInterruptsPhase2(t *testing.T) { + accounts := make([]Account, 30) + for i := range accounts { + accounts[i] = Account{ + ID: int64(300 + i), + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + } + } + + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + refreshDelay: 100 * time.Millisecond, + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + done := make(chan struct{}) + go func() { + svc.processRefresh() + close(done) + }() + + // 短暂等待让部分 goroutine 启动 + time.Sleep(30 * time.Millisecond) + svc.Stop() + + select { + case <-done: + // ok + case <-time.After(3 * time.Second): + t.Fatal("processRefresh 应在收到 stop 信号后及时退出") + } + + // 因中断,不应刷新全部 30 个账号 + require.Less(t, refresher.refreshCalls.Load(), int64(30), "stop 应中断后续任务提交") +} + +func TestProcessRefresh_EmptyAccounts(t *testing.T) { + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: nil} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg) + refresher := &concurrentTokenRefresherStub{} + svc.refreshers = []TokenRefresher{refresher} + + // 不应 panic + require.NotPanics(t, func() { + svc.processRefresh() + }) + require.Equal(t, int64(0), refresher.refreshCalls.Load()) +} + +func TestProcessRefresh_NoAccountsNeedRefresh(t *testing.T) { + accounts := []Account{ + {ID: 401, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + {ID: 402, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + } + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + needsRefreshFn: func(a *Account, d time.Duration) bool { return false }, + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + svc.processRefresh() + + require.Equal(t, int64(0), refresher.refreshCalls.Load()) +} + +func TestProcessRefresh_MixedSuccessAndFailure(t *testing.T) { + accounts := []Account{ + {ID: 501, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + {ID: 502, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + {ID: 503, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + {ID: 504, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + } + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + + // 偶数 ID 成功,奇数 ID 失败 + refresher := &concurrentTokenRefresherStub{ + credentials: map[string]any{"access_token": "tok"}, + } + + failRefresher := &concurrentTokenRefresherStub{ + refreshErr: errors.New("refresh failed"), + } + + // 使用 selectiveRefresher 按 ID 分流 + selectiveRefresher := &selectiveTokenRefresherStub{ + successRefresher: refresher, + failRefresher: failRefresher, + failIDs: map[int64]bool{501: true, 503: true}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{selectiveRefresher} + + svc.processRefresh() + + totalCalls := refresher.refreshCalls.Load() + failRefresher.refreshCalls.Load() + require.Equal(t, int64(4), totalCalls) + require.Equal(t, int64(2), refresher.refreshCalls.Load()) + require.Equal(t, int64(2), failRefresher.refreshCalls.Load()) +} + +// selectiveTokenRefresherStub 按账号 ID 分流到不同的 refresher +type selectiveTokenRefresherStub struct { + successRefresher *concurrentTokenRefresherStub + failRefresher *concurrentTokenRefresherStub + failIDs map[int64]bool +} + +func (r *selectiveTokenRefresherStub) CanRefresh(account *Account) bool { + return true +} + +func (r *selectiveTokenRefresherStub) NeedsRefresh(account *Account, window time.Duration) bool { + return true +} + +func (r *selectiveTokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) { + if r.failIDs[account.ID] { + return r.failRefresher.Refresh(ctx, account) + } + return r.successRefresher.Refresh(ctx, account) +} + +func TestProcessRefresh_SingleAccount(t *testing.T) { + accounts := []Account{ + {ID: 601, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + } + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + svc.processRefresh() + + require.Equal(t, int64(1), refresher.refreshCalls.Load()) + require.Equal(t, int64(1), refresher.maxConcurrent.Load()) +} + +func TestProcessRefresh_AllFailed(t *testing.T) { + accounts := make([]Account, 5) + for i := range accounts { + accounts[i] = Account{ + ID: int64(700 + i), + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + } + } + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + refreshErr: errors.New("all fail"), + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + // 不应 panic + require.NotPanics(t, func() { + svc.processRefresh() + }) + require.Equal(t, int64(5), refresher.refreshCalls.Load()) + + repo.mu.Lock() + require.Equal(t, 5, repo.setErrorCalls) + require.Equal(t, 0, repo.updateCalls) + repo.mu.Unlock() +} + +func TestProcessRefresh_CanRefreshFilters(t *testing.T) { + accounts := []Account{ + {ID: 801, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + {ID: 802, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive}, + {ID: 803, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive}, + } + repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts} + lockStub := &tokenRefreshSchedulerLockStub{} + refresher := &concurrentTokenRefresherStub{ + canRefreshFn: func(a *Account) bool { return a.Type == AccountTypeOAuth }, + credentials: map[string]any{"access_token": "tok"}, + } + + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 5, + RefreshBeforeExpiryHours: 1, + }, + } + svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + svc.refreshers = []TokenRefresher{refresher} + + svc.processRefresh() + + // 只有 OAuth 账号(ID 801, 803)应被刷新 + require.Equal(t, int64(2), refresher.refreshCalls.Load()) +} diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index a37e0d0ac..ae373a831 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -6,11 +6,19 @@ import ( "log/slog" "strings" "sync" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" ) +const ( + tokenRefreshDistributedLockPlatform = "token_refresh" + tokenRefreshDistributedLockMinTTL = 30 * time.Second + tokenRefreshDistributedLockMaxTTL = 10 * time.Minute + tokenRefreshDistributedLockTimeout = 2 * time.Second +) + // TokenRefreshService OAuth token自动刷新服务 // 定期检查并刷新即将过期的token type TokenRefreshService struct { @@ -20,8 +28,9 @@ type TokenRefreshService struct { cacheInvalidator TokenCacheInvalidator schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题 - stopCh chan struct{} - wg sync.WaitGroup + stopCh chan struct{} + stopOnce sync.Once + wg sync.WaitGroup } // NewTokenRefreshService 创建token刷新服务 @@ -87,7 +96,12 @@ func (s *TokenRefreshService) Start() { // Stop 停止刷新服务 func (s *TokenRefreshService) Stop() { - close(s.stopCh) + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + }) s.wg.Wait() slog.Info("token_refresh.service_stopped") } @@ -118,9 +132,24 @@ func (s *TokenRefreshService) refreshLoop() { } } +// refreshTask 封装一个待刷新的账号及其对应的刷新器 +type refreshTask struct { + account *Account + refresher TokenRefresher +} + // processRefresh 执行一次刷新检查 +// 分两阶段:先串行收集需刷新的账号,再并行执行刷新(信号量限制并发数) func (s *TokenRefreshService) processRefresh() { - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + go func() { + select { + case <-s.stopCh: + cancel() + case <-ctx.Done(): + } + }() // 计算刷新窗口 refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour)) @@ -128,68 +157,188 @@ func (s *TokenRefreshService) processRefresh() { // 获取所有active状态的账号 accounts, err := s.listActiveAccounts(ctx) if err != nil { + if ctx.Err() != nil { + slog.Info("token_refresh.cycle_interrupted_by_stop") + return + } slog.Error("token_refresh.list_accounts_failed", "error", err) return } totalAccounts := len(accounts) - oauthAccounts := 0 // 可刷新的OAuth账号数 - needsRefresh := 0 // 需要刷新的账号数 - refreshed, failed := 0, 0 + + // Phase 1: 收集需要刷新的账号(轻量操作,仍串行) + var tasks []refreshTask + oauthAccounts := 0 for i := range accounts { account := &accounts[i] - - // 遍历所有刷新器,找到能处理此账号的 for _, refresher := range s.refreshers { if !refresher.CanRefresh(account) { continue } - oauthAccounts++ - - // 检查是否需要刷新 - if !refresher.NeedsRefresh(account, refreshWindow) { - break // 不需要刷新,跳过 + if refresher.NeedsRefresh(account, refreshWindow) { + if !s.tryAcquireDistributedRefreshLock(ctx, account) { + break + } + tasks = append(tasks, refreshTask{account: account, refresher: refresher}) } + break // 每个账号只由一个refresher处理 + } + } - needsRefresh++ + needsRefresh := len(tasks) + if needsRefresh == 0 { + slog.Debug("token_refresh.cycle_completed", + "total", totalAccounts, "oauth", oauthAccounts, + "needs_refresh", 0, "refreshed", 0, "failed", 0) + return + } + + // Phase 2: 并行刷新(带信号量限制) + const maxConcurrency = 10 + sem := make(chan struct{}, maxConcurrency) + var refreshed, failed atomic.Int64 + var wg sync.WaitGroup + interrupted := false + +submitLoop: + for _, task := range tasks { + // 检查停止信号 + select { + case <-s.stopCh: + slog.Info("token_refresh.cycle_interrupted_by_stop") + interrupted = true + break submitLoop + default: + } + + select { + case sem <- struct{}{}: // 获取信号量 + case <-s.stopCh: + slog.Info("token_refresh.cycle_interrupted_by_stop") + interrupted = true + break submitLoop + case <-ctx.Done(): + interrupted = true + break submitLoop + } - // 执行刷新 - if err := s.refreshWithRetry(ctx, account, refresher); err != nil { + wg.Add(1) + go func(t refreshTask) { + defer func() { + <-sem // 释放信号量 + wg.Done() + }() + + if err := s.refreshWithRetry(ctx, t.account, t.refresher); err != nil { slog.Warn("token_refresh.account_refresh_failed", - "account_id", account.ID, - "account_name", account.Name, + "account_id", t.account.ID, + "account_name", t.account.Name, "error", err, ) - failed++ + failed.Add(1) } else { slog.Info("token_refresh.account_refreshed", - "account_id", account.ID, - "account_name", account.Name, + "account_id", t.account.ID, + "account_name", t.account.Name, ) - refreshed++ + refreshed.Add(1) } - - // 每个账号只由一个refresher处理 - break - } + }(task) } - // 无刷新活动时降级为 Debug,有实际刷新活动时保持 Info - if needsRefresh == 0 && failed == 0 { + wg.Wait() + + r, f := int(refreshed.Load()), int(failed.Load()) + if interrupted { + slog.Info("token_refresh.cycle_wait_completed_after_stop", + "needs_refresh", needsRefresh, "refreshed", r, "failed", f) + } + if needsRefresh == 0 && f == 0 { slog.Debug("token_refresh.cycle_completed", "total", totalAccounts, "oauth", oauthAccounts, - "needs_refresh", needsRefresh, "refreshed", refreshed, "failed", failed) + "needs_refresh", needsRefresh, "refreshed", r, "failed", f) } else { slog.Info("token_refresh.cycle_completed", - "total", totalAccounts, - "oauth", oauthAccounts, - "needs_refresh", needsRefresh, - "refreshed", refreshed, - "failed", failed, + "total", totalAccounts, "oauth", oauthAccounts, + "needs_refresh", needsRefresh, "refreshed", r, "failed", f) + } +} + +func (s *TokenRefreshService) tryAcquireDistributedRefreshLock(ctx context.Context, account *Account) bool { + if s == nil || account == nil || account.ID <= 0 || s.schedulerCache == nil { + return true + } + lockTTL := s.tokenRefreshDistributedLockTTL() + if lockTTL <= 0 { + return true + } + lockCtx := ctx + if lockCtx == nil { + lockCtx = context.Background() + } + lockCtx, cancel := context.WithTimeout(lockCtx, tokenRefreshDistributedLockTimeout) + defer cancel() + + lockBucket := SchedulerBucket{ + GroupID: account.ID, + Platform: tokenRefreshDistributedLockPlatform, + Mode: normalizeTokenRefreshLockMode(account.Platform), + } + locked, err := s.schedulerCache.TryLockBucket(lockCtx, lockBucket, lockTTL) + if err != nil { + if ctx != nil && ctx.Err() != nil { + slog.Info("token_refresh.distributed_lock_canceled", + "account_id", account.ID, + "platform", account.Platform, + "error", err, + ) + return false + } + slog.Warn("token_refresh.distributed_lock_failed", + "account_id", account.ID, + "platform", account.Platform, + "error", err, + "fail_open", true, + ) + return true + } + if !locked { + slog.Debug("token_refresh.distributed_lock_held", + "account_id", account.ID, + "platform", account.Platform, ) + return false + } + return true +} + +func normalizeTokenRefreshLockMode(platform string) string { + mode := strings.TrimSpace(platform) + if mode == "" { + return "unknown" + } + return mode +} + +func (s *TokenRefreshService) tokenRefreshDistributedLockTTL() time.Duration { + if s == nil || s.cfg == nil { + return tokenRefreshDistributedLockMinTTL } + checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute + if checkInterval <= 0 { + return tokenRefreshDistributedLockMinTTL + } + ttl := checkInterval / 2 + if ttl < tokenRefreshDistributedLockMinTTL { + ttl = tokenRefreshDistributedLockMinTTL + } + if ttl > tokenRefreshDistributedLockMaxTTL { + ttl = tokenRefreshDistributedLockMaxTTL + } + return ttl } // listActiveAccounts 获取所有active状态的账号 @@ -281,7 +430,9 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc if attempt < s.cfg.MaxRetries { // 指数退避:2^(attempt-1) * baseSeconds backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1)) - time.Sleep(backoff) + if err := s.waitRetryBackoff(ctx, backoff); err != nil { + return err + } } } @@ -306,6 +457,23 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc return lastErr } +func (s *TokenRefreshService) waitRetryBackoff(ctx context.Context, backoff time.Duration) error { + if backoff <= 0 { + return nil + } + timer := time.NewTimer(backoff) + defer timer.Stop() + + select { + case <-timer.C: + return nil + case <-s.stopCh: + return context.Canceled + case <-ctx.Done(): + return ctx.Err() + } +} + // isNonRetryableRefreshError 判断是否为不可重试的刷新错误 // 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权 // 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误 diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go index 8e16c6f5d..142a5c169 100644 --- a/backend/internal/service/token_refresh_service_test.go +++ b/backend/internal/service/token_refresh_service_test.go @@ -14,10 +14,11 @@ import ( type tokenRefreshAccountRepo struct { mockAccountRepoForGemini - updateCalls int - setErrorCalls int - lastAccount *Account - updateErr error + updateCalls int + setErrorCalls int + lastAccount *Account + updateErr error + activeAccounts []Account } func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error { @@ -31,6 +32,15 @@ func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorM return nil } +func (r *tokenRefreshAccountRepo) ListActive(ctx context.Context) ([]Account, error) { + if len(r.activeAccounts) == 0 { + return nil, nil + } + out := make([]Account, 0, len(r.activeAccounts)) + out = append(out, r.activeAccounts...) + return out, nil +} + type tokenCacheInvalidatorStub struct { calls int err error @@ -42,8 +52,9 @@ func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account } type tokenRefresherStub struct { - credentials map[string]any - err error + credentials map[string]any + err error + refreshCalls int } func (r *tokenRefresherStub) CanRefresh(account *Account) bool { @@ -55,12 +66,41 @@ func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuratio } func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) { + r.refreshCalls++ if r.err != nil { return nil, r.err } return r.credentials, nil } +type tokenRefreshSchedulerLockStub struct { + SchedulerCache + lockByAccount map[int64]bool + err error + calls []SchedulerBucket +} + +func (s *tokenRefreshSchedulerLockStub) TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error) { + _ = ctx + _ = ttl + s.calls = append(s.calls, bucket) + if s.err != nil { + return false, s.err + } + if s.lockByAccount != nil { + if ok, exists := s.lockByAccount[bucket.GroupID]; exists { + return ok, nil + } + } + return true, nil +} + +func (s *tokenRefreshSchedulerLockStub) SetAccount(ctx context.Context, account *Account) error { + _ = ctx + _ = account + return nil +} + func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { repo := &tokenRefreshAccountRepo{} invalidator := &tokenCacheInvalidatorStub{} @@ -89,6 +129,96 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) { require.Equal(t, "new-token", account.GetCredential("access_token")) } +func TestTokenRefreshService_ProcessRefresh_SkipsWhenDistributedLockHeld(t *testing.T) { + repo := &tokenRefreshAccountRepo{ + activeAccounts: []Account{ + {ID: 31, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true}, + }, + } + lockStub := &tokenRefreshSchedulerLockStub{ + lockByAccount: map[int64]bool{31: false}, + } + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 2, + RefreshBeforeExpiryHours: 1, + SyncLinkedSoraAccounts: false, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}} + service.refreshers = []TokenRefresher{refresher} + + service.processRefresh() + + require.Len(t, lockStub.calls, 1) + require.Equal(t, int64(31), lockStub.calls[0].GroupID) + require.Equal(t, tokenRefreshDistributedLockPlatform, lockStub.calls[0].Platform) + require.Equal(t, PlatformOpenAI, lockStub.calls[0].Mode) + require.Equal(t, 0, refresher.refreshCalls, "lock held by another instance should skip refresh") + require.Equal(t, 0, repo.updateCalls) +} + +func TestTokenRefreshService_ProcessRefresh_RefreshesWhenDistributedLockAcquired(t *testing.T) { + repo := &tokenRefreshAccountRepo{ + activeAccounts: []Account{ + {ID: 32, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true}, + }, + } + lockStub := &tokenRefreshSchedulerLockStub{ + lockByAccount: map[int64]bool{32: true}, + } + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 2, + RefreshBeforeExpiryHours: 1, + SyncLinkedSoraAccounts: false, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}} + service.refreshers = []TokenRefresher{refresher} + + service.processRefresh() + + require.Len(t, lockStub.calls, 1) + require.Equal(t, 1, refresher.refreshCalls) + require.Equal(t, 1, repo.updateCalls, "lock acquired should allow refresh") +} + +func TestTokenRefreshService_ProcessRefresh_FailOpenWhenDistributedLockError(t *testing.T) { + repo := &tokenRefreshAccountRepo{ + activeAccounts: []Account{ + {ID: 33, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true}, + }, + } + lockStub := &tokenRefreshSchedulerLockStub{ + err: errors.New("redis unavailable"), + } + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + CheckIntervalMinutes: 2, + RefreshBeforeExpiryHours: 1, + SyncLinkedSoraAccounts: false, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg) + refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}} + service.refreshers = []TokenRefresher{refresher} + + service.processRefresh() + + require.Len(t, lockStub.calls, 1) + require.Equal(t, 1, refresher.refreshCalls, "lock backend error should fail-open and continue refresh") + require.Equal(t, 1, repo.updateCalls) +} + func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) { repo := &tokenRefreshAccountRepo{} invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")} @@ -359,3 +489,56 @@ func TestIsNonRetryableRefreshError(t *testing.T) { }) } } + +func TestTokenRefreshService_Stop_Idempotent(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 1, + RetryBackoffSeconds: 0, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg) + + require.NotPanics(t, func() { + service.Stop() + service.Stop() + }) +} + +func TestTokenRefreshService_RefreshWithRetry_StopInterruptsBackoff(t *testing.T) { + repo := &tokenRefreshAccountRepo{} + cfg := &config.Config{ + TokenRefresh: config.TokenRefreshConfig{ + MaxRetries: 3, + RetryBackoffSeconds: 5, + }, + } + service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg) + account := &Account{ + ID: 21, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + refresher := &tokenRefresherStub{ + err: errors.New("refresh failed"), + } + + start := time.Now() + done := make(chan error, 1) + go func() { + done <- service.refreshWithRetry(context.Background(), account, refresher) + }() + + time.Sleep(80 * time.Millisecond) + service.Stop() + + select { + case err := <-done: + require.ErrorIs(t, err, context.Canceled) + case <-time.After(600 * time.Millisecond): + t.Fatal("refreshWithRetry should exit quickly after service stop") + } + require.Less(t, time.Since(start), time.Second) + require.Equal(t, 0, repo.setErrorCalls, "stop 中断时不应落错误状态") +} diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go index 0dd3cf45d..a4fc1c70e 100644 --- a/backend/internal/service/token_refresher.go +++ b/backend/internal/service/token_refresher.go @@ -7,6 +7,18 @@ import ( "time" ) +const openAISoraSyncConcurrencyLimit = 8 + +// needsRefreshWithoutExpiry 在 expires_at 缺失时判断是否需要刷新。 +// 通过 Account.UpdatedAt 避免每轮刷新周期都发起无效刷新: +// 如果账号在 refreshWindow 内曾被更新,说明最近可能已刷新过,跳过本轮。 +func needsRefreshWithoutExpiry(account *Account, refreshWindow time.Duration) bool { + if refreshWindow <= 0 { + return true + } + return time.Since(account.UpdatedAt) >= refreshWindow +} + // TokenRefresher 定义平台特定的token刷新策略接口 // 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini) type TokenRefresher interface { @@ -46,7 +58,8 @@ func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool { func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { expiresAt := account.GetCredentialAsTime("expires_at") if expiresAt == nil { - return false + // 无过期时间:如果账号近期已更新(可能刚刷新过),跳过本轮 + return needsRefreshWithoutExpiry(account, refreshWindow) } return time.Until(*expiresAt) < refreshWindow } @@ -87,6 +100,10 @@ type OpenAITokenRefresher struct { accountRepo AccountRepository soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步 syncLinkedSora bool + syncLinkedSoraSem chan struct{} + + // test hook: override sync execution target when needed. + syncLinkedSoraAccountsFn func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) } // NewOpenAITokenRefresher 创建 OpenAI token刷新器 @@ -94,6 +111,7 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService, accountRepo return &OpenAITokenRefresher{ openaiOAuthService: openaiOAuthService, accountRepo: accountRepo, + syncLinkedSoraSem: make(chan struct{}, openAISoraSyncConcurrencyLimit), } } @@ -120,7 +138,8 @@ func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool { func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool { expiresAt := account.GetCredentialAsTime("expires_at") if expiresAt == nil { - return false + // 无过期时间:如果账号近期已更新(可能刚刷新过),跳过本轮 + return needsRefreshWithoutExpiry(account, refreshWindow) } return time.Until(*expiresAt) < refreshWindow @@ -147,12 +166,58 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m // 异步同步关联的 Sora 账号(不阻塞主流程) if r.accountRepo != nil && r.syncLinkedSora { - go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials) + syncCredentials := copyCredentialsMap(newCredentials) + syncFn := r.syncLinkedSoraAccounts + if r.syncLinkedSoraAccountsFn != nil { + syncFn = r.syncLinkedSoraAccountsFn + } + if r.tryAcquireSyncLinkedSoraSlot() { + go func() { + defer r.releaseSyncLinkedSoraSlot() + syncFn(context.Background(), account.ID, syncCredentials) + }() + } else { + // 达到并发上限时回退为同步执行,避免 goroutine 无界堆积。 + syncFn(ctx, account.ID, syncCredentials) + } } return newCredentials, nil } +func copyCredentialsMap(src map[string]any) map[string]any { + if len(src) == 0 { + return map[string]any{} + } + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + +func (r *OpenAITokenRefresher) tryAcquireSyncLinkedSoraSlot() bool { + if r == nil || r.syncLinkedSoraSem == nil { + return false + } + select { + case r.syncLinkedSoraSem <- struct{}{}: + return true + default: + return false + } +} + +func (r *OpenAITokenRefresher) releaseSyncLinkedSoraSlot() { + if r == nil || r.syncLinkedSoraSem == nil { + return + } + select { + case <-r.syncLinkedSoraSem: + default: + } +} + // syncLinkedSoraAccounts 同步关联的 Sora 账号的 token(双表同步) // 该方法异步执行,失败只记录日志,不影响主流程 // diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go index 264d79125..27d786634 100644 --- a/backend/internal/service/token_refresher_test.go +++ b/backend/internal/service/token_refresher_test.go @@ -3,13 +3,38 @@ package service import ( + "context" "strconv" "testing" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/stretchr/testify/require" ) +type openAIOAuthClientStubForRefresher struct { + tokenResp *openai.TokenResponse + err error +} + +func (s *openAIOAuthClientStubForRefresher) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, s.err +} + +func (s *openAIOAuthClientStubForRefresher) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + if s.err != nil { + return nil, s.err + } + return s.tokenResp, nil +} + +func (s *openAIOAuthClientStubForRefresher) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + if s.err != nil { + return nil, s.err + } + return s.tokenResp, nil +} + func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) { refresher := &ClaudeTokenRefresher{} refreshWindow := 30 * time.Minute @@ -64,26 +89,26 @@ func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) { { name: "expires_at missing", credentials: map[string]any{}, - wantRefresh: false, + wantRefresh: true, }, { name: "expires_at is nil", credentials: map[string]any{ "expires_at": nil, }, - wantRefresh: false, + wantRefresh: true, }, { name: "expires_at is invalid string", credentials: map[string]any{ "expires_at": "invalid", }, - wantRefresh: false, + wantRefresh: true, }, { name: "credentials is nil", credentials: nil, - wantRefresh: false, + wantRefresh: true, }, } @@ -179,6 +204,36 @@ func TestClaudeTokenRefresher_NeedsRefresh_OutsideWindow(t *testing.T) { } } +func TestNeedsRefreshWithoutExpiry_RecentlyUpdated(t *testing.T) { + refreshWindow := 30 * time.Minute + + t.Run("recently_updated_skips_refresh", func(t *testing.T) { + // 账号近期更新过(5 分钟前),不需要刷新 + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Credentials: map[string]any{}, + UpdatedAt: time.Now().Add(-5 * time.Minute), + } + refresher := &ClaudeTokenRefresher{} + require.False(t, refresher.NeedsRefresh(account, refreshWindow), + "近期更新过的账号无 expires_at 时不应刷新") + }) + + t.Run("old_updated_needs_refresh", func(t *testing.T) { + // 账号很久没更新(2 小时前),需要刷新 + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{}, + UpdatedAt: time.Now().Add(-2 * time.Hour), + } + refresher := &OpenAITokenRefresher{} + require.True(t, refresher.NeedsRefresh(account, refreshWindow), + "长期未更新的账号无 expires_at 时应刷新") + }) +} + func TestClaudeTokenRefresher_CanRefresh(t *testing.T) { refresher := &ClaudeTokenRefresher{} @@ -266,3 +321,159 @@ func TestOpenAITokenRefresher_CanRefresh(t *testing.T) { }) } } + +func TestOpenAITokenRefresher_NeedsRefresh(t *testing.T) { + refresher := &OpenAITokenRefresher{} + refreshWindow := 30 * time.Minute + + tests := []struct { + name string + credentials map[string]any + wantRefresh bool + }{ + { + name: "expires_at missing", + credentials: map[string]any{ + "access_token": "token", + }, + wantRefresh: true, + }, + { + name: "expires_at invalid", + credentials: map[string]any{ + "expires_at": "invalid", + }, + wantRefresh: true, + }, + { + name: "expires_at expired", + credentials: map[string]any{ + "expires_at": strconv.FormatInt(time.Now().Add(-time.Minute).Unix(), 10), + }, + wantRefresh: true, + }, + { + name: "expires_at far future", + credentials: map[string]any{ + "expires_at": strconv.FormatInt(time.Now().Add(2*time.Hour).Unix(), 10), + }, + wantRefresh: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: tt.credentials, + } + require.Equal(t, tt.wantRefresh, refresher.NeedsRefresh(account, refreshWindow)) + }) + } +} + +func TestOpenAITokenRefresher_Refresh_AsyncSyncUsesCopiedCredentials(t *testing.T) { + oauthSvc := NewOpenAIOAuthService(nil, &openAIOAuthClientStubForRefresher{ + tokenResp: &openai.TokenResponse{ + AccessToken: "new_access_token", + RefreshToken: "new_refresh_token", + ExpiresIn: 3600, + }, + }) + refresher := NewOpenAITokenRefresher(oauthSvc, &mockAccountRepoForGemini{}) + refresher.SetSyncLinkedSoraAccounts(true) + refresher.syncLinkedSoraSem = make(chan struct{}, 1) + + readNow := make(chan struct{}) + seenValue := make(chan string, 1) + refresher.syncLinkedSoraAccountsFn = func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) { + <-readNow + if v, ok := newCredentials["custom"].(string); ok { + seenValue <- v + return + } + seenValue <- "" + } + + account := &Account{ + ID: 1001, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "old_refresh_token", + "client_id": "test-client", + "custom": "original", + }, + } + + newCredentials, err := refresher.Refresh(context.Background(), account) + require.NoError(t, err) + require.NotNil(t, newCredentials) + + newCredentials["custom"] = "mutated_after_return" + close(readNow) + + select { + case got := <-seenValue: + require.Equal(t, "original", got, "异步同步应使用 credentials 副本,避免并发写污染") + case <-time.After(500 * time.Millisecond): + t.Fatal("timed out waiting for sync hook") + } +} + +func TestOpenAITokenRefresher_Refresh_FallsBackToSyncWhenLimiterFull(t *testing.T) { + oauthSvc := NewOpenAIOAuthService(nil, &openAIOAuthClientStubForRefresher{ + tokenResp: &openai.TokenResponse{ + AccessToken: "new_access_token", + RefreshToken: "new_refresh_token", + ExpiresIn: 3600, + }, + }) + refresher := NewOpenAITokenRefresher(oauthSvc, &mockAccountRepoForGemini{}) + refresher.SetSyncLinkedSoraAccounts(true) + refresher.syncLinkedSoraSem = make(chan struct{}, 1) + refresher.syncLinkedSoraSem <- struct{}{} // 填满 limiter,强制走同步降级路径 + + entered := make(chan struct{}) + releaseSync := make(chan struct{}) + refresher.syncLinkedSoraAccountsFn = func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) { + close(entered) + <-releaseSync + } + + account := &Account{ + ID: 1002, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "refresh_token": "old_refresh_token", + "client_id": "test-client", + }, + } + + done := make(chan struct{}) + go func() { + _, _ = refresher.Refresh(context.Background(), account) + close(done) + }() + + select { + case <-entered: + case <-time.After(500 * time.Millisecond): + t.Fatal("sync hook was not invoked") + } + + select { + case <-done: + t.Fatal("Refresh should block when falling back to synchronous linked-sora sync") + default: + } + + close(releaseSync) + select { + case <-done: + case <-time.After(500 * time.Millisecond): + t.Fatal("Refresh did not finish after releasing synchronous sync hook") + } +} diff --git a/backend/internal/service/usage_billing_compensation_service.go b/backend/internal/service/usage_billing_compensation_service.go new file mode 100644 index 000000000..47ea71618 --- /dev/null +++ b/backend/internal/service/usage_billing_compensation_service.go @@ -0,0 +1,256 @@ +package service + +import ( + "context" + "errors" + "log/slog" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +const ( + defaultUsageBillingCompensationInterval = 20 * time.Second + defaultUsageBillingCompensationBatchSize = 64 + defaultUsageBillingCompensationTaskTimout = 8 * time.Second + defaultUsageBillingCompensationStaleAfter = 3 * time.Minute +) + +// UsageBillingCompensationService retries pending usage charges in billing_usage_entries. +// It only runs when usageLogRepo supports UsageBillingEntryStore. +type UsageBillingCompensationService struct { + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + billingCache *BillingCacheService + cfg *config.Config + + startOnce sync.Once + stopOnce sync.Once + stopCh chan struct{} +} + +func NewUsageBillingCompensationService( + usageLogRepo UsageLogRepository, + userRepo UserRepository, + userSubRepo UserSubscriptionRepository, + billingCache *BillingCacheService, + cfg *config.Config, +) *UsageBillingCompensationService { + return &UsageBillingCompensationService{ + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + billingCache: billingCache, + cfg: cfg, + stopCh: make(chan struct{}), + } +} + +func (s *UsageBillingCompensationService) Start() { + if s == nil || s.store() == nil { + return + } + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + return + } + s.startOnce.Do(func() { + slog.Info("usage_billing_compensation.started", + "interval", defaultUsageBillingCompensationInterval.String(), + "batch_size", defaultUsageBillingCompensationBatchSize, + ) + go s.runLoop() + }) +} + +func (s *UsageBillingCompensationService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + close(s.stopCh) + slog.Info("usage_billing_compensation.stopped") + }) +} + +func (s *UsageBillingCompensationService) runLoop() { + ticker := time.NewTicker(defaultUsageBillingCompensationInterval) + defer ticker.Stop() + + s.processOnce() + + for { + select { + case <-ticker.C: + s.processOnce() + case <-s.stopCh: + return + } + } +} + +func (s *UsageBillingCompensationService) processOnce() { + store := s.store() + if store == nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), defaultUsageBillingCompensationTaskTimout) + defer cancel() + go func() { + select { + case <-s.stopCh: + cancel() + case <-ctx.Done(): + } + }() + + entries, err := store.ClaimUsageBillingEntries(ctx, defaultUsageBillingCompensationBatchSize, defaultUsageBillingCompensationStaleAfter) + if err != nil { + slog.Warn("usage_billing_compensation.claim_failed", "error", err) + return + } + for i := range entries { + if ctx.Err() != nil { + return + } + s.processEntry(ctx, entries[i]) + } +} + +func (s *UsageBillingCompensationService) processEntry(ctx context.Context, entry UsageBillingEntry) { + if entry.Applied || entry.DeltaUSD <= 0 { + s.markApplied(ctx, entry) + return + } + + if err := s.applyEntry(ctx, entry); err != nil { + s.markRetry(ctx, entry, err) + return + } +} + +func (s *UsageBillingCompensationService) applyEntry(ctx context.Context, entry UsageBillingEntry) error { + switch entry.BillingType { + case BillingTypeSubscription: + return s.applySubscriptionEntry(ctx, entry) + default: + return s.applyBalanceEntry(ctx, entry) + } +} + +func (s *UsageBillingCompensationService) applyBalanceEntry(ctx context.Context, entry UsageBillingEntry) error { + if s.userRepo == nil { + return errors.New("user repository unavailable") + } + + cacheDeducted := false + if s.billingCache != nil { + if err := s.billingCache.DeductBalanceCache(ctx, entry.UserID, entry.DeltaUSD); err != nil { + slog.Warn("usage_billing_compensation.balance_cache_deduct_failed", + "entry_id", entry.ID, + "user_id", entry.UserID, + "amount", entry.DeltaUSD, + "error", err, + ) + _ = s.billingCache.InvalidateUserBalance(ctx, entry.UserID) + } else { + cacheDeducted = true + } + } + + if err := s.runWithTx(ctx, func(txCtx context.Context) error { + if err := s.userRepo.DeductBalance(txCtx, entry.UserID, entry.DeltaUSD); err != nil { + return err + } + return s.store().MarkUsageBillingEntryApplied(txCtx, entry.ID) + }); err != nil { + if s.billingCache != nil && cacheDeducted { + _ = s.billingCache.InvalidateUserBalance(ctx, entry.UserID) + } + return err + } + + return nil +} + +func (s *UsageBillingCompensationService) applySubscriptionEntry(ctx context.Context, entry UsageBillingEntry) error { + if s.userSubRepo == nil { + return errors.New("subscription repository unavailable") + } + if entry.SubscriptionID == nil { + return errors.New("subscription_id is nil for subscription billing") + } + + if err := s.runWithTx(ctx, func(txCtx context.Context) error { + if err := s.userSubRepo.IncrementUsage(txCtx, *entry.SubscriptionID, entry.DeltaUSD); err != nil { + return err + } + return s.store().MarkUsageBillingEntryApplied(txCtx, entry.ID) + }); err != nil { + return err + } + + if s.billingCache != nil { + sub, err := s.userSubRepo.GetByID(ctx, *entry.SubscriptionID) + if err == nil && sub != nil { + _ = s.billingCache.InvalidateSubscription(ctx, entry.UserID, sub.GroupID) + } + } + + return nil +} + +func (s *UsageBillingCompensationService) markApplied(ctx context.Context, entry UsageBillingEntry) { + store := s.store() + if store == nil { + return + } + if err := store.MarkUsageBillingEntryApplied(ctx, entry.ID); err != nil { + slog.Warn("usage_billing_compensation.mark_applied_failed", "entry_id", entry.ID, "error", err) + } +} + +func (s *UsageBillingCompensationService) markRetry(ctx context.Context, entry UsageBillingEntry, cause error) { + store := s.store() + if store == nil { + return + } + errMsg := strings.TrimSpace(cause.Error()) + if len(errMsg) > 500 { + errMsg = errMsg[:500] + } + backoff := usageBillingRetryBackoff(entry.AttemptCount) + nextRetryAt := time.Now().Add(backoff) + if err := store.MarkUsageBillingEntryRetry(ctx, entry.ID, nextRetryAt, errMsg); err != nil { + slog.Warn("usage_billing_compensation.mark_retry_failed", + "entry_id", entry.ID, + "next_retry_at", nextRetryAt, + "error", err, + ) + return + } + slog.Warn("usage_billing_compensation.requeued", + "entry_id", entry.ID, + "attempt", entry.AttemptCount, + "next_retry_at", nextRetryAt, + "error", errMsg, + ) +} + +func (s *UsageBillingCompensationService) runWithTx(ctx context.Context, fn func(txCtx context.Context) error) error { + if runner, ok := s.usageLogRepo.(UsageBillingTxRunner); ok && runner != nil { + return runner.WithUsageBillingTx(ctx, fn) + } + return fn(ctx) +} + +func (s *UsageBillingCompensationService) store() UsageBillingEntryStore { + store, ok := s.usageLogRepo.(UsageBillingEntryStore) + if !ok { + return nil + } + return store +} diff --git a/backend/internal/service/usage_billing_compensation_service_test.go b/backend/internal/service/usage_billing_compensation_service_test.go new file mode 100644 index 000000000..f45c78efe --- /dev/null +++ b/backend/internal/service/usage_billing_compensation_service_test.go @@ -0,0 +1,230 @@ +package service + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type usageBillingCompRepoStub struct { + UsageLogRepository + + claimErr error + claims []UsageBillingEntry + + markAppliedCalls int + markRetryCalls int + lastRetryID int64 + lastRetryAt time.Time + lastRetryErr string + lastMarkAppliedCtx context.Context + lastTxCtx context.Context +} + +func (s *usageBillingCompRepoStub) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error) { + return nil, ErrUsageBillingEntryNotFound +} + +func (s *usageBillingCompRepoStub) UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error) { + return entry, true, nil +} + +func (s *usageBillingCompRepoStub) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error { + s.markAppliedCalls++ + s.lastMarkAppliedCtx = ctx + return nil +} + +func (s *usageBillingCompRepoStub) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error { + s.markRetryCalls++ + s.lastRetryID = entryID + s.lastRetryAt = nextRetryAt + s.lastRetryErr = lastError + _ = ctx + return nil +} + +func (s *usageBillingCompRepoStub) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error) { + if s.claimErr != nil { + return nil, s.claimErr + } + out := make([]UsageBillingEntry, len(s.claims)) + copy(out, s.claims) + s.claims = nil + return out, nil +} + +func (s *usageBillingCompRepoStub) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error { + s.lastTxCtx = ctx + if fn == nil { + return nil + } + return fn(ctx) +} + +type usageBillingCompUserRepoStub struct { + UserRepository + + deductCalls int + deductErr error + lastDeductCtx context.Context +} + +func (s *usageBillingCompUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error { + s.deductCalls++ + s.lastDeductCtx = ctx + return s.deductErr +} + +type usageBillingCompSubRepoStub struct { + UserSubscriptionRepository + + incrementCalls int + incrementErr error + lastIncrementCtx context.Context +} + +func (s *usageBillingCompSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + s.incrementCalls++ + s.lastIncrementCtx = ctx + return s.incrementErr +} + +type usageBillingCompCtxKey string + +func TestUsageBillingCompensationService_ProcessOnceBalanceSuccess(t *testing.T) { + repo := &usageBillingCompRepoStub{ + claims: []UsageBillingEntry{ + { + ID: 1, + UsageLogID: 1001, + UserID: 2001, + BillingType: BillingTypeBalance, + DeltaUSD: 1.23, + AttemptCount: 1, + }, + }, + } + userRepo := &usageBillingCompUserRepoStub{} + subRepo := &usageBillingCompSubRepoStub{} + svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{}) + + svc.processOnce() + + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 1, repo.markAppliedCalls) + require.Equal(t, 0, repo.markRetryCalls) +} + +func TestUsageBillingCompensationService_ProcessOnceBalanceFailureRequeues(t *testing.T) { + repo := &usageBillingCompRepoStub{ + claims: []UsageBillingEntry{ + { + ID: 2, + UsageLogID: 1002, + UserID: 2002, + BillingType: BillingTypeBalance, + DeltaUSD: 2.34, + AttemptCount: 2, + }, + }, + } + userRepo := &usageBillingCompUserRepoStub{deductErr: errors.New("db down")} + subRepo := &usageBillingCompSubRepoStub{} + svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{}) + + svc.processOnce() + + require.Equal(t, 1, userRepo.deductCalls) + require.Equal(t, 0, repo.markAppliedCalls) + require.Equal(t, 1, repo.markRetryCalls) + require.Equal(t, int64(2), repo.lastRetryID) + require.NotZero(t, repo.lastRetryAt) + require.Contains(t, repo.lastRetryErr, "db down") +} + +func TestUsageBillingCompensationService_ProcessOnceSubscriptionSuccess(t *testing.T) { + subID := int64(4003) + repo := &usageBillingCompRepoStub{ + claims: []UsageBillingEntry{ + { + ID: 3, + UsageLogID: 1003, + UserID: 2003, + SubscriptionID: &subID, + BillingType: BillingTypeSubscription, + DeltaUSD: 3.45, + AttemptCount: 1, + }, + }, + } + userRepo := &usageBillingCompUserRepoStub{} + subRepo := &usageBillingCompSubRepoStub{} + svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{}) + + svc.processOnce() + + require.Equal(t, 1, subRepo.incrementCalls) + require.Equal(t, 1, repo.markAppliedCalls) + require.Equal(t, 0, repo.markRetryCalls) +} + +func TestUsageBillingCompensationService_ApplyBalanceEntryPropagatesContext(t *testing.T) { + repo := &usageBillingCompRepoStub{} + userRepo := &usageBillingCompUserRepoStub{} + subRepo := &usageBillingCompSubRepoStub{} + svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{}) + + entry := UsageBillingEntry{ + ID: 10, + UsageLogID: 1010, + UserID: 2010, + BillingType: BillingTypeBalance, + DeltaUSD: 1.11, + } + key := usageBillingCompCtxKey("trace") + ctx := context.WithValue(context.Background(), key, "balance") + + err := svc.applyBalanceEntry(ctx, entry) + require.NoError(t, err) + require.Equal(t, 1, userRepo.deductCalls) + require.NotNil(t, repo.lastTxCtx) + require.NotNil(t, userRepo.lastDeductCtx) + require.NotNil(t, repo.lastMarkAppliedCtx) + require.Equal(t, "balance", repo.lastTxCtx.Value(key)) + require.Equal(t, "balance", userRepo.lastDeductCtx.Value(key)) + require.Equal(t, "balance", repo.lastMarkAppliedCtx.Value(key)) +} + +func TestUsageBillingCompensationService_ApplySubscriptionEntryPropagatesContext(t *testing.T) { + subID := int64(4010) + repo := &usageBillingCompRepoStub{} + userRepo := &usageBillingCompUserRepoStub{} + subRepo := &usageBillingCompSubRepoStub{} + svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{}) + + entry := UsageBillingEntry{ + ID: 11, + UsageLogID: 1011, + UserID: 2011, + SubscriptionID: &subID, + BillingType: BillingTypeSubscription, + DeltaUSD: 2.22, + } + key := usageBillingCompCtxKey("trace") + ctx := context.WithValue(context.Background(), key, "subscription") + + err := svc.applySubscriptionEntry(ctx, entry) + require.NoError(t, err) + require.Equal(t, 1, subRepo.incrementCalls) + require.NotNil(t, repo.lastTxCtx) + require.NotNil(t, subRepo.lastIncrementCtx) + require.NotNil(t, repo.lastMarkAppliedCtx) + require.Equal(t, "subscription", repo.lastTxCtx.Value(key)) + require.Equal(t, "subscription", subRepo.lastIncrementCtx.Value(key)) + require.Equal(t, "subscription", repo.lastMarkAppliedCtx.Value(key)) +} diff --git a/backend/internal/service/usage_billing_entry.go b/backend/internal/service/usage_billing_entry.go new file mode 100644 index 000000000..a24714082 --- /dev/null +++ b/backend/internal/service/usage_billing_entry.go @@ -0,0 +1,60 @@ +package service + +import ( + "context" + "errors" + "time" +) + +var ErrUsageBillingEntryNotFound = errors.New("usage billing entry not found") + +type UsageBillingEntryStatus int16 + +const ( + UsageBillingEntryStatusPending UsageBillingEntryStatus = 0 + UsageBillingEntryStatusProcessing UsageBillingEntryStatus = 1 + UsageBillingEntryStatusApplied UsageBillingEntryStatus = 2 +) + +type UsageBillingEntry struct { + ID int64 + UsageLogID int64 + UserID int64 + APIKeyID int64 + SubscriptionID *int64 + BillingType int8 + Applied bool + DeltaUSD float64 + Status UsageBillingEntryStatus + AttemptCount int + NextRetryAt time.Time + UpdatedAt time.Time + CreatedAt time.Time + LastError *string +} + +type UsageBillingEntryStore interface { + GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error) + UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error) + MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error + MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error + ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error) +} + +type UsageBillingTxRunner interface { + WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error +} + +func usageBillingRetryBackoff(attempt int) time.Duration { + if attempt <= 1 { + return 30 * time.Second + } + backoff := 30 * time.Second + for i := 1; i < attempt && backoff < 30*time.Minute; i++ { + backoff *= 2 + } + if backoff > 30*time.Minute { + return 30 * time.Minute + } + return backoff +} diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 05fe50560..c8661c915 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -46,7 +46,7 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int return 0, nil } func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil } -func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil } func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil } func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil } diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index c71851901..2ada41adf 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -293,13 +293,6 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC return apiKeyService } -// ProvideSettingService wires SettingService with group reader for default subscription validation. -func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService { - svc := NewSettingService(settingRepo, cfg) - svc.SetDefaultSubscriptionGroupReader(groupRepo) - return svc -} - // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services @@ -342,7 +335,7 @@ var ProviderSet = wire.NewSet( ProvideRateLimitService, NewAccountUsageService, NewAccountTestService, - ProvideSettingService, + NewSettingService, NewDataManagementService, ProvideOpsSystemLogSink, NewOpsService, @@ -355,7 +348,6 @@ var ProviderSet = wire.NewSet( ProvideEmailQueueService, NewTurnstileService, NewSubscriptionService, - wire.Bind(new(DefaultSubscriptionAssigner), new(*SubscriptionService)), ProvideConcurrencyService, ProvideUserMessageQueueService, NewUsageRecordWorkerPool, diff --git a/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql index d0ed5d6dd..93af0da7f 100644 --- a/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql +++ b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql @@ -1,46 +1,37 @@ --- Add gemini-3.1-flash-image and gemini-3.1-flash-image-preview to model_mapping +-- Add gemini-3.1-flash-image mapping keys without wiping existing custom mappings. -- -- Background: --- Antigravity now supports gemini-3.1-flash-image as the latest image generation model, --- replacing the previous gemini-3-pro-image. +-- Antigravity now supports gemini-3.1-flash-image as the latest image generation model. +-- Existing accounts may still contain gemini-3-pro-image aliases. -- -- Strategy: --- Directly overwrite the entire model_mapping with updated mappings --- This ensures consistency with DefaultAntigravityModelMapping in constants.go +-- Incrementally upsert only image-related keys in credentials.model_mapping: +-- 1) add canonical 3.1 image keys +-- 2) keep legacy 3-pro-image keys but remap them to 3.1 image for compatibility +-- This preserves user custom mappings and avoids full mapping overwrite. UPDATE accounts SET credentials = jsonb_set( - credentials, - '{model_mapping}', - '{ - "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", - "claude-opus-4-6": "claude-opus-4-6-thinking", - "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", - "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", - "claude-sonnet-4-6": "claude-sonnet-4-6", - "claude-sonnet-4-5": "claude-sonnet-4-5", - "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", - "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", - "claude-haiku-4-5": "claude-sonnet-4-5", - "claude-haiku-4-5-20251001": "claude-sonnet-4-5", - "gemini-2.5-flash": "gemini-2.5-flash", - "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", - "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", - "gemini-2.5-pro": "gemini-2.5-pro", - "gemini-3-flash": "gemini-3-flash", - "gemini-3-pro-high": "gemini-3-pro-high", - "gemini-3-pro-low": "gemini-3-pro-low", - "gemini-3-flash-preview": "gemini-3-flash", - "gemini-3-pro-preview": "gemini-3-pro-high", - "gemini-3.1-pro-high": "gemini-3.1-pro-high", - "gemini-3.1-pro-low": "gemini-3.1-pro-low", - "gemini-3.1-pro-preview": "gemini-3.1-pro-high", - "gemini-3.1-flash-image": "gemini-3.1-flash-image", - "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", - "gpt-oss-120b-medium": "gpt-oss-120b-medium", - "tab_flash_lite_preview": "tab_flash_lite_preview" - }'::jsonb + jsonb_set( + jsonb_set( + jsonb_set( + credentials, + '{model_mapping,gemini-3.1-flash-image}', + '"gemini-3.1-flash-image"'::jsonb, + true + ), + '{model_mapping,gemini-3.1-flash-image-preview}', + '"gemini-3.1-flash-image"'::jsonb, + true + ), + '{model_mapping,gemini-3-pro-image}', + '"gemini-3.1-flash-image"'::jsonb, + true + ), + '{model_mapping,gemini-3-pro-image-preview}', + '"gemini-3.1-flash-image"'::jsonb, + true ) WHERE platform = 'antigravity' AND deleted_at IS NULL - AND credentials->'model_mapping' IS NOT NULL; \ No newline at end of file + AND credentials->'model_mapping' IS NOT NULL; diff --git a/backend/migrations/064_add_billing_usage_entry_retry_fields.sql b/backend/migrations/064_add_billing_usage_entry_retry_fields.sql new file mode 100644 index 000000000..aebb39295 --- /dev/null +++ b/backend/migrations/064_add_billing_usage_entry_retry_fields.sql @@ -0,0 +1,27 @@ +-- 064_add_billing_usage_entry_retry_fields.sql +-- Add retry-state columns for billing_usage_entries compensation worker. + +ALTER TABLE billing_usage_entries + ADD COLUMN IF NOT EXISTS status SMALLINT NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS attempt_count INTEGER NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS next_retry_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + ADD COLUMN IF NOT EXISTS last_error TEXT; + +-- Keep legacy rows aligned with applied flag. +UPDATE billing_usage_entries +SET status = CASE WHEN applied THEN 2 ELSE 0 END +WHERE status NOT IN (0, 1, 2) + OR (applied = TRUE AND status <> 2) + OR (applied = FALSE AND status = 2); + +ALTER TABLE billing_usage_entries + DROP CONSTRAINT IF EXISTS chk_billing_usage_entries_status; + +ALTER TABLE billing_usage_entries + ADD CONSTRAINT chk_billing_usage_entries_status + CHECK (status IN (0, 1, 2)); + +CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_retry + ON billing_usage_entries (status, next_retry_at, updated_at) + WHERE applied = FALSE; diff --git a/deploy/Caddyfile b/deploy/Caddyfile index b643fe9b8..cbd762b1c 100644 --- a/deploy/Caddyfile +++ b/deploy/Caddyfile @@ -30,6 +30,36 @@ api.sub2api.com { # ========================================================================= # 反向代理配置 # ========================================================================= + # OpenAI Responses(含 WebSocket/SSE)单独代理策略: + # 1) flush_interval -1:尽快转发流式分片,降低中间层缓冲导致的断流概率 + # 2) versions 1.1:确保上游走标准 HTTP/1.1 Upgrade,避免协议协商差异 + # 3) stream_timeout/stream_close_delay:为长连接提供更宽松生命周期 + @openai_responses { + path /openai/v1/responses* + } + reverse_proxy @openai_responses localhost:8080 { + # 长连接/流式场景建议关闭代理缓冲 + flush_interval -1 + # 长连接超时窗口(避免长会话被代理层过早回收) + stream_timeout 24h + # 配置热重载时,给现有流预留关闭缓冲期 + stream_close_delay 5m + + # 传递真实客户端信息 + header_up X-Real-IP {remote_host} + header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP} + + transport http { + # WebSocket Upgrade 对上游统一使用 HTTP/1.1,更稳妥 + versions 1.1 + keepalive 120s + keepalive_idle_conns 256 + read_buffer 32KB + write_buffer 32KB + compression off + } + } + reverse_proxy localhost:8080 { # 健康检查 health_uri /health @@ -45,9 +75,6 @@ api.sub2api.com { # 传递真实客户端信息 # 兼容 Cloudflare 和直连:后端应优先读取 CF-Connecting-IP,其次 X-Real-IP header_up X-Real-IP {remote_host} - header_up X-Forwarded-For {remote_host} - header_up X-Forwarded-Proto {scheme} - header_up X-Forwarded-Host {host} # 保留 Cloudflare 原始头(如果存在) # 后端获取 IP 的优先级建议: CF-Connecting-IP → X-Real-IP → X-Forwarded-For header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP} diff --git a/deploy/Dockerfile b/deploy/Dockerfile index b33203009..c9fcf3017 100644 --- a/deploy/Dockerfile +++ b/deploy/Dockerfile @@ -7,7 +7,7 @@ # ============================================================================= ARG NODE_IMAGE=node:24-alpine -ARG GOLANG_IMAGE=golang:1.25.5-alpine +ARG GOLANG_IMAGE=golang:1.25.7-alpine ARG ALPINE_IMAGE=alpine:3.20 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index e2eb3130c..da1a54a4c 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -134,12 +134,6 @@ security: # Allow skipping TLS verification for proxy probe (debug only) # 允许代理探测时跳过 TLS 证书验证(仅用于调试) insecure_skip_verify: false - proxy_fallback: - # Allow auxiliary services (update check, pricing data) to fallback to direct - # connection when proxy initialization fails. Does NOT affect AI gateway connections. - # 辅助服务(更新检查、定价数据拉取)代理初始化失败时是否允许回退直连。 - # 不影响 AI 账号网关连接。默认 false:fail-fast 防止 IP 泄露。 - allow_direct_on_error: false # ============================================================================= # Gateway Configuration @@ -207,10 +201,12 @@ gateway: openai_passthrough_allow_timeout_headers: false # OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP) openai_ws: - # 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。 - mode_router_v2_enabled: false - # ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效) - ingress_mode_default: shared + # 新版 WS mode 路由(默认开启)。关闭时保持当前 legacy 实现行为。 + mode_router_v2_enabled: true + # ingress 默认模式:off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效) + # 建议:常规场景选 ctx_pool;需要“原样透传 + 不走连接池”时选 passthrough;紧急回滚选 off。 + # Recommendation: use ctx_pool for normal traffic, passthrough for raw relay without pool reuse, off for emergency rollback. + ingress_mode_default: ctx_pool # 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由 enabled: true # 按账号类型细分开关 @@ -233,7 +229,7 @@ gateway: # 协议 feature 开关,v2 优先于 v1 responses_websockets: false responses_websockets_v2: true - # 连接池参数(按账号池化复用) + # 连接池参数(按账号池化复用,仅 ctx_pool 模式生效;passthrough 模式不使用连接池) max_conns_per_account: 128 min_idle_per_account: 4 max_idle_per_account: 12 @@ -248,6 +244,10 @@ gateway: write_timeout_seconds: 120 pool_target_utilization: 0.7 queue_limit_per_conn: 64 + # 上游 WebSocket 连接最大存活时间(秒)。 + # OpenAI 在 60 分钟后强制断开连接,此参数控制主动轮换阈值。 + # 默认 3300(55 分钟);设为 0 则禁用超龄轮换。 + upstream_conn_max_age_seconds: 3300 # 流式写出批量 flush 参数 event_flush_batch_size: 1 event_flush_interval_ms: 10 @@ -265,7 +265,7 @@ gateway: # payload_schema 日志采样率(0-1);降低热路径日志放大 payload_log_sample_rate: 0.2 # 调度与粘连参数 - lb_top_k: 7 + lb_top_k: 999 sticky_session_ttl_seconds: 3600 # 会话哈希迁移兼容开关:新 key 未命中时回退读取旧 SHA-256 key session_hash_read_old_fallback: true @@ -282,6 +282,24 @@ gateway: queue: 0.7 error_rate: 0.8 ttft: 0.5 + # OpenAI HTTP upstream protocol strategy + # OpenAI HTTP 上游协议策略 + openai_http2: + # Enable OpenAI HTTP/2 preference (default on) + # 启用 OpenAI HTTP/2 优先策略(默认开启) + enabled: true + # Allow fallback to HTTP/1.1 for incompatible HTTP proxies + # 当 HTTP 代理不兼容时允许回退到 HTTP/1.1 + allow_proxy_fallback_to_http1: true + # Fallback triggers after N HTTP/2 compatibility errors within window + # 在窗口期内累计 N 次 HTTP/2 兼容错误后触发回退 + fallback_error_threshold: 2 + # Error counting window (seconds) + # 错误计数窗口(秒) + fallback_window_seconds: 60 + # How long to stay in HTTP/1.1 fallback mode (seconds) + # 进入 HTTP/1.1 回退态后的持续时间(秒) + fallback_ttl_seconds: 600 # HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults) # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值) # Max idle connections across all hosts @@ -316,6 +334,15 @@ gateway: # SSE max line size in bytes (default: 40MB) # SSE 单行最大字节数(默认 40MB) max_line_size: 41943040 + # Usage record worker pool (bounded queue) + # 使用量记录异步池(有界队列) + usage_record: + # queue overflow policy: drop/sample/sync + # 队列溢出策略:drop/sample/sync + overflow_policy: sync + # only used when overflow_policy=sample + # 仅在 overflow_policy=sample 时生效 + overflow_sample_percent: 10 # Log upstream error response body summary (safe/truncated; does not log request content) # 记录上游错误响应体摘要(安全/截断;不记录请求内容) log_upstream_error_body: true diff --git a/docs/ADMIN_PAYMENT_INTEGRATION_API.md b/docs/ADMIN_PAYMENT_INTEGRATION_API.md deleted file mode 100644 index 4cc215948..000000000 --- a/docs/ADMIN_PAYMENT_INTEGRATION_API.md +++ /dev/null @@ -1,241 +0,0 @@ -# ADMIN_PAYMENT_INTEGRATION_API - -> 单文件中英双语文档 / Single-file bilingual documentation (Chinese + English) - ---- - -## 中文 - -### 目标 -本文档用于对接外部支付系统(如 `sub2apipay`)与 Sub2API 的 Admin API,覆盖: -- 支付成功后充值 -- 用户查询 -- 人工余额修正 -- 前端购买页参数透传 - -### 基础地址 -- 生产:`https://` -- Beta:`http://:8084` - -### 认证 -推荐使用: -- `x-api-key: admin-<64hex>` -- `Content-Type: application/json` -- 幂等接口额外传:`Idempotency-Key` - -说明:管理员 JWT 也可访问 admin 路由,但服务间调用建议使用 Admin API Key。 - -### 1) 一步完成创建并兑换 -`POST /api/v1/admin/redeem-codes/create-and-redeem` - -用途:原子完成“创建兑换码 + 兑换到指定用户”。 - -请求头: -- `x-api-key` -- `Idempotency-Key` - -请求体示例: -```json -{ - "code": "s2p_cm1234567890", - "type": "balance", - "value": 100.0, - "user_id": 123, - "notes": "sub2apipay order: cm1234567890" -} -``` - -幂等语义: -- 同 `code` 且 `used_by` 一致:`200` -- 同 `code` 但 `used_by` 不一致:`409` -- 缺少 `Idempotency-Key`:`400`(`IDEMPOTENCY_KEY_REQUIRED`) - -curl 示例: -```bash -curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \ - -H "x-api-key: ${KEY}" \ - -H "Idempotency-Key: pay-cm1234567890-success" \ - -H "Content-Type: application/json" \ - -d '{ - "code":"s2p_cm1234567890", - "type":"balance", - "value":100.00, - "user_id":123, - "notes":"sub2apipay order: cm1234567890" - }' -``` - -### 2) 查询用户(可选前置校验) -`GET /api/v1/admin/users/:id` - -```bash -curl -s "${BASE}/api/v1/admin/users/123" \ - -H "x-api-key: ${KEY}" -``` - -### 3) 余额调整(已有接口) -`POST /api/v1/admin/users/:id/balance` - -用途:人工补偿 / 扣减,支持 `set` / `add` / `subtract`。 - -请求体示例(扣减): -```json -{ - "balance": 100.0, - "operation": "subtract", - "notes": "manual correction" -} -``` - -```bash -curl -X POST "${BASE}/api/v1/admin/users/123/balance" \ - -H "x-api-key: ${KEY}" \ - -H "Idempotency-Key: balance-subtract-cm1234567890" \ - -H "Content-Type: application/json" \ - -d '{ - "balance":100.00, - "operation":"subtract", - "notes":"manual correction" - }' -``` - -### 4) 购买页 URL Query 透传(iframe / 新窗口一致) -当 Sub2API 打开 `purchase_subscription_url` 时,会统一追加: -- `user_id` -- `token` -- `theme`(`light` / `dark`) -- `ui_mode`(固定 `embedded`) - -示例: -```text -https://pay.example.com/pay?user_id=123&token=&theme=light&ui_mode=embedded -``` - -### 5) 失败处理建议 -- 支付成功与充值成功分状态落库 -- 回调验签成功后立即标记“支付成功” -- 支付成功但充值失败的订单允许后续重试 -- 重试保持相同 `code`,并使用新的 `Idempotency-Key` - -### 6) `doc_url` 配置建议 -- 查看链接:`https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md` -- 下载链接:`https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md` - ---- - -## English - -### Purpose -This document describes the minimal Sub2API Admin API surface for external payment integrations (for example, `sub2apipay`), including: -- Recharge after payment success -- User lookup -- Manual balance correction -- Purchase page query parameter forwarding - -### Base URL -- Production: `https://` -- Beta: `http://:8084` - -### Authentication -Recommended headers: -- `x-api-key: admin-<64hex>` -- `Content-Type: application/json` -- `Idempotency-Key` for idempotent endpoints - -Note: Admin JWT can also access admin routes, but Admin API Key is recommended for server-to-server integration. - -### 1) Create and Redeem in one step -`POST /api/v1/admin/redeem-codes/create-and-redeem` - -Use case: atomically create a redeem code and redeem it to a target user. - -Headers: -- `x-api-key` -- `Idempotency-Key` - -Request body: -```json -{ - "code": "s2p_cm1234567890", - "type": "balance", - "value": 100.0, - "user_id": 123, - "notes": "sub2apipay order: cm1234567890" -} -``` - -Idempotency behavior: -- Same `code` and same `used_by`: `200` -- Same `code` but different `used_by`: `409` -- Missing `Idempotency-Key`: `400` (`IDEMPOTENCY_KEY_REQUIRED`) - -curl example: -```bash -curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \ - -H "x-api-key: ${KEY}" \ - -H "Idempotency-Key: pay-cm1234567890-success" \ - -H "Content-Type: application/json" \ - -d '{ - "code":"s2p_cm1234567890", - "type":"balance", - "value":100.00, - "user_id":123, - "notes":"sub2apipay order: cm1234567890" - }' -``` - -### 2) Query User (optional pre-check) -`GET /api/v1/admin/users/:id` - -```bash -curl -s "${BASE}/api/v1/admin/users/123" \ - -H "x-api-key: ${KEY}" -``` - -### 3) Balance Adjustment (existing API) -`POST /api/v1/admin/users/:id/balance` - -Use case: manual correction with `set` / `add` / `subtract`. - -Request body example (`subtract`): -```json -{ - "balance": 100.0, - "operation": "subtract", - "notes": "manual correction" -} -``` - -```bash -curl -X POST "${BASE}/api/v1/admin/users/123/balance" \ - -H "x-api-key: ${KEY}" \ - -H "Idempotency-Key: balance-subtract-cm1234567890" \ - -H "Content-Type: application/json" \ - -d '{ - "balance":100.00, - "operation":"subtract", - "notes":"manual correction" - }' -``` - -### 4) Purchase URL query forwarding (iframe and new tab) -When Sub2API opens `purchase_subscription_url`, it appends: -- `user_id` -- `token` -- `theme` (`light` / `dark`) -- `ui_mode` (fixed: `embedded`) - -Example: -```text -https://pay.example.com/pay?user_id=123&token=&theme=light&ui_mode=embedded -``` - -### 5) Failure handling recommendations -- Persist payment success and recharge success as separate states -- Mark payment as successful immediately after verified callback -- Allow retry for orders with payment success but recharge failure -- Keep the same `code` for retry, and use a new `Idempotency-Key` - -### 6) Recommended `doc_url` -- View URL: `https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md` -- Download URL: `https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md` diff --git a/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts b/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts new file mode 100644 index 000000000..5baccfe8c --- /dev/null +++ b/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts @@ -0,0 +1,184 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { + deleteBulkEditTemplate, + getBulkEditTemplates, + getBulkEditTemplateVersions, + rollbackBulkEditTemplate, + upsertBulkEditTemplate +} from '../admin/bulkEditTemplates' +import { apiClient } from '../client' + +vi.mock('../client', () => ({ + apiClient: { + get: vi.fn(), + post: vi.fn(), + delete: vi.fn() + } +})) + +describe('admin settings bulk-edit templates api', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + it('requests template list with expected query params', async () => { + (apiClient.get as any).mockResolvedValue({ + data: { + items: [ + { + id: 'tpl-1', + name: 'Template', + scope_platform: 'openai', + scope_type: 'oauth', + share_scope: 'team', + group_ids: [], + state: {}, + created_by: 1, + updated_by: 1, + created_at: 1, + updated_at: 2 + } + ] + } + }) + + const items = await getBulkEditTemplates({ + scope_platform: 'openai', + scope_type: 'oauth', + scope_group_ids: [3, 9] + }) + + expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates', { + params: { + scope_platform: 'openai', + scope_type: 'oauth', + scope_group_ids: '3,9' + } + }) + expect(items).toHaveLength(1) + expect(items[0].id).toBe('tpl-1') + }) + + it('returns empty list when response items is invalid', async () => { + (apiClient.get as any).mockResolvedValue({ data: { items: null } }) + const items = await getBulkEditTemplates({}) + expect(items).toEqual([]) + }) + + it('posts upsert payload and returns saved template', async () => { + (apiClient.post as any).mockResolvedValue({ + data: { + id: 'tpl-2', + name: 'Shared', + scope_platform: 'openai', + scope_type: 'apikey', + share_scope: 'groups', + group_ids: [1], + state: { enableProxy: true }, + created_by: 2, + updated_by: 2, + created_at: 10, + updated_at: 11 + } + }) + + const saved = await upsertBulkEditTemplate({ + id: 'tpl-2', + name: 'Shared', + scope_platform: 'openai', + scope_type: 'apikey', + share_scope: 'groups', + group_ids: [1], + state: { enableProxy: true } + }) + + expect(apiClient.post).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates', { + id: 'tpl-2', + name: 'Shared', + scope_platform: 'openai', + scope_type: 'apikey', + share_scope: 'groups', + group_ids: [1], + state: { enableProxy: true } + }) + expect(saved.id).toBe('tpl-2') + }) + + it('requests template versions with scope group params', async () => { + (apiClient.get as any).mockResolvedValue({ + data: { + items: [ + { + version_id: 'ver-1', + share_scope: 'team', + group_ids: [], + state: { enableOpenAIWSMode: true }, + updated_by: 11, + updated_at: 100 + } + ] + } + }) + + const items = await getBulkEditTemplateVersions('tpl-2', { scope_group_ids: [5, 8] }) + + expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-2/versions', { + params: { + scope_group_ids: '5,8' + } + }) + expect(items).toHaveLength(1) + expect(items[0].version_id).toBe('ver-1') + }) + + it('returns empty versions list when payload is invalid', async () => { + (apiClient.get as any).mockResolvedValue({ data: { items: undefined } }) + + const items = await getBulkEditTemplateVersions('tpl-any') + + expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-any/versions', { + params: {} + }) + expect(items).toEqual([]) + }) + + it('posts rollback request with optional query params', async () => { + (apiClient.post as any).mockResolvedValue({ + data: { + id: 'tpl-3', + name: 'Rollbacked', + scope_platform: 'openai', + scope_type: 'oauth', + share_scope: 'private', + group_ids: [], + state: { enableOpenAIPassthrough: false }, + created_by: 1, + updated_by: 2, + created_at: 10, + updated_at: 12 + } + }) + + const saved = await rollbackBulkEditTemplate( + 'tpl-3', + { version_id: 'ver-2' }, + { scope_group_ids: [2] } + ) + + expect(apiClient.post).toHaveBeenCalledWith( + '/admin/settings/bulk-edit-templates/tpl-3/rollback', + { version_id: 'ver-2' }, + { params: { scope_group_ids: '2' } } + ) + expect(saved.id).toBe('tpl-3') + }) + + it('calls delete endpoint for template removal', async () => { + (apiClient.delete as any).mockResolvedValue({ data: { deleted: true } }) + + const result = await deleteBulkEditTemplate('tpl-9') + + expect(apiClient.delete).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-9') + expect(result).toEqual({ deleted: true }) + }) +}) diff --git a/frontend/src/api/admin/bulkEditTemplates.ts b/frontend/src/api/admin/bulkEditTemplates.ts new file mode 100644 index 000000000..45c5e8ac1 --- /dev/null +++ b/frontend/src/api/admin/bulkEditTemplates.ts @@ -0,0 +1,129 @@ +import { apiClient } from '../client' +import type { AccountPlatform, AccountType } from '@/types' + +export type BulkEditTemplateShareScope = 'private' | 'team' | 'groups' + +export interface BulkEditTemplateRecord> { + id: string + name: string + scope_platform: AccountPlatform | '' + scope_type: AccountType | '' + share_scope: BulkEditTemplateShareScope + group_ids: number[] + state: TState + created_by: number + updated_by: number + created_at: number + updated_at: number +} + +export interface BulkEditTemplateVersionRecord> { + version_id: string + share_scope: BulkEditTemplateShareScope + group_ids: number[] + state: TState + updated_by: number + updated_at: number +} + +export interface GetBulkEditTemplatesParams { + scope_platform?: AccountPlatform | '' + scope_type?: AccountType | '' + scope_group_ids?: number[] +} + +export interface GetBulkEditTemplateVersionsParams { + scope_group_ids?: number[] +} + +export interface UpsertBulkEditTemplateRequest> { + id?: string + name: string + scope_platform: AccountPlatform | '' + scope_type: AccountType | '' + share_scope: BulkEditTemplateShareScope + group_ids: number[] + state: TState +} + +export interface RollbackBulkEditTemplateRequest { + version_id: string +} + +export async function getBulkEditTemplates>( + params: GetBulkEditTemplatesParams +): Promise[]> { + const query: Record = {} + if (params.scope_platform) query.scope_platform = params.scope_platform + if (params.scope_type) query.scope_type = params.scope_type + if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) { + query.scope_group_ids = params.scope_group_ids.join(',') + } + + const { data } = await apiClient.get<{ items: BulkEditTemplateRecord[] }>( + '/admin/settings/bulk-edit-templates', + { params: query } + ) + return Array.isArray(data.items) ? data.items : [] +} + +export async function getBulkEditTemplateVersions>( + templateID: string, + params: GetBulkEditTemplateVersionsParams = {} +): Promise[]> { + const query: Record = {} + if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) { + query.scope_group_ids = params.scope_group_ids.join(',') + } + + const { data } = await apiClient.get<{ items: BulkEditTemplateVersionRecord[] }>( + `/admin/settings/bulk-edit-templates/${templateID}/versions`, + { params: query } + ) + return Array.isArray(data.items) ? data.items : [] +} + +export async function upsertBulkEditTemplate>( + request: UpsertBulkEditTemplateRequest +): Promise> { + const { data } = await apiClient.post>( + '/admin/settings/bulk-edit-templates', + request + ) + return data +} + +export async function rollbackBulkEditTemplate>( + templateID: string, + request: RollbackBulkEditTemplateRequest, + params: GetBulkEditTemplateVersionsParams = {} +): Promise> { + const query: Record = {} + if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) { + query.scope_group_ids = params.scope_group_ids.join(',') + } + + const { data } = await apiClient.post>( + `/admin/settings/bulk-edit-templates/${templateID}/rollback`, + request, + { params: query } + ) + return data +} + +export async function deleteBulkEditTemplate(templateID: string): Promise<{ deleted: boolean }> { + const { data } = await apiClient.delete<{ deleted: boolean }>( + `/admin/settings/bulk-edit-templates/${templateID}` + ) + return data +} + +const bulkEditTemplatesAPI = { + getBulkEditTemplates, + getBulkEditTemplateVersions, + upsertBulkEditTemplate, + rollbackBulkEditTemplate, + deleteBulkEditTemplate +} + +export default bulkEditTemplatesAPI diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index 54bd92a49..a5113dd1f 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -8,7 +8,6 @@ import type { DashboardStats, TrendDataPoint, ModelStat, - GroupStat, ApiKeyUsageTrendPoint, UserUsageTrendPoint, UsageRequestType @@ -102,34 +101,6 @@ export async function getModelStats(params?: ModelStatsParams): Promise { - const { data } = await apiClient.get('/admin/dashboard/groups', { params }) - return data -} - export interface ApiKeyTrendParams extends TrendParams { limit?: number } @@ -232,7 +203,6 @@ export const dashboardAPI = { getRealtimeMetrics, getUsageTrend, getModelStats, - getGroupStats, getApiKeyUsageTrend, getUserUsageTrend, getBatchUsersUsage, diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index 5db998e57..a2c82ecbf 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -12,6 +12,7 @@ import redeemAPI from './redeem' import promoAPI from './promo' import announcementsAPI from './announcements' import settingsAPI from './settings' +import bulkEditTemplatesAPI from './bulkEditTemplates' import systemAPI from './system' import subscriptionsAPI from './subscriptions' import usageAPI from './usage' @@ -21,7 +22,6 @@ import userAttributesAPI from './userAttributes' import opsAPI from './ops' import errorPassthroughAPI from './errorPassthrough' import dataManagementAPI from './dataManagement' -import apiKeysAPI from './apiKeys' /** * Unified admin API object for convenient access @@ -36,6 +36,7 @@ export const adminAPI = { promo: promoAPI, announcements: announcementsAPI, settings: settingsAPI, + bulkEditTemplates: bulkEditTemplatesAPI, system: systemAPI, subscriptions: subscriptionsAPI, usage: usageAPI, @@ -44,8 +45,7 @@ export const adminAPI = { userAttributes: userAttributesAPI, ops: opsAPI, errorPassthrough: errorPassthroughAPI, - dataManagement: dataManagementAPI, - apiKeys: apiKeysAPI + dataManagement: dataManagementAPI } export { @@ -58,6 +58,7 @@ export { promoAPI, announcementsAPI, settingsAPI, + bulkEditTemplatesAPI, systemAPI, subscriptionsAPI, usageAPI, @@ -66,8 +67,7 @@ export { userAttributesAPI, opsAPI, errorPassthroughAPI, - dataManagementAPI, - apiKeysAPI + dataManagementAPI } export default adminAPI diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 52855a040..a3ed6c6ff 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -6,11 +6,6 @@ import { apiClient } from '../client' import type { CustomMenuItem } from '@/types' -export interface DefaultSubscriptionSetting { - group_id: number - validity_days: number -} - /** * System settings interface */ @@ -26,7 +21,6 @@ export interface SystemSettings { // Default settings default_balance: number default_concurrency: number - default_subscriptions: DefaultSubscriptionSetting[] // OEM settings site_name: string site_logo: string @@ -75,9 +69,6 @@ export interface SystemSettings { ops_realtime_monitoring_enabled: boolean ops_query_mode_default: 'auto' | 'raw' | 'preagg' | string ops_metrics_interval_seconds: number - - // Claude Code version check - min_claude_code_version: string } export interface UpdateSettingsRequest { @@ -89,7 +80,6 @@ export interface UpdateSettingsRequest { totp_enabled?: boolean // TOTP 双因素认证 default_balance?: number default_concurrency?: number - default_subscriptions?: DefaultSubscriptionSetting[] site_name?: string site_logo?: string site_subtitle?: string @@ -127,7 +117,6 @@ export interface UpdateSettingsRequest { ops_realtime_monitoring_enabled?: boolean ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string ops_metrics_interval_seconds?: number - min_claude_code_version?: string } /** diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index d36a2a5a9..287aef967 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -4,7 +4,7 @@ */ import { apiClient } from '../client' -import type { AdminUser, UpdateUserRequest, PaginatedResponse, ApiKey } from '@/types' +import type { AdminUser, UpdateUserRequest, PaginatedResponse } from '@/types' /** * List all users with pagination @@ -145,8 +145,8 @@ export async function toggleStatus(id: number, status: 'active' | 'disabled'): P * @param id - User ID * @returns List of user's API keys */ -export async function getUserApiKeys(id: number): Promise> { - const { data } = await apiClient.get>(`/admin/users/${id}/api-keys`) +export async function getUserApiKeys(id: number): Promise> { + const { data } = await apiClient.get>(`/admin/users/${id}/api-keys`) return data } diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 95f9ff318..22db5a44a 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -267,7 +267,6 @@ apiClient.interceptors.response.use( return Promise.reject({ status, code: apiData.code, - error: apiData.error, message: apiData.message || apiData.detail || error.message }) } diff --git a/frontend/src/components/account/AccountCapacityCell.vue b/frontend/src/components/account/AccountCapacityCell.vue index 2a4babf20..ae338aca1 100644 --- a/frontend/src/components/account/AccountCapacityCell.vue +++ b/frontend/src/components/account/AccountCapacityCell.vue @@ -52,25 +52,6 @@ {{ account.max_sessions }} - - -
- - - - - {{ currentRPM }} - / - {{ account.base_rpm }} - {{ rpmStrategyTag }} - -
@@ -144,15 +125,19 @@ const windowCostClass = computed(() => { const limit = props.account.window_cost_limit || 0 const reserve = props.account.window_cost_sticky_reserve || 10 + // >= 阈值+预留: 完全不可调度 (红色) if (current >= limit + reserve) { return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400' } + // >= 阈值: 仅粘性会话 (橙色) if (current >= limit) { return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400' } + // >= 80% 阈值: 警告 (黄色) if (current >= limit * 0.8) { return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400' } + // 正常 (绿色) return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400' }) @@ -180,12 +165,15 @@ const sessionLimitClass = computed(() => { const current = activeSessions.value const max = props.account.max_sessions || 0 + // >= 最大: 完全占满 (红色) if (current >= max) { return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400' } + // >= 80%: 警告 (黄色) if (current >= max * 0.8) { return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400' } + // 正常 (绿色) return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400' }) @@ -203,89 +191,6 @@ const sessionLimitTooltip = computed(() => { return t('admin.accounts.capacity.sessions.normal', { idle }) }) -// 是否显示 RPM 限制 -const showRpmLimit = computed(() => { - return ( - isAnthropicOAuthOrSetupToken.value && - props.account.base_rpm !== undefined && - props.account.base_rpm !== null && - props.account.base_rpm > 0 - ) -}) - -// 当前 RPM 计数 -const currentRPM = computed(() => props.account.current_rpm ?? 0) - -// RPM 策略 -const rpmStrategy = computed(() => props.account.rpm_strategy || 'tiered') - -// RPM 策略标签 -const rpmStrategyTag = computed(() => { - return rpmStrategy.value === 'sticky_exempt' ? '[S]' : '[T]' -}) - -// RPM buffer 计算(与后端一致:base <= 0 时 buffer 为 0) -const rpmBuffer = computed(() => { - const base = props.account.base_rpm || 0 - return props.account.rpm_sticky_buffer ?? (base > 0 ? Math.max(1, Math.floor(base / 5)) : 0) -}) - -// RPM 状态样式 -const rpmClass = computed(() => { - if (!showRpmLimit.value) return '' - - const current = currentRPM.value - const base = props.account.base_rpm ?? 0 - const buffer = rpmBuffer.value - - if (rpmStrategy.value === 'tiered') { - if (current >= base + buffer) { - return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400' - } - if (current >= base) { - return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400' - } - } else { - if (current >= base) { - return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400' - } - } - if (current >= base * 0.8) { - return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400' - } - return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400' -}) - -// RPM 提示文字(增强版:显示策略、区域、缓冲区) -const rpmTooltip = computed(() => { - if (!showRpmLimit.value) return '' - - const current = currentRPM.value - const base = props.account.base_rpm ?? 0 - const buffer = rpmBuffer.value - - if (rpmStrategy.value === 'tiered') { - if (current >= base + buffer) { - return t('admin.accounts.capacity.rpm.tieredBlocked', { buffer }) - } - if (current >= base) { - return t('admin.accounts.capacity.rpm.tieredStickyOnly', { buffer }) - } - if (current >= base * 0.8) { - return t('admin.accounts.capacity.rpm.tieredWarning') - } - return t('admin.accounts.capacity.rpm.tieredNormal') - } else { - if (current >= base) { - return t('admin.accounts.capacity.rpm.stickyExemptOver') - } - if (current >= base * 0.8) { - return t('admin.accounts.capacity.rpm.stickyExemptWarning') - } - return t('admin.accounts.capacity.rpm.stickyExemptNormal') - } -}) - // 格式化费用显示 const formatCost = (value: number | null | undefined) => { if (value === null || value === undefined) return '0' diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index 1c83e6583..163bb391b 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -1,7 +1,7 @@ - - diff --git a/frontend/src/components/account/BulkEditAccountScopedModal.vue b/frontend/src/components/account/BulkEditAccountScopedModal.vue new file mode 100644 index 000000000..f67cc06af --- /dev/null +++ b/frontend/src/components/account/BulkEditAccountScopedModal.vue @@ -0,0 +1,84 @@ + + + diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 75f040815..ffccda057 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1536,119 +1536,6 @@ - -
-
-
- -

- {{ t('admin.accounts.quotaControl.rpmLimit.hint') }} -

-
- -
- -
-
- - -

{{ t('admin.accounts.quotaControl.rpmLimit.baseRpmHint') }}

-
- -
- -
- - -
-
- -
- - -

{{ t('admin.accounts.quotaControl.rpmLimit.stickyBufferHint') }}

-
- -
- - -
- -

- {{ t('admin.accounts.quotaControl.rpmLimit.userMsgQueueHint') }} -

-
- -
-
-
-
@@ -1807,7 +1694,7 @@
- +

- {{ t('admin.accounts.openai.wsModeConcurrencyHint') }} + {{ t(openAIWSModeConcurrencyHintKey) }}

@@ -2341,10 +2228,11 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { - OPENAI_WS_MODE_DEDICATED, + OPENAI_WS_MODE_PASSTHROUGH, + OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, - OPENAI_WS_MODE_SHARED, isOpenAIWSModeEnabled, + resolveOpenAIWSModeConcurrencyHintKey, type OpenAIWSMode } from '@/utils/openaiWsMode' import OAuthAuthorizationFlow from './OAuthAuthorizationFlow.vue' @@ -2506,16 +2394,6 @@ const windowCostStickyReserve = ref(null) const sessionLimitEnabled = ref(false) const maxSessions = ref(null) const sessionIdleTimeout = ref(null) -const rpmLimitEnabled = ref(false) -const baseRpm = ref(null) -const rpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') -const rpmStickyBuffer = ref(null) -const userMsgQueueMode = ref('') -const umqModeOptions = computed(() => [ - { value: '', label: t('admin.accounts.quotaControl.rpmLimit.umqModeOff') }, - { value: 'throttle', label: t('admin.accounts.quotaControl.rpmLimit.umqModeThrottle') }, - { value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') }, -]) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false) @@ -2541,8 +2419,8 @@ const geminiSelectedTier = computed(() => { const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - { value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') }, - { value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') } + { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) const openaiResponsesWebSocketV2Mode = computed({ @@ -2561,6 +2439,10 @@ const openaiResponsesWebSocketV2Mode = computed({ } }) +const openAIWSModeConcurrencyHintKey = computed(() => + resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value) +) + const isOpenAIModelRestrictionDisabled = computed(() => form.platform === 'openai' && openaiPassthroughEnabled.value ) @@ -3140,11 +3022,6 @@ const resetForm = () => { sessionLimitEnabled.value = false maxSessions.value = null sessionIdleTimeout.value = null - rpmLimitEnabled.value = false - baseRpm.value = null - rpmStrategy.value = 'tiered' - rpmStickyBuffer.value = null - userMsgQueueMode.value = '' tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false cacheTTLOverrideEnabled.value = false @@ -3180,10 +3057,14 @@ const buildOpenAIExtra = (base?: Record): Record = { ...(base || {}) } - extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value - extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value - extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) - extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + // 按账号类型分流写入,避免 oauth 账号写入 apikey 的字段(反之亦然) + if (accountCategory.value === 'oauth-based') { + extra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value + extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) + } else if (accountCategory.value === 'apikey') { + extra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value + extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + } // 清理兼容旧键,统一改用分类型开关。 delete extra.responses_websockets_v2_enabled delete extra.openai_ws_enabled @@ -4054,20 +3935,6 @@ const handleAnthropicExchange = async (authCode: string) => { extra.session_idle_timeout_minutes = sessionIdleTimeout.value ?? 5 } - // Add RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - extra.base_rpm = baseRpm.value - extra.rpm_strategy = rpmStrategy.value - if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { - extra.rpm_sticky_buffer = rpmStickyBuffer.value - } - } - - // UMQ mode(独立于 RPM) - if (userMsgQueueMode.value) { - extra.user_msg_queue_mode = userMsgQueueMode.value - } - // Add TLS fingerprint settings if (tlsFingerprintEnabled.value) { extra.enable_tls_fingerprint = true @@ -4166,20 +4033,6 @@ const handleCookieAuth = async (sessionKey: string) => { extra.session_idle_timeout_minutes = sessionIdleTimeout.value ?? 5 } - // Add RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - extra.base_rpm = baseRpm.value - extra.rpm_strategy = rpmStrategy.value - if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { - extra.rpm_sticky_buffer = rpmStickyBuffer.value - } - } - - // UMQ mode(独立于 RPM) - if (userMsgQueueMode.value) { - extra.user_msg_queue_mode = userMsgQueueMode.value - } - // Add TLS fingerprint settings if (tlsFingerprintEnabled.value) { extra.enable_tls_fingerprint = true diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 24166a5ca..d3524b644 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -708,7 +708,7 @@
- +

- {{ t('admin.accounts.openai.wsModeConcurrencyHint') }} + {{ t(openAIWSModeConcurrencyHintKey) }}

@@ -946,119 +946,6 @@
- -
-
-
- -

- {{ t('admin.accounts.quotaControl.rpmLimit.hint') }} -

-
- -
- -
-
- - -

{{ t('admin.accounts.quotaControl.rpmLimit.baseRpmHint') }}

-
- -
- -
- - -
-
- -
- - -

{{ t('admin.accounts.quotaControl.rpmLimit.stickyBufferHint') }}

-
- -
- - -
- -

- {{ t('admin.accounts.quotaControl.rpmLimit.userMsgQueueHint') }} -

-
- -
-
-
-
@@ -1273,10 +1160,11 @@ import { applyInterceptWarmup } from '@/components/account/credentialsBuilder' import { formatDateTimeLocalInput, parseDateTimeLocalInput } from '@/utils/format' import { createStableObjectKeyResolver } from '@/utils/stableObjectKey' import { - OPENAI_WS_MODE_DEDICATED, + OPENAI_WS_MODE_PASSTHROUGH, + OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, - OPENAI_WS_MODE_SHARED, isOpenAIWSModeEnabled, + resolveOpenAIWSModeConcurrencyHintKey, type OpenAIWSMode, resolveOpenAIWSModeFromExtra } from '@/utils/openaiWsMode' @@ -1364,16 +1252,6 @@ const windowCostStickyReserve = ref(null) const sessionLimitEnabled = ref(false) const maxSessions = ref(null) const sessionIdleTimeout = ref(null) -const rpmLimitEnabled = ref(false) -const baseRpm = ref(null) -const rpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') -const rpmStickyBuffer = ref(null) -const userMsgQueueMode = ref('') -const umqModeOptions = computed(() => [ - { value: '', label: t('admin.accounts.quotaControl.rpmLimit.umqModeOff') }, - { value: 'throttle', label: t('admin.accounts.quotaControl.rpmLimit.umqModeThrottle') }, - { value: 'serialize', label: t('admin.accounts.quotaControl.rpmLimit.umqModeSerialize') }, -]) const tlsFingerprintEnabled = ref(false) const sessionIdMaskingEnabled = ref(false) const cacheTTLOverrideEnabled = ref(false) @@ -1387,8 +1265,8 @@ const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) const openAIWSModeOptions = computed(() => [ { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, - { value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') }, - { value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') } + { value: OPENAI_WS_MODE_CTX_POOL, label: t('admin.accounts.openai.wsModeCtxPool') }, + { value: OPENAI_WS_MODE_PASSTHROUGH, label: t('admin.accounts.openai.wsModePassthrough') } ]) const openaiResponsesWebSocketV2Mode = computed({ get: () => { @@ -1405,6 +1283,11 @@ const openaiResponsesWebSocketV2Mode = computed({ openaiOAuthResponsesWebSocketV2Mode.value = mode } }) + +const openAIWSModeConcurrencyHintKey = computed(() => + resolveOpenAIWSModeConcurrencyHintKey(openaiResponsesWebSocketV2Mode.value) +) + const isOpenAIModelRestrictionDisabled = computed(() => props.account?.platform === 'openai' && openaiPassthroughEnabled.value ) @@ -1833,11 +1716,6 @@ function loadQuotaControlSettings(account: Account) { sessionLimitEnabled.value = false maxSessions.value = null sessionIdleTimeout.value = null - rpmLimitEnabled.value = false - baseRpm.value = null - rpmStrategy.value = 'tiered' - rpmStickyBuffer.value = null - userMsgQueueMode.value = '' tlsFingerprintEnabled.value = false sessionIdMaskingEnabled.value = false cacheTTLOverrideEnabled.value = false @@ -1861,17 +1739,6 @@ function loadQuotaControlSettings(account: Account) { sessionIdleTimeout.value = account.session_idle_timeout_minutes ?? 5 } - // RPM limit - if (account.base_rpm != null && account.base_rpm > 0) { - rpmLimitEnabled.value = true - baseRpm.value = account.base_rpm - rpmStrategy.value = (account.rpm_strategy as 'tiered' | 'sticky_exempt') || 'tiered' - rpmStickyBuffer.value = account.rpm_sticky_buffer ?? null - } - - // UMQ mode(独立于 RPM 加载,防止编辑无 RPM 账号时丢失已有配置) - userMsgQueueMode.value = account.user_msg_queue_mode ?? '' - // Load TLS fingerprint setting if (account.enable_tls_fingerprint === true) { tlsFingerprintEnabled.value = true @@ -1992,7 +1859,7 @@ const ensureAntigravityMixedChannelConfirmed = async (onConfirm: () => Promise { antigravityMixedChannelConfirmed.value = true await submitUpdateAccount(accountID, updatePayload) @@ -2025,7 +1892,7 @@ const submitUpdateAccount = async (accountID: number, updatePayload: Record { delete newExtra.session_idle_timeout_minutes } - // RPM limit settings - if (rpmLimitEnabled.value && baseRpm.value != null && baseRpm.value > 0) { - newExtra.base_rpm = baseRpm.value - newExtra.rpm_strategy = rpmStrategy.value - if (rpmStickyBuffer.value != null && rpmStickyBuffer.value > 0) { - newExtra.rpm_sticky_buffer = rpmStickyBuffer.value - } else { - delete newExtra.rpm_sticky_buffer - } - } else { - delete newExtra.base_rpm - delete newExtra.rpm_strategy - delete newExtra.rpm_sticky_buffer - } - - // UMQ mode(独立于 RPM 保存) - if (userMsgQueueMode.value) { - newExtra.user_msg_queue_mode = userMsgQueueMode.value - } else { - delete newExtra.user_msg_queue_mode - } - delete newExtra.user_msg_queue_enabled // 清理旧字段 - // TLS fingerprint setting if (tlsFingerprintEnabled.value) { newExtra.enable_tls_fingerprint = true @@ -2248,10 +2092,14 @@ const handleSubmit = async () => { const currentExtra = (props.account.extra as Record) || {} const newExtra: Record = { ...currentExtra } const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true - newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value - newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value - newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) - newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + // 按账号类型分流写入,避免 oauth 账号写入 apikey 的字段(反之亦然) + if (props.account.type === 'oauth') { + newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value + newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) + } else if (props.account.type === 'apikey') { + newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value + newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + } delete newExtra.responses_websockets_v2_enabled delete newExtra.openai_ws_enabled if (openaiPassthroughEnabled.value) { @@ -2284,7 +2132,7 @@ const handleSubmit = async () => { await submitUpdateAccount(accountID, updatePayload) } catch (error: any) { - appStore.showError(error.message || t('admin.accounts.failedToUpdate')) + appStore.showError(error.response?.data?.message || error.response?.data?.detail || t('admin.accounts.failedToUpdate')) } } diff --git a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts index 28ac61ecc..0e700ac6b 100644 --- a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts +++ b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts @@ -1,7 +1,11 @@ -import { describe, expect, it, vi } from 'vitest' +import { beforeEach, describe, expect, it, vi } from 'vitest' import { mount } from '@vue/test-utils' import BulkEditAccountModal from '../BulkEditAccountModal.vue' +const { bulkUpdateMock } = vi.hoisted(() => ({ + bulkUpdateMock: vi.fn() +})) + vi.mock('@/stores/app', () => ({ useAppStore: () => ({ showError: vi.fn(), @@ -13,7 +17,7 @@ vi.mock('@/stores/app', () => ({ vi.mock('@/api/admin', () => ({ adminAPI: { accounts: { - bulkEdit: vi.fn() + bulkUpdate: bulkUpdateMock } } })) @@ -28,19 +32,26 @@ vi.mock('vue-i18n', async () => { } }) -function mountModal() { +function mountModal( + scope: { scopePlatform: string; scopeType: string } = { + scopePlatform: 'gemini', + scopeType: 'apikey' + }, + selectStub: boolean | Record = true +) { return mount(BulkEditAccountModal, { props: { show: true, accountIds: [1, 2], - selectedPlatforms: ['antigravity'], + scopePlatform: scope.scopePlatform, + scopeType: scope.scopeType, proxies: [], groups: [] } as any, global: { stubs: { BaseDialog: { template: '
' }, - Select: true, + Select: selectStub, ProxySelector: true, GroupSelector: true, Icon: true @@ -50,7 +61,52 @@ function mountModal() { } describe('BulkEditAccountModal', () => { - it('antigravity 白名单包含 Gemini 图片模型且过滤掉普通 GPT 模型', () => { + beforeEach(() => { + bulkUpdateMock.mockReset() + }) + + it('OpenAI OAuth 选择 WS mode 后会写入 bulkUpdate payload', async () => { + bulkUpdateMock.mockResolvedValue({ success: 2, failed: 0 }) + const wrapper = mountModal( + { scopePlatform: 'openai', scopeType: 'oauth' }, + { + props: ['options', 'modelValue'], + template: ` + + ` + } + ) + + await wrapper.get('#bulk-edit-openai-ws-mode-enabled').setValue(true) + const selects = wrapper.findAll('[data-testid="select-stub"]') + const wsModeSelect = selects.find( + (select) => select.text().includes('off') && select.text().includes('ctx_pool') && select.text().includes('passthrough') + ) + expect(wsModeSelect).toBeTruthy() + await wsModeSelect!.setValue('passthrough') + + await wrapper.get('form#bulk-edit-account-form').trigger('submit.prevent') + + expect(bulkUpdateMock).toHaveBeenCalledTimes(1) + expect(bulkUpdateMock).toHaveBeenCalledWith( + [1, 2], + expect.objectContaining({ + extra: expect.objectContaining({ + openai_oauth_responses_websockets_v2_mode: 'passthrough' + }) + }) + ) + }) + + it('Gemini 范围白名单包含图片模型并过滤 GPT 模型', () => { const wrapper = mountModal() expect(wrapper.text()).toContain('Gemini 3.1 Flash Image') @@ -58,7 +114,7 @@ describe('BulkEditAccountModal', () => { expect(wrapper.text()).not.toContain('GPT-5.3 Codex') }) - it('antigravity 映射预设包含图片映射并过滤 OpenAI 预设', async () => { + it('Gemini 范围映射预设包含图片映射并过滤 OpenAI 预设', async () => { const wrapper = mountModal() const mappingTab = wrapper.findAll('button').find((btn) => btn.text().includes('admin.accounts.modelMapping')) @@ -69,4 +125,21 @@ describe('BulkEditAccountModal', () => { expect(wrapper.text()).toContain('G3 Image→3.1') expect(wrapper.text()).not.toContain('GPT-5.3 Codex') }) + + it('OpenAI OAuth 范围 WS mode 选项仅保留 off、ctx_pool 与 passthrough', () => { + const wrapper = mountModal( + { scopePlatform: 'openai', scopeType: 'oauth' }, + { + props: ['options'], + template: + '
{{ option.value }}
' + } + ) + + expect(wrapper.text()).toContain('off') + expect(wrapper.text()).toContain('ctx_pool') + expect(wrapper.text()).toContain('passthrough') + expect(wrapper.text()).not.toContain('shared') + expect(wrapper.text()).not.toContain('dedicated') + }) }) diff --git a/frontend/src/components/account/__tests__/bulkEditPayload.spec.ts b/frontend/src/components/account/__tests__/bulkEditPayload.spec.ts new file mode 100644 index 000000000..c09c69ba1 --- /dev/null +++ b/frontend/src/components/account/__tests__/bulkEditPayload.spec.ts @@ -0,0 +1,247 @@ +import { describe, expect, it, vi } from 'vitest' +import { + OPENAI_WS_MODE_OFF, + OPENAI_WS_MODE_CTX_POOL +} from '@/utils/openaiWsMode' +import { + buildBulkEditUpdatePayload, + hasAnyBulkEditFieldEnabled, + type BulkEditPayloadInput +} from '../bulkEditPayload' + +const createInput = (overrides: Partial = {}): BulkEditPayloadInput => ({ + scopeType: 'oauth', + enableBaseUrl: false, + enableModelRestriction: false, + enableCustomErrorCodes: false, + enableInterceptWarmup: false, + enableOpenAIPassthrough: false, + enableOpenAIWSMode: false, + enableCodexCLIOnly: false, + enableAnthropicPassthrough: false, + enableProxy: false, + enableConcurrency: false, + enablePriority: false, + enableRateMultiplier: false, + enableStatus: false, + enableGroups: false, + baseUrl: '', + modelRestrictionMode: 'whitelist', + allowedModels: [], + modelMappings: [], + selectedErrorCodes: [], + interceptWarmupRequests: false, + openAIPassthroughEnabled: false, + openAIWSMode: OPENAI_WS_MODE_OFF, + codexCLIOnlyEnabled: false, + anthropicPassthroughEnabled: false, + proxyId: null, + concurrency: 1, + priority: 1, + rateMultiplier: 1, + status: 'active', + groupIds: [], + ...overrides +}) + +describe('hasAnyBulkEditFieldEnabled', () => { + it('returns false when all toggles are disabled', () => { + expect(hasAnyBulkEditFieldEnabled(createInput())).toBe(false) + }) + + it('returns true when at least one toggle is enabled', () => { + expect(hasAnyBulkEditFieldEnabled(createInput({ enableOpenAIWSMode: true }))).toBe(true) + }) +}) + +describe('buildBulkEditUpdatePayload', () => { + it('returns null when no field is enabled', () => { + expect(buildBulkEditUpdatePayload(createInput())).toBeNull() + }) + + it('builds base fields and supports clearing proxy with 0', () => { + const payload = buildBulkEditUpdatePayload( + createInput({ + enableProxy: true, + proxyId: null, + enableConcurrency: true, + concurrency: 8, + enablePriority: true, + priority: 9, + enableRateMultiplier: true, + rateMultiplier: 1.25, + enableStatus: true, + status: 'inactive', + enableGroups: true, + groupIds: [2, 3] + }) + ) + + expect(payload).toEqual({ + proxy_id: 0, + concurrency: 8, + priority: 9, + rate_multiplier: 1.25, + status: 'inactive', + group_ids: [2, 3] + }) + }) + + it('trims base_url and ignores empty input', () => { + const withValue = buildBulkEditUpdatePayload( + createInput({ + enableBaseUrl: true, + baseUrl: ' https://api.example.com/v1 ' + }) + ) + expect(withValue).toEqual({ + credentials: { + base_url: 'https://api.example.com/v1' + } + }) + + const withEmpty = buildBulkEditUpdatePayload( + createInput({ + enableBaseUrl: true, + baseUrl: ' ' + }) + ) + expect(withEmpty).toBeNull() + }) + + it('builds model whitelist mapping', () => { + const payload = buildBulkEditUpdatePayload( + createInput({ + enableModelRestriction: true, + modelRestrictionMode: 'whitelist', + allowedModels: ['claude-sonnet-4-6', 'gpt-5.2-codex'] + }) + ) + + expect(payload).toEqual({ + credentials: { + model_mapping: { + 'claude-sonnet-4-6': 'claude-sonnet-4-6', + 'gpt-5.2-codex': 'gpt-5.2-codex' + } + } + }) + }) + + it('builds model mapping mode and filters invalid rules', () => { + const warnSpy = vi.spyOn(console, 'warn').mockImplementation(() => {}) + try { + const payload = buildBulkEditUpdatePayload( + createInput({ + enableModelRestriction: true, + modelRestrictionMode: 'mapping', + modelMappings: [ + { from: '', to: 'ignored' }, + { from: 'bad*wild', to: 'x' }, + { from: 'claude-*', to: 'claude-sonnet-4-6' }, + { from: 'gpt-5.2-codex', to: 'gpt-*' }, + { from: 'gpt-5.1-codex', to: 'gpt-5.2-codex' } + ] + }) + ) + + expect(payload).toEqual({ + credentials: { + model_mapping: { + 'claude-*': 'claude-sonnet-4-6', + 'gpt-5.1-codex': 'gpt-5.2-codex' + } + } + }) + } finally { + warnSpy.mockRestore() + } + }) + + it('writes custom error codes and warmup interception into credentials', () => { + const payload = buildBulkEditUpdatePayload( + createInput({ + enableCustomErrorCodes: true, + selectedErrorCodes: [429, 503], + enableInterceptWarmup: true, + interceptWarmupRequests: true + }) + ) + + expect(payload).toEqual({ + credentials: { + custom_error_codes_enabled: true, + custom_error_codes: [429, 503], + intercept_warmup_requests: true + } + }) + }) + + it('writes OpenAI passthrough and OAuth ws mode keys', () => { + const payload = buildBulkEditUpdatePayload( + createInput({ + scopeType: 'oauth', + enableOpenAIPassthrough: true, + openAIPassthroughEnabled: true, + enableOpenAIWSMode: true, + openAIWSMode: OPENAI_WS_MODE_CTX_POOL + }) + ) + + expect(payload).toEqual({ + extra: { + openai_passthrough: true, + openai_oauth_passthrough: true, + openai_oauth_responses_websockets_v2_mode: OPENAI_WS_MODE_CTX_POOL, + openai_oauth_responses_websockets_v2_enabled: true + } + }) + }) + + it('writes API key ws mode keys and off-mode disabled flag', () => { + const payload = buildBulkEditUpdatePayload( + createInput({ + scopeType: 'apikey', + enableOpenAIWSMode: true, + openAIWSMode: OPENAI_WS_MODE_OFF + }) + ) + + expect(payload).toEqual({ + extra: { + openai_apikey_responses_websockets_v2_mode: OPENAI_WS_MODE_OFF, + openai_apikey_responses_websockets_v2_enabled: false + } + }) + }) + + it('ignores ws mode when scope type is not oauth/apikey', () => { + const payload = buildBulkEditUpdatePayload( + createInput({ + scopeType: 'setup-token', + enableOpenAIWSMode: true, + openAIWSMode: OPENAI_WS_MODE_CTX_POOL + }) + ) + + expect(payload).toBeNull() + }) + + it('writes codex and anthropic passthrough flags', () => { + const payload = buildBulkEditUpdatePayload( + createInput({ + enableCodexCLIOnly: true, + codexCLIOnlyEnabled: true, + enableAnthropicPassthrough: true, + anthropicPassthroughEnabled: false + }) + ) + + expect(payload).toEqual({ + extra: { + codex_cli_only: true, + anthropic_passthrough: false + } + }) + }) +}) diff --git a/frontend/src/components/account/__tests__/bulkEditScopeProfile.spec.ts b/frontend/src/components/account/__tests__/bulkEditScopeProfile.spec.ts new file mode 100644 index 000000000..b3f79f273 --- /dev/null +++ b/frontend/src/components/account/__tests__/bulkEditScopeProfile.spec.ts @@ -0,0 +1,78 @@ +import { describe, expect, it } from 'vitest' +import { + resolveBulkEditScopeCapabilities, + resolveBulkEditScopeEditorKey +} from '../bulkEditScopeProfile' + +describe('resolveBulkEditScopeCapabilities', () => { + it('returns OpenAI OAuth capabilities', () => { + const profile = resolveBulkEditScopeCapabilities('openai', 'oauth') + + expect(profile.supportsBaseUrl).toBe(false) + expect(profile.supportsModelRestriction).toBe(false) + expect(profile.supportsCustomErrorCodes).toBe(false) + expect(profile.supportsInterceptWarmup).toBe(false) + expect(profile.supportsOpenAIPassthrough).toBe(true) + expect(profile.supportsOpenAIWSMode).toBe(true) + expect(profile.supportsCodexCLIOnly).toBe(true) + expect(profile.supportsAnthropicPassthrough).toBe(false) + }) + + it('returns OpenAI API Key capabilities', () => { + const profile = resolveBulkEditScopeCapabilities('openai', 'apikey') + + expect(profile.supportsBaseUrl).toBe(true) + expect(profile.supportsModelRestriction).toBe(true) + expect(profile.supportsCustomErrorCodes).toBe(true) + expect(profile.supportsInterceptWarmup).toBe(false) + expect(profile.supportsOpenAIPassthrough).toBe(true) + expect(profile.supportsOpenAIWSMode).toBe(true) + expect(profile.supportsCodexCLIOnly).toBe(false) + expect(profile.supportsAnthropicPassthrough).toBe(false) + }) + + it('returns Anthropic API Key capabilities', () => { + const profile = resolveBulkEditScopeCapabilities('anthropic', 'apikey') + + expect(profile.supportsBaseUrl).toBe(true) + expect(profile.supportsModelRestriction).toBe(true) + expect(profile.supportsCustomErrorCodes).toBe(true) + expect(profile.supportsInterceptWarmup).toBe(true) + expect(profile.supportsOpenAIPassthrough).toBe(false) + expect(profile.supportsOpenAIWSMode).toBe(false) + expect(profile.supportsCodexCLIOnly).toBe(false) + expect(profile.supportsAnthropicPassthrough).toBe(true) + }) + + it('returns Antigravity Upstream capabilities', () => { + const profile = resolveBulkEditScopeCapabilities('antigravity', 'upstream') + + expect(profile.supportsBaseUrl).toBe(true) + expect(profile.supportsModelRestriction).toBe(false) + expect(profile.supportsCustomErrorCodes).toBe(false) + expect(profile.supportsInterceptWarmup).toBe(true) + expect(profile.supportsOpenAIPassthrough).toBe(false) + expect(profile.supportsOpenAIWSMode).toBe(false) + expect(profile.supportsCodexCLIOnly).toBe(false) + expect(profile.supportsAnthropicPassthrough).toBe(false) + }) +}) + +describe('resolveBulkEditScopeEditorKey', () => { + it('resolves known scope keys', () => { + expect(resolveBulkEditScopeEditorKey('openai', 'oauth')).toBe('openai:oauth') + expect(resolveBulkEditScopeEditorKey('anthropic', 'setup-token')).toBe( + 'anthropic:setup-token' + ) + expect(resolveBulkEditScopeEditorKey('antigravity', 'upstream')).toBe( + 'antigravity:upstream' + ) + }) + + it('returns null for unsupported or incomplete scope', () => { + expect(resolveBulkEditScopeEditorKey('openai', 'upstream')).toBeNull() + expect(resolveBulkEditScopeEditorKey('sora', 'setup-token')).toBeNull() + expect(resolveBulkEditScopeEditorKey('', 'oauth')).toBeNull() + expect(resolveBulkEditScopeEditorKey('gemini', '')).toBeNull() + }) +}) diff --git a/frontend/src/components/account/__tests__/bulkEditTemplateRemoteMapper.spec.ts b/frontend/src/components/account/__tests__/bulkEditTemplateRemoteMapper.spec.ts new file mode 100644 index 000000000..8a257cd54 --- /dev/null +++ b/frontend/src/components/account/__tests__/bulkEditTemplateRemoteMapper.spec.ts @@ -0,0 +1,79 @@ +import { describe, expect, it } from 'vitest' +import { + mapBulkEditTemplateFromRemote, + mapBulkEditTemplateToUpsertRequest +} from '../bulkEditTemplateRemoteMapper' + +describe('bulkEditTemplateRemoteMapper', () => { + it('maps remote template to local template record', () => { + const local = mapBulkEditTemplateFromRemote({ + id: 'tpl-1', + name: 'OpenAI OAuth Shared', + scope_platform: 'openai', + scope_type: 'oauth', + share_scope: 'groups', + group_ids: [9, 3, 3], + state: { enableBaseUrl: true }, + created_by: 8, + updated_by: 8, + created_at: 100, + updated_at: 200 + }) + + expect(local).toEqual({ + id: 'tpl-1', + name: 'OpenAI OAuth Shared', + scopePlatform: 'openai', + scopeType: 'oauth', + shareScope: 'groups', + groupIds: [3, 9], + state: { enableBaseUrl: true }, + updatedAt: 200, + ownerUserId: 8 + }) + }) + + it('normalizes malformed remote payload', () => { + const local = mapBulkEditTemplateFromRemote({ + id: 'tpl-2', + name: 'Broken', + scope_platform: 'openai', + scope_type: 'oauth', + share_scope: 'bad' as any, + group_ids: [0, 2, 2, 1], + state: undefined as any, + created_by: 0, + updated_by: 0, + created_at: 0, + updated_at: 0 + }) + + expect(local.shareScope).toBe('private') + expect(local.groupIds).toEqual([1, 2]) + expect(local.ownerUserId).toBeNull() + expect(typeof local.updatedAt).toBe('number') + expect(local.updatedAt).toBeGreaterThan(0) + }) + + it('maps local model to upsert request', () => { + const request = mapBulkEditTemplateToUpsertRequest({ + id: 'tpl-3', + name: 'Team Template', + scopePlatform: 'openai', + scopeType: 'apikey', + shareScope: 'team', + groupIds: [8, 2, 8], + state: { enableGroups: true } + }) + + expect(request).toEqual({ + id: 'tpl-3', + name: 'Team Template', + scope_platform: 'openai', + scope_type: 'apikey', + share_scope: 'team', + group_ids: [2, 8], + state: { enableGroups: true } + }) + }) +}) diff --git a/frontend/src/components/account/__tests__/bulkEditTemplateState.spec.ts b/frontend/src/components/account/__tests__/bulkEditTemplateState.spec.ts new file mode 100644 index 000000000..2866efb6f --- /dev/null +++ b/frontend/src/components/account/__tests__/bulkEditTemplateState.spec.ts @@ -0,0 +1,81 @@ +import { describe, expect, it } from 'vitest' +import { OPENAI_WS_MODE_CTX_POOL } from '@/utils/openaiWsMode' +import { + createBulkEditTemplateStateSnapshot, + createDefaultBulkEditTemplateState, + normalizeBulkEditTemplateState +} from '../bulkEditTemplateState' + +describe('bulkEditTemplateState', () => { + it('builds default state', () => { + const state = createDefaultBulkEditTemplateState() + expect(state.enableBaseUrl).toBe(false) + expect(state.openAIWSMode).toBe('off') + expect(state.modelMappings).toEqual([]) + expect(state.groupIds).toEqual([]) + }) + + it('normalizes invalid input to defaults', () => { + const state = normalizeBulkEditTemplateState(null) + expect(state).toEqual(createDefaultBulkEditTemplateState()) + }) + + it('normalizes and sanitizes mixed payload', () => { + const state = normalizeBulkEditTemplateState({ + enableBaseUrl: true, + baseUrl: 'https://api.example.com', + modelRestrictionMode: 'mapping', + allowedModels: ['a', 1, 'b'], + modelMappings: [{ from: 'x', to: 'y' }, { from: 'bad' }, 'bad-item'], + selectedErrorCodes: [429, '503', 529.8], + openAIWSMode: 'ctx_pool', + proxyId: 18.9, + concurrency: 0, + priority: 9.4, + rateMultiplier: -2, + status: 'inactive', + groupIds: [1, 2.7, '3'] + }) + + expect(state.enableBaseUrl).toBe(true) + expect(state.baseUrl).toBe('https://api.example.com') + expect(state.modelRestrictionMode).toBe('mapping') + expect(state.allowedModels).toEqual(['a', 'b']) + expect(state.modelMappings).toEqual([{ from: 'x', to: 'y' }]) + expect(state.selectedErrorCodes).toEqual([429, 529]) + expect(state.openAIWSMode).toBe('ctx_pool') + expect(state.proxyId).toBe(18) + expect(state.concurrency).toBe(1) + expect(state.priority).toBe(9) + expect(state.rateMultiplier).toBe(0) + expect(state.status).toBe('inactive') + expect(state.groupIds).toEqual([1, 2]) + }) + + it('falls back for invalid ws mode and proxy id', () => { + const state = normalizeBulkEditTemplateState({ + openAIWSMode: 'invalid-mode', + proxyId: 0 + }) + expect(state.openAIWSMode).toBe('off') + expect(state.proxyId).toBeNull() + }) + + it('creates snapshot as deep-normalized clone', () => { + const source = createDefaultBulkEditTemplateState() + source.openAIWSMode = OPENAI_WS_MODE_CTX_POOL + source.allowedModels.push('gpt-5.2-codex') + source.modelMappings.push({ from: 'a', to: 'b' }) + source.groupIds.push(9) + + const snapshot = createBulkEditTemplateStateSnapshot(source) + source.allowedModels[0] = 'mutated' + source.modelMappings[0].to = 'changed' + source.groupIds[0] = 0 + + expect(snapshot.openAIWSMode).toBe('ctx_pool') + expect(snapshot.allowedModels).toEqual(['gpt-5.2-codex']) + expect(snapshot.modelMappings).toEqual([{ from: 'a', to: 'b' }]) + expect(snapshot.groupIds).toEqual([9]) + }) +}) diff --git a/frontend/src/components/account/__tests__/bulkEditTemplateStore.spec.ts b/frontend/src/components/account/__tests__/bulkEditTemplateStore.spec.ts new file mode 100644 index 000000000..05b3d0d7d --- /dev/null +++ b/frontend/src/components/account/__tests__/bulkEditTemplateStore.spec.ts @@ -0,0 +1,115 @@ +import { describe, expect, it } from 'vitest' +import { + filterBulkEditTemplateRecordsByScope, + normalizeBulkEditTemplateGroupIDs, + normalizeBulkEditTemplateShareScope, + parseBulkEditTemplateRecords, + removeBulkEditTemplateRecord, + serializeBulkEditTemplateRecords, + upsertBulkEditTemplateRecord +} from '../bulkEditTemplateStore' + +const templates = [ + { + id: 'a', + name: 'OpenAI OAuth Default', + scopePlatform: 'openai', + scopeType: 'oauth', + shareScope: 'private', + groupIds: [], + state: { foo: 1 }, + updatedAt: 10 + }, + { + id: 'b', + name: 'OpenAI OAuth Latest', + scopePlatform: 'openai', + scopeType: 'oauth', + shareScope: 'team', + groupIds: [], + state: { foo: 2 }, + updatedAt: 20 + }, + { + id: 'c', + name: 'OpenAI APIKey', + scopePlatform: 'openai', + scopeType: 'apikey', + shareScope: 'groups', + groupIds: [9], + state: { foo: 3 }, + updatedAt: 5 + } +] + +describe('bulkEditTemplateStore', () => { + it('parses valid templates and ignores invalid payload', () => { + const parsed = parseBulkEditTemplateRecords(JSON.stringify(templates)) + expect(parsed).toHaveLength(3) + expect(parsed[0].shareScope).toBe('private') + expect(parsed[2].groupIds).toEqual([9]) + expect(parseBulkEditTemplateRecords('')).toEqual([]) + expect(parseBulkEditTemplateRecords('invalid-json')).toEqual([]) + expect(parseBulkEditTemplateRecords(JSON.stringify({ foo: 'bar' }))).toEqual([]) + }) + + it('normalizes legacy payload without share metadata', () => { + const parsed = parseBulkEditTemplateRecords( + JSON.stringify([ + { + id: 'legacy', + name: 'Legacy', + scopePlatform: 'openai', + scopeType: 'oauth', + state: { foo: 1 }, + updatedAt: 1 + } + ]) + ) + expect(parsed).toHaveLength(1) + expect(parsed[0].shareScope).toBe('private') + expect(parsed[0].groupIds).toEqual([]) + }) + + it('serializes templates', () => { + const raw = serializeBulkEditTemplateRecords(templates) + expect(typeof raw).toBe('string') + expect(parseBulkEditTemplateRecords(raw)).toHaveLength(3) + }) + + it('upserts by same scope + same name (case-insensitive)', () => { + const next = upsertBulkEditTemplateRecord(templates, { + id: 'd', + name: 'openai oauth default', + scopePlatform: 'openai', + scopeType: 'oauth', + shareScope: 'private', + groupIds: [], + state: { foo: 9 }, + updatedAt: 99 + }) + expect(next).toHaveLength(3) + expect(next.find((item) => item.id === 'd')).toBeTruthy() + expect(next.find((item) => item.id === 'a')).toBeFalsy() + }) + + it('removes template by id', () => { + const next = removeBulkEditTemplateRecord(templates, 'b') + expect(next).toHaveLength(2) + expect(next.find((item) => item.id === 'b')).toBeFalsy() + }) + + it('filters and sorts by scope', () => { + const scoped = filterBulkEditTemplateRecordsByScope(templates, 'openai', 'oauth') + expect(scoped.map((item) => item.id)).toEqual(['b', 'a']) + expect(filterBulkEditTemplateRecordsByScope(templates, '', 'oauth')).toEqual([]) + expect(filterBulkEditTemplateRecordsByScope(templates, 'openai', '')).toEqual([]) + }) + + it('normalizes share scope and group ids', () => { + expect(normalizeBulkEditTemplateShareScope('team')).toBe('team') + expect(normalizeBulkEditTemplateShareScope('groups')).toBe('groups') + expect(normalizeBulkEditTemplateShareScope('invalid')).toBe('private') + expect(normalizeBulkEditTemplateGroupIDs([3, 1, 3, 2.8, -1] as any)).toEqual([1, 2, 3]) + }) +}) diff --git a/frontend/src/components/account/bulkEditPayload.ts b/frontend/src/components/account/bulkEditPayload.ts new file mode 100644 index 000000000..e9ef20401 --- /dev/null +++ b/frontend/src/components/account/bulkEditPayload.ts @@ -0,0 +1,199 @@ +import type { AccountType } from '@/types' +import { buildModelMappingObject as buildModelMappingPayload } from '@/composables/useModelWhitelist' +import { isOpenAIWSModeEnabled, type OpenAIWSMode } from '@/utils/openaiWsMode' + +export interface BulkEditModelMapping { + from: string + to: string +} + +export interface BulkEditPayloadInput { + scopeType?: AccountType | '' | null + enableBaseUrl: boolean + enableModelRestriction: boolean + enableCustomErrorCodes: boolean + enableInterceptWarmup: boolean + enableOpenAIPassthrough: boolean + enableOpenAIWSMode: boolean + enableCodexCLIOnly: boolean + enableAnthropicPassthrough: boolean + enableProxy: boolean + enableConcurrency: boolean + enablePriority: boolean + enableRateMultiplier: boolean + enableStatus: boolean + enableGroups: boolean + baseUrl: string + modelRestrictionMode: 'whitelist' | 'mapping' + allowedModels: string[] + modelMappings: BulkEditModelMapping[] + selectedErrorCodes: number[] + interceptWarmupRequests: boolean + openAIPassthroughEnabled: boolean + openAIWSMode: OpenAIWSMode + codexCLIOnlyEnabled: boolean + anthropicPassthroughEnabled: boolean + proxyId: number | null + concurrency: number + priority: number + rateMultiplier: number + status: 'active' | 'inactive' + groupIds: number[] +} + +type BulkEditEnabledFlags = Pick< + BulkEditPayloadInput, + | 'enableBaseUrl' + | 'enableModelRestriction' + | 'enableCustomErrorCodes' + | 'enableInterceptWarmup' + | 'enableOpenAIPassthrough' + | 'enableOpenAIWSMode' + | 'enableCodexCLIOnly' + | 'enableAnthropicPassthrough' + | 'enableProxy' + | 'enableConcurrency' + | 'enablePriority' + | 'enableRateMultiplier' + | 'enableStatus' + | 'enableGroups' +> + +export const hasAnyBulkEditFieldEnabled = (flags: BulkEditEnabledFlags): boolean => { + return ( + flags.enableBaseUrl || + flags.enableModelRestriction || + flags.enableCustomErrorCodes || + flags.enableInterceptWarmup || + flags.enableOpenAIPassthrough || + flags.enableOpenAIWSMode || + flags.enableCodexCLIOnly || + flags.enableAnthropicPassthrough || + flags.enableProxy || + flags.enableConcurrency || + flags.enablePriority || + flags.enableRateMultiplier || + flags.enableStatus || + flags.enableGroups + ) +} + +export const buildBulkEditUpdatePayload = ( + input: BulkEditPayloadInput +): Record | null => { + const updates: Record = {} + const credentials: Record = {} + const extra: Record = {} + let credentialsChanged = false + let extraChanged = false + + if (input.enableProxy) { + // Backend expects `proxy_id: 0` to clear proxy. + updates.proxy_id = input.proxyId === null ? 0 : input.proxyId + } + + if (input.enableConcurrency) { + updates.concurrency = input.concurrency + } + + if (input.enablePriority) { + updates.priority = input.priority + } + + if (input.enableRateMultiplier) { + updates.rate_multiplier = input.rateMultiplier + } + + if (input.enableStatus) { + updates.status = input.status + } + + if (input.enableGroups) { + updates.group_ids = input.groupIds + } + + if (input.enableBaseUrl) { + const baseUrlValue = input.baseUrl.trim() + if (baseUrlValue) { + credentials.base_url = baseUrlValue + credentialsChanged = true + } + } + + if (input.enableModelRestriction) { + if (input.modelRestrictionMode === 'whitelist') { + if (input.allowedModels.length > 0) { + const mapping: Record = {} + for (const model of input.allowedModels) { + mapping[model] = model + } + credentials.model_mapping = mapping + credentialsChanged = true + } + } else { + const modelMapping = buildModelMappingPayload( + input.modelRestrictionMode, + input.allowedModels, + input.modelMappings + ) + if (modelMapping) { + credentials.model_mapping = modelMapping + credentialsChanged = true + } + } + } + + if (input.enableCustomErrorCodes) { + credentials.custom_error_codes_enabled = true + credentials.custom_error_codes = [...input.selectedErrorCodes] + credentialsChanged = true + } + + if (input.enableInterceptWarmup) { + credentials.intercept_warmup_requests = input.interceptWarmupRequests + credentialsChanged = true + } + + if (input.enableOpenAIPassthrough) { + extra.openai_passthrough = input.openAIPassthroughEnabled + // Keep backward compatibility key aligned. + extra.openai_oauth_passthrough = input.openAIPassthroughEnabled + extraChanged = true + } + + if (input.enableOpenAIWSMode) { + if (input.scopeType === 'oauth') { + extra.openai_oauth_responses_websockets_v2_mode = input.openAIWSMode + extra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled( + input.openAIWSMode + ) + extraChanged = true + } else if (input.scopeType === 'apikey') { + extra.openai_apikey_responses_websockets_v2_mode = input.openAIWSMode + extra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled( + input.openAIWSMode + ) + extraChanged = true + } + } + + if (input.enableCodexCLIOnly) { + extra.codex_cli_only = input.codexCLIOnlyEnabled + extraChanged = true + } + + if (input.enableAnthropicPassthrough) { + extra.anthropic_passthrough = input.anthropicPassthroughEnabled + extraChanged = true + } + + if (credentialsChanged) { + updates.credentials = credentials + } + + if (extraChanged) { + updates.extra = extra + } + + return Object.keys(updates).length > 0 ? updates : null +} diff --git a/frontend/src/components/account/bulkEditScopeProfile.ts b/frontend/src/components/account/bulkEditScopeProfile.ts new file mode 100644 index 000000000..03991fc73 --- /dev/null +++ b/frontend/src/components/account/bulkEditScopeProfile.ts @@ -0,0 +1,69 @@ +import type { AccountPlatform, AccountType } from '@/types' + +export interface BulkEditScopeCapabilities { + supportsBaseUrl: boolean + supportsModelRestriction: boolean + supportsCustomErrorCodes: boolean + supportsInterceptWarmup: boolean + supportsOpenAIPassthrough: boolean + supportsOpenAIWSMode: boolean + supportsCodexCLIOnly: boolean + supportsAnthropicPassthrough: boolean +} + +export const BULK_EDIT_SCOPE_EDITOR_KEYS = [ + 'anthropic:oauth', + 'anthropic:setup-token', + 'anthropic:apikey', + 'openai:oauth', + 'openai:apikey', + 'gemini:oauth', + 'gemini:apikey', + 'antigravity:oauth', + 'antigravity:upstream', + 'sora:oauth', + 'sora:apikey' +] as const + +export type BulkEditScopeEditorKey = (typeof BULK_EDIT_SCOPE_EDITOR_KEYS)[number] + +const bulkEditScopeEditorKeySet = new Set(BULK_EDIT_SCOPE_EDITOR_KEYS) + +const isOpenAIScope = (platform?: AccountPlatform | '' | null) => platform === 'openai' +const isAnthropicScope = (platform?: AccountPlatform | '' | null) => platform === 'anthropic' +const isAnthropicOrAntigravityScope = (platform?: AccountPlatform | '' | null) => + platform === 'anthropic' || platform === 'antigravity' +const isAPIKeyScope = (type?: AccountType | '' | null) => type === 'apikey' +const isOpenAIOAuthScope = ( + platform?: AccountPlatform | '' | null, + type?: AccountType | '' | null +) => isOpenAIScope(platform) && type === 'oauth' +const isOpenAIOAuthOrAPIKeyScope = ( + platform?: AccountPlatform | '' | null, + type?: AccountType | '' | null +) => isOpenAIScope(platform) && (type === 'oauth' || type === 'apikey') + +export const resolveBulkEditScopeEditorKey = ( + platform?: AccountPlatform | '' | null, + type?: AccountType | '' | null +): BulkEditScopeEditorKey | null => { + if (!platform || !type) return null + const key = `${platform}:${type}` as BulkEditScopeEditorKey + return bulkEditScopeEditorKeySet.has(key) ? key : null +} + +export const resolveBulkEditScopeCapabilities = ( + platform?: AccountPlatform | '' | null, + type?: AccountType | '' | null +): BulkEditScopeCapabilities => { + return { + supportsBaseUrl: type === 'apikey' || type === 'upstream', + supportsModelRestriction: isAPIKeyScope(type) && platform !== 'antigravity', + supportsCustomErrorCodes: isAPIKeyScope(type), + supportsInterceptWarmup: isAnthropicOrAntigravityScope(platform), + supportsOpenAIPassthrough: isOpenAIOAuthOrAPIKeyScope(platform, type), + supportsOpenAIWSMode: isOpenAIOAuthOrAPIKeyScope(platform, type), + supportsCodexCLIOnly: isOpenAIOAuthScope(platform, type), + supportsAnthropicPassthrough: isAnthropicScope(platform) && isAPIKeyScope(type) + } +} diff --git a/frontend/src/components/account/bulkEditScoped/BulkEditAnthropicApiKeyModal.vue b/frontend/src/components/account/bulkEditScoped/BulkEditAnthropicApiKeyModal.vue new file mode 100644 index 000000000..12aac17ef --- /dev/null +++ b/frontend/src/components/account/bulkEditScoped/BulkEditAnthropicApiKeyModal.vue @@ -0,0 +1,32 @@ + + + diff --git a/frontend/src/components/account/bulkEditScoped/BulkEditAnthropicOAuthModal.vue b/frontend/src/components/account/bulkEditScoped/BulkEditAnthropicOAuthModal.vue new file mode 100644 index 000000000..acc11de47 --- /dev/null +++ b/frontend/src/components/account/bulkEditScoped/BulkEditAnthropicOAuthModal.vue @@ -0,0 +1,32 @@ + + + diff --git a/frontend/src/components/account/bulkEditScoped/BulkEditAnthropicSetupTokenModal.vue b/frontend/src/components/account/bulkEditScoped/BulkEditAnthropicSetupTokenModal.vue new file mode 100644 index 000000000..688eaa9c0 --- /dev/null +++ b/frontend/src/components/account/bulkEditScoped/BulkEditAnthropicSetupTokenModal.vue @@ -0,0 +1,32 @@ + + + diff --git a/frontend/src/components/account/bulkEditScoped/BulkEditAntigravityOAuthModal.vue b/frontend/src/components/account/bulkEditScoped/BulkEditAntigravityOAuthModal.vue new file mode 100644 index 000000000..844153110 --- /dev/null +++ b/frontend/src/components/account/bulkEditScoped/BulkEditAntigravityOAuthModal.vue @@ -0,0 +1,32 @@ + + + diff --git a/frontend/src/components/account/bulkEditScoped/BulkEditAntigravityUpstreamModal.vue b/frontend/src/components/account/bulkEditScoped/BulkEditAntigravityUpstreamModal.vue new file mode 100644 index 000000000..8d8d75c67 --- /dev/null +++ b/frontend/src/components/account/bulkEditScoped/BulkEditAntigravityUpstreamModal.vue @@ -0,0 +1,32 @@ + + + diff --git a/frontend/src/components/account/bulkEditScoped/BulkEditGeminiApiKeyModal.vue b/frontend/src/components/account/bulkEditScoped/BulkEditGeminiApiKeyModal.vue new file mode 100644 index 000000000..a6e265c3b --- /dev/null +++ b/frontend/src/components/account/bulkEditScoped/BulkEditGeminiApiKeyModal.vue @@ -0,0 +1,32 @@ + + + diff --git a/frontend/src/components/account/bulkEditScoped/BulkEditGeminiOAuthModal.vue b/frontend/src/components/account/bulkEditScoped/BulkEditGeminiOAuthModal.vue new file mode 100644 index 000000000..7f3320fcf --- /dev/null +++ b/frontend/src/components/account/bulkEditScoped/BulkEditGeminiOAuthModal.vue @@ -0,0 +1,32 @@ + + + diff --git a/frontend/src/components/account/bulkEditScoped/BulkEditOpenAIApiKeyModal.vue b/frontend/src/components/account/bulkEditScoped/BulkEditOpenAIApiKeyModal.vue new file mode 100644 index 000000000..7695c7f4a --- /dev/null +++ b/frontend/src/components/account/bulkEditScoped/BulkEditOpenAIApiKeyModal.vue @@ -0,0 +1,32 @@ + + + diff --git a/frontend/src/components/account/bulkEditScoped/BulkEditOpenAIOAuthModal.vue b/frontend/src/components/account/bulkEditScoped/BulkEditOpenAIOAuthModal.vue new file mode 100644 index 000000000..26ef70b54 --- /dev/null +++ b/frontend/src/components/account/bulkEditScoped/BulkEditOpenAIOAuthModal.vue @@ -0,0 +1,32 @@ + + + diff --git a/frontend/src/components/account/bulkEditScoped/BulkEditSoraApiKeyModal.vue b/frontend/src/components/account/bulkEditScoped/BulkEditSoraApiKeyModal.vue new file mode 100644 index 000000000..c1f4eb2c1 --- /dev/null +++ b/frontend/src/components/account/bulkEditScoped/BulkEditSoraApiKeyModal.vue @@ -0,0 +1,32 @@ + + + diff --git a/frontend/src/components/account/bulkEditScoped/BulkEditSoraOAuthModal.vue b/frontend/src/components/account/bulkEditScoped/BulkEditSoraOAuthModal.vue new file mode 100644 index 000000000..58b3b0468 --- /dev/null +++ b/frontend/src/components/account/bulkEditScoped/BulkEditSoraOAuthModal.vue @@ -0,0 +1,32 @@ + + + diff --git a/frontend/src/components/account/bulkEditTemplateRemoteMapper.ts b/frontend/src/components/account/bulkEditTemplateRemoteMapper.ts new file mode 100644 index 000000000..cbee45c0b --- /dev/null +++ b/frontend/src/components/account/bulkEditTemplateRemoteMapper.ts @@ -0,0 +1,55 @@ +import type { AccountPlatform, AccountType } from '@/types' +import type { + BulkEditTemplateRecord as BulkEditTemplateRemoteRecord, + BulkEditTemplateShareScope, + UpsertBulkEditTemplateRequest +} from '@/api/admin/bulkEditTemplates' +import { + normalizeBulkEditTemplateGroupIDs, + normalizeBulkEditTemplateShareScope, + type BulkEditTemplateRecord +} from './bulkEditTemplateStore' + +const normalizeTimestamp = (value: unknown): number => + typeof value === 'number' && Number.isFinite(value) && value > 0 + ? Math.floor(value) + : Date.now() + +const normalizeOwnerUserID = (value: unknown): number | null => + typeof value === 'number' && Number.isFinite(value) && value > 0 ? Math.floor(value) : null + +export const mapBulkEditTemplateFromRemote = >( + record: BulkEditTemplateRemoteRecord +): BulkEditTemplateRecord => ({ + id: typeof record.id === 'string' ? record.id : '', + name: typeof record.name === 'string' ? record.name : '', + scopePlatform: (record.scope_platform ?? '') as AccountPlatform | '', + scopeType: (record.scope_type ?? '') as AccountType | '', + shareScope: normalizeBulkEditTemplateShareScope(record.share_scope), + groupIds: normalizeBulkEditTemplateGroupIDs(record.group_ids), + state: (record.state ?? ({} as TState)) as TState, + updatedAt: normalizeTimestamp(record.updated_at), + ownerUserId: normalizeOwnerUserID(record.created_by) +}) + +export interface BulkEditTemplateUpsertModel> { + id?: string + name: string + scopePlatform: AccountPlatform | '' + scopeType: AccountType | '' + shareScope: BulkEditTemplateShareScope + groupIds: number[] + state: TState +} + +export const mapBulkEditTemplateToUpsertRequest = >( + model: BulkEditTemplateUpsertModel +): UpsertBulkEditTemplateRequest => ({ + ...(model.id ? { id: model.id } : {}), + name: model.name, + scope_platform: model.scopePlatform, + scope_type: model.scopeType, + share_scope: normalizeBulkEditTemplateShareScope(model.shareScope), + group_ids: normalizeBulkEditTemplateGroupIDs(model.groupIds), + state: model.state +}) diff --git a/frontend/src/components/account/bulkEditTemplateState.ts b/frontend/src/components/account/bulkEditTemplateState.ts new file mode 100644 index 000000000..f2c4e00dd --- /dev/null +++ b/frontend/src/components/account/bulkEditTemplateState.ts @@ -0,0 +1,196 @@ +import { + OPENAI_WS_MODE_OFF, + normalizeOpenAIWSMode, + type OpenAIWSMode +} from '@/utils/openaiWsMode' +import type { BulkEditModelMapping } from './bulkEditPayload' + +export interface BulkEditTemplateState { + enableBaseUrl: boolean + enableModelRestriction: boolean + enableCustomErrorCodes: boolean + enableInterceptWarmup: boolean + enableOpenAIPassthrough: boolean + enableOpenAIWSMode: boolean + enableCodexCLIOnly: boolean + enableAnthropicPassthrough: boolean + enableProxy: boolean + enableConcurrency: boolean + enablePriority: boolean + enableRateMultiplier: boolean + enableStatus: boolean + enableGroups: boolean + baseUrl: string + modelRestrictionMode: 'whitelist' | 'mapping' + allowedModels: string[] + modelMappings: BulkEditModelMapping[] + selectedErrorCodes: number[] + interceptWarmupRequests: boolean + openAIPassthroughEnabled: boolean + openAIWSMode: OpenAIWSMode + codexCLIOnlyEnabled: boolean + anthropicPassthroughEnabled: boolean + proxyId: number | null + concurrency: number + priority: number + rateMultiplier: number + status: 'active' | 'inactive' + groupIds: number[] +} + +const isRecord = (value: unknown): value is Record => + typeof value === 'object' && value !== null + +const toBoolean = (value: unknown, fallback: boolean): boolean => + typeof value === 'boolean' ? value : fallback + +const toStringValue = (value: unknown, fallback: string): string => + typeof value === 'string' ? value : fallback + +const toPositiveInteger = (value: unknown, fallback: number): number => { + if (typeof value !== 'number' || !Number.isFinite(value)) return fallback + return Math.max(1, Math.floor(value)) +} + +const toRateMultiplier = (value: unknown, fallback: number): number => { + if (typeof value !== 'number' || !Number.isFinite(value)) return fallback + return Math.max(0, value) +} + +const toNumberList = (value: unknown): number[] => { + if (!Array.isArray(value)) return [] + return value + .filter((item): item is number => typeof item === 'number' && Number.isFinite(item)) + .map((item) => Math.floor(item)) +} + +const toStringList = (value: unknown): string[] => { + if (!Array.isArray(value)) return [] + return value.filter((item): item is string => typeof item === 'string') +} + +const toModelMappings = (value: unknown): BulkEditModelMapping[] => { + if (!Array.isArray(value)) return [] + const next: BulkEditModelMapping[] = [] + for (const item of value) { + if (!isRecord(item)) continue + if (typeof item.from !== 'string' || typeof item.to !== 'string') continue + next.push({ from: item.from, to: item.to }) + } + return next +} + +const toStatus = (value: unknown): 'active' | 'inactive' => { + return value === 'inactive' ? 'inactive' : 'active' +} + +const toModelRestrictionMode = (value: unknown): 'whitelist' | 'mapping' => { + return value === 'mapping' ? 'mapping' : 'whitelist' +} + +const toProxyID = (value: unknown): number | null => { + if (value === null) return null + if (typeof value !== 'number' || !Number.isFinite(value)) return null + return value > 0 ? Math.floor(value) : null +} + +export const createDefaultBulkEditTemplateState = (): BulkEditTemplateState => ({ + enableBaseUrl: false, + enableModelRestriction: false, + enableCustomErrorCodes: false, + enableInterceptWarmup: false, + enableOpenAIPassthrough: false, + enableOpenAIWSMode: false, + enableCodexCLIOnly: false, + enableAnthropicPassthrough: false, + enableProxy: false, + enableConcurrency: false, + enablePriority: false, + enableRateMultiplier: false, + enableStatus: false, + enableGroups: false, + baseUrl: '', + modelRestrictionMode: 'whitelist', + allowedModels: [], + modelMappings: [], + selectedErrorCodes: [], + interceptWarmupRequests: false, + openAIPassthroughEnabled: false, + openAIWSMode: OPENAI_WS_MODE_OFF, + codexCLIOnlyEnabled: false, + anthropicPassthroughEnabled: false, + proxyId: null, + concurrency: 1, + priority: 1, + rateMultiplier: 1, + status: 'active', + groupIds: [] +}) + +export const normalizeBulkEditTemplateState = (value: unknown): BulkEditTemplateState => { + const defaults = createDefaultBulkEditTemplateState() + if (!isRecord(value)) return defaults + + const normalizedWSMode = normalizeOpenAIWSMode(value.openAIWSMode) + + return { + enableBaseUrl: toBoolean(value.enableBaseUrl, defaults.enableBaseUrl), + enableModelRestriction: toBoolean( + value.enableModelRestriction, + defaults.enableModelRestriction + ), + enableCustomErrorCodes: toBoolean( + value.enableCustomErrorCodes, + defaults.enableCustomErrorCodes + ), + enableInterceptWarmup: toBoolean( + value.enableInterceptWarmup, + defaults.enableInterceptWarmup + ), + enableOpenAIPassthrough: toBoolean( + value.enableOpenAIPassthrough, + defaults.enableOpenAIPassthrough + ), + enableOpenAIWSMode: toBoolean(value.enableOpenAIWSMode, defaults.enableOpenAIWSMode), + enableCodexCLIOnly: toBoolean(value.enableCodexCLIOnly, defaults.enableCodexCLIOnly), + enableAnthropicPassthrough: toBoolean( + value.enableAnthropicPassthrough, + defaults.enableAnthropicPassthrough + ), + enableProxy: toBoolean(value.enableProxy, defaults.enableProxy), + enableConcurrency: toBoolean(value.enableConcurrency, defaults.enableConcurrency), + enablePriority: toBoolean(value.enablePriority, defaults.enablePriority), + enableRateMultiplier: toBoolean(value.enableRateMultiplier, defaults.enableRateMultiplier), + enableStatus: toBoolean(value.enableStatus, defaults.enableStatus), + enableGroups: toBoolean(value.enableGroups, defaults.enableGroups), + baseUrl: toStringValue(value.baseUrl, defaults.baseUrl), + modelRestrictionMode: toModelRestrictionMode(value.modelRestrictionMode), + allowedModels: toStringList(value.allowedModels), + modelMappings: toModelMappings(value.modelMappings), + selectedErrorCodes: toNumberList(value.selectedErrorCodes), + interceptWarmupRequests: toBoolean( + value.interceptWarmupRequests, + defaults.interceptWarmupRequests + ), + openAIPassthroughEnabled: toBoolean( + value.openAIPassthroughEnabled, + defaults.openAIPassthroughEnabled + ), + openAIWSMode: normalizedWSMode ?? defaults.openAIWSMode, + codexCLIOnlyEnabled: toBoolean(value.codexCLIOnlyEnabled, defaults.codexCLIOnlyEnabled), + anthropicPassthroughEnabled: toBoolean( + value.anthropicPassthroughEnabled, + defaults.anthropicPassthroughEnabled + ), + proxyId: toProxyID(value.proxyId), + concurrency: toPositiveInteger(value.concurrency, defaults.concurrency), + priority: toPositiveInteger(value.priority, defaults.priority), + rateMultiplier: toRateMultiplier(value.rateMultiplier, defaults.rateMultiplier), + status: toStatus(value.status), + groupIds: toNumberList(value.groupIds) + } +} + +export const createBulkEditTemplateStateSnapshot = ( + state: BulkEditTemplateState +): BulkEditTemplateState => normalizeBulkEditTemplateState(state) diff --git a/frontend/src/components/account/bulkEditTemplateStore.ts b/frontend/src/components/account/bulkEditTemplateStore.ts new file mode 100644 index 000000000..9d6405150 --- /dev/null +++ b/frontend/src/components/account/bulkEditTemplateStore.ts @@ -0,0 +1,117 @@ +import type { AccountPlatform, AccountType } from '@/types' + +export const BULK_EDIT_TEMPLATES_STORAGE_KEY = 'admin.bulk_edit_templates.v1' +export type BulkEditTemplateShareScope = 'private' | 'team' | 'groups' + +export const normalizeBulkEditTemplateShareScope = ( + value: unknown +): BulkEditTemplateShareScope => { + if (value === 'team') return 'team' + if (value === 'groups') return 'groups' + return 'private' +} + +export const normalizeBulkEditTemplateGroupIDs = (value: unknown): number[] => { + if (!Array.isArray(value)) return [] + const seen = new Set() + const next: number[] = [] + for (const item of value) { + if (typeof item !== 'number' || !Number.isFinite(item) || item <= 0) continue + const normalized = Math.floor(item) + if (seen.has(normalized)) continue + seen.add(normalized) + next.push(normalized) + } + return next.sort((a, b) => a - b) +} + +export interface BulkEditTemplateRecord> { + id: string + name: string + scopePlatform: AccountPlatform | '' + scopeType: AccountType | '' + shareScope: BulkEditTemplateShareScope + groupIds: number[] + state: TState + updatedAt: number + ownerUserId?: number | null +} + +export const parseBulkEditTemplateRecords = >( + raw: string | null | undefined +): BulkEditTemplateRecord[] => { + if (!raw) return [] + + try { + const parsed = JSON.parse(raw) + if (!Array.isArray(parsed)) return [] + const records: BulkEditTemplateRecord[] = [] + for (const item of parsed) { + if (!item || typeof item !== 'object') continue + if (typeof item.id !== 'string' || typeof item.name !== 'string') continue + if (typeof item.scopePlatform !== 'string' || typeof item.scopeType !== 'string') continue + if (typeof item.updatedAt !== 'number' || !Number.isFinite(item.updatedAt)) continue + + const ownerUserId = + typeof item.ownerUserId === 'number' && Number.isFinite(item.ownerUserId) + ? Math.floor(item.ownerUserId) + : null + + records.push({ + id: item.id, + name: item.name, + scopePlatform: item.scopePlatform, + scopeType: item.scopeType, + shareScope: normalizeBulkEditTemplateShareScope(item.shareScope), + groupIds: normalizeBulkEditTemplateGroupIDs(item.groupIds), + state: (item.state as TState) ?? ({} as TState), + updatedAt: item.updatedAt, + ownerUserId + }) + } + return records + } catch { + return [] + } +} + +export const serializeBulkEditTemplateRecords = >( + templates: BulkEditTemplateRecord[] +): string => JSON.stringify(templates) + +export const upsertBulkEditTemplateRecord = >( + templates: BulkEditTemplateRecord[], + template: BulkEditTemplateRecord +): BulkEditTemplateRecord[] => { + const next = [...templates] + const existingIdx = next.findIndex( + (item) => + item.scopePlatform === template.scopePlatform && + item.scopeType === template.scopeType && + item.name.trim().toLowerCase() === template.name.trim().toLowerCase() + ) + if (existingIdx >= 0) { + next[existingIdx] = template + } else { + next.push(template) + } + return next +} + +export const removeBulkEditTemplateRecord = >( + templates: BulkEditTemplateRecord[], + templateID: string +): BulkEditTemplateRecord[] => templates.filter((item) => item.id !== templateID) + +export const filterBulkEditTemplateRecordsByScope = >( + templates: BulkEditTemplateRecord[], + scopePlatform?: AccountPlatform | '' | null, + scopeType?: AccountType | '' | null +): BulkEditTemplateRecord[] => { + if (!scopePlatform || !scopeType) return [] + return templates + .filter( + (item) => item.scopePlatform === scopePlatform && item.scopeType === scopeType + ) + .sort((a, b) => b.updatedAt - a.updatedAt) +} diff --git a/frontend/src/components/account/index.ts b/frontend/src/components/account/index.ts index 0010e62c5..ca44d7c0b 100644 --- a/frontend/src/components/account/index.ts +++ b/frontend/src/components/account/index.ts @@ -1,6 +1,7 @@ export { default as CreateAccountModal } from './CreateAccountModal.vue' export { default as EditAccountModal } from './EditAccountModal.vue' export { default as BulkEditAccountModal } from './BulkEditAccountModal.vue' +export { default as BulkEditAccountScopedModal } from './BulkEditAccountScopedModal.vue' export { default as ReAuthAccountModal } from './ReAuthAccountModal.vue' export { default as OAuthAuthorizationFlow } from './OAuthAuthorizationFlow.vue' export { default as AccountStatusIndicator } from './AccountStatusIndicator.vue' diff --git a/frontend/src/components/admin/user/UserApiKeysModal.vue b/frontend/src/components/admin/user/UserApiKeysModal.vue index 7e3c8c258..c2159ff4e 100644 --- a/frontend/src/components/admin/user/UserApiKeysModal.vue +++ b/frontend/src/components/admin/user/UserApiKeysModal.vue @@ -1,5 +1,5 @@ diff --git a/frontend/src/components/keys/UseKeyModal.vue b/frontend/src/components/keys/UseKeyModal.vue index 4dd7ff0cb..4f6064102 100644 --- a/frontend/src/components/keys/UseKeyModal.vue +++ b/frontend/src/components/keys/UseKeyModal.vue @@ -268,7 +268,6 @@ const clientTabs = computed((): TabConfig[] => { case 'openai': return [ { id: 'codex', label: t('keys.useKeyModal.cliTabs.codexCli'), icon: TerminalIcon }, - { id: 'codex-ws', label: t('keys.useKeyModal.cliTabs.codexCliWs'), icon: TerminalIcon }, { id: 'opencode', label: t('keys.useKeyModal.cliTabs.opencode'), icon: TerminalIcon } ] case 'gemini': @@ -307,7 +306,7 @@ const showShellTabs = computed(() => activeClientTab.value !== 'opencode') const currentTabs = computed(() => { if (!showShellTabs.value) return [] - if (activeClientTab.value === 'codex' || activeClientTab.value === 'codex-ws') { + if (props.platform === 'openai') { return openaiTabs } return shellTabs @@ -402,9 +401,6 @@ const currentFiles = computed((): FileConfig[] => { switch (props.platform) { case 'openai': - if (activeClientTab.value === 'codex-ws') { - return generateOpenAIWsFiles(baseUrl, apiKey) - } return generateOpenAIFiles(baseUrl, apiKey) case 'gemini': return [generateGeminiCliContent(baseUrl, apiKey)] @@ -528,47 +524,6 @@ requires_openai_auth = true` ] } -function generateOpenAIWsFiles(baseUrl: string, apiKey: string): FileConfig[] { - const isWindows = activeTab.value === 'windows' - const configDir = isWindows ? '%userprofile%\\.codex' : '~/.codex' - - // config.toml content with WebSocket v2 - const configContent = `model_provider = "sub2api" -model = "gpt-5.3-codex" -model_reasoning_effort = "high" -network_access = "enabled" -disable_response_storage = true -windows_wsl_setup_acknowledged = true -model_verbosity = "high" - -[model_providers.sub2api] -name = "sub2api" -base_url = "${baseUrl}" -wire_api = "responses" -supports_websockets = true -requires_openai_auth = true - -[features] -responses_websockets_v2 = true` - - // auth.json content - const authContent = `{ - "OPENAI_API_KEY": "${apiKey}" -}` - - return [ - { - path: `${configDir}/config.toml`, - content: configContent, - hint: t('keys.useKeyModal.openai.configTomlHint') - }, - { - path: `${configDir}/auth.json`, - content: authContent - } - ] -} - function generateOpenCodeConfig(platform: string, baseUrl: string, apiKey: string, pathLabel?: string): FileConfig { const provider: Record = { [platform]: { diff --git a/frontend/src/components/layout/AppSidebar.vue b/frontend/src/components/layout/AppSidebar.vue index dcfc60bbb..afb7e2495 100644 --- a/frontend/src/components/layout/AppSidebar.vue +++ b/frontend/src/components/layout/AppSidebar.vue @@ -302,26 +302,6 @@ const CreditCardIcon = { ) } -const RechargeSubscriptionIcon = { - render: () => - h( - 'svg', - { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, - [ - h('path', { - 'stroke-linecap': 'round', - 'stroke-linejoin': 'round', - d: 'M2.25 7.5A2.25 2.25 0 014.5 5.25h15A2.25 2.25 0 0121.75 7.5v9A2.25 2.25 0 0119.5 18.75h-15A2.25 2.25 0 012.25 16.5v-9z' - }), - h('path', { - 'stroke-linecap': 'round', - 'stroke-linejoin': 'round', - d: 'M6.75 12h3m4.5 0h3m-3-3v6' - }) - ] - ) -} - const GlobeIcon = { render: () => h( @@ -352,36 +332,6 @@ const ServerIcon = { ) } -const DatabaseIcon = { - render: () => - h( - 'svg', - { fill: 'none', viewBox: '0 0 24 24', stroke: 'currentColor', 'stroke-width': '1.5' }, - [ - h('path', { - 'stroke-linecap': 'round', - 'stroke-linejoin': 'round', - d: 'M3.75 5.25C3.75 4.007 7.443 3 12 3s8.25 1.007 8.25 2.25S16.557 7.5 12 7.5 3.75 6.493 3.75 5.25z' - }), - h('path', { - 'stroke-linecap': 'round', - 'stroke-linejoin': 'round', - d: 'M3.75 5.25v4.5C3.75 10.993 7.443 12 12 12s8.25-1.007 8.25-2.25v-4.5' - }), - h('path', { - 'stroke-linecap': 'round', - 'stroke-linejoin': 'round', - d: 'M3.75 9.75v4.5c0 1.243 3.693 2.25 8.25 2.25s8.25-1.007 8.25-2.25v-4.5' - }), - h('path', { - 'stroke-linecap': 'round', - 'stroke-linejoin': 'round', - d: 'M3.75 14.25v4.5C3.75 19.993 7.443 21 12 21s8.25-1.007 8.25-2.25v-4.5' - }) - ] - ) -} - const BellIcon = { render: () => h( @@ -522,7 +472,7 @@ const userNavItems = computed((): NavItem[] => { { path: '/purchase', label: t('nav.buySubscription'), - icon: RechargeSubscriptionIcon, + icon: CreditCardIcon, hideInSimpleMode: true } ] @@ -553,7 +503,7 @@ const personalNavItems = computed((): NavItem[] => { { path: '/purchase', label: t('nav.buySubscription'), - icon: RechargeSubscriptionIcon, + icon: CreditCardIcon, hideInSimpleMode: true } ] @@ -607,7 +557,6 @@ const adminNavItems = computed((): NavItem[] => { if (authStore.isSimpleMode) { const filtered = baseItems.filter(item => !item.hideInSimpleMode) filtered.push({ path: '/keys', label: t('nav.apiKeys'), icon: KeyIcon }) - filtered.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon }) filtered.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon }) // Add admin custom menu items after settings for (const cm of customMenuItemsForAdmin.value) { @@ -616,7 +565,6 @@ const adminNavItems = computed((): NavItem[] => { return filtered } - baseItems.push({ path: '/admin/data-management', label: t('nav.dataManagement'), icon: DatabaseIcon }) baseItems.push({ path: '/admin/settings', label: t('nav.settings'), icon: CogIcon }) // Add admin custom menu items after settings for (const cm of customMenuItemsForAdmin.value) { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 41edeb6a0..6e387b29d 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -280,7 +280,7 @@ export default { logout: 'Logout', github: 'GitHub', mySubscriptions: 'My Subscriptions', - buySubscription: 'Recharge / Subscription', + buySubscription: 'Purchase Subscription', docs: 'Docs', sora: 'Sora Studio' }, @@ -408,12 +408,9 @@ export default { day: 'Day', hour: 'Hour', modelDistribution: 'Model Distribution', - groupDistribution: 'Group Usage Distribution', tokenUsageTrend: 'Token Usage Trend', noDataAvailable: 'No data available', model: 'Model', - group: 'Group', - noGroup: 'No Group', requests: 'Requests', tokens: 'Tokens', actual: 'Actual', @@ -504,7 +501,6 @@ export default { claudeCode: 'Claude Code', geminiCli: 'Gemini CLI', codexCli: 'Codex CLI', - codexCliWs: 'Codex CLI (WebSocket)', opencode: 'OpenCode', }, antigravity: { @@ -560,19 +556,6 @@ export default { resetQuotaConfirmMessage: 'Are you sure you want to reset the used quota (${used}) for key "{name}" to 0? This action cannot be undone.', quotaResetSuccess: 'Quota reset successfully', failedToResetQuota: 'Failed to reset quota', - rateLimitColumn: 'Rate Limit', - rateLimitSection: 'Rate Limit', - resetUsage: 'Reset', - rateLimit5h: '5-Hour Limit (USD)', - rateLimit1d: 'Daily Limit (USD)', - rateLimit7d: '7-Day Limit (USD)', - rateLimitHint: 'Set the maximum spending for this key within each time window. 0 = unlimited.', - rateLimitUsage: 'Rate Limit Usage', - resetRateLimitUsage: 'Reset Rate Limit Usage', - resetRateLimitTitle: 'Confirm Reset Rate Limit', - resetRateLimitConfirmMessage: 'Are you sure you want to reset the rate limit usage for key "{name}"? All time window usage will be reset to zero. This action cannot be undone.', - rateLimitResetSuccess: 'Rate limit usage reset successfully', - failedToResetRateLimit: 'Failed to reset rate limit usage', expiration: 'Expiration', expiresInDays: '{days} days', extendDays: '+{days} days', @@ -848,12 +831,9 @@ export default { day: 'Day', hour: 'Hour', modelDistribution: 'Model Distribution', - groupDistribution: 'Group Usage Distribution', tokenUsageTrend: 'Token Usage Trend', userUsageTrend: 'User Usage Trend (Top 12)', model: 'Model', - group: 'Group', - noGroup: 'No Group', requests: 'Requests', tokens: 'Tokens', actual: 'Actual', @@ -1096,9 +1076,6 @@ export default { noApiKeys: 'This user has no API keys', group: 'Group', none: 'None', - groupChangedSuccess: 'Group updated successfully', - groupChangedWithGrant: 'Group updated. User auto-granted access to "{group}"', - groupChangeFailed: 'Failed to update group', noUsersYet: 'No users yet', createFirstUser: 'Create your first user to get started.', userCreated: 'User created successfully', @@ -1636,19 +1613,7 @@ export default { sessions: { full: 'Active sessions full, new sessions must wait (idle timeout: {idle} min)', normal: 'Active sessions normal (idle timeout: {idle} min)' - }, - rpm: { - full: 'RPM limit reached', - warning: 'RPM approaching limit', - normal: 'RPM normal', - tieredNormal: 'RPM limit (Tiered) - Normal', - tieredWarning: 'RPM limit (Tiered) - Approaching limit', - tieredStickyOnly: 'RPM limit (Tiered) - Sticky only | Buffer: {buffer}', - tieredBlocked: 'RPM limit (Tiered) - Blocked | Buffer: {buffer}', - stickyExemptNormal: 'RPM limit (Sticky Exempt) - Normal', - stickyExemptWarning: 'RPM limit (Sticky Exempt) - Approaching limit', - stickyExemptOver: 'RPM limit (Sticky Exempt) - Over limit, sticky only' - }, + } }, tempUnschedulable: { title: 'Temp Unschedulable', @@ -1719,6 +1684,20 @@ export default { title: 'Bulk Edit Accounts', selectionInfo: '{count} account(s) selected. Only checked or filled fields will be updated; others stay unchanged.', + scopeTitle: 'Bulk Edit Scope', + scopeInfo: '{count} account(s) selected. Choose platform and account type first.', + onlySameTypeHint: + 'Only accounts in the same platform + type scope will be edited; other selected accounts stay unchanged.', + choosePlatform: 'Please choose a platform', + chooseType: 'Please choose an account type', + scopeMatched: 'Matched accounts: {count}', + scopeSummaryTitle: 'Selected Scope Summary', + scopeTargetPreview: 'This update will apply to {matched} of {selected} selected account(s).', + scopeExcludedHint: '{count} selected account(s) are out of current scope and will not be edited.', + unsupportedScope: 'Unsupported scope', + openScopedEditor: 'Open Bulk Editor', + noScopedMatch: 'No editable accounts match the selected scope', + loadSelectionFailed: 'Failed to load selected accounts ({count}). Please retry.', baseUrlPlaceholder: 'https://api.anthropic.com or https://api.openai.com', baseUrlNotice: 'Applies to API Key accounts only; leave empty to keep existing value', submit: 'Update Accounts', @@ -1728,6 +1707,43 @@ export default { failed: 'Bulk update failed', noSelection: 'Please select accounts to edit', noFieldsSelected: 'Select at least one field to update', + templateTitle: 'Edit Templates', + templateScopeHint: 'Templates are scoped to {platform} / {type}.', + templateLoading: 'Loading templates...', + templateSelectLabel: 'Saved templates', + templateEmpty: 'No templates yet', + templateApply: 'Apply', + templateDelete: 'Delete', + templateNameLabel: 'Template name', + templateNamePlaceholder: 'e.g. OpenAI OAuth baseline', + templateShareScopeLabel: 'Share scope', + templateShareScopePrivate: 'Only me', + templateShareScopeTeam: 'All admins', + templateShareScopeGroups: 'Scoped groups', + templateShareGroupsLabel: 'Visible groups', + templateShareGroupsHint: 'Only appears when current bulk-edit scope intersects selected groups.', + templateShareGroupsRequired: 'Choose at least one group for group-shared template', + templateSave: 'Save as template', + templateNameRequired: 'Enter a template name first', + templateSaved: 'Template "{name}" saved', + templateApplied: 'Template "{name}" applied', + templateDeleted: 'Template "{name}" deleted', + templateDeleteConfirm: 'Delete template "{name}"? This action cannot be undone.', + templateStorageFailed: 'Failed to save template to local storage', + templateLoadFailed: 'Failed to load templates', + templateSaveFailed: 'Failed to save template', + templateDeleteFailed: 'Failed to delete template', + templateVersionTitle: 'Version history', + templateVersionHint: 'Each update snapshots previous template state.', + templateVersionLoading: 'Loading version history...', + templateVersionEmpty: 'No historical versions yet', + templateVersionUnknownTime: 'Unknown time', + templateVersionLoadFailed: 'Failed to load template version history', + templateRollback: 'Rollback', + templateRollbacking: 'Rolling back...', + templateRollbackConfirm: 'Rollback template "{name}" to {updatedAt}?', + templateRollbackSuccess: 'Template "{name}" rollback succeeded', + templateRollbackFailed: 'Failed to rollback template', mixedPlatformWarning: 'Selected accounts span multiple platforms ({platforms}). Model mapping presets shown are combined — ensure mappings are appropriate for each platform.' }, bulkDeleteTitle: 'Bulk Delete Accounts', @@ -1771,10 +1787,11 @@ export default { wsMode: 'WS mode', wsModeDesc: 'Only applies to the current OpenAI account type.', wsModeOff: 'Off (off)', - wsModeShared: 'Shared (shared)', - wsModeDedicated: 'Dedicated (dedicated)', + wsModeCtxPool: 'Context Pool (ctx_pool)', + wsModePassthrough: 'Passthrough (passthrough)', wsModeConcurrencyHint: 'When WS mode is enabled, account concurrency becomes the WS connection pool limit for this account.', + wsModePassthroughHint: 'Passthrough mode does not use the WS connection pool.', oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode', oauthResponsesWebsocketsV2Desc: 'Only applies to OpenAI OAuth. This account can use OpenAI WebSocket Mode only when enabled.', @@ -1863,27 +1880,6 @@ export default { idleTimeoutPlaceholder: '5', idleTimeoutHint: 'Sessions will be released after idle timeout' }, - rpmLimit: { - label: 'RPM Limit', - hint: 'Limit requests per minute to protect upstream accounts', - baseRpm: 'Base RPM', - baseRpmPlaceholder: '15', - baseRpmHint: 'Max requests per minute, 0 or empty means no limit', - strategy: 'RPM Strategy', - strategyTiered: 'Tiered Model', - strategyStickyExempt: 'Sticky Exempt', - strategyTieredHint: 'Green → Yellow → Sticky only → Blocked, progressive throttling', - strategyStickyExemptHint: 'Only sticky sessions allowed when over limit', - strategyHint: 'Tiered: gradually restrict when exceeded; Sticky Exempt: existing sessions unrestricted', - stickyBuffer: 'Sticky Buffer', - stickyBufferPlaceholder: 'Default: 20% of base RPM', - stickyBufferHint: 'Extra requests allowed for sticky sessions after exceeding base RPM. Leave empty to use default (20% of base RPM, min 1)', - userMsgQueue: 'User Message Rate Control', - userMsgQueueHint: 'Rate-limit user messages to avoid triggering upstream RPM limits', - umqModeOff: 'Off', - umqModeThrottle: 'Throttle', - umqModeSerialize: 'Serialize', - }, tlsFingerprint: { label: 'TLS Fingerprint Simulation', hint: 'Simulate Node.js/Claude Code client TLS fingerprint' @@ -2363,8 +2359,6 @@ export default { dataExportConfirm: 'Confirm Export', dataExported: 'Data exported successfully', dataExportFailed: 'Failed to export data', - copyProxyUrl: 'Copy Proxy URL', - urlCopied: 'Proxy URL copied', searchProxies: 'Search proxies...', allProtocols: 'All Protocols', allStatus: 'All Status', @@ -2378,7 +2372,6 @@ export default { name: 'Name', protocol: 'Protocol', address: 'Address', - auth: 'Auth', location: 'Location', status: 'Status', accounts: 'Accounts', @@ -3573,23 +3566,7 @@ export default { defaultBalance: 'Default Balance', defaultBalanceHint: 'Initial balance for new users', defaultConcurrency: 'Default Concurrency', - defaultConcurrencyHint: 'Maximum concurrent requests for new users', - defaultSubscriptions: 'Default Subscriptions', - defaultSubscriptionsHint: 'Auto-assign these subscriptions when a new user is created or registered', - addDefaultSubscription: 'Add Default Subscription', - defaultSubscriptionsEmpty: 'No default subscriptions configured.', - defaultSubscriptionsDuplicate: - 'Duplicate subscription group: {groupId}. Each group can only appear once.', - subscriptionGroup: 'Subscription Group', - subscriptionValidityDays: 'Validity (days)' - }, - claudeCode: { - title: 'Claude Code Settings', - description: 'Control Claude Code client access requirements', - minVersion: 'Minimum Version', - minVersionPlaceholder: 'e.g. 2.1.63', - minVersionHint: - 'Reject Claude Code clients below this version (semver format). Leave empty to disable version check.' + defaultConcurrencyHint: 'Maximum concurrent requests for new users' }, site: { title: 'Site Settings', @@ -3625,17 +3602,15 @@ export default { hideCcsImportButtonHint: 'When enabled, the "Import to CCS" button will be hidden on the API Keys page' }, purchase: { - title: 'Recharge / Subscription Page', - description: 'Show a "Recharge / Subscription" entry in the sidebar and open the configured URL in an iframe', - enabled: 'Show Recharge / Subscription Entry', + title: 'Purchase Page', + description: 'Show a "Purchase Subscription" entry in the sidebar and open the configured URL in an iframe', + enabled: 'Show Purchase Entry', enabledHint: 'Only shown in standard mode (not simple mode)', - url: 'Recharge / Subscription URL', + url: 'Purchase URL', urlPlaceholder: 'https://example.com/purchase', urlHint: 'Must be an absolute http(s) URL', iframeWarning: - '⚠️ iframe note: Some websites block embedding via X-Frame-Options or CSP (frame-ancestors). If the page is blank, provide an "Open in new tab" alternative.', - integrationDoc: 'Payment Integration Docs', - integrationDocHint: 'Covers endpoint specs, idempotency semantics, and code samples' + '⚠️ iframe note: Some websites block embedding via X-Frame-Options or CSP (frame-ancestors). If the page is blank, provide an "Open in new tab" alternative.' }, soraClient: { title: 'Sora Client', @@ -3643,27 +3618,6 @@ export default { enabled: 'Enable Sora Client', enabledHint: 'When enabled, the Sora entry will be shown in the sidebar for users to access Sora features' }, - customMenu: { - title: 'Custom Menu Pages', - description: 'Add custom iframe pages to the sidebar navigation. Each page can be visible to regular users or administrators.', - itemLabel: 'Menu Item #{n}', - name: 'Menu Name', - namePlaceholder: 'e.g. Help Center', - url: 'Page URL', - urlPlaceholder: 'https://example.com/page', - iconSvg: 'SVG Icon', - iconSvgPlaceholder: '...', - iconPreview: 'Icon Preview', - uploadSvg: 'Upload SVG', - removeSvg: 'Remove', - visibility: 'Visible To', - visibilityUser: 'Regular Users', - visibilityAdmin: 'Administrators', - add: 'Add Menu Item', - remove: 'Remove', - moveUp: 'Move Up', - moveDown: 'Move Down', - }, smtp: { title: 'SMTP Settings', description: 'Configure email sending for verification codes', @@ -3940,26 +3894,16 @@ export default { retry: 'Retry' }, - // Recharge / Subscription Page + // Purchase Subscription Page purchase: { - title: 'Recharge / Subscription', - description: 'Recharge balance or purchase subscription via the embedded page', + title: 'Purchase Subscription', + description: 'Purchase a subscription via the embedded page', openInNewTab: 'Open in new tab', notEnabledTitle: 'Feature not enabled', - notEnabledDesc: 'The administrator has not enabled the recharge/subscription entry. Please contact admin.', - notConfiguredTitle: 'Recharge / Subscription URL not configured', + notEnabledDesc: 'The administrator has not enabled the purchase page. Please contact admin.', + notConfiguredTitle: 'Purchase URL not configured', notConfiguredDesc: - 'The administrator enabled the entry but has not configured a recharge/subscription URL. Please contact admin.' - }, - - // Custom Page (iframe embed) - customPage: { - title: 'Custom Page', - openInNewTab: 'Open in new tab', - notFoundTitle: 'Page not found', - notFoundDesc: 'This custom page does not exist or has been removed.', - notConfiguredTitle: 'Page URL not configured', - notConfiguredDesc: 'The URL for this custom page has not been properly configured.', + 'The administrator enabled the entry but has not configured a purchase URL. Please contact admin.' }, // Announcements Page diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 397ecbb2f..5448a6c61 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -280,7 +280,7 @@ export default { logout: '退出登录', github: 'GitHub', mySubscriptions: '我的订阅', - buySubscription: '充值/订阅', + buySubscription: '购买订阅', docs: '文档', sora: 'Sora 创作' }, @@ -409,12 +409,9 @@ export default { day: '按天', hour: '按小时', modelDistribution: '模型分布', - groupDistribution: '分组使用分布', tokenUsageTrend: 'Token 使用趋势', noDataAvailable: '暂无数据', model: '模型', - group: '分组', - noGroup: '无分组', requests: '请求', tokens: 'Token', actual: '实际', @@ -506,7 +503,6 @@ export default { claudeCode: 'Claude Code', geminiCli: 'Gemini CLI', codexCli: 'Codex CLI', - codexCliWs: 'Codex CLI (WebSocket)', opencode: 'OpenCode' }, antigravity: { @@ -566,19 +562,6 @@ export default { resetQuotaConfirmMessage: '确定要将密钥 "{name}" 的已用额度(${used})重置为 0 吗?此操作不可撤销。', quotaResetSuccess: '额度重置成功', failedToResetQuota: '重置额度失败', - rateLimitColumn: '速率限制', - rateLimitSection: '速率限制', - resetUsage: '重置', - rateLimit5h: '5小时限额 (USD)', - rateLimit1d: '日限额 (USD)', - rateLimit7d: '7天限额 (USD)', - rateLimitHint: '设置此密钥在指定时间窗口内的最大消费额。0 = 无限制。', - rateLimitUsage: '速率限制用量', - resetRateLimitUsage: '重置速率限制用量', - resetRateLimitTitle: '确认重置速率限制', - resetRateLimitConfirmMessage: '确定要重置密钥 "{name}" 的速率限制用量吗?所有时间窗口的已用额度将归零。此操作不可撤销。', - rateLimitResetSuccess: '速率限制已重置', - failedToResetRateLimit: '重置速率限制失败', expiration: '密钥有效期', expiresInDays: '{days} 天', extendDays: '+{days} 天', @@ -862,12 +845,9 @@ export default { day: '按天', hour: '按小时', modelDistribution: '模型分布', - groupDistribution: '分组使用分布', tokenUsageTrend: 'Token 使用趋势', noDataAvailable: '暂无数据', model: '模型', - group: '分组', - noGroup: '无分组', requests: '请求', tokens: 'Token', cache: '缓存', @@ -1124,9 +1104,6 @@ export default { noApiKeys: '此用户暂无 API 密钥', group: '分组', none: '无', - groupChangedSuccess: '分组修改成功', - groupChangedWithGrant: '分组修改成功,已自动为用户添加「{group}」分组权限', - groupChangeFailed: '分组修改失败', noUsersYet: '暂无用户', createFirstUser: '创建您的第一个用户以开始使用系统', userCreated: '用户创建成功', @@ -1687,19 +1664,7 @@ export default { sessions: { full: '活跃会话已满,新会话需等待(空闲超时:{idle}分钟)', normal: '活跃会话正常(空闲超时:{idle}分钟)' - }, - rpm: { - full: '已达 RPM 上限', - warning: 'RPM 接近上限', - normal: 'RPM 正常', - tieredNormal: 'RPM 限制 (三区模型) - 正常', - tieredWarning: 'RPM 限制 (三区模型) - 接近阈值', - tieredStickyOnly: 'RPM 限制 (三区模型) - 仅粘性会话 | 缓冲区: {buffer}', - tieredBlocked: 'RPM 限制 (三区模型) - 已阻塞 | 缓冲区: {buffer}', - stickyExemptNormal: 'RPM 限制 (粘性豁免) - 正常', - stickyExemptWarning: 'RPM 限制 (粘性豁免) - 接近阈值', - stickyExemptOver: 'RPM 限制 (粘性豁免) - 超限,仅粘性会话' - }, + } }, clearRateLimit: '清除速率限制', testConnection: '测试连接', @@ -1866,6 +1831,19 @@ export default { bulkEdit: { title: '批量编辑账号', selectionInfo: '已选择 {count} 个账号。只更新您勾选或填写的字段,未勾选的字段保持不变。', + scopeTitle: '批量编辑范围', + scopeInfo: '已选择 {count} 个账号,请先选择平台和账号类型。', + onlySameTypeHint: '仅会批量编辑“同平台 + 同账号类型”的账号,其他已选账号不会被修改。', + choosePlatform: '请选择平台', + chooseType: '请选择账号类型', + scopeMatched: '符合条件账号:{count} 个', + scopeSummaryTitle: '选中范围统计', + scopeTargetPreview: '本次将编辑 {matched}/{selected} 个已选账号。', + scopeExcludedHint: '有 {count} 个已选账号不在当前范围内,不会被编辑。', + unsupportedScope: '当前范围暂不支持批量编辑', + openScopedEditor: '打开批量编辑器', + noScopedMatch: '当前筛选下没有可编辑账号,请重新选择', + loadSelectionFailed: '读取选中账号失败({count} 个),请重试', baseUrlPlaceholder: 'https://api.anthropic.com 或 https://api.openai.com', baseUrlNotice: '仅适用于 API Key 账号,留空则不修改', submit: '批量更新', @@ -1875,6 +1853,43 @@ export default { failed: '批量更新失败', noSelection: '请选择要编辑的账号', noFieldsSelected: '请至少选择一个要更新的字段', + templateTitle: '批量模板', + templateScopeHint: '模板仅作用于当前范围:{platform} / {type}', + templateLoading: '正在加载模板...', + templateSelectLabel: '已保存模板', + templateEmpty: '暂无模板', + templateApply: '应用', + templateDelete: '删除', + templateNameLabel: '模板名称', + templateNamePlaceholder: '例如:OpenAI OAuth 默认配置', + templateShareScopeLabel: '共享范围', + templateShareScopePrivate: '仅自己可见', + templateShareScopeTeam: '团队管理员可见', + templateShareScopeGroups: '按分组可见', + templateShareGroupsLabel: '可见分组', + templateShareGroupsHint: '仅当当前批量编辑命中分组与所选分组有交集时可见。', + templateShareGroupsRequired: '分组共享模板至少需要选择一个分组', + templateSave: '保存模板', + templateNameRequired: '请先输入模板名称', + templateSaved: '模板“{name}”已保存', + templateApplied: '已应用模板“{name}”', + templateDeleted: '模板“{name}”已删除', + templateDeleteConfirm: '确定删除模板“{name}”?该操作不可撤销。', + templateStorageFailed: '保存到本地模板失败', + templateLoadFailed: '加载模板失败', + templateSaveFailed: '保存模板失败', + templateDeleteFailed: '删除模板失败', + templateVersionTitle: '版本历史', + templateVersionHint: '每次更新都会自动快照当前模板配置。', + templateVersionLoading: '正在加载版本历史...', + templateVersionEmpty: '暂无历史版本', + templateVersionUnknownTime: '未知时间', + templateVersionLoadFailed: '加载模板版本历史失败', + templateRollback: '回滚', + templateRollbacking: '回滚中...', + templateRollbackConfirm: '确定将模板“{name}”回滚到 {updatedAt} 吗?', + templateRollbackSuccess: '模板“{name}”回滚成功', + templateRollbackFailed: '模板回滚失败', mixedPlatformWarning: '所选账号跨越多个平台({platforms})。显示的模型映射预设为合并结果——请确保映射对每个平台都适用。' }, bulkDeleteTitle: '批量删除账号', @@ -1920,9 +1935,10 @@ export default { wsMode: 'WS mode', wsModeDesc: '仅对当前 OpenAI 账号类型生效。', wsModeOff: '关闭(off)', - wsModeShared: '共享(shared)', - wsModeDedicated: '独享(dedicated)', + wsModeCtxPool: '上下文池(ctx_pool)', + wsModePassthrough: '透传(passthrough)', wsModeConcurrencyHint: '启用 WS mode 后,该账号并发数将作为该账号 WS 连接池上限。', + wsModePassthroughHint: 'passthrough 模式不使用 WS 连接池。', oauthResponsesWebsocketsV2: 'OAuth WebSocket Mode', oauthResponsesWebsocketsV2Desc: '仅对 OpenAI OAuth 生效。开启后该账号才允许使用 OpenAI WebSocket Mode 协议。', @@ -2006,27 +2022,6 @@ export default { idleTimeoutPlaceholder: '5', idleTimeoutHint: '会话空闲超时后自动释放' }, - rpmLimit: { - label: 'RPM 限制', - hint: '限制每分钟请求数量,保护上游账号', - baseRpm: '基础 RPM', - baseRpmPlaceholder: '15', - baseRpmHint: '每分钟最大请求数,0 或留空表示不限制', - strategy: 'RPM 策略', - strategyTiered: '三区模型', - strategyStickyExempt: '粘性豁免', - strategyTieredHint: '绿区→黄区→仅粘性→阻塞,逐步限流', - strategyStickyExemptHint: '超限后仅允许粘性会话', - strategyHint: '三区模型: 超限后逐步限制; 粘性豁免: 已有会话不受限', - stickyBuffer: '粘性缓冲区', - stickyBufferPlaceholder: '默认: base RPM 的 20%', - stickyBufferHint: '超过 base RPM 后,粘性会话额外允许的请求数。为空则使用默认值(base RPM 的 20%,最小为 1)', - userMsgQueue: '用户消息限速', - userMsgQueueHint: '对用户消息施加发送限制,避免触发上游 RPM 限制', - umqModeOff: '关闭', - umqModeThrottle: '软性限速', - umqModeSerialize: '串行队列', - }, tlsFingerprint: { label: 'TLS 指纹模拟', hint: '模拟 Node.js/Claude Code 客户端的 TLS 指纹' @@ -2477,7 +2472,6 @@ export default { name: '名称', protocol: '协议', address: '地址', - auth: '认证', location: '地理位置', status: '状态', accounts: '账号数', @@ -2505,8 +2499,6 @@ export default { allStatuses: '全部状态' }, // Additional keys used in ProxiesView - copyProxyUrl: '复制代理 URL', - urlCopied: '代理 URL 已复制', allProtocols: '全部协议', allStatus: '全部状态', searchProxies: '搜索代理...', @@ -3743,21 +3735,7 @@ export default { defaultBalance: '默认余额', defaultBalanceHint: '新用户的初始余额', defaultConcurrency: '默认并发数', - defaultConcurrencyHint: '新用户的最大并发请求数', - defaultSubscriptions: '默认订阅列表', - defaultSubscriptionsHint: '新用户创建或注册时自动分配这些订阅', - addDefaultSubscription: '添加默认订阅', - defaultSubscriptionsEmpty: '未配置默认订阅。新用户不会自动获得订阅套餐。', - defaultSubscriptionsDuplicate: '默认订阅存在重复分组:{groupId}。每个分组只能出现一次。', - subscriptionGroup: '订阅分组', - subscriptionValidityDays: '有效期(天)' - }, - claudeCode: { - title: 'Claude Code 设置', - description: '控制 Claude Code 客户端访问要求', - minVersion: '最低版本号', - minVersionPlaceholder: '例如 2.1.63', - minVersionHint: '拒绝低于此版本的 Claude Code 客户端请求(semver 格式)。留空则不检查版本。' + defaultConcurrencyHint: '新用户的最大并发请求数' }, site: { title: '站点设置', @@ -3795,17 +3773,15 @@ export default { hideCcsImportButtonHint: '启用后将在 API Keys 页面隐藏"导入 CCS"按钮' }, purchase: { - title: '充值/订阅页面', - description: '在侧边栏展示“充值/订阅”入口,并在页面内通过 iframe 打开指定链接', - enabled: '显示充值/订阅入口', + title: '购买订阅页面', + description: '在侧边栏展示”购买订阅”入口,并在页面内通过 iframe 打开指定链接', + enabled: '显示购买订阅入口', enabledHint: '仅在标准模式(非简单模式)下展示', - url: '充值/订阅页面 URL', + url: '购买页面 URL', urlPlaceholder: 'https://example.com/purchase', urlHint: '必须是完整的 http(s) 链接', iframeWarning: - '⚠️ iframe 提示:部分网站会通过 X-Frame-Options 或 CSP(frame-ancestors)禁止被 iframe 嵌入,出现空白时可引导用户使用”新窗口打开”。', - integrationDoc: '支付集成文档', - integrationDocHint: '包含接口说明、幂等语义及示例代码' + '⚠️ iframe 提示:部分网站会通过 X-Frame-Options 或 CSP(frame-ancestors)禁止被 iframe 嵌入,出现空白时可引导用户使用”新窗口打开”。' }, soraClient: { title: 'Sora 客户端', @@ -3813,27 +3789,6 @@ export default { enabled: '启用 Sora 客户端', enabledHint: '开启后,侧边栏将显示 Sora 入口,用户可访问 Sora 功能' }, - customMenu: { - title: '自定义菜单页面', - description: '添加自定义 iframe 页面到侧边栏导航。每个页面可以设置为普通用户或管理员可见。', - itemLabel: '菜单项 #{n}', - name: '菜单名称', - namePlaceholder: '如:帮助中心', - url: '页面 URL', - urlPlaceholder: 'https://example.com/page', - iconSvg: 'SVG 图标', - iconSvgPlaceholder: '...', - iconPreview: '图标预览', - uploadSvg: '上传 SVG', - removeSvg: '清除', - visibility: '可见角色', - visibilityUser: '普通用户', - visibilityAdmin: '管理员', - add: '添加菜单项', - remove: '删除', - moveUp: '上移', - moveDown: '下移', - }, smtp: { title: 'SMTP 设置', description: '配置用于发送验证码的邮件服务', @@ -4109,25 +4064,15 @@ export default { retry: '重试' }, - // Recharge / Subscription Page + // Purchase Subscription Page purchase: { - title: '充值/订阅', - description: '通过内嵌页面完成充值/订阅', + title: '购买订阅', + description: '通过内嵌页面完成订阅购买', openInNewTab: '新窗口打开', notEnabledTitle: '该功能未开启', - notEnabledDesc: '管理员暂未开启充值/订阅入口,请联系管理员。', - notConfiguredTitle: '充值/订阅链接未配置', - notConfiguredDesc: '管理员已开启入口,但尚未配置充值/订阅链接,请联系管理员。' - }, - - // Custom Page (iframe embed) - customPage: { - title: '自定义页面', - openInNewTab: '新窗口打开', - notFoundTitle: '页面不存在', - notFoundDesc: '该自定义页面不存在或已被删除。', - notConfiguredTitle: '页面链接未配置', - notConfiguredDesc: '该自定义页面的 URL 未正确配置。', + notEnabledDesc: '管理员暂未开启购买订阅入口,请联系管理员。', + notConfiguredTitle: '购买链接未配置', + notConfiguredDesc: '管理员已开启入口,但尚未配置购买订阅链接,请联系管理员。' }, // Announcements Page diff --git a/frontend/src/router/index.ts b/frontend/src/router/index.ts index 08f492d4d..882b31baa 100644 --- a/frontend/src/router/index.ts +++ b/frontend/src/router/index.ts @@ -340,18 +340,6 @@ const routes: RouteRecordRaw[] = [ descriptionKey: 'admin.promo.description' } }, - { - path: '/admin/data-management', - name: 'AdminDataManagement', - component: () => import('@/views/admin/DataManagementView.vue'), - meta: { - requiresAuth: true, - requiresAdmin: true, - title: 'Data Management', - titleKey: 'admin.dataManagement.title', - descriptionKey: 'admin.dataManagement.description' - } - }, { path: '/admin/settings', name: 'AdminSettings', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 6e5aa3020..d77b67ee6 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -75,15 +75,6 @@ export interface SendVerifyCodeResponse { countdown: number } -export interface CustomMenuItem { - id: string - label: string - icon_svg: string - url: string - visibility: 'user' | 'admin' - sort_order: number -} - export interface PublicSettings { registration_enabled: boolean email_verify_enabled: boolean @@ -102,7 +93,6 @@ export interface PublicSettings { hide_ccs_import_button: boolean purchase_subscription_enabled: boolean purchase_subscription_url: string - custom_menu_items: CustomMenuItem[] linuxdo_oauth_enabled: boolean sora_client_enabled: boolean version: string @@ -421,15 +411,6 @@ export interface ApiKey { created_at: string updated_at: string group?: Group - rate_limit_5h: number - rate_limit_1d: number - rate_limit_7d: number - usage_5h: number - usage_1d: number - usage_7d: number - window_5h_start: string | null - window_1d_start: string | null - window_7d_start: string | null } export interface CreateApiKeyRequest { @@ -440,9 +421,6 @@ export interface CreateApiKeyRequest { ip_blacklist?: string[] quota?: number // Quota limit in USD (0 = unlimited) expires_in_days?: number // Days until expiry (null = never expires) - rate_limit_5h?: number - rate_limit_1d?: number - rate_limit_7d?: number } export interface UpdateApiKeyRequest { @@ -454,10 +432,6 @@ export interface UpdateApiKeyRequest { quota?: number // Quota limit in USD (null = no change, 0 = unlimited) expires_at?: string | null // Expiration time (null = no change) reset_quota?: boolean // Reset quota_used to 0 - rate_limit_5h?: number - rate_limit_1d?: number - rate_limit_7d?: number - reset_rate_limit_usage?: boolean } export interface CreateGroupRequest { @@ -687,12 +661,6 @@ export interface Account { max_sessions?: number | null session_idle_timeout_minutes?: number | null - // RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效) - base_rpm?: number | null - rpm_strategy?: string | null - rpm_sticky_buffer?: number | null - user_msg_queue_mode?: string | null // "serialize" | "throttle" | null - // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) enable_tls_fingerprint?: boolean | null @@ -707,7 +675,6 @@ export interface Account { // 运行时状态(仅当启用对应限制时返回) current_window_cost?: number | null // 当前窗口费用 active_sessions?: number | null // 当前活跃会话数 - current_rpm?: number | null // 当前分钟 RPM 计数 } // Account Usage types @@ -1115,7 +1082,7 @@ export interface ModelStat { export interface GroupStat { group_id: number - group_name: string + group_name: string | null requests: number total_tokens: number cost: number // 标准计费 diff --git a/frontend/src/utils/__tests__/openaiWsMode.spec.ts b/frontend/src/utils/__tests__/openaiWsMode.spec.ts index 39f21aef6..6895332ef 100644 --- a/frontend/src/utils/__tests__/openaiWsMode.spec.ts +++ b/frontend/src/utils/__tests__/openaiWsMode.spec.ts @@ -1,31 +1,34 @@ import { describe, expect, it } from 'vitest' import { - OPENAI_WS_MODE_DEDICATED, + OPENAI_WS_MODE_CTX_POOL, OPENAI_WS_MODE_OFF, - OPENAI_WS_MODE_SHARED, + OPENAI_WS_MODE_PASSTHROUGH, isOpenAIWSModeEnabled, normalizeOpenAIWSMode, openAIWSModeFromEnabled, + resolveOpenAIWSModeConcurrencyHintKey, resolveOpenAIWSModeFromExtra } from '@/utils/openaiWsMode' describe('openaiWsMode utils', () => { it('normalizes mode values', () => { expect(normalizeOpenAIWSMode('off')).toBe(OPENAI_WS_MODE_OFF) - expect(normalizeOpenAIWSMode(' Shared ')).toBe(OPENAI_WS_MODE_SHARED) - expect(normalizeOpenAIWSMode('DEDICATED')).toBe(OPENAI_WS_MODE_DEDICATED) + expect(normalizeOpenAIWSMode('CTX_POOL')).toBe(OPENAI_WS_MODE_CTX_POOL) + expect(normalizeOpenAIWSMode(' passthrough ')).toBe(OPENAI_WS_MODE_PASSTHROUGH) + expect(normalizeOpenAIWSMode(' Shared ')).toBeNull() + expect(normalizeOpenAIWSMode('DEDICATED')).toBeNull() expect(normalizeOpenAIWSMode('invalid')).toBeNull() }) it('maps legacy enabled flag to mode', () => { - expect(openAIWSModeFromEnabled(true)).toBe(OPENAI_WS_MODE_SHARED) + expect(openAIWSModeFromEnabled(true)).toBe(OPENAI_WS_MODE_CTX_POOL) expect(openAIWSModeFromEnabled(false)).toBe(OPENAI_WS_MODE_OFF) expect(openAIWSModeFromEnabled('true')).toBeNull() }) it('resolves by mode key first, then enabled, then fallback enabled keys', () => { const extra = { - openai_oauth_responses_websockets_v2_mode: 'dedicated', + openai_oauth_responses_websockets_v2_mode: 'ctx_pool', openai_oauth_responses_websockets_v2_enabled: false, responses_websockets_v2_enabled: false } @@ -34,7 +37,7 @@ describe('openaiWsMode utils', () => { enabledKey: 'openai_oauth_responses_websockets_v2_enabled', fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'] }) - expect(mode).toBe(OPENAI_WS_MODE_DEDICATED) + expect(mode).toBe(OPENAI_WS_MODE_CTX_POOL) }) it('falls back to default when nothing is present', () => { @@ -47,9 +50,62 @@ describe('openaiWsMode utils', () => { expect(mode).toBe(OPENAI_WS_MODE_OFF) }) - it('treats off as disabled and shared/dedicated as enabled', () => { + it('treats off as disabled and non-off modes as enabled', () => { expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_OFF)).toBe(false) - expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_SHARED)).toBe(true) - expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_DEDICATED)).toBe(true) + expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_CTX_POOL)).toBe(true) + expect(isOpenAIWSModeEnabled(OPENAI_WS_MODE_PASSTHROUGH)).toBe(true) + }) + + it('resolves concurrency hint key by mode', () => { + expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_OFF)).toBe('admin.accounts.openai.wsModeConcurrencyHint') + expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_CTX_POOL)).toBe('admin.accounts.openai.wsModeConcurrencyHint') + expect(resolveOpenAIWSModeConcurrencyHintKey(OPENAI_WS_MODE_PASSTHROUGH)).toBe( + 'admin.accounts.openai.wsModePassthroughHint' + ) + }) + + it('resolves passthrough from modeKey', () => { + const extra = { + openai_oauth_responses_websockets_v2_mode: 'passthrough', + openai_oauth_responses_websockets_v2_enabled: true + } + const mode = resolveOpenAIWSModeFromExtra(extra, { + modeKey: 'openai_oauth_responses_websockets_v2_mode', + enabledKey: 'openai_oauth_responses_websockets_v2_enabled' + }) + expect(mode).toBe(OPENAI_WS_MODE_PASSTHROUGH) + }) + + it('resolves passthrough from apikey modeKey', () => { + const extra = { + openai_apikey_responses_websockets_v2_mode: 'passthrough' + } + const mode = resolveOpenAIWSModeFromExtra(extra, { + modeKey: 'openai_apikey_responses_websockets_v2_mode', + enabledKey: 'openai_apikey_responses_websockets_v2_enabled' + }) + expect(mode).toBe(OPENAI_WS_MODE_PASSTHROUGH) + }) + + it('resolves enabled fallback when mode key is absent', () => { + const extra = { + openai_oauth_responses_websockets_v2_enabled: true + } + const mode = resolveOpenAIWSModeFromExtra(extra, { + modeKey: 'openai_oauth_responses_websockets_v2_mode', + enabledKey: 'openai_oauth_responses_websockets_v2_enabled' + }) + expect(mode).toBe(OPENAI_WS_MODE_CTX_POOL) + }) + + it('resolves disabled fallback when enabled is false', () => { + const extra = { + openai_oauth_responses_websockets_v2_enabled: false + } + const mode = resolveOpenAIWSModeFromExtra(extra, { + modeKey: 'openai_oauth_responses_websockets_v2_mode', + enabledKey: 'openai_oauth_responses_websockets_v2_enabled' + }) + expect(mode).toBe(OPENAI_WS_MODE_OFF) }) }) diff --git a/frontend/src/utils/openaiWsMode.ts b/frontend/src/utils/openaiWsMode.ts index b3e9cc00a..2a8c4ee67 100644 --- a/frontend/src/utils/openaiWsMode.ts +++ b/frontend/src/utils/openaiWsMode.ts @@ -1,16 +1,16 @@ export const OPENAI_WS_MODE_OFF = 'off' -export const OPENAI_WS_MODE_SHARED = 'shared' -export const OPENAI_WS_MODE_DEDICATED = 'dedicated' +export const OPENAI_WS_MODE_CTX_POOL = 'ctx_pool' +export const OPENAI_WS_MODE_PASSTHROUGH = 'passthrough' export type OpenAIWSMode = | typeof OPENAI_WS_MODE_OFF - | typeof OPENAI_WS_MODE_SHARED - | typeof OPENAI_WS_MODE_DEDICATED + | typeof OPENAI_WS_MODE_CTX_POOL + | typeof OPENAI_WS_MODE_PASSTHROUGH const OPENAI_WS_MODES = new Set([ OPENAI_WS_MODE_OFF, - OPENAI_WS_MODE_SHARED, - OPENAI_WS_MODE_DEDICATED + OPENAI_WS_MODE_CTX_POOL, + OPENAI_WS_MODE_PASSTHROUGH ]) export interface ResolveOpenAIWSModeOptions { @@ -31,13 +31,22 @@ export const normalizeOpenAIWSMode = (mode: unknown): OpenAIWSMode | null => { export const openAIWSModeFromEnabled = (enabled: unknown): OpenAIWSMode | null => { if (typeof enabled !== 'boolean') return null - return enabled ? OPENAI_WS_MODE_SHARED : OPENAI_WS_MODE_OFF + return enabled ? OPENAI_WS_MODE_CTX_POOL : OPENAI_WS_MODE_OFF } export const isOpenAIWSModeEnabled = (mode: OpenAIWSMode): boolean => { return mode !== OPENAI_WS_MODE_OFF } +export const resolveOpenAIWSModeConcurrencyHintKey = ( + mode: OpenAIWSMode +): 'admin.accounts.openai.wsModeConcurrencyHint' | 'admin.accounts.openai.wsModePassthroughHint' => { + if (mode === OPENAI_WS_MODE_PASSTHROUGH) { + return 'admin.accounts.openai.wsModePassthroughHint' + } + return 'admin.accounts.openai.wsModeConcurrencyHint' +} + export const resolveOpenAIWSModeFromExtra = ( extra: Record | null | undefined, options: ResolveOpenAIWSModeOptions diff --git a/frontend/src/views/__tests__/accountsBulkEditScope.spec.ts b/frontend/src/views/__tests__/accountsBulkEditScope.spec.ts new file mode 100644 index 000000000..514ffccb2 --- /dev/null +++ b/frontend/src/views/__tests__/accountsBulkEditScope.spec.ts @@ -0,0 +1,100 @@ +import { describe, expect, it } from 'vitest' +import { + buildBulkEditPlatformOptions, + buildBulkEditScopeGroupedStats, + buildBulkEditTypeOptions, + countBulkEditScopedAccounts, + matchBulkEditScopedAccountIds +} from '../admin/accountsBulkEditScope' + +const candidates = [ + { id: 1, platform: 'openai', type: 'oauth' }, + { id: 2, platform: 'openai', type: 'apikey' }, + { id: 3, platform: 'openai', type: 'oauth' }, + { id: 4, platform: 'anthropic', type: 'apikey' }, + { id: 5, platform: 'gemini', type: 'oauth' } +] + +describe('accountsBulkEditScope helpers', () => { + it('builds platform options with default item and unique platforms', () => { + const options = buildBulkEditPlatformOptions( + candidates, + '请选择平台', + (platform) => `label:${platform}` + ) + + expect(options).toEqual([ + { value: '', label: '请选择平台' }, + { value: 'openai', label: 'label:openai' }, + { value: 'anthropic', label: 'label:anthropic' }, + { value: 'gemini', label: 'label:gemini' } + ]) + }) + + it('returns only default type option when platform is empty', () => { + const options = buildBulkEditTypeOptions(candidates, '', '请选择类型', (type) => type) + expect(options).toEqual([{ value: '', label: '请选择类型' }]) + }) + + it('builds type options for selected platform with deduplication', () => { + const options = buildBulkEditTypeOptions( + candidates, + 'openai', + '请选择类型', + (type) => `type:${type}` + ) + + expect(options).toEqual([ + { value: '', label: '请选择类型' }, + { value: 'oauth', label: 'type:oauth' }, + { value: 'apikey', label: 'type:apikey' } + ]) + }) + + it('supports option meta builders for platform/type options', () => { + const platformOptions = buildBulkEditPlatformOptions( + candidates, + '请选择平台', + (platform) => platform, + (_platform, count) => ({ label: `count:${count}` }) + ) + expect(platformOptions).toEqual([ + { value: '', label: '请选择平台' }, + { value: 'openai', label: 'count:3' }, + { value: 'anthropic', label: 'count:1' }, + { value: 'gemini', label: 'count:1' } + ]) + + const typeOptions = buildBulkEditTypeOptions( + candidates, + 'openai', + '请选择类型', + (type) => type, + (type, count) => ({ + label: `${type} (${count})`, + disabled: type === 'apikey' + }) + ) + expect(typeOptions).toEqual([ + { value: '', label: '请选择类型' }, + { value: 'oauth', label: 'oauth (2)', disabled: false }, + { value: 'apikey', label: 'apikey (1)', disabled: true } + ]) + }) + + it('matches scoped account IDs and count', () => { + expect(matchBulkEditScopedAccountIds(candidates, 'openai', 'oauth')).toEqual([1, 3]) + expect(countBulkEditScopedAccounts(candidates, 'openai', 'oauth')).toBe(2) + expect(matchBulkEditScopedAccountIds(candidates, 'sora', 'apikey')).toEqual([]) + expect(countBulkEditScopedAccounts(candidates, 'sora', 'apikey')).toBe(0) + }) + + it('builds grouped stats for scope preview', () => { + expect(buildBulkEditScopeGroupedStats(candidates)).toEqual([ + { key: 'anthropic:apikey', platform: 'anthropic', type: 'apikey', count: 1 }, + { key: 'gemini:oauth', platform: 'gemini', type: 'oauth', count: 1 }, + { key: 'openai:apikey', platform: 'openai', type: 'apikey', count: 1 }, + { key: 'openai:oauth', platform: 'openai', type: 'oauth', count: 2 } + ]) + }) +}) diff --git a/frontend/src/views/admin/AccountsView.vue b/frontend/src/views/admin/AccountsView.vue index defcd4346..1cd223bd6 100644 --- a/frontend/src/views/admin/AccountsView.vue +++ b/frontend/src/views/admin/AccountsView.vue @@ -115,6 +115,14 @@ {{ selIds.length ? t('admin.accounts.dataExportSelected') : t('admin.accounts.dataExport') }} +
- -