diff --git a/README.md b/README.md
index 1e2f22907..8804ec306 100644
--- a/README.md
+++ b/README.md
@@ -54,7 +54,34 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
## Documentation
- Dependency Security: `docs/dependency-security.md`
-- Admin Payment Integration API: `docs/ADMIN_PAYMENT_INTEGRATION_API.md`
+
+---
+
+## Codex CLI WebSocket v2 Example
+
+To enable OpenAI WebSocket Mode v2 in Codex CLI with Sub2API, add the following to `~/.codex/config.toml`:
+
+```toml
+model_provider = "aicodx2api"
+model = "gpt-5.3-codex"
+review_model = "gpt-5.3-codex"
+model_reasoning_effort = "xhigh"
+disable_response_storage = true
+network_access = "enabled"
+windows_wsl_setup_acknowledged = true
+
+[model_providers.aicodx2api]
+name = "aicodx2api"
+base_url = "https://api.sub2api.ai"
+wire_api = "responses"
+supports_websockets = true
+requires_openai_auth = true
+
+[features]
+responses_websockets_v2 = true
+```
+
+After updating the config, restart Codex CLI.
---
diff --git a/README_CN.md b/README_CN.md
index 9da089b74..22a772b5a 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -62,6 +62,34 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
+## Codex CLI 开启 OpenAI WebSocket Mode v2 示例配置
+
+如需在 Codex CLI 中通过 Sub2API 启用 OpenAI WebSocket Mode v2,可将以下配置写入 `~/.codex/config.toml`:
+
+```toml
+model_provider = "aicodx2api"
+model = "gpt-5.3-codex"
+review_model = "gpt-5.3-codex"
+model_reasoning_effort = "xhigh"
+disable_response_storage = true
+network_access = "enabled"
+windows_wsl_setup_acknowledged = true
+
+[model_providers.aicodx2api]
+name = "aicodx2api"
+base_url = "https://api.sub2api.ai"
+wire_api = "responses"
+supports_websockets = true
+requires_openai_auth = true
+
+[features]
+responses_websockets_v2 = true
+```
+
+配置更新后,重启 Codex CLI 使其生效。
+
+---
+
## 部署方式
### 方式一:脚本安装(推荐)
diff --git a/backend/.golangci.yml b/backend/.golangci.yml
index 68b76751f..3ec692a84 100644
--- a/backend/.golangci.yml
+++ b/backend/.golangci.yml
@@ -5,7 +5,6 @@ linters:
enable:
- depguard
- errcheck
- - gosec
- govet
- ineffassign
- staticcheck
@@ -43,22 +42,6 @@ linters:
desc: "handler must not import gorm"
- pkg: github.com/redis/go-redis/v9
desc: "handler must not import redis"
- gosec:
- excludes:
- - G101
- - G103
- - G104
- - G109
- - G115
- - G201
- - G202
- - G301
- - G302
- - G304
- - G306
- - G404
- severity: high
- confidence: high
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
diff --git a/backend/.gosec.json b/backend/.gosec.json
new file mode 100644
index 000000000..7a8ccb6a1
--- /dev/null
+++ b/backend/.gosec.json
@@ -0,0 +1,5 @@
+{
+ "global": {
+ "exclude": "G704,G101,G103,G104,G109,G115,G201,G202,G301,G302,G304,G306,G404"
+ }
+}
diff --git a/backend/Dockerfile b/backend/Dockerfile
index aeb20fdb6..6db2b1756 100644
--- a/backend/Dockerfile
+++ b/backend/Dockerfile
@@ -1,4 +1,4 @@
-FROM golang:1.25.7-alpine
+FROM registry-1.docker.io/library/golang:1.25.7-alpine
WORKDIR /app
diff --git a/backend/THIRD_PARTY_NOTICES.md b/backend/THIRD_PARTY_NOTICES.md
new file mode 100644
index 000000000..0524259ad
--- /dev/null
+++ b/backend/THIRD_PARTY_NOTICES.md
@@ -0,0 +1,26 @@
+# Third-Party Notices
+
+## caddyserver/caddy
+
+- Project: https://github.com/caddyserver/caddy
+- License: Apache License 2.0
+- Copyright:
+ Copyright 2015 Matthew Holt and The Caddy Authors
+
+### Usage in this repository
+
+OpenAI WS v2 passthrough relay adopts the Caddy reverse proxy streaming architecture
+(bidirectional tunnel + convergence shutdown) and is adapted with minimal changes for
+`coder/websocket` frame-based forwarding.
+
+- Referenced commit:
+ `f283062d37c50627d53ca682ebae2ce219b35515`
+- Referenced upstream files:
+ - `modules/caddyhttp/reverseproxy/streaming.go`
+ - `modules/caddyhttp/reverseproxy/reverseproxy.go`
+- Local adaptation files:
+ - `backend/internal/service/openai_ws_v2/caddy_adapter.go`
+ - `backend/internal/service/openai_ws_v2/passthrough_relay.go`
+
+The adaptation preserves Apache-2.0 license obligations.
+
diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go
index bc0016939..2ff7358b1 100644
--- a/backend/cmd/jwtgen/main.go
+++ b/backend/cmd/jwtgen/main.go
@@ -33,7 +33,7 @@ func main() {
}()
userRepo := repository.NewUserRepository(client, sqlDB)
- authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
+ authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index 32844913e..6ba552b88 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.88
\ No newline at end of file
+0.1.87.4
diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go
index cbf89ba3b..5044f7ee0 100644
--- a/backend/cmd/server/wire.go
+++ b/backend/cmd/server/wire.go
@@ -210,9 +210,9 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
- {"OpenAIWSPool", func() error {
+ {"OpenAIWSCtxPool", func() error {
if openAIGateway != nil {
- openAIGateway.CloseOpenAIWSPool()
+ openAIGateway.CloseOpenAIWSCtxPool()
}
return nil
}},
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index cbeb9a693..5bc0d35c3 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -48,8 +48,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
redisClient := repository.ProvideRedis(configConfig)
refreshTokenCache := repository.NewRefreshTokenCache(redisClient)
settingRepository := repository.NewSettingRepository(client)
- groupRepository := repository.NewGroupRepository(client, db)
- settingService := service.ProvideSettingService(settingRepository, groupRepository, configConfig)
+ settingService := service.NewSettingService(settingRepository, configConfig)
emailCache := repository.NewEmailCache(redisClient)
emailService := service.NewEmailService(settingRepository, emailCache)
turnstileVerifier := repository.NewTurnstileVerifier()
@@ -60,15 +59,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
userSubscriptionRepository := repository.NewUserSubscriptionRepository(client)
apiKeyRepository := repository.NewAPIKeyRepository(client, db)
billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig)
+ groupRepository := repository.NewGroupRepository(client, db)
userGroupRateRepository := repository.NewUserGroupRateRepository(db)
apiKeyCache := repository.NewAPIKeyCache(redisClient)
apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig)
- apiKeyService.SetRateLimitCacheInvalidator(billingCache)
apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService)
promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator)
- subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
- authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService, subscriptionService)
+ authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService)
userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator, billingCache)
+ subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService, client, configConfig)
redeemCache := repository.NewRedeemCache(redisClient)
redeemService := service.NewRedeemService(redeemCodeRepository, userRepository, subscriptionService, redeemCache, billingCacheService, client, apiKeyAuthCacheInvalidator)
secretEncryptor, err := repository.NewAESEncryptor(configConfig)
@@ -104,7 +103,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
- adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService)
+ adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
@@ -139,8 +138,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig)
crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig)
sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig)
- rpmCache := repository.NewRPMCache(redisClient)
- accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, rpmCache, compositeTokenCacheInvalidator)
+ accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator)
adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService)
dataManagementService := service.NewDataManagementService()
dataManagementHandler := admin.NewDataManagementHandler(dataManagementService)
@@ -162,7 +160,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
deferredService := service.ProvideDeferredService(accountRepository, timingWheelService)
claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService)
digestSessionStore := service.NewDigestSessionStore()
- gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore)
+ gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, digestSessionStore)
openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService)
openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider)
geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig)
@@ -197,9 +195,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
- userMsgQueueCache := repository.NewUserMsgQueueCache(redisClient)
- userMessageQueueService := service.ProvideUserMessageQueueService(userMsgQueueCache, rpmCache, configConfig)
- gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, userMessageQueueService, configConfig, settingService)
+ gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, nil, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository)
soraMediaStorage := service.ProvideSoraMediaStorage(configConfig)
@@ -396,9 +392,9 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
- {"OpenAIWSPool", func() error {
+ {"OpenAIWSCtxPool", func() error {
if openAIGateway != nil {
- openAIGateway.CloseOpenAIWSPool()
+ openAIGateway.CloseOpenAIWSCtxPool()
}
return nil
}},
diff --git a/backend/go.mod b/backend/go.mod
index ab76258a2..08c4e26f0 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -109,7 +109,6 @@ require (
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
- github.com/google/subcommands v1.2.0 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/hashicorp/hcl/v2 v2.18.1 // indirect
@@ -178,7 +177,6 @@ require (
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
- golang.org/x/tools v0.41.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect
diff --git a/backend/go.sum b/backend/go.sum
index 32e389a76..98914a83e 100644
--- a/backend/go.sum
+++ b/backend/go.sum
@@ -182,8 +182,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
-github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
-github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4=
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index c1f54ab69..365cb0f7c 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -273,13 +273,8 @@ type CSPConfig struct {
}
type ProxyFallbackConfig struct {
- // AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。
- // 仅影响以下非 AI 账号连接的辅助服务:
- // - GitHub Release 更新检查
- // - 定价数据拉取
- // 不影响 AI 账号网关连接(Claude/OpenAI/Gemini/Antigravity),
- // 这些关键路径的代理失败始终返回错误,不会回退直连。
- // 默认 false:避免因代理配置错误导致服务器真实 IP 泄露。
+ // AllowDirectOnError 当代理初始化失败时是否允许回退直连。
+ // 默认 false:避免因代理配置错误导致 IP 泄露/关联。
AllowDirectOnError bool `mapstructure:"allow_direct_on_error"`
}
@@ -379,6 +374,8 @@ type GatewayConfig struct {
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
+ // OpenAIHTTP2: OpenAI HTTP 上游协议策略(默认启用 HTTP/2,可按代理能力回退 HTTP/1.1)
+ OpenAIHTTP2 GatewayOpenAIHTTP2Config `mapstructure:"openai_http2"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
@@ -511,12 +508,27 @@ func (c *UserMessageQueueConfig) GetEffectiveMode() string {
return ""
}
+// GatewayOpenAIHTTP2Config OpenAI HTTP 上游协议配置。
+// 默认启用 HTTP/2,多路复用提升并发效率;在部分代理不兼容时按策略回退 HTTP/1.1。
+type GatewayOpenAIHTTP2Config struct {
+ // Enabled: 是否启用 OpenAI HTTP/2 优先策略
+ Enabled bool `mapstructure:"enabled"`
+ // AllowProxyFallbackToHTTP1: 代理不兼容 HTTP/2 时是否允许回退 HTTP/1.1
+ AllowProxyFallbackToHTTP1 bool `mapstructure:"allow_proxy_fallback_to_http1"`
+ // FallbackErrorThreshold: 在窗口期内触发回退所需的连续错误次数
+ FallbackErrorThreshold int `mapstructure:"fallback_error_threshold"`
+ // FallbackWindowSeconds: 连续错误计数窗口(秒)
+ FallbackWindowSeconds int `mapstructure:"fallback_window_seconds"`
+ // FallbackTTLSeconds: 进入 HTTP/1.1 回退态后的持续时间(秒)
+ FallbackTTLSeconds int `mapstructure:"fallback_ttl_seconds"`
+}
+
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
type GatewayOpenAIWSConfig struct {
- // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
+ // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 true;关闭时保持 legacy 行为)
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
- // IngressModeDefault: ingress 默认模式(off/shared/dedicated)
+ // IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough)
IngressModeDefault string `mapstructure:"ingress_mode_default"`
// Enabled: 全局总开关(默认 true)
Enabled bool `mapstructure:"enabled"`
@@ -554,12 +566,13 @@ type GatewayOpenAIWSConfig struct {
// OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor))
OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"`
// APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor))
- APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"`
- DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
- ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
- WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
- PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"`
- QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"`
+ APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"`
+ DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
+ ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
+ ClientReadIdleTimeoutSeconds int `mapstructure:"client_read_idle_timeout_seconds"`
+ WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
+ PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"`
+ QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"`
// EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数)
EventFlushBatchSize int `mapstructure:"event_flush_batch_size"`
// EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发
@@ -579,6 +592,11 @@ type GatewayOpenAIWSConfig struct {
// PayloadLogSampleRate: payload_schema 日志采样率(0-1)
PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"`
+ // UpstreamConnMaxAgeSeconds: 上游 WebSocket 连接最大存活时间(秒)。
+ // OpenAI 在 60 分钟后强制断开连接,此参数控制主动轮换阈值。
+ // 默认 3300(55 分钟);设为 0 则禁用超龄轮换。
+ UpstreamConnMaxAgeSeconds int `mapstructure:"upstream_conn_max_age_seconds"`
+
// 账号调度与粘连参数
LBTopK int `mapstructure:"lb_top_k"`
// StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL
@@ -595,6 +613,32 @@ type GatewayOpenAIWSConfig struct {
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
+
+ // SchedulerP2CEnabled: 启用 P2C(Power-of-Two-Choices)选择算法替代 Top-K 加权采样
+ SchedulerP2CEnabled bool `mapstructure:"scheduler_p2c_enabled"`
+
+ // Softmax 温度采样:替代线性平移的概率选择策略
+ SchedulerSoftmaxEnabled bool `mapstructure:"scheduler_softmax_enabled"`
+ SchedulerSoftmaxTemperature float64 `mapstructure:"scheduler_softmax_temperature"`
+
+ // 账号级熔断器
+ SchedulerCircuitBreakerEnabled bool `mapstructure:"scheduler_circuit_breaker_enabled"`
+ SchedulerCircuitBreakerFailThreshold int `mapstructure:"scheduler_circuit_breaker_fail_threshold"`
+ SchedulerCircuitBreakerCooldownSec int `mapstructure:"scheduler_circuit_breaker_cooldown_sec"`
+ SchedulerCircuitBreakerHalfOpenMax int `mapstructure:"scheduler_circuit_breaker_half_open_max"`
+
+ // 条件性 Sticky Session 释放:当粘连账号不健康时主动释放,回退到负载均衡
+ StickyReleaseEnabled bool `mapstructure:"sticky_release_enabled"`
+ StickyReleaseErrorThreshold float64 `mapstructure:"sticky_release_error_threshold"`
+
+ // Per-model TTFT tracking
+ SchedulerPerModelTTFTEnabled bool `mapstructure:"scheduler_per_model_ttft_enabled"`
+ SchedulerPerModelTTFTMaxModels int `mapstructure:"scheduler_per_model_ttft_max_models"`
+
+ // SchedulerTrendEnabled: 启用负载趋势预测(线性回归外推),在打分时对 loadFactor 施加趋势修正
+ SchedulerTrendEnabled bool `mapstructure:"scheduler_trend_enabled"`
+ // SchedulerTrendMaxSlope: 趋势斜率归一化上限(每秒负载百分比变化率);0 或负数使用默认值 5.0
+ SchedulerTrendMaxSlope float64 `mapstructure:"scheduler_trend_max_slope"`
}
// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。
@@ -1172,9 +1216,6 @@ func setDefaults() {
viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
- // Security - disable direct fallback on proxy error
- viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
-
// Billing
viper.SetDefault("billing.circuit_breaker.enabled", true)
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
@@ -1332,8 +1373,8 @@ func setDefaults() {
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
viper.SetDefault("gateway.openai_ws.enabled", true)
- viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
- viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
+ viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", true)
+ viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
viper.SetDefault("gateway.openai_ws.force_http", false)
@@ -1364,7 +1405,8 @@ func setDefaults() {
viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2)
viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000)
viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2)
- viper.SetDefault("gateway.openai_ws.lb_top_k", 7)
+ viper.SetDefault("gateway.openai_ws.upstream_conn_max_age_seconds", 3300)
+ viper.SetDefault("gateway.openai_ws.lb_top_k", 999)
viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true)
viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true)
@@ -1376,6 +1418,12 @@ func setDefaults() {
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
+ // OpenAI HTTP upstream protocol strategy
+ viper.SetDefault("gateway.openai_http2.enabled", true)
+ viper.SetDefault("gateway.openai_http2.allow_proxy_fallback_to_http1", true)
+ viper.SetDefault("gateway.openai_http2.fallback_error_threshold", 2)
+ viper.SetDefault("gateway.openai_http2.fallback_window_seconds", 60)
+ viper.SetDefault("gateway.openai_http2.fallback_ttl_seconds", 600)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
@@ -1420,7 +1468,7 @@ func setDefaults() {
viper.SetDefault("gateway.usage_record.worker_count", 128)
viper.SetDefault("gateway.usage_record.queue_size", 16384)
viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5)
- viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample)
+ viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySync)
viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10)
viper.SetDefault("gateway.usage_record.auto_scale_enabled", true)
viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128)
@@ -1493,6 +1541,9 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
+ // Security - proxy fallback
+ viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
+
// Subscription Maintenance (bounded queue + worker pool)
viper.SetDefault("subscription_maintenance.worker_count", 2)
viper.SetDefault("subscription_maintenance.queue_size", 1024)
@@ -1971,6 +2022,15 @@ func (c *Config) Validate() error {
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
+ if c.Gateway.OpenAIHTTP2.FallbackErrorThreshold < 0 {
+ return fmt.Errorf("gateway.openai_http2.fallback_error_threshold must be non-negative")
+ }
+ if c.Gateway.OpenAIHTTP2.FallbackWindowSeconds < 0 {
+ return fmt.Errorf("gateway.openai_http2.fallback_window_seconds must be non-negative")
+ }
+ if c.Gateway.OpenAIHTTP2.FallbackTTLSeconds < 0 {
+ return fmt.Errorf("gateway.openai_http2.fallback_ttl_seconds must be non-negative")
+ }
// 兼容旧键 sticky_previous_response_ttl_seconds
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
@@ -2039,11 +2099,16 @@ func (c *Config) Validate() error {
if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 {
return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative")
}
+ if c.Gateway.OpenAIWS.ResponsesWebsockets && !c.Gateway.OpenAIWS.ResponsesWebsocketsV2 {
+ return fmt.Errorf("gateway.openai_ws.responses_websockets (v1) is not supported; enable gateway.openai_ws.responses_websockets_v2")
+ }
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
switch mode {
- case "off", "shared", "dedicated":
+ case "off", "ctx_pool", "passthrough":
+ case "shared", "dedicated":
+ slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode)
default:
- return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
+ return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough")
}
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
@@ -2056,6 +2121,9 @@ func (c *Config) Validate() error {
if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 {
return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]")
}
+ if c.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds < 0 {
+ return fmt.Errorf("gateway.openai_ws.upstream_conn_max_age_seconds must be non-negative")
+ }
if c.Gateway.OpenAIWS.LBTopK <= 0 {
return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive")
}
@@ -2083,6 +2151,22 @@ func (c *Config) Validate() error {
if weightSum <= 0 {
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero")
}
+ // Validate new scheduler/sticky-release config ranges.
+ if c.Gateway.OpenAIWS.SchedulerSoftmaxTemperature < 0 {
+ return fmt.Errorf("gateway.openai_ws.scheduler_softmax_temperature must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.StickyReleaseErrorThreshold < 0 || c.Gateway.OpenAIWS.StickyReleaseErrorThreshold > 1 {
+ return fmt.Errorf("gateway.openai_ws.sticky_release_error_threshold must be within [0,1]")
+ }
+ if c.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold < 0 {
+ return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_fail_threshold must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec < 0 {
+ return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_cooldown_sec must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax < 0 {
+ return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_half_open_max must be non-negative")
+ }
if c.Gateway.MaxLineSize < 0 {
return fmt.Errorf("gateway.max_line_size must be non-negative")
}
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index e3b592e2c..480a17c16 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -150,11 +150,11 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" {
t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict")
}
- if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
- t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
+ if !cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
+ t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = false, want true")
}
- if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
- t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
+ if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
+ t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
}
}
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
},
{
- name: "ingress_mode_default 必须为 off|shared|dedicated",
+ name: "ingress_mode_default 必须为 off|ctx_pool|passthrough",
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
wantErr: "gateway.openai_ws.ingress_mode_default",
},
@@ -1387,6 +1387,14 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 },
wantErr: "gateway.openai_ws.retry_total_budget_ms",
},
+ {
+ name: "responses_websockets v1-only 配置不允许",
+ mutate: func(c *Config) {
+ c.Gateway.OpenAIWS.ResponsesWebsockets = true
+ c.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
+ },
+ wantErr: "gateway.openai_ws.responses_websockets",
+ },
{
name: "lb_top_k 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 },
@@ -1637,8 +1645,8 @@ func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 {
t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds)
}
- if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample {
- t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample)
+ if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySync {
+ t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySync)
}
if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 {
t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent)
diff --git a/backend/internal/handler/admin/account_data_handler_test.go b/backend/internal/handler/admin/account_data_handler_test.go
index 285033a17..c8b04c2ae 100644
--- a/backend/internal/handler/admin/account_data_handler_test.go
+++ b/backend/internal/handler/admin/account_data_handler_test.go
@@ -64,7 +64,6 @@ func setupAccountDataRouter() (*gin.Engine, *stubAdminService) {
nil,
nil,
nil,
- nil,
)
router.GET("/api/v1/admin/accounts/data", h.ExportData)
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 98ead2841..e4a69032f 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -53,7 +53,6 @@ type AccountHandler struct {
concurrencyService *service.ConcurrencyService
crsSyncService *service.CRSSyncService
sessionLimitCache service.SessionLimitCache
- rpmCache service.RPMCache
tokenCacheInvalidator service.TokenCacheInvalidator
}
@@ -70,7 +69,6 @@ func NewAccountHandler(
concurrencyService *service.ConcurrencyService,
crsSyncService *service.CRSSyncService,
sessionLimitCache service.SessionLimitCache,
- rpmCache service.RPMCache,
tokenCacheInvalidator service.TokenCacheInvalidator,
) *AccountHandler {
return &AccountHandler{
@@ -85,7 +83,6 @@ func NewAccountHandler(
concurrencyService: concurrencyService,
crsSyncService: crsSyncService,
sessionLimitCache: sessionLimitCache,
- rpmCache: rpmCache,
tokenCacheInvalidator: tokenCacheInvalidator,
}
}
@@ -137,6 +134,7 @@ type BulkUpdateAccountsRequest struct {
RateMultiplier *float64 `json:"rate_multiplier"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
+ AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
@@ -157,7 +155,6 @@ type AccountWithConcurrency struct {
// 以下字段仅对 Anthropic OAuth/SetupToken 账号有效,且仅在启用相应功能时返回
CurrentWindowCost *float64 `json:"current_window_cost,omitempty"` // 当前窗口费用
ActiveSessions *int `json:"active_sessions,omitempty"` // 当前活跃会话数
- CurrentRPM *int `json:"current_rpm,omitempty"` // 当前分钟 RPM 计数
}
func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, account *service.Account) AccountWithConcurrency {
@@ -193,12 +190,6 @@ func (h *AccountHandler) buildAccountResponseWithRuntime(ctx context.Context, ac
}
}
}
-
- if h.rpmCache != nil && account.GetBaseRPM() > 0 {
- if rpm, err := h.rpmCache.GetRPM(ctx, account.ID); err == nil {
- item.CurrentRPM = &rpm
- }
- }
}
return item
@@ -241,10 +232,9 @@ func (h *AccountHandler) List(c *gin.Context) {
concurrencyCounts = make(map[int64]int)
}
- // 识别需要查询窗口费用、会话数和 RPM 的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
+ // 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能)
windowCostAccountIDs := make([]int64, 0)
sessionLimitAccountIDs := make([]int64, 0)
- rpmAccountIDs := make([]int64, 0)
sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置
for i := range accounts {
acc := &accounts[i]
@@ -256,24 +246,12 @@ func (h *AccountHandler) List(c *gin.Context) {
sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID)
sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute
}
- if acc.GetBaseRPM() > 0 {
- rpmAccountIDs = append(rpmAccountIDs, acc.ID)
- }
}
}
- // 并行获取窗口费用、活跃会话数和 RPM 计数
+ // 并行获取窗口费用和活跃会话数
var windowCosts map[int64]float64
var activeSessions map[int64]int
- var rpmCounts map[int64]int
-
- // 获取 RPM 计数(批量查询)
- if len(rpmAccountIDs) > 0 && h.rpmCache != nil {
- rpmCounts, _ = h.rpmCache.GetRPMBatch(c.Request.Context(), rpmAccountIDs)
- if rpmCounts == nil {
- rpmCounts = make(map[int64]int)
- }
- }
// 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置)
if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil {
@@ -334,13 +312,6 @@ func (h *AccountHandler) List(c *gin.Context) {
}
}
- // 添加 RPM 计数(仅当启用时)
- if rpmCounts != nil {
- if rpm, ok := rpmCounts[acc.ID]; ok {
- item.CurrentRPM = &rpm
- }
- }
-
result[i] = item
}
@@ -483,8 +454,6 @@ func (h *AccountHandler) Create(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
- // base_rpm 输入校验:负值归零,超过 10000 截断
- sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -554,8 +523,6 @@ func (h *AccountHandler) Update(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
- // base_rpm 输入校验:负值归零,超过 10000 截断
- sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -938,9 +905,6 @@ func (h *AccountHandler) BatchCreate(c *gin.Context) {
continue
}
- // base_rpm 输入校验:负值归零,超过 10000 截断
- sanitizeExtraBaseRPM(item.Extra)
-
skipCheck := item.ConfirmMixedChannelRisk != nil && *item.ConfirmMixedChannelRisk
account, err := h.adminService.CreateAccount(ctx, &service.CreateAccountInput{
@@ -1085,8 +1049,6 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
response.BadRequest(c, "rate_multiplier must be >= 0")
return
}
- // base_rpm 输入校验:负值归零,超过 10000 截断
- sanitizeExtraBaseRPM(req.Extra)
// 确定是否跳过混合渠道检查
skipCheck := req.ConfirmMixedChannelRisk != nil && *req.ConfirmMixedChannelRisk
@@ -1098,6 +1060,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req.RateMultiplier != nil ||
req.Status != "" ||
req.Schedulable != nil ||
+ req.AutoPauseOnExpired != nil ||
req.GroupIDs != nil ||
len(req.Credentials) > 0 ||
len(req.Extra) > 0
@@ -1116,20 +1079,13 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
RateMultiplier: req.RateMultiplier,
Status: req.Status,
Schedulable: req.Schedulable,
+ AutoPauseOnExpired: req.AutoPauseOnExpired,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
SkipMixedChannelCheck: skipCheck,
})
if err != nil {
- var mixedErr *service.MixedChannelError
- if errors.As(err, &mixedErr) {
- c.JSON(409, gin.H{
- "error": "mixed_channel_warning",
- "message": mixedErr.Error(),
- })
- return
- }
response.ErrorFrom(c, err)
return
}
@@ -1753,22 +1709,3 @@ func (h *AccountHandler) BatchRefreshTier(c *gin.Context) {
func (h *AccountHandler) GetAntigravityDefaultModelMapping(c *gin.Context) {
response.Success(c, domain.DefaultAntigravityModelMapping)
}
-
-// sanitizeExtraBaseRPM 对 extra map 中的 base_rpm 值进行范围校验和归一化。
-// 负值归零,超过 10000 截断为 10000。extra 为 nil 或不含 base_rpm 时无操作。
-func sanitizeExtraBaseRPM(extra map[string]any) {
- if extra == nil {
- return
- }
- raw, ok := extra["base_rpm"]
- if !ok {
- return
- }
- v := service.ParseExtraInt(raw)
- if v < 0 {
- v = 0
- } else if v > 10000 {
- v = 10000
- }
- extra["base_rpm"] = v
-}
diff --git a/backend/internal/handler/admin/account_handler_bulk_update_test.go b/backend/internal/handler/admin/account_handler_bulk_update_test.go
new file mode 100644
index 000000000..c2dfdf746
--- /dev/null
+++ b/backend/internal/handler/admin/account_handler_bulk_update_test.go
@@ -0,0 +1,62 @@
+package admin
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func setupAccountBulkUpdateRouter(adminSvc *stubAdminService) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate)
+ return router
+}
+
+func TestAccountHandlerBulkUpdate_ForwardsAutoPauseOnExpired(t *testing.T) {
+ adminSvc := newStubAdminService()
+ router := setupAccountBulkUpdateRouter(adminSvc)
+
+ body, err := json.Marshal(map[string]any{
+ "account_ids": []int64{1},
+ "auto_pause_on_expired": true,
+ })
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.NotNil(t, adminSvc.lastBulkUpdateInput)
+ require.NotNil(t, adminSvc.lastBulkUpdateInput.AutoPauseOnExpired)
+ require.True(t, *adminSvc.lastBulkUpdateInput.AutoPauseOnExpired)
+}
+
+func TestAccountHandlerBulkUpdate_RejectsEmptyUpdates(t *testing.T) {
+ adminSvc := newStubAdminService()
+ router := setupAccountBulkUpdateRouter(adminSvc)
+
+ body, err := json.Marshal(map[string]any{
+ "account_ids": []int64{1},
+ })
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+
+ var resp map[string]any
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Contains(t, resp["message"], "No updates provided")
+}
diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go
index 24ec5bcfe..ad004844d 100644
--- a/backend/internal/handler/admin/account_handler_mixed_channel_test.go
+++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go
@@ -15,11 +15,10 @@ import (
func setupAccountMixedChannelRouter(adminSvc *stubAdminService) *gin.Engine {
gin.SetMode(gin.TestMode)
router := gin.New()
- accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/check-mixed-channel", accountHandler.CheckMixedChannel)
router.POST("/api/v1/admin/accounts", accountHandler.Create)
router.PUT("/api/v1/admin/accounts/:id", accountHandler.Update)
- router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate)
return router
}
@@ -146,53 +145,3 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T
require.False(t, hasDetails)
require.False(t, hasRequireConfirmation)
}
-
-func TestAccountHandlerBulkUpdateMixedChannelConflict(t *testing.T) {
- adminSvc := newStubAdminService()
- adminSvc.bulkUpdateAccountErr = &service.MixedChannelError{
- GroupID: 27,
- GroupName: "claude-max",
- CurrentPlatform: "Antigravity",
- OtherPlatform: "Anthropic",
- }
- router := setupAccountMixedChannelRouter(adminSvc)
-
- body, _ := json.Marshal(map[string]any{
- "account_ids": []int64{1, 2, 3},
- "group_ids": []int64{27},
- })
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
- req.Header.Set("Content-Type", "application/json")
- router.ServeHTTP(rec, req)
-
- require.Equal(t, http.StatusConflict, rec.Code)
- var resp map[string]any
- require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
- require.Equal(t, "mixed_channel_warning", resp["error"])
- require.Contains(t, resp["message"], "claude-max")
-}
-
-func TestAccountHandlerBulkUpdateMixedChannelConfirmSkips(t *testing.T) {
- adminSvc := newStubAdminService()
- router := setupAccountMixedChannelRouter(adminSvc)
-
- body, _ := json.Marshal(map[string]any{
- "account_ids": []int64{1, 2},
- "group_ids": []int64{27},
- "confirm_mixed_channel_risk": true,
- })
- rec := httptest.NewRecorder()
- req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
- req.Header.Set("Content-Type", "application/json")
- router.ServeHTTP(rec, req)
-
- require.Equal(t, http.StatusOK, rec.Code)
- var resp map[string]any
- require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
- require.Equal(t, float64(0), resp["code"])
- data, ok := resp["data"].(map[string]any)
- require.True(t, ok)
- require.Equal(t, float64(2), data["success"])
- require.Equal(t, float64(0), data["failed"])
-}
diff --git a/backend/internal/handler/admin/account_handler_passthrough_test.go b/backend/internal/handler/admin/account_handler_passthrough_test.go
index d86501c04..d09cccd6d 100644
--- a/backend/internal/handler/admin/account_handler_passthrough_test.go
+++ b/backend/internal/handler/admin/account_handler_passthrough_test.go
@@ -28,7 +28,6 @@ func TestAccountHandler_Create_AnthropicAPIKeyPassthroughExtraForwarded(t *testi
nil,
nil,
nil,
- nil,
)
router := gin.New()
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index f3b99ddbe..a84988617 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -31,7 +31,8 @@ type stubAdminService struct {
platform string
groupIDs []int64
}
- mu sync.Mutex
+ lastBulkUpdateInput *service.BulkUpdateAccountsInput
+ mu sync.Mutex
}
func newStubAdminService() *stubAdminService {
@@ -239,6 +240,9 @@ func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *servic
if s.bulkUpdateAccountErr != nil {
return nil, s.bulkUpdateAccountErr
}
+ s.mu.Lock()
+ s.lastBulkUpdateInput = input
+ s.mu.Unlock()
return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil
}
diff --git a/backend/internal/handler/admin/apikey_handler.go b/backend/internal/handler/admin/apikey_handler.go
index 8dd245a43..6b7b25158 100644
--- a/backend/internal/handler/admin/apikey_handler.go
+++ b/backend/internal/handler/admin/apikey_handler.go
@@ -35,6 +35,10 @@ func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
response.BadRequest(c, "Invalid API key ID")
return
}
+ if keyID <= 0 {
+ response.BadRequest(c, "Invalid API key ID")
+ return
+ }
var req AdminUpdateAPIKeyGroupRequest
if err := c.ShouldBindJSON(&req); err != nil {
diff --git a/backend/internal/handler/admin/apikey_handler_test.go b/backend/internal/handler/admin/apikey_handler_test.go
index bf128b18a..be8ba03b5 100644
--- a/backend/internal/handler/admin/apikey_handler_test.go
+++ b/backend/internal/handler/admin/apikey_handler_test.go
@@ -36,6 +36,19 @@ func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) {
require.Contains(t, rec.Body.String(), "Invalid API key ID")
}
+func TestAdminAPIKeyHandler_UpdateGroup_InvalidNegativeID(t *testing.T) {
+ router := setupAPIKeyHandler(newStubAdminService())
+ body := `{"group_id": 2}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/-1", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "Invalid API key ID")
+}
+
func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) {
router := setupAPIKeyHandler(newStubAdminService())
diff --git a/backend/internal/handler/admin/batch_update_credentials_test.go b/backend/internal/handler/admin/batch_update_credentials_test.go
index 0b1b66917..c8185735c 100644
--- a/backend/internal/handler/admin/batch_update_credentials_test.go
+++ b/backend/internal/handler/admin/batch_update_credentials_test.go
@@ -36,7 +36,7 @@ func (f *failingAdminService) UpdateAccount(ctx context.Context, id int64, input
func setupAccountHandlerWithService(adminSvc service.AdminService) (*gin.Engine, *AccountHandler) {
gin.SetMode(gin.TestMode)
router := gin.New()
- handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ handler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
router.POST("/api/v1/admin/accounts/batch-update-credentials", handler.BatchUpdateCredentials)
return router, handler
}
diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go
index 1d48c653c..ea2572fa4 100644
--- a/backend/internal/handler/admin/dashboard_handler.go
+++ b/backend/internal/handler/admin/dashboard_handler.go
@@ -333,71 +333,15 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
})
}
-// GetGroupStats handles getting group usage statistics
+// GetGroupStats handles getting group usage statistics.
// GET /api/v1/admin/dashboard/groups
-// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type
+//
+// NOTE: Group-level aggregation pipeline is not available in this branch baseline.
+// Keep this endpoint for API compatibility and return an empty dataset.
func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
startTime, endTime := parseTimeRange(c)
-
- var userID, apiKeyID, accountID, groupID int64
- var requestType *int16
- var stream *bool
- var billingType *int8
-
- if userIDStr := c.Query("user_id"); userIDStr != "" {
- if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil {
- userID = id
- }
- }
- if apiKeyIDStr := c.Query("api_key_id"); apiKeyIDStr != "" {
- if id, err := strconv.ParseInt(apiKeyIDStr, 10, 64); err == nil {
- apiKeyID = id
- }
- }
- if accountIDStr := c.Query("account_id"); accountIDStr != "" {
- if id, err := strconv.ParseInt(accountIDStr, 10, 64); err == nil {
- accountID = id
- }
- }
- if groupIDStr := c.Query("group_id"); groupIDStr != "" {
- if id, err := strconv.ParseInt(groupIDStr, 10, 64); err == nil {
- groupID = id
- }
- }
- if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" {
- parsed, err := service.ParseUsageRequestType(requestTypeStr)
- if err != nil {
- response.BadRequest(c, err.Error())
- return
- }
- value := int16(parsed)
- requestType = &value
- } else if streamStr := c.Query("stream"); streamStr != "" {
- if streamVal, err := strconv.ParseBool(streamStr); err == nil {
- stream = &streamVal
- } else {
- response.BadRequest(c, "Invalid stream value, use true or false")
- return
- }
- }
- if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" {
- if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil {
- bt := int8(v)
- billingType = &bt
- } else {
- response.BadRequest(c, "Invalid billing_type")
- return
- }
- }
-
- stats, err := h.dashboardService.GetGroupStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
- if err != nil {
- response.Error(c, 500, "Failed to get group statistics")
- return
- }
-
response.Success(c, gin.H{
- "groups": stats,
+ "groups": []any{},
"start_date": startTime.Format("2006-01-02"),
"end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
})
diff --git a/backend/internal/handler/admin/data_management_handler.go b/backend/internal/handler/admin/data_management_handler.go
index 02fc766f9..2b3ceae72 100644
--- a/backend/internal/handler/admin/data_management_handler.go
+++ b/backend/internal/handler/admin/data_management_handler.go
@@ -1,7 +1,6 @@
package admin
import (
- "context"
"strconv"
"strings"
@@ -14,34 +13,13 @@ import (
)
type DataManagementHandler struct {
- dataManagementService dataManagementService
+ dataManagementService *service.DataManagementService
}
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
return &DataManagementHandler{dataManagementService: dataManagementService}
}
-type dataManagementService interface {
- GetConfig(ctx context.Context) (service.DataManagementConfig, error)
- UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error)
- ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error)
- CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error)
- ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error)
- CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error)
- UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error)
- DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error
- SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error)
- ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error)
- CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error)
- UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error)
- DeleteS3Profile(ctx context.Context, profileID string) error
- SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error)
- ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error)
- GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error)
- EnsureAgentEnabled(ctx context.Context) error
- GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth
-}
-
type TestS3ConnectionRequest struct {
Endpoint string `json:"endpoint"`
Region string `json:"region" binding:"required"`
@@ -123,12 +101,8 @@ func (h *DataManagementHandler) GetConfig(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
- cfg, err := h.dataManagementService.GetConfig(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, cfg)
+ _, err := h.dataManagementService.GetConfig(c.Request.Context())
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
@@ -141,12 +115,8 @@ func (h *DataManagementHandler) UpdateConfig(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
- cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, cfg)
+ _, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) TestS3(c *gin.Context) {
@@ -159,7 +129,7 @@ func (h *DataManagementHandler) TestS3(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
- result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
+ _, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{
Enabled: true,
Endpoint: req.Endpoint,
Region: req.Region,
@@ -170,11 +140,7 @@ func (h *DataManagementHandler) TestS3(c *gin.Context) {
ForcePathStyle: req.ForcePathStyle,
UseSSL: req.UseSSL,
})
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, gin.H{"ok": result.OK, "message": result.Message})
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
@@ -193,7 +159,7 @@ func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok {
triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10)
}
- job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
+ _, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{
BackupType: req.BackupType,
UploadToS3: req.UploadToS3,
S3ProfileID: req.S3ProfileID,
@@ -202,11 +168,7 @@ func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) {
TriggeredBy: triggeredBy,
IdempotencyKey: req.IdempotencyKey,
})
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status})
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
@@ -223,12 +185,8 @@ func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
- items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, gin.H{"items": items})
+ _, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
@@ -247,18 +205,14 @@ func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
- profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
+ _, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{
SourceType: sourceType,
ProfileID: req.ProfileID,
Name: req.Name,
Config: req.Config,
SetActive: req.SetActive,
})
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, profile)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
@@ -282,17 +236,13 @@ func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
- profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
+ _, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{
SourceType: sourceType,
ProfileID: profileID,
Name: req.Name,
Config: req.Config,
})
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, profile)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
@@ -310,11 +260,8 @@ func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
- if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, gin.H{"deleted": true})
+ err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
@@ -332,12 +279,8 @@ func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
- profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, profile)
+ _, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
@@ -345,12 +288,8 @@ func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) {
return
}
- items, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, gin.H{"items": items})
+ _, err := h.dataManagementService.ListS3Profiles(c.Request.Context())
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
@@ -364,7 +303,7 @@ func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
return
}
- profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
+ _, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{
ProfileID: req.ProfileID,
Name: req.Name,
SetActive: req.SetActive,
@@ -380,11 +319,7 @@ func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) {
UseSSL: req.UseSSL,
},
})
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, profile)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
@@ -404,7 +339,7 @@ func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
return
}
- profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
+ _, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{
ProfileID: profileID,
Name: req.Name,
S3: service.DataManagementS3Config{
@@ -419,11 +354,7 @@ func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) {
UseSSL: req.UseSSL,
},
})
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, profile)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
@@ -436,11 +367,8 @@ func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
- if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, gin.H{"deleted": true})
+ err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
@@ -453,12 +381,8 @@ func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
- profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, profile)
+ _, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
@@ -476,17 +400,13 @@ func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) {
pageSize = int32(v)
}
- result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
+ _, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{
PageSize: pageSize,
PageToken: c.Query("page_token"),
Status: c.Query("status"),
BackupType: c.Query("backup_type"),
})
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, result)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
@@ -499,12 +419,8 @@ func (h *DataManagementHandler) GetBackupJob(c *gin.Context) {
if !h.requireAgentEnabled(c) {
return
}
- job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
- if err != nil {
- response.ErrorFrom(c, err)
- return
- }
- response.Success(c, job)
+ _, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID)
+ response.ErrorFrom(c, err)
}
func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
@@ -517,12 +433,9 @@ func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool {
return false
}
- if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil {
- response.ErrorFrom(c, err)
- return false
- }
-
- return true
+ err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context())
+ response.ErrorFrom(c, err)
+ return false
}
func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth {
diff --git a/backend/internal/handler/admin/ops_openai_ws_v2_handler.go b/backend/internal/handler/admin/ops_openai_ws_v2_handler.go
new file mode 100644
index 000000000..a77df4b0a
--- /dev/null
+++ b/backend/internal/handler/admin/ops_openai_ws_v2_handler.go
@@ -0,0 +1,28 @@
+package admin
+
+import (
+ "net/http"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
+ "github.com/gin-gonic/gin"
+)
+
+// GetOpenAIWSV2PassthroughMetrics returns OpenAI WS v2 passthrough runtime metrics.
+// GET /api/v1/admin/ops/openai-ws-v2/passthrough-metrics
+func (h *OpsHandler) GetOpenAIWSV2PassthroughMetrics(c *gin.Context) {
+ if h.opsService == nil {
+ response.Error(c, http.StatusServiceUnavailable, "Ops service not available")
+ return
+ }
+ if err := h.opsService.RequireMonitoringEnabled(c.Request.Context()); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{
+ "passthrough": openaiwsv2.SnapshotMetrics(),
+ "timestamp": time.Now().UTC(),
+ })
+}
diff --git a/backend/internal/handler/admin/ops_openai_ws_v2_handler_test.go b/backend/internal/handler/admin/ops_openai_ws_v2_handler_test.go
new file mode 100644
index 000000000..89b235466
--- /dev/null
+++ b/backend/internal/handler/admin/ops_openai_ws_v2_handler_test.go
@@ -0,0 +1,64 @@
+package admin
+
+import (
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+)
+
+func newOpsOpenAIWSV2TestRouter(handler *OpsHandler) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ r := gin.New()
+ r.GET("/metrics", handler.GetOpenAIWSV2PassthroughMetrics)
+ return r
+}
+
+func TestOpsOpenAIWSV2Handler_GetPassthroughMetrics_ServiceUnavailable(t *testing.T) {
+ r := newOpsOpenAIWSV2TestRouter(NewOpsHandler(nil))
+
+ w := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
+ r.ServeHTTP(w, req)
+ if w.Code != http.StatusServiceUnavailable {
+ t.Fatalf("status=%d, want %d", w.Code, http.StatusServiceUnavailable)
+ }
+}
+
+func TestOpsOpenAIWSV2Handler_GetPassthroughMetrics_Success(t *testing.T) {
+ r := newOpsOpenAIWSV2TestRouter(NewOpsHandler(newRuntimeOpsService(t)))
+
+ w := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/metrics", nil)
+ r.ServeHTTP(w, req)
+ if w.Code != http.StatusOK {
+ t.Fatalf("status=%d, want %d, body=%s", w.Code, http.StatusOK, w.Body.String())
+ }
+
+ var payload map[string]any
+ if err := json.Unmarshal(w.Body.Bytes(), &payload); err != nil {
+ t.Fatalf("unmarshal body: %v", err)
+ }
+ if code, _ := payload["code"].(float64); int(code) != 0 {
+ t.Fatalf("code=%v, want 0", payload["code"])
+ }
+ data, ok := payload["data"].(map[string]any)
+ if !ok {
+ t.Fatalf("missing data field: %v", payload)
+ }
+ passthrough, ok := data["passthrough"].(map[string]any)
+ if !ok {
+ t.Fatalf("missing passthrough field: %v", data)
+ }
+ if _, ok := passthrough["semantic_mutation_total"].(float64); !ok {
+ t.Fatalf("missing semantic_mutation_total: %v", passthrough)
+ }
+ if _, ok := passthrough["usage_parse_failure_total"].(float64); !ok {
+ t.Fatalf("missing usage_parse_failure_total: %v", passthrough)
+ }
+ if _, ok := data["timestamp"].(string); !ok {
+ t.Fatalf("missing timestamp: %v", data)
+ }
+}
diff --git a/backend/internal/handler/admin/proxy_handler.go b/backend/internal/handler/admin/proxy_handler.go
index e8ae0ce2d..9fd187fc5 100644
--- a/backend/internal/handler/admin/proxy_handler.go
+++ b/backend/internal/handler/admin/proxy_handler.go
@@ -64,9 +64,9 @@ func (h *ProxyHandler) List(c *gin.Context) {
return
}
- out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
+ out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
- out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
+ out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
}
response.Paginated(c, out, total, page, pageSize)
}
@@ -83,9 +83,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- out := make([]dto.AdminProxyWithAccountCount, 0, len(proxies))
+ out := make([]dto.ProxyWithAccountCount, 0, len(proxies))
for i := range proxies {
- out = append(out, *dto.ProxyWithAccountCountFromServiceAdmin(&proxies[i]))
+ out = append(out, *dto.ProxyWithAccountCountFromService(&proxies[i]))
}
response.Success(c, out)
return
@@ -97,9 +97,9 @@ func (h *ProxyHandler) GetAll(c *gin.Context) {
return
}
- out := make([]dto.AdminProxy, 0, len(proxies))
+ out := make([]dto.Proxy, 0, len(proxies))
for i := range proxies {
- out = append(out, *dto.ProxyFromServiceAdmin(&proxies[i]))
+ out = append(out, *dto.ProxyFromService(&proxies[i]))
}
response.Success(c, out)
}
@@ -119,7 +119,7 @@ func (h *ProxyHandler) GetByID(c *gin.Context) {
return
}
- response.Success(c, dto.ProxyFromServiceAdmin(proxy))
+ response.Success(c, dto.ProxyFromService(proxy))
}
// Create handles creating a new proxy
@@ -143,7 +143,7 @@ func (h *ProxyHandler) Create(c *gin.Context) {
if err != nil {
return nil, err
}
- return dto.ProxyFromServiceAdmin(proxy), nil
+ return dto.ProxyFromService(proxy), nil
})
}
@@ -176,7 +176,7 @@ func (h *ProxyHandler) Update(c *gin.Context) {
return
}
- response.Success(c, dto.ProxyFromServiceAdmin(proxy))
+ response.Success(c, dto.ProxyFromService(proxy))
}
// Delete handles deleting a proxy
diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go
index 0a932ee98..12706ee62 100644
--- a/backend/internal/handler/admin/redeem_handler.go
+++ b/backend/internal/handler/admin/redeem_handler.go
@@ -43,7 +43,7 @@ type GenerateRedeemCodesRequest struct {
// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
type CreateAndRedeemCodeRequest struct {
Code string `json:"code" binding:"required,min=3,max=128"`
- Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
+ Type string `json:"type" binding:"required,oneof=balance concurrency"`
Value float64 `json:"value" binding:"required,gt=0"`
UserID int64 `json:"user_id" binding:"required,gt=0"`
Notes string `json:"notes"`
@@ -136,6 +136,10 @@ func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
return
}
req.Code = strings.TrimSpace(req.Code)
+ if len(req.Code) < 3 || len(req.Code) > 128 {
+ response.BadRequest(c, "Invalid request: code length must be between 3 and 128")
+ return
+ }
executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
existing, err := h.redeemService.GetByCode(ctx, req.Code)
diff --git a/backend/internal/handler/admin/redeem_handler_create_and_redeem_test.go b/backend/internal/handler/admin/redeem_handler_create_and_redeem_test.go
new file mode 100644
index 000000000..1c06a366d
--- /dev/null
+++ b/backend/internal/handler/admin/redeem_handler_create_and_redeem_test.go
@@ -0,0 +1,46 @@
+package admin
+
+import (
+ "bytes"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func setupCreateAndRedeemRouter() *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ h := NewRedeemHandler(newStubAdminService(), &service.RedeemService{})
+ router.POST("/api/v1/admin/redeem-codes/create-and-redeem", h.CreateAndRedeem)
+ return router
+}
+
+func TestCreateAndRedeem_RejectsUnsupportedType(t *testing.T) {
+ router := setupCreateAndRedeemRouter()
+ body := `{"code":"ORDER-123","type":"subscription","value":100,"user_id":1}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "Invalid request")
+}
+
+func TestCreateAndRedeem_RejectsTrimmedEmptyCode(t *testing.T) {
+ router := setupCreateAndRedeemRouter()
+ body := `{"code":" ","type":"balance","value":100,"user_id":1}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/create-and-redeem", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "code length must be between 3 and 128")
+}
diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go
index e32c142f4..c7b93497b 100644
--- a/backend/internal/handler/admin/setting_handler.go
+++ b/backend/internal/handler/admin/setting_handler.go
@@ -1,13 +1,8 @@
package admin
import (
- "crypto/rand"
- "encoding/hex"
- "encoding/json"
"fmt"
"log"
- "net/http"
- "regexp"
"strings"
"time"
@@ -20,21 +15,6 @@ import (
"github.com/gin-gonic/gin"
)
-// semverPattern 预编译 semver 格式校验正则
-var semverPattern = regexp.MustCompile(`^\d+\.\d+\.\d+$`)
-
-// menuItemIDPattern validates custom menu item IDs: alphanumeric, hyphens, underscores only.
-var menuItemIDPattern = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
-
-// generateMenuItemID generates a short random hex ID for a custom menu item.
-func generateMenuItemID() (string, error) {
- b := make([]byte, 8)
- if _, err := rand.Read(b); err != nil {
- return "", fmt.Errorf("generate menu item ID: %w", err)
- }
- return hex.EncodeToString(b), nil
-}
-
// SettingHandler 系统设置处理器
type SettingHandler struct {
settingService *service.SettingService
@@ -66,13 +46,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
// Check if ops monitoring is enabled (respects config.ops.enabled)
opsEnabled := h.opsService != nil && h.opsService.IsMonitoringEnabled(c.Request.Context())
- defaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(settings.DefaultSubscriptions))
- for _, sub := range settings.DefaultSubscriptions {
- defaultSubscriptions = append(defaultSubscriptions, dto.DefaultSubscriptionSetting{
- GroupID: sub.GroupID,
- ValidityDays: sub.ValidityDays,
- })
- }
response.Success(c, dto.SystemSettings{
RegistrationEnabled: settings.RegistrationEnabled,
@@ -107,10 +80,8 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
- CustomMenuItems: dto.ParseCustomMenuItems(settings.CustomMenuItems),
DefaultConcurrency: settings.DefaultConcurrency,
DefaultBalance: settings.DefaultBalance,
- DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: settings.EnableModelFallback,
FallbackModelAnthropic: settings.FallbackModelAnthropic,
FallbackModelOpenAI: settings.FallbackModelOpenAI,
@@ -122,7 +93,6 @@ func (h *SettingHandler) GetSettings(c *gin.Context) {
OpsRealtimeMonitoringEnabled: settings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: settings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: settings.OpsMetricsIntervalSeconds,
- MinClaudeCodeVersion: settings.MinClaudeCodeVersion,
})
}
@@ -157,23 +127,21 @@ type UpdateSettingsRequest struct {
LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"`
// OEM设置
- SiteName string `json:"site_name"`
- SiteLogo string `json:"site_logo"`
- SiteSubtitle string `json:"site_subtitle"`
- APIBaseURL string `json:"api_base_url"`
- ContactInfo string `json:"contact_info"`
- DocURL string `json:"doc_url"`
- HomeContent string `json:"home_content"`
- HideCcsImportButton bool `json:"hide_ccs_import_button"`
- PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
- PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
- SoraClientEnabled bool `json:"sora_client_enabled"`
- CustomMenuItems *[]dto.CustomMenuItem `json:"custom_menu_items"`
+ SiteName string `json:"site_name"`
+ SiteLogo string `json:"site_logo"`
+ SiteSubtitle string `json:"site_subtitle"`
+ APIBaseURL string `json:"api_base_url"`
+ ContactInfo string `json:"contact_info"`
+ DocURL string `json:"doc_url"`
+ HomeContent string `json:"home_content"`
+ HideCcsImportButton bool `json:"hide_ccs_import_button"`
+ PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"`
+ PurchaseSubscriptionURL *string `json:"purchase_subscription_url"`
+ SoraClientEnabled bool `json:"sora_client_enabled"`
// 默认配置
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
- DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -191,8 +159,6 @@ type UpdateSettingsRequest struct {
OpsRealtimeMonitoringEnabled *bool `json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault *string `json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds *int `json:"ops_metrics_interval_seconds"`
-
- MinClaudeCodeVersion string `json:"min_claude_code_version"`
}
// UpdateSettings 更新系统设置
@@ -220,7 +186,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
if req.SMTPPort <= 0 {
req.SMTPPort = 587
}
- req.DefaultSubscriptions = normalizeDefaultSubscriptions(req.DefaultSubscriptions)
// Turnstile 参数验证
if req.TurnstileEnabled {
@@ -316,84 +281,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
}
- // 自定义菜单项验证
- const (
- maxCustomMenuItems = 20
- maxMenuItemLabelLen = 50
- maxMenuItemURLLen = 2048
- maxMenuItemIconSVGLen = 10 * 1024 // 10KB
- maxMenuItemIDLen = 32
- )
-
- customMenuJSON := previousSettings.CustomMenuItems
- if req.CustomMenuItems != nil {
- items := *req.CustomMenuItems
- if len(items) > maxCustomMenuItems {
- response.BadRequest(c, "Too many custom menu items (max 20)")
- return
- }
- for i, item := range items {
- if strings.TrimSpace(item.Label) == "" {
- response.BadRequest(c, "Custom menu item label is required")
- return
- }
- if len(item.Label) > maxMenuItemLabelLen {
- response.BadRequest(c, "Custom menu item label is too long (max 50 characters)")
- return
- }
- if strings.TrimSpace(item.URL) == "" {
- response.BadRequest(c, "Custom menu item URL is required")
- return
- }
- if len(item.URL) > maxMenuItemURLLen {
- response.BadRequest(c, "Custom menu item URL is too long (max 2048 characters)")
- return
- }
- if err := config.ValidateAbsoluteHTTPURL(strings.TrimSpace(item.URL)); err != nil {
- response.BadRequest(c, "Custom menu item URL must be an absolute http(s) URL")
- return
- }
- if item.Visibility != "user" && item.Visibility != "admin" {
- response.BadRequest(c, "Custom menu item visibility must be 'user' or 'admin'")
- return
- }
- if len(item.IconSVG) > maxMenuItemIconSVGLen {
- response.BadRequest(c, "Custom menu item icon SVG is too large (max 10KB)")
- return
- }
- // Auto-generate ID if missing
- if strings.TrimSpace(item.ID) == "" {
- id, err := generateMenuItemID()
- if err != nil {
- response.Error(c, http.StatusInternalServerError, "Failed to generate menu item ID")
- return
- }
- items[i].ID = id
- } else if len(item.ID) > maxMenuItemIDLen {
- response.BadRequest(c, "Custom menu item ID is too long (max 32 characters)")
- return
- } else if !menuItemIDPattern.MatchString(item.ID) {
- response.BadRequest(c, "Custom menu item ID contains invalid characters (only a-z, A-Z, 0-9, - and _ are allowed)")
- return
- }
- }
- // ID uniqueness check
- seen := make(map[string]struct{}, len(items))
- for _, item := range items {
- if _, exists := seen[item.ID]; exists {
- response.BadRequest(c, "Duplicate custom menu item ID: "+item.ID)
- return
- }
- seen[item.ID] = struct{}{}
- }
- menuBytes, err := json.Marshal(items)
- if err != nil {
- response.BadRequest(c, "Failed to serialize custom menu items")
- return
- }
- customMenuJSON = string(menuBytes)
- }
-
// Ops metrics collector interval validation (seconds).
if req.OpsMetricsIntervalSeconds != nil {
v := *req.OpsMetricsIntervalSeconds
@@ -405,21 +292,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
}
req.OpsMetricsIntervalSeconds = &v
}
- defaultSubscriptions := make([]service.DefaultSubscriptionSetting, 0, len(req.DefaultSubscriptions))
- for _, sub := range req.DefaultSubscriptions {
- defaultSubscriptions = append(defaultSubscriptions, service.DefaultSubscriptionSetting{
- GroupID: sub.GroupID,
- ValidityDays: sub.ValidityDays,
- })
- }
-
- // 验证最低版本号格式(空字符串=禁用,或合法 semver)
- if req.MinClaudeCodeVersion != "" {
- if !semverPattern.MatchString(req.MinClaudeCodeVersion) {
- response.Error(c, http.StatusBadRequest, "min_claude_code_version must be empty or a valid semver (e.g. 2.1.63)")
- return
- }
- }
settings := &service.SystemSettings{
RegistrationEnabled: req.RegistrationEnabled,
@@ -453,10 +325,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PurchaseSubscriptionEnabled: purchaseEnabled,
PurchaseSubscriptionURL: purchaseURL,
SoraClientEnabled: req.SoraClientEnabled,
- CustomMenuItems: customMenuJSON,
DefaultConcurrency: req.DefaultConcurrency,
DefaultBalance: req.DefaultBalance,
- DefaultSubscriptions: defaultSubscriptions,
EnableModelFallback: req.EnableModelFallback,
FallbackModelAnthropic: req.FallbackModelAnthropic,
FallbackModelOpenAI: req.FallbackModelOpenAI,
@@ -464,7 +334,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
FallbackModelAntigravity: req.FallbackModelAntigravity,
EnableIdentityPatch: req.EnableIdentityPatch,
IdentityPatchPrompt: req.IdentityPatchPrompt,
- MinClaudeCodeVersion: req.MinClaudeCodeVersion,
OpsMonitoringEnabled: func() bool {
if req.OpsMonitoringEnabled != nil {
return *req.OpsMonitoringEnabled
@@ -504,13 +373,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
response.ErrorFrom(c, err)
return
}
- updatedDefaultSubscriptions := make([]dto.DefaultSubscriptionSetting, 0, len(updatedSettings.DefaultSubscriptions))
- for _, sub := range updatedSettings.DefaultSubscriptions {
- updatedDefaultSubscriptions = append(updatedDefaultSubscriptions, dto.DefaultSubscriptionSetting{
- GroupID: sub.GroupID,
- ValidityDays: sub.ValidityDays,
- })
- }
response.Success(c, dto.SystemSettings{
RegistrationEnabled: updatedSettings.RegistrationEnabled,
@@ -545,10 +407,8 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL,
SoraClientEnabled: updatedSettings.SoraClientEnabled,
- CustomMenuItems: dto.ParseCustomMenuItems(updatedSettings.CustomMenuItems),
DefaultConcurrency: updatedSettings.DefaultConcurrency,
DefaultBalance: updatedSettings.DefaultBalance,
- DefaultSubscriptions: updatedDefaultSubscriptions,
EnableModelFallback: updatedSettings.EnableModelFallback,
FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic,
FallbackModelOpenAI: updatedSettings.FallbackModelOpenAI,
@@ -560,7 +420,6 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) {
OpsRealtimeMonitoringEnabled: updatedSettings.OpsRealtimeMonitoringEnabled,
OpsQueryModeDefault: updatedSettings.OpsQueryModeDefault,
OpsMetricsIntervalSeconds: updatedSettings.OpsMetricsIntervalSeconds,
- MinClaudeCodeVersion: updatedSettings.MinClaudeCodeVersion,
})
}
@@ -670,9 +529,6 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.DefaultBalance != after.DefaultBalance {
changed = append(changed, "default_balance")
}
- if !equalDefaultSubscriptions(before.DefaultSubscriptions, after.DefaultSubscriptions) {
- changed = append(changed, "default_subscriptions")
- }
if before.EnableModelFallback != after.EnableModelFallback {
changed = append(changed, "enable_model_fallback")
}
@@ -706,50 +562,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
if before.OpsMetricsIntervalSeconds != after.OpsMetricsIntervalSeconds {
changed = append(changed, "ops_metrics_interval_seconds")
}
- if before.MinClaudeCodeVersion != after.MinClaudeCodeVersion {
- changed = append(changed, "min_claude_code_version")
- }
- if before.PurchaseSubscriptionEnabled != after.PurchaseSubscriptionEnabled {
- changed = append(changed, "purchase_subscription_enabled")
- }
- if before.PurchaseSubscriptionURL != after.PurchaseSubscriptionURL {
- changed = append(changed, "purchase_subscription_url")
- }
- if before.CustomMenuItems != after.CustomMenuItems {
- changed = append(changed, "custom_menu_items")
- }
return changed
}
-func normalizeDefaultSubscriptions(input []dto.DefaultSubscriptionSetting) []dto.DefaultSubscriptionSetting {
- if len(input) == 0 {
- return nil
- }
- normalized := make([]dto.DefaultSubscriptionSetting, 0, len(input))
- for _, item := range input {
- if item.GroupID <= 0 || item.ValidityDays <= 0 {
- continue
- }
- if item.ValidityDays > service.MaxValidityDays {
- item.ValidityDays = service.MaxValidityDays
- }
- normalized = append(normalized, item)
- }
- return normalized
-}
-
-func equalDefaultSubscriptions(a, b []service.DefaultSubscriptionSetting) bool {
- if len(a) != len(b) {
- return false
- }
- for i := range a {
- if a[i].GroupID != b[i].GroupID || a[i].ValidityDays != b[i].ValidityDays {
- return false
- }
- }
- return true
-}
-
// TestSMTPRequest 测试SMTP连接请求
type TestSMTPRequest struct {
SMTPHost string `json:"smtp_host" binding:"required"`
diff --git a/backend/internal/handler/admin/setting_handler_bulk_edit_template.go b/backend/internal/handler/admin/setting_handler_bulk_edit_template.go
new file mode 100644
index 000000000..c63ed9af2
--- /dev/null
+++ b/backend/internal/handler/admin/setting_handler_bulk_edit_template.go
@@ -0,0 +1,228 @@
+package admin
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+)
+
+type UpsertBulkEditTemplateRequest struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ ScopePlatform string `json:"scope_platform"`
+ ScopeType string `json:"scope_type"`
+ ShareScope string `json:"share_scope"`
+ GroupIDs []int64 `json:"group_ids"`
+ State map[string]any `json:"state"`
+}
+
+type RollbackBulkEditTemplateRequest struct {
+ VersionID string `json:"version_id"`
+}
+
+// ListBulkEditTemplates 获取批量编辑模板列表
+// GET /api/v1/admin/settings/bulk-edit-templates
+func (h *SettingHandler) ListBulkEditTemplates(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+
+ scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids"))
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ items, listErr := h.settingService.ListBulkEditTemplates(c.Request.Context(), service.BulkEditTemplateQuery{
+ ScopePlatform: c.Query("scope_platform"),
+ ScopeType: c.Query("scope_type"),
+ ScopeGroupIDs: scopeGroupIDs,
+ RequesterUserID: subject.UserID,
+ })
+ if listErr != nil {
+ response.ErrorFrom(c, listErr)
+ return
+ }
+
+ response.Success(c, gin.H{"items": items})
+}
+
+// UpsertBulkEditTemplate 创建/更新批量编辑模板
+// POST /api/v1/admin/settings/bulk-edit-templates
+func (h *SettingHandler) UpsertBulkEditTemplate(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+
+ var req UpsertBulkEditTemplateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ item, upsertErr := h.settingService.UpsertBulkEditTemplate(
+ c.Request.Context(),
+ service.BulkEditTemplateUpsertInput{
+ ID: req.ID,
+ Name: req.Name,
+ ScopePlatform: req.ScopePlatform,
+ ScopeType: req.ScopeType,
+ ShareScope: req.ShareScope,
+ GroupIDs: req.GroupIDs,
+ State: req.State,
+ RequesterUserID: subject.UserID,
+ },
+ )
+ if upsertErr != nil {
+ response.ErrorFrom(c, upsertErr)
+ return
+ }
+
+ response.Success(c, item)
+}
+
+// DeleteBulkEditTemplate 删除批量编辑模板
+// DELETE /api/v1/admin/settings/bulk-edit-templates/:template_id
+func (h *SettingHandler) DeleteBulkEditTemplate(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+
+ templateID := strings.TrimSpace(c.Param("template_id"))
+ if templateID == "" {
+ response.BadRequest(c, "template_id is required")
+ return
+ }
+
+ if err := h.settingService.DeleteBulkEditTemplate(c.Request.Context(), templateID, subject.UserID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"deleted": true})
+}
+
+// ListBulkEditTemplateVersions 获取模板版本历史
+// GET /api/v1/admin/settings/bulk-edit-templates/:template_id/versions
+func (h *SettingHandler) ListBulkEditTemplateVersions(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+
+ templateID := strings.TrimSpace(c.Param("template_id"))
+ if templateID == "" {
+ response.BadRequest(c, "template_id is required")
+ return
+ }
+
+ scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids"))
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ items, listErr := h.settingService.ListBulkEditTemplateVersions(
+ c.Request.Context(),
+ service.BulkEditTemplateVersionQuery{
+ TemplateID: templateID,
+ ScopeGroupIDs: scopeGroupIDs,
+ RequesterUserID: subject.UserID,
+ },
+ )
+ if listErr != nil {
+ response.ErrorFrom(c, listErr)
+ return
+ }
+
+ response.Success(c, gin.H{"items": items})
+}
+
+// RollbackBulkEditTemplate 回滚模板到指定版本
+// POST /api/v1/admin/settings/bulk-edit-templates/:template_id/rollback
+func (h *SettingHandler) RollbackBulkEditTemplate(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+
+ templateID := strings.TrimSpace(c.Param("template_id"))
+ if templateID == "" {
+ response.BadRequest(c, "template_id is required")
+ return
+ }
+
+ scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids"))
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ var req RollbackBulkEditTemplateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ item, rollbackErr := h.settingService.RollbackBulkEditTemplate(
+ c.Request.Context(),
+ service.BulkEditTemplateRollbackInput{
+ TemplateID: templateID,
+ VersionID: req.VersionID,
+ ScopeGroupIDs: scopeGroupIDs,
+ RequesterUserID: subject.UserID,
+ },
+ )
+ if rollbackErr != nil {
+ response.ErrorFrom(c, rollbackErr)
+ return
+ }
+
+ response.Success(c, item)
+}
+
+func parseScopeGroupIDs(raw string) ([]int64, error) {
+ trimmed := strings.TrimSpace(raw)
+ if trimmed == "" {
+ return nil, nil
+ }
+
+ parts := strings.Split(trimmed, ",")
+ if len(parts) == 0 {
+ return nil, nil
+ }
+
+ seen := make(map[int64]struct{}, len(parts))
+ groupIDs := make([]int64, 0, len(parts))
+ for _, part := range parts {
+ candidate := strings.TrimSpace(part)
+ if candidate == "" {
+ continue
+ }
+
+ groupID, err := strconv.ParseInt(candidate, 10, 64)
+ if err != nil || groupID <= 0 {
+ return nil, fmt.Errorf("scope_group_ids must be comma-separated positive integers")
+ }
+ if _, exists := seen[groupID]; exists {
+ continue
+ }
+ seen[groupID] = struct{}{}
+ groupIDs = append(groupIDs, groupID)
+ }
+
+ return groupIDs, nil
+}
diff --git a/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go b/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go
new file mode 100644
index 000000000..c72d412bf
--- /dev/null
+++ b/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go
@@ -0,0 +1,603 @@
+package admin
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerTemplateRepoStub struct {
+ values map[string]string
+}
+
+func newSettingHandlerTemplateRepoStub() *settingHandlerTemplateRepoStub {
+ return &settingHandlerTemplateRepoStub{values: map[string]string{}}
+}
+
+func (s *settingHandlerTemplateRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ value, err := s.GetValue(ctx, key)
+ if err != nil {
+ return nil, err
+ }
+ return &service.Setting{Key: key, Value: value}, nil
+}
+
+func (s *settingHandlerTemplateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *settingHandlerTemplateRepoStub) Set(ctx context.Context, key, value string) error {
+ s.values[key] = value
+ return nil
+}
+
+func (s *settingHandlerTemplateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerTemplateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ for key, value := range settings {
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *settingHandlerTemplateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *settingHandlerTemplateRepoStub) Delete(ctx context.Context, key string) error {
+ delete(s.values, key)
+ return nil
+}
+
+type failingSettingRepoStub struct{}
+
+func (s *failingSettingRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ return nil, errors.New("boom")
+}
+func (s *failingSettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ return "", errors.New("boom")
+}
+func (s *failingSettingRepoStub) Set(ctx context.Context, key, value string) error {
+ return errors.New("boom")
+}
+func (s *failingSettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ return nil, errors.New("boom")
+}
+func (s *failingSettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ return errors.New("boom")
+}
+func (s *failingSettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ return nil, errors.New("boom")
+}
+func (s *failingSettingRepoStub) Delete(ctx context.Context, key string) error {
+ return errors.New("boom")
+}
+
+func setupBulkEditTemplateRouter() *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ repo := newSettingHandlerTemplateRepoStub()
+ settingService := service.NewSettingService(repo, nil)
+ handler := NewSettingHandler(settingService, nil, nil, nil, nil)
+
+ router := gin.New()
+ router.Use(func(c *gin.Context) {
+ uid := int64(1)
+ if header := c.GetHeader("X-User-ID"); header != "" {
+ if parsed, err := strconv.ParseInt(header, 10, 64); err == nil && parsed > 0 {
+ uid = parsed
+ }
+ }
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: uid})
+ c.Next()
+ })
+
+ router.GET("/api/v1/admin/settings/bulk-edit-templates", handler.ListBulkEditTemplates)
+ router.POST("/api/v1/admin/settings/bulk-edit-templates", handler.UpsertBulkEditTemplate)
+ router.DELETE("/api/v1/admin/settings/bulk-edit-templates/:template_id", handler.DeleteBulkEditTemplate)
+ router.GET("/api/v1/admin/settings/bulk-edit-templates/:template_id/versions", handler.ListBulkEditTemplateVersions)
+ router.POST("/api/v1/admin/settings/bulk-edit-templates/:template_id/rollback", handler.RollbackBulkEditTemplate)
+
+ return router
+}
+
+func decodeResponseDataMap(t *testing.T, body []byte) map[string]any {
+ t.Helper()
+ var payload response.Response
+ require.NoError(t, json.Unmarshal(body, &payload))
+ if payload.Data == nil {
+ return map[string]any{}
+ }
+ asMap, ok := payload.Data.(map[string]any)
+ require.True(t, ok)
+ return asMap
+}
+
+func TestSettingHandlerBulkEditTemplate_CRUDFlow(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ createBody := map[string]any{
+ "name": "OpenAI OAuth Baseline",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "team",
+ "state": map[string]any{
+ "enableOpenAIPassthrough": true,
+ },
+ }
+ raw, err := json.Marshal(createBody)
+ require.NoError(t, err)
+
+ createRec := httptest.NewRecorder()
+ createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw))
+ createReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(createRec, createReq)
+ require.Equal(t, http.StatusOK, createRec.Code)
+
+ createData := decodeResponseDataMap(t, createRec.Body.Bytes())
+ templateID, ok := createData["id"].(string)
+ require.True(t, ok)
+ require.NotEmpty(t, templateID)
+
+ listRec := httptest.NewRecorder()
+ listReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth",
+ nil,
+ )
+ router.ServeHTTP(listRec, listReq)
+ require.Equal(t, http.StatusOK, listRec.Code)
+
+ listData := decodeResponseDataMap(t, listRec.Body.Bytes())
+ items, ok := listData["items"].([]any)
+ require.True(t, ok)
+ require.Len(t, items, 1)
+
+ deleteRec := httptest.NewRecorder()
+ deleteReq := httptest.NewRequest(
+ http.MethodDelete,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID,
+ nil,
+ )
+ router.ServeHTTP(deleteRec, deleteReq)
+ require.Equal(t, http.StatusOK, deleteRec.Code)
+
+ listAfterDeleteRec := httptest.NewRecorder()
+ listAfterDeleteReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth",
+ nil,
+ )
+ router.ServeHTTP(listAfterDeleteRec, listAfterDeleteReq)
+ require.Equal(t, http.StatusOK, listAfterDeleteRec.Code)
+
+ listAfterDeleteData := decodeResponseDataMap(t, listAfterDeleteRec.Body.Bytes())
+ itemsAfterDelete, ok := listAfterDeleteData["items"].([]any)
+ require.True(t, ok)
+ require.Len(t, itemsAfterDelete, 0)
+}
+
+func TestSettingHandlerBulkEditTemplate_VersionsAndRollback(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ createBody := map[string]any{
+ "name": "Rollback Target",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "groups",
+ "group_ids": []int64{2},
+ "state": map[string]any{
+ "enableOpenAIWSMode": true,
+ },
+ }
+ createRaw, err := json.Marshal(createBody)
+ require.NoError(t, err)
+
+ createRec := httptest.NewRecorder()
+ createReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates",
+ bytes.NewReader(createRaw),
+ )
+ createReq.Header.Set("Content-Type", "application/json")
+ createReq.Header.Set("X-User-ID", "9")
+ router.ServeHTTP(createRec, createReq)
+ require.Equal(t, http.StatusOK, createRec.Code)
+ createData := decodeResponseDataMap(t, createRec.Body.Bytes())
+ templateID, ok := createData["id"].(string)
+ require.True(t, ok)
+
+ updateBody := map[string]any{
+ "id": templateID,
+ "name": "Rollback Target",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "team",
+ "group_ids": []int64{},
+ "state": map[string]any{
+ "enableOpenAIWSMode": false,
+ },
+ }
+ updateRaw, err := json.Marshal(updateBody)
+ require.NoError(t, err)
+
+ updateRec := httptest.NewRecorder()
+ updateReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates",
+ bytes.NewReader(updateRaw),
+ )
+ updateReq.Header.Set("Content-Type", "application/json")
+ updateReq.Header.Set("X-User-ID", "9")
+ router.ServeHTTP(updateRec, updateReq)
+ require.Equal(t, http.StatusOK, updateRec.Code)
+
+ versionsRec := httptest.NewRecorder()
+ versionsReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/versions?scope_group_ids=2",
+ nil,
+ )
+ versionsReq.Header.Set("X-User-ID", "9")
+ router.ServeHTTP(versionsRec, versionsReq)
+ require.Equal(t, http.StatusOK, versionsRec.Code)
+ versionsData := decodeResponseDataMap(t, versionsRec.Body.Bytes())
+ versions, ok := versionsData["items"].([]any)
+ require.True(t, ok)
+ require.Len(t, versions, 1)
+ versionData, ok := versions[0].(map[string]any)
+ require.True(t, ok)
+ versionID, ok := versionData["version_id"].(string)
+ require.True(t, ok)
+
+ rollbackBody := map[string]any{"version_id": versionID}
+ rollbackRaw, err := json.Marshal(rollbackBody)
+ require.NoError(t, err)
+
+ rollbackRec := httptest.NewRecorder()
+ rollbackReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/rollback?scope_group_ids=2",
+ bytes.NewReader(rollbackRaw),
+ )
+ rollbackReq.Header.Set("Content-Type", "application/json")
+ rollbackReq.Header.Set("X-User-ID", "9")
+ router.ServeHTTP(rollbackRec, rollbackReq)
+ require.Equal(t, http.StatusOK, rollbackRec.Code)
+ rollbackData := decodeResponseDataMap(t, rollbackRec.Body.Bytes())
+ require.Equal(t, "groups", rollbackData["share_scope"])
+ groupIDs, ok := rollbackData["group_ids"].([]any)
+ require.True(t, ok)
+ require.Equal(t, []any{float64(2)}, groupIDs)
+ state, ok := rollbackData["state"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, true, state["enableOpenAIWSMode"])
+
+ versionsAfterRollbackRec := httptest.NewRecorder()
+ versionsAfterRollbackReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/versions?scope_group_ids=2",
+ nil,
+ )
+ versionsAfterRollbackReq.Header.Set("X-User-ID", "9")
+ router.ServeHTTP(versionsAfterRollbackRec, versionsAfterRollbackReq)
+ require.Equal(t, http.StatusOK, versionsAfterRollbackRec.Code)
+ versionsAfterRollbackData := decodeResponseDataMap(t, versionsAfterRollbackRec.Body.Bytes())
+ versionsAfterRollback, ok := versionsAfterRollbackData["items"].([]any)
+ require.True(t, ok)
+ require.Len(t, versionsAfterRollback, 2)
+}
+
+func TestSettingHandlerBulkEditTemplate_Validation(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ invalidCreateBody := map[string]any{
+ "name": "Groups Template",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "groups",
+ "group_ids": []int64{},
+ "state": map[string]any{},
+ }
+ raw, err := json.Marshal(invalidCreateBody)
+ require.NoError(t, err)
+
+ invalidCreateRec := httptest.NewRecorder()
+ invalidCreateReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw))
+ invalidCreateReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(invalidCreateRec, invalidCreateReq)
+ require.Equal(t, http.StatusBadRequest, invalidCreateRec.Code)
+
+ invalidListRec := httptest.NewRecorder()
+ invalidListReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_group_ids=abc",
+ nil,
+ )
+ router.ServeHTTP(invalidListRec, invalidListReq)
+ require.Equal(t, http.StatusBadRequest, invalidListRec.Code)
+}
+
+func TestSettingHandlerBulkEditTemplate_PrivateVisibilityAndDeletePermission(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ createBody := map[string]any{
+ "name": "Private Template",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "private",
+ "state": map[string]any{
+ "enableBaseUrl": true,
+ },
+ }
+ raw, err := json.Marshal(createBody)
+ require.NoError(t, err)
+
+ createRec := httptest.NewRecorder()
+ createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw))
+ createReq.Header.Set("Content-Type", "application/json")
+ createReq.Header.Set("X-User-ID", "100")
+ router.ServeHTTP(createRec, createReq)
+ require.Equal(t, http.StatusOK, createRec.Code)
+
+ createData := decodeResponseDataMap(t, createRec.Body.Bytes())
+ templateID, ok := createData["id"].(string)
+ require.True(t, ok)
+
+ listByOtherRec := httptest.NewRecorder()
+ listByOtherReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth",
+ nil,
+ )
+ listByOtherReq.Header.Set("X-User-ID", "200")
+ router.ServeHTTP(listByOtherRec, listByOtherReq)
+ require.Equal(t, http.StatusOK, listByOtherRec.Code)
+
+ listByOtherData := decodeResponseDataMap(t, listByOtherRec.Body.Bytes())
+ items, ok := listByOtherData["items"].([]any)
+ require.True(t, ok)
+ require.Len(t, items, 0)
+
+ deleteByOtherRec := httptest.NewRecorder()
+ deleteByOtherReq := httptest.NewRequest(
+ http.MethodDelete,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID,
+ nil,
+ )
+ deleteByOtherReq.Header.Set("X-User-ID", "200")
+ router.ServeHTTP(deleteByOtherRec, deleteByOtherReq)
+ require.Equal(t, http.StatusForbidden, deleteByOtherRec.Code)
+
+ deleteByOwnerRec := httptest.NewRecorder()
+ deleteByOwnerReq := httptest.NewRequest(
+ http.MethodDelete,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID,
+ nil,
+ )
+ deleteByOwnerReq.Header.Set("X-User-ID", "100")
+ router.ServeHTTP(deleteByOwnerRec, deleteByOwnerReq)
+ require.Equal(t, http.StatusOK, deleteByOwnerRec.Code)
+}
+
+func TestSettingHandlerBulkEditTemplate_GroupsVisibilityByScopeGroupIDs(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ createBody := map[string]any{
+ "name": "Group Shared",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "groups",
+ "group_ids": []int64{3, 8},
+ "state": map[string]any{"enableOpenAIWSMode": true},
+ }
+ raw, err := json.Marshal(createBody)
+ require.NoError(t, err)
+
+ createRec := httptest.NewRecorder()
+ createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw))
+ createReq.Header.Set("Content-Type", "application/json")
+ createReq.Header.Set("X-User-ID", "1")
+ router.ServeHTTP(createRec, createReq)
+ require.Equal(t, http.StatusOK, createRec.Code)
+
+ invisibleRec := httptest.NewRecorder()
+ invisibleReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth&scope_group_ids=9",
+ nil,
+ )
+ invisibleReq.Header.Set("X-User-ID", "2")
+ router.ServeHTTP(invisibleRec, invisibleReq)
+ require.Equal(t, http.StatusOK, invisibleRec.Code)
+ invisibleData := decodeResponseDataMap(t, invisibleRec.Body.Bytes())
+ invisibleItems, ok := invisibleData["items"].([]any)
+ require.True(t, ok)
+ require.Len(t, invisibleItems, 0)
+
+ visibleRec := httptest.NewRecorder()
+ visibleReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth&scope_group_ids=8",
+ nil,
+ )
+ visibleReq.Header.Set("X-User-ID", "2")
+ router.ServeHTTP(visibleRec, visibleReq)
+ require.Equal(t, http.StatusOK, visibleRec.Code)
+ visibleData := decodeResponseDataMap(t, visibleRec.Body.Bytes())
+ visibleItems, ok := visibleData["items"].([]any)
+ require.True(t, ok)
+ require.Len(t, visibleItems, 1)
+}
+
+func TestSettingHandlerBulkEditTemplate_UnauthorizedAndInvalidRequests(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := newSettingHandlerTemplateRepoStub()
+ settingService := service.NewSettingService(repo, nil)
+ handler := NewSettingHandler(settingService, nil, nil, nil, nil)
+
+ router := gin.New()
+ router.GET("/list", handler.ListBulkEditTemplates)
+ router.GET("/versions/:template_id", handler.ListBulkEditTemplateVersions)
+ router.POST("/rollback/:template_id", handler.RollbackBulkEditTemplate)
+ router.POST("/upsert", handler.UpsertBulkEditTemplate)
+ router.DELETE("/delete/:template_id", handler.DeleteBulkEditTemplate)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/list", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/upsert", bytes.NewBufferString("{bad-json"))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodDelete, "/delete/%20", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/versions/abc", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/rollback/abc", bytes.NewBufferString(`{"version_id":"v1"}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestParseScopeGroupIDs(t *testing.T) {
+ ids, err := parseScopeGroupIDs("")
+ require.NoError(t, err)
+ require.Nil(t, ids)
+
+ ids, err = parseScopeGroupIDs("1, 2,2,3")
+ require.NoError(t, err)
+ require.Equal(t, []int64{1, 2, 3}, ids)
+
+ _, err = parseScopeGroupIDs("x,2")
+ require.Error(t, err)
+}
+
+func TestSettingHandlerBulkEditTemplate_BindErrorAndMissingTemplateID(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ bindErrRec := httptest.NewRecorder()
+ bindErrReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates",
+ bytes.NewBufferString("{bad-json"),
+ )
+ bindErrReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(bindErrRec, bindErrReq)
+ require.Equal(t, http.StatusBadRequest, bindErrRec.Code)
+
+ missingIDRec := httptest.NewRecorder()
+ missingIDReq := httptest.NewRequest(
+ http.MethodDelete,
+ "/api/v1/admin/settings/bulk-edit-templates/%20",
+ nil,
+ )
+ router.ServeHTTP(missingIDRec, missingIDReq)
+ require.Equal(t, http.StatusBadRequest, missingIDRec.Code)
+
+ invalidScopeRec := httptest.NewRecorder()
+ invalidScopeReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates/abc/versions?scope_group_ids=bad",
+ nil,
+ )
+ router.ServeHTTP(invalidScopeRec, invalidScopeReq)
+ require.Equal(t, http.StatusBadRequest, invalidScopeRec.Code)
+
+ rollbackMissingIDRec := httptest.NewRecorder()
+ rollbackMissingIDReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates/%20/rollback",
+ bytes.NewBufferString(`{"version_id":"v1"}`),
+ )
+ rollbackMissingIDReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rollbackMissingIDRec, rollbackMissingIDReq)
+ require.Equal(t, http.StatusBadRequest, rollbackMissingIDRec.Code)
+
+ rollbackInvalidScopeRec := httptest.NewRecorder()
+ rollbackInvalidScopeReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates/abc/rollback?scope_group_ids=bad",
+ bytes.NewBufferString(`{"version_id":"v1"}`),
+ )
+ rollbackInvalidScopeReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rollbackInvalidScopeRec, rollbackInvalidScopeReq)
+ require.Equal(t, http.StatusBadRequest, rollbackInvalidScopeRec.Code)
+
+ rollbackBindErrRec := httptest.NewRecorder()
+ rollbackBindErrReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates/abc/rollback",
+ bytes.NewBufferString("{bad-json"),
+ )
+ rollbackBindErrReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rollbackBindErrRec, rollbackBindErrReq)
+ require.Equal(t, http.StatusBadRequest, rollbackBindErrRec.Code)
+}
+
+func TestSettingHandlerBulkEditTemplate_ListErrorFromService(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ settingService := service.NewSettingService(&failingSettingRepoStub{}, nil)
+ handler := NewSettingHandler(settingService, nil, nil, nil, nil)
+ router := gin.New()
+ router.Use(func(c *gin.Context) {
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 1})
+ c.Next()
+ })
+ router.GET("/list", handler.ListBulkEditTemplates)
+ router.GET("/versions/:template_id", handler.ListBulkEditTemplateVersions)
+ router.POST("/rollback/:template_id", handler.RollbackBulkEditTemplate)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/list", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/versions/tpl-1", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/rollback/tpl-1", bytes.NewBufferString(`{"version_id":"v1"}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go
index fe2a1d773..49c74522a 100644
--- a/backend/internal/handler/dto/mappers.go
+++ b/backend/internal/handler/dto/mappers.go
@@ -72,31 +72,22 @@ func APIKeyFromService(k *service.APIKey) *APIKey {
return nil
}
return &APIKey{
- ID: k.ID,
- UserID: k.UserID,
- Key: k.Key,
- Name: k.Name,
- GroupID: k.GroupID,
- Status: k.Status,
- IPWhitelist: k.IPWhitelist,
- IPBlacklist: k.IPBlacklist,
- LastUsedAt: k.LastUsedAt,
- Quota: k.Quota,
- QuotaUsed: k.QuotaUsed,
- ExpiresAt: k.ExpiresAt,
- CreatedAt: k.CreatedAt,
- UpdatedAt: k.UpdatedAt,
- RateLimit5h: k.RateLimit5h,
- RateLimit1d: k.RateLimit1d,
- RateLimit7d: k.RateLimit7d,
- Usage5h: k.Usage5h,
- Usage1d: k.Usage1d,
- Usage7d: k.Usage7d,
- Window5hStart: k.Window5hStart,
- Window1dStart: k.Window1dStart,
- Window7dStart: k.Window7dStart,
- User: UserFromServiceShallow(k.User),
- Group: GroupFromServiceShallow(k.Group),
+ ID: k.ID,
+ UserID: k.UserID,
+ Key: k.Key,
+ Name: k.Name,
+ GroupID: k.GroupID,
+ Status: k.Status,
+ IPWhitelist: k.IPWhitelist,
+ IPBlacklist: k.IPBlacklist,
+ LastUsedAt: k.LastUsedAt,
+ Quota: k.Quota,
+ QuotaUsed: k.QuotaUsed,
+ ExpiresAt: k.ExpiresAt,
+ CreatedAt: k.CreatedAt,
+ UpdatedAt: k.UpdatedAt,
+ User: UserFromServiceShallow(k.User),
+ Group: GroupFromServiceShallow(k.Group),
}
}
@@ -218,17 +209,6 @@ func AccountFromServiceShallow(a *service.Account) *Account {
if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 {
out.SessionIdleTimeoutMin = &idleTimeout
}
- if rpm := a.GetBaseRPM(); rpm > 0 {
- out.BaseRPM = &rpm
- strategy := a.GetRPMStrategy()
- out.RPMStrategy = &strategy
- buffer := a.GetRPMStickyBuffer()
- out.RPMStickyBuffer = &buffer
- }
- // 用户消息队列模式
- if mode := a.GetUserMsgQueueMode(); mode != "" {
- out.UserMsgQueueMode = &mode
- }
// TLS指纹伪装开关
if a.IsTLSFingerprintEnabled() {
enabled := true
@@ -306,6 +286,7 @@ func ProxyFromService(p *service.Proxy) *Proxy {
Host: p.Host,
Port: p.Port,
Username: p.Username,
+ Password: p.Password,
Status: p.Status,
CreatedAt: p.CreatedAt,
UpdatedAt: p.UpdatedAt,
@@ -335,51 +316,6 @@ func ProxyWithAccountCountFromService(p *service.ProxyWithAccountCount) *ProxyWi
}
}
-// ProxyFromServiceAdmin converts a service Proxy to AdminProxy DTO for admin users.
-// It includes the password field - user-facing endpoints must not use this.
-func ProxyFromServiceAdmin(p *service.Proxy) *AdminProxy {
- if p == nil {
- return nil
- }
- base := ProxyFromService(p)
- if base == nil {
- return nil
- }
- return &AdminProxy{
- Proxy: *base,
- Password: p.Password,
- }
-}
-
-// ProxyWithAccountCountFromServiceAdmin converts a service ProxyWithAccountCount to AdminProxyWithAccountCount DTO.
-// It includes the password field - user-facing endpoints must not use this.
-func ProxyWithAccountCountFromServiceAdmin(p *service.ProxyWithAccountCount) *AdminProxyWithAccountCount {
- if p == nil {
- return nil
- }
- admin := ProxyFromServiceAdmin(&p.Proxy)
- if admin == nil {
- return nil
- }
- return &AdminProxyWithAccountCount{
- AdminProxy: *admin,
- AccountCount: p.AccountCount,
- LatencyMs: p.LatencyMs,
- LatencyStatus: p.LatencyStatus,
- LatencyMessage: p.LatencyMessage,
- IPAddress: p.IPAddress,
- Country: p.Country,
- CountryCode: p.CountryCode,
- Region: p.Region,
- City: p.City,
- QualityStatus: p.QualityStatus,
- QualityScore: p.QualityScore,
- QualityGrade: p.QualityGrade,
- QualitySummary: p.QualitySummary,
- QualityChecked: p.QualityChecked,
- }
-}
-
func ProxyAccountSummaryFromService(a *service.ProxyAccountSummary) *ProxyAccountSummary {
if a == nil {
return nil
diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go
index beb03e679..41676b831 100644
--- a/backend/internal/handler/dto/settings.go
+++ b/backend/internal/handler/dto/settings.go
@@ -55,9 +55,8 @@ type SystemSettings struct {
SoraClientEnabled bool `json:"sora_client_enabled"`
CustomMenuItems []CustomMenuItem `json:"custom_menu_items"`
- DefaultConcurrency int `json:"default_concurrency"`
- DefaultBalance float64 `json:"default_balance"`
- DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"`
+ DefaultConcurrency int `json:"default_concurrency"`
+ DefaultBalance float64 `json:"default_balance"`
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -75,13 +74,6 @@ type SystemSettings struct {
OpsRealtimeMonitoringEnabled bool `json:"ops_realtime_monitoring_enabled"`
OpsQueryModeDefault string `json:"ops_query_mode_default"`
OpsMetricsIntervalSeconds int `json:"ops_metrics_interval_seconds"`
-
- MinClaudeCodeVersion string `json:"min_claude_code_version"`
-}
-
-type DefaultSubscriptionSetting struct {
- GroupID int64 `json:"group_id"`
- ValidityDays int `json:"validity_days"`
}
type PublicSettings struct {
diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go
index 920615f70..732433975 100644
--- a/backend/internal/handler/dto/types.go
+++ b/backend/internal/handler/dto/types.go
@@ -47,17 +47,6 @@ type APIKey struct {
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
- // Rate limit fields
- RateLimit5h float64 `json:"rate_limit_5h"`
- RateLimit1d float64 `json:"rate_limit_1d"`
- RateLimit7d float64 `json:"rate_limit_7d"`
- Usage5h float64 `json:"usage_5h"`
- Usage1d float64 `json:"usage_1d"`
- Usage7d float64 `json:"usage_7d"`
- Window5hStart *time.Time `json:"window_5h_start"`
- Window1dStart *time.Time `json:"window_1d_start"`
- Window7dStart *time.Time `json:"window_7d_start"`
-
User *User `json:"user,omitempty"`
Group *Group `json:"group,omitempty"`
}
@@ -164,13 +153,6 @@ type Account struct {
MaxSessions *int `json:"max_sessions,omitempty"`
SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"`
- // RPM 限制(仅 Anthropic OAuth/SetupToken 账号有效)
- // 从 extra 字段提取,方便前端显示和编辑
- BaseRPM *int `json:"base_rpm,omitempty"`
- RPMStrategy *string `json:"rpm_strategy,omitempty"`
- RPMStickyBuffer *int `json:"rpm_sticky_buffer,omitempty"`
- UserMsgQueueMode *string `json:"user_msg_queue_mode,omitempty"`
-
// TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效)
// 从 extra 字段提取,方便前端显示和编辑
EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"`
@@ -233,32 +215,6 @@ type ProxyWithAccountCount struct {
QualityChecked *int64 `json:"quality_checked,omitempty"`
}
-// AdminProxy 是管理员接口使用的 proxy DTO(包含密码等敏感字段)。
-// 注意:普通接口不得使用此 DTO。
-type AdminProxy struct {
- Proxy
- Password string `json:"password,omitempty"`
-}
-
-// AdminProxyWithAccountCount 是管理员接口使用的带账号统计的 proxy DTO。
-type AdminProxyWithAccountCount struct {
- AdminProxy
- AccountCount int64 `json:"account_count"`
- LatencyMs *int64 `json:"latency_ms,omitempty"`
- LatencyStatus string `json:"latency_status,omitempty"`
- LatencyMessage string `json:"latency_message,omitempty"`
- IPAddress string `json:"ip_address,omitempty"`
- Country string `json:"country,omitempty"`
- CountryCode string `json:"country_code,omitempty"`
- Region string `json:"region,omitempty"`
- City string `json:"city,omitempty"`
- QualityStatus string `json:"quality_status,omitempty"`
- QualityScore *int `json:"quality_score,omitempty"`
- QualityGrade string `json:"quality_grade,omitempty"`
- QualitySummary string `json:"quality_summary,omitempty"`
- QualityChecked *int64 `json:"quality_checked,omitempty"`
-}
-
type ProxyAccountSummary struct {
ID int64 `json:"id"`
Name string `json:"name"`
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index c47e66df3..191d1d0c9 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -49,7 +49,6 @@ type GatewayHandler struct {
maxAccountSwitches int
maxAccountSwitchesGemini int
cfg *config.Config
- settingService *service.SettingService
}
// NewGatewayHandler creates a new GatewayHandler
@@ -66,7 +65,6 @@ func NewGatewayHandler(
errorPassthroughService *service.ErrorPassthroughService,
userMsgQueueService *service.UserMessageQueueService,
cfg *config.Config,
- settingService *service.SettingService,
) *GatewayHandler {
pingInterval := time.Duration(0)
maxAccountSwitches := 10
@@ -102,7 +100,6 @@ func NewGatewayHandler(
maxAccountSwitches: maxAccountSwitches,
maxAccountSwitchesGemini: maxAccountSwitchesGemini,
cfg: cfg,
- settingService: settingService,
}
}
@@ -168,11 +165,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
SetClaudeCodeClientContext(c, body, parsedReq)
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
- // 版本检查:仅对 Claude Code 客户端,拒绝低于最低版本的请求
- if !h.checkClaudeCodeVersion(c) {
- return
- }
-
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
@@ -421,15 +413,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
- // RPM 计数递增(Forward 成功后)
- // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
- // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
- if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
- if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil {
- reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
- }
- }
-
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
@@ -682,7 +665,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
h.handleStreamingAwareError(c, status, code, message, streamStarted)
return
}
- // 兜底重试按"直接请求兜底分组"处理:清除强制平台,允许按分组平台调度
+ // 兜底重试按“直接请求兜底分组”处理:清除强制平台,允许按分组平台调度
ctx := context.WithValue(c.Request.Context(), ctxkey.ForcePlatform, "")
c.Request = c.Request.WithContext(ctx)
currentAPIKey = fallbackAPIKey
@@ -716,15 +699,6 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
return
}
- // RPM 计数递增(Forward 成功后)
- // 注意:TOCTOU 竞态是已知且可接受的设计权衡,与 WindowCost 一致的 soft-limit 模式。
- // 在高并发下可能短暂超出 RPM 限制,但不会导致请求失败。
- if account.IsAnthropicOAuthOrSetupToken() && account.GetBaseRPM() > 0 {
- if err := h.gatewayService.IncrementAccountRPM(c.Request.Context(), account.ID); err != nil {
- reqLog.Warn("gateway.rpm_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
- }
- }
-
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
@@ -1081,41 +1055,6 @@ func (h *GatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarte
return true
}
-// checkClaudeCodeVersion 检查 Claude Code 客户端版本是否满足最低要求
-// 仅对已识别的 Claude Code 客户端执行,count_tokens 路径除外
-func (h *GatewayHandler) checkClaudeCodeVersion(c *gin.Context) bool {
- ctx := c.Request.Context()
- if !service.IsClaudeCodeClient(ctx) {
- return true
- }
-
- // 排除 count_tokens 子路径
- if strings.HasSuffix(c.Request.URL.Path, "/count_tokens") {
- return true
- }
-
- minVersion := h.settingService.GetMinClaudeCodeVersion(ctx)
- if minVersion == "" {
- return true // 未设置,不检查
- }
-
- clientVersion := service.GetClaudeCodeVersion(ctx)
- if clientVersion == "" {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error",
- "Unable to determine Claude Code version. Please update Claude Code: npm update -g @anthropic-ai/claude-code")
- return false
- }
-
- if service.CompareVersions(clientVersion, minVersion) < 0 {
- h.errorResponse(c, http.StatusBadRequest, "invalid_request_error",
- fmt.Sprintf("Your Claude Code version (%s) is below the minimum required version (%s). Please update: npm update -g @anthropic-ai/claude-code",
- clientVersion, minVersion))
- return false
- }
-
- return true
-}
-
// errorResponse 返回Claude API格式的错误响应
func (h *GatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
c.JSON(status, gin.H{
@@ -1493,25 +1432,12 @@ func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger
}
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
- if task == nil {
- return
- }
- if h.usageRecordWorkerPool != nil {
- h.usageRecordWorkerPool.Submit(task)
- return
- }
- // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- defer func() {
- if recovered := recover(); recovered != nil {
- logger.L().With(
- zap.String("component", "handler.gateway.messages"),
- zap.Any("panic", recovered),
- ).Error("gateway.usage_record_task_panic_recovered")
- }
- }()
- task(ctx)
+ submitUsageRecordTaskWithFallback(
+ "handler.gateway.messages",
+ h.usageRecordWorkerPool,
+ h.cfg,
+ task,
+ )
}
// getUserMsgQueueMode 获取当前请求的 UMQ 模式
diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
index c07c568d3..74f0861a9 100644
--- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
+++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go
@@ -153,7 +153,6 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi
nil, // deferredService
nil, // claudeTokenProvider
nil, // sessionLimitCache
- nil, // rpmCache
nil, // digestStore
)
diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go
index 09e6c09ba..ea8a5f1a9 100644
--- a/backend/internal/handler/gateway_helper.go
+++ b/backend/internal/handler/gateway_helper.go
@@ -29,10 +29,8 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.
if parsedReq != nil {
c.Set(claudeCodeParsedRequestContextKey, parsedReq)
}
-
- ua := c.GetHeader("User-Agent")
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
- if !claudeCodeValidator.ValidateUserAgent(ua) {
+ if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) {
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
c.Request = c.Request.WithContext(ctx)
return
@@ -56,14 +54,6 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.
// 更新 request context
ctx := service.SetClaudeCodeClient(c.Request.Context(), isClaudeCode)
-
- // 仅在确认为 Claude Code 客户端时提取版本号写入 context
- if isClaudeCode {
- if version := claudeCodeValidator.ExtractVersion(ua); version != "" {
- ctx = service.SetClaudeCodeVersion(ctx, version)
- }
- }
-
c.Request = c.Request.WithContext(ctx)
}
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 4bbd17bae..770d8ca5b 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -9,6 +9,7 @@ import (
"runtime/debug"
"strconv"
"strings"
+ "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -32,7 +33,31 @@ type OpenAIGatewayHandler struct {
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
+ cfg *config.Config
maxAccountSwitches int
+
+ // Test hooks for websocket ingress path. Production keeps them nil.
+ wsSelectAccountWithSchedulerFn func(
+ ctx context.Context,
+ groupID *int64,
+ previousResponseID string,
+ sessionHash string,
+ requestedModel string,
+ excludedIDs map[int64]struct{},
+ requiredTransport service.OpenAIUpstreamTransport,
+ ) (*service.AccountSelectionResult, service.OpenAIAccountScheduleDecision, error)
+ wsBindStickySessionFn func(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error
+ wsGetAccessTokenFn func(ctx context.Context, account *service.Account) (string, string, error)
+ wsProxyResponsesWSFn func(
+ ctx context.Context,
+ c *gin.Context,
+ clientConn *coderws.Conn,
+ account *service.Account,
+ token string,
+ firstClientMessageType coderws.MessageType,
+ firstClientMessage []byte,
+ hooks *service.OpenAIWSIngressHooks,
+ ) error
}
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
@@ -60,6 +85,7 @@ func NewOpenAIGatewayHandler(
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
+ cfg: cfg,
maxAccountSwitches: maxAccountSwitches,
}
}
@@ -113,30 +139,29 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
- setOpsRequestContext(c, "", false, body)
-
- // 校验请求体 JSON 合法性
+ // 校验请求体 JSON 合法性,避免畸形 JSON 被 gjson 部分解析后继续下游处理。
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
- // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
- modelResult := gjson.GetBytes(body, "model")
+ // 使用 GetManyBytes 一次扫描提取所有字段,避免多次遍历大请求体
+ results := gjson.GetManyBytes(body, "model", "stream", "previous_response_id")
+ modelResult := results[0]
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
- streamResult := gjson.GetBytes(body, "stream")
+ streamResult := results[1]
if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type")
return
}
reqStream := streamResult.Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
- previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
+ previousResponseID := strings.TrimSpace(results[2].String())
if previousResponseID != "" {
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
reqLog = reqLog.With(
@@ -155,6 +180,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
+ // 缓存已提取的 meta 到 context,供 Service 层复用,避免重复解析请求体。
+ // prompt_cache_key 也在此处提前提取,确保 meta 设置后只读不写,避免并发竞态。
+ promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
+ c.Set(service.OpenAIRequestMetaKey, &service.OpenAIRequestMeta{
+ Model: reqModel,
+ Stream: reqStream,
+ PromptCacheKey: promptCacheKey,
+ })
+
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
return
@@ -269,7 +303,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0)
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
@@ -286,7 +320,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
)
continue
}
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
@@ -301,26 +335,32 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
if result != nil {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
+ var ttftMsVal float64
+ if result.FirstTokenMs != nil && *result.FirstTokenMs > 0 {
+ ttftMsVal = float64(*result.FirstTokenMs)
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs, reqModel, ttftMsVal)
} else {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil, reqModel, 0)
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
+ requestID := strings.TrimSpace(c.GetHeader("X-Request-ID"))
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
- Result: result,
- APIKey: apiKey,
- User: apiKey.User,
- Account: account,
- Subscription: subscription,
- UserAgent: userAgent,
- IPAddress: clientIP,
- APIKeyService: h.apiKeyService,
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ FallbackRequestID: requestID,
+ APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
@@ -550,9 +590,10 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
reqLog.Info("openai.websocket_ingress_started")
clientIP := ip.GetClientIP(c)
userAgent := strings.TrimSpace(c.GetHeader("User-Agent"))
+ requestID := strings.TrimSpace(c.GetHeader("X-Request-ID"))
wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
+ CompressionMode: coderws.CompressionNoContextTakeover,
})
if err != nil {
reqLog.Warn("openai.websocket_accept_failed",
@@ -603,16 +644,21 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
}
previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String())
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
- if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
- closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id")
- return
- }
+ service.SetOpenAIWSFirstMessageMeta(c, reqModel, previousResponseID, previousResponseIDKind)
reqLog = reqLog.With(
zap.Bool("ws_ingress", true),
zap.String("model", reqModel),
zap.Bool("has_previous_response_id", previousResponseID != ""),
zap.String("previous_response_id_kind", previousResponseIDKind),
)
+ if h.shouldRejectWSMessageIDPreviousResponseIDEarly(previousResponseIDKind) {
+ reqLog.Warn("openai.websocket_request_validation_failed",
+ zap.String("reason", "previous_response_id_looks_like_message_id"),
+ zap.String("openai_ws_ingress_mode_default", h.cfg.Gateway.OpenAIWS.IngressModeDefault),
+ )
+ closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id")
+ return
+ }
setOpsRequestContext(c, reqModel, true, firstMessage)
var currentUserRelease func()
@@ -654,7 +700,7 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
firstMessage,
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
)
- selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
+ selection, scheduleDecision, err := h.wsSelectAccountWithScheduler(
ctx,
apiKey.GroupID,
previousResponseID,
@@ -674,6 +720,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
}
account := selection.Account
+ wsIngressMode, wsModeRouterV2Enabled := h.resolveOpenAIWSIngressMode(account)
+ reqLog = reqLog.With(
+ zap.Bool("openai_ws_mode_router_v2_enabled", wsModeRouterV2Enabled),
+ zap.String("openai_ws_ingress_mode", wsIngressMode),
+ )
accountMaxConcurrency := account.Concurrency
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
@@ -701,11 +752,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
accountReleaseFunc = fastReleaseFunc
}
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
- if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
+ if err := h.wsBindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
}
- token, _, err := h.gatewayService.GetAccessToken(ctx, account)
+ token, _, err := h.wsGetAccessToken(ctx, account)
if err != nil {
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
@@ -717,13 +768,112 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
zap.String("account_name", account.Name),
zap.String("schedule_layer", scheduleDecision.Layer),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
+ zap.Bool("openai_ws_mode_router_v2_enabled", wsModeRouterV2Enabled),
+ zap.String("openai_ws_ingress_mode", wsIngressMode),
)
+ var turnScheduleReported atomic.Bool
+ afterTurn := func(releaseSlots bool) func(turn int, result *service.OpenAIForwardResult, turnErr error) {
+ return func(turn int, result *service.OpenAIForwardResult, turnErr error) {
+ if releaseSlots {
+ releaseTurnSlots()
+ }
+ if turnErr != nil {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0)
+ turnScheduleReported.Store(true)
+ if partialResult, ok := service.OpenAIWSIngressTurnPartialResult(turnErr); ok && partialResult != nil {
+ h.submitUsageRecordTask(func(taskCtx context.Context) {
+ if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
+ Result: partialResult,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ FallbackRequestID: requestID,
+ APIKeyService: h.apiKeyService,
+ }); err != nil {
+ reqLog.Error("openai.websocket_record_partial_usage_failed",
+ zap.Int64("account_id", account.ID),
+ zap.String("request_id", partialResult.RequestID),
+ zap.Error(err),
+ )
+ }
+ })
+ }
+ return
+ }
+ if result == nil {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0)
+ turnScheduleReported.Store(true)
+ return
+ }
+ var turnTTFTMs float64
+ if result.FirstTokenMs != nil && *result.FirstTokenMs > 0 {
+ turnTTFTMs = float64(*result.FirstTokenMs)
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs, reqModel, turnTTFTMs)
+ turnScheduleReported.Store(true)
+ h.submitUsageRecordTask(func(taskCtx context.Context) {
+ if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ FallbackRequestID: requestID,
+ APIKeyService: h.apiKeyService,
+ }); err != nil {
+ reqLog.Error("openai.websocket_record_usage_failed",
+ zap.Int64("account_id", account.ID),
+ zap.String("request_id", result.RequestID),
+ zap.Error(err),
+ )
+ }
+ })
+ }
+ }
+ handleProxyError := func(proxyErr error) {
+ if !turnScheduleReported.Load() {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0)
+ }
+ closeStatus, closeReason := summarizeWSCloseErrorForLog(proxyErr)
+ reqLog.Warn("openai.websocket_proxy_failed",
+ zap.Int64("account_id", account.ID),
+ zap.Error(proxyErr),
+ zap.String("close_status", closeStatus),
+ zap.String("close_reason", closeReason),
+ )
+ var closeErr *service.OpenAIWSClientCloseError
+ if errors.As(proxyErr, &closeErr) {
+ closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
+ return
+ }
+ closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
+ }
+ if wsIngressMode == service.OpenAIWSIngressModePassthrough {
+ passthroughHooks := &service.OpenAIWSIngressHooks{
+ AfterTurn: afterTurn(false),
+ }
+ if err := h.wsProxyResponses(ctx, c, wsConn, account, token, msgType, firstMessage, passthroughHooks); err != nil {
+ handleProxyError(err)
+ return
+ }
+ reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
+ return
+ }
+
hooks := &service.OpenAIWSIngressHooks{
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
}
+ if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
+ return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "billing check failed", err)
+ }
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
releaseTurnSlots()
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
@@ -751,48 +901,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
return nil
},
- AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
- releaseTurnSlots()
- if turnErr != nil || result == nil {
- return
- }
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
- h.submitUsageRecordTask(func(taskCtx context.Context) {
- if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
- Result: result,
- APIKey: apiKey,
- User: apiKey.User,
- Account: account,
- Subscription: subscription,
- UserAgent: userAgent,
- IPAddress: clientIP,
- APIKeyService: h.apiKeyService,
- }); err != nil {
- reqLog.Error("openai.websocket_record_usage_failed",
- zap.Int64("account_id", account.ID),
- zap.String("request_id", result.RequestID),
- zap.Error(err),
- )
- }
- })
- },
+ AfterTurn: afterTurn(true),
}
- if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
- closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
- reqLog.Warn("openai.websocket_proxy_failed",
- zap.Int64("account_id", account.ID),
- zap.Error(err),
- zap.String("close_status", closeStatus),
- zap.String("close_reason", closeReason),
- )
- var closeErr *service.OpenAIWSClientCloseError
- if errors.As(err, &closeErr) {
- closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
- return
- }
- closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
+ if err := h.wsProxyResponses(ctx, c, wsConn, account, token, msgType, firstMessage, hooks); err != nil {
+ handleProxyError(err)
return
}
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
@@ -882,25 +995,12 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) {
}
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
- if task == nil {
- return
- }
- if h.usageRecordWorkerPool != nil {
- h.usageRecordWorkerPool.Submit(task)
- return
- }
- // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- defer func() {
- if recovered := recover(); recovered != nil {
- logger.L().With(
- zap.String("component", "handler.openai_gateway.responses"),
- zap.Any("panic", recovered),
- ).Error("openai.usage_record_task_panic_recovered")
- }
- }()
- task(ctx)
+ submitUsageRecordTaskWithFallback(
+ "handler.openai_gateway.responses",
+ h.usageRecordWorkerPool,
+ h.cfg,
+ task,
+ )
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
@@ -1030,6 +1130,124 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64)
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
}
+func (h *OpenAIGatewayHandler) wsSelectAccountWithScheduler(
+ ctx context.Context,
+ groupID *int64,
+ previousResponseID string,
+ sessionHash string,
+ requestedModel string,
+ excludedIDs map[int64]struct{},
+ requiredTransport service.OpenAIUpstreamTransport,
+) (*service.AccountSelectionResult, service.OpenAIAccountScheduleDecision, error) {
+ if h != nil && h.wsSelectAccountWithSchedulerFn != nil {
+ return h.wsSelectAccountWithSchedulerFn(
+ ctx,
+ groupID,
+ previousResponseID,
+ sessionHash,
+ requestedModel,
+ excludedIDs,
+ requiredTransport,
+ )
+ }
+ return h.gatewayService.SelectAccountWithScheduler(
+ ctx,
+ groupID,
+ previousResponseID,
+ sessionHash,
+ requestedModel,
+ excludedIDs,
+ requiredTransport,
+ )
+}
+
+func (h *OpenAIGatewayHandler) wsBindStickySession(
+ ctx context.Context,
+ groupID *int64,
+ sessionHash string,
+ accountID int64,
+) error {
+ if h != nil && h.wsBindStickySessionFn != nil {
+ return h.wsBindStickySessionFn(ctx, groupID, sessionHash, accountID)
+ }
+ return h.gatewayService.BindStickySession(ctx, groupID, sessionHash, accountID)
+}
+
+func (h *OpenAIGatewayHandler) wsGetAccessToken(
+ ctx context.Context,
+ account *service.Account,
+) (string, string, error) {
+ if h != nil && h.wsGetAccessTokenFn != nil {
+ return h.wsGetAccessTokenFn(ctx, account)
+ }
+ return h.gatewayService.GetAccessToken(ctx, account)
+}
+
+func (h *OpenAIGatewayHandler) wsProxyResponses(
+ ctx context.Context,
+ c *gin.Context,
+ clientConn *coderws.Conn,
+ account *service.Account,
+ token string,
+ firstClientMessageType coderws.MessageType,
+ firstClientMessage []byte,
+ hooks *service.OpenAIWSIngressHooks,
+) error {
+ if h != nil && h.wsProxyResponsesWSFn != nil {
+ return h.wsProxyResponsesWSFn(
+ ctx,
+ c,
+ clientConn,
+ account,
+ token,
+ firstClientMessageType,
+ firstClientMessage,
+ hooks,
+ )
+ }
+ return h.gatewayService.ProxyResponsesWebSocketFromClient(
+ ctx,
+ c,
+ clientConn,
+ account,
+ token,
+ firstClientMessageType,
+ firstClientMessage,
+ hooks,
+ )
+}
+
+func (h *OpenAIGatewayHandler) resolveOpenAIWSIngressMode(account *service.Account) (mode string, modeRouterV2Enabled bool) {
+ if account == nil {
+ return "account_missing", false
+ }
+ if h == nil || h.cfg == nil {
+ return "config_missing", false
+ }
+ modeRouterV2Enabled = h.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
+ if !modeRouterV2Enabled {
+ return "legacy", false
+ }
+ resolvedMode := account.ResolveOpenAIResponsesWebSocketV2Mode(h.cfg.Gateway.OpenAIWS.IngressModeDefault)
+ if resolvedMode == "" {
+ resolvedMode = service.OpenAIWSIngressModeOff
+ }
+ return resolvedMode, true
+}
+
+func (h *OpenAIGatewayHandler) shouldRejectWSMessageIDPreviousResponseIDEarly(previousResponseIDKind string) bool {
+ if previousResponseIDKind != service.OpenAIPreviousResponseIDKindMessageID {
+ return false
+ }
+ if h == nil || h.cfg == nil {
+ return false
+ }
+ if !h.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
+ return false
+ }
+ return strings.TrimSpace(h.cfg.Gateway.OpenAIWS.IngressModeDefault) == service.OpenAIWSIngressModeCtxPool
+}
+
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
if r == nil {
return false
diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go
index a26b3a0c3..9ea62044e 100644
--- a/backend/internal/handler/openai_gateway_handler_test.go
+++ b/backend/internal/handler/openai_gateway_handler_test.go
@@ -7,9 +7,11 @@ import (
"net/http"
"net/http/httptest"
"strings"
+ "sync/atomic"
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/server/middleware"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -431,6 +433,34 @@ func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
}
+func TestOpenAIResponses_InvalidJSONBodyReturnsBadRequest(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,invalid}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 201,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Failed to parse request body")
+}
+
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
gin.SetMode(gin.TestMode)
@@ -461,10 +491,86 @@ func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T
require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c))
}
-func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
+func TestOpenAIResponsesWebSocket_DoesNotEarlyRejectMessageIDPreviousResponseID(t *testing.T) {
gin.SetMode(gin.TestMode)
- h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ cache := &concurrencyCacheMock{
+ acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
+ return false, nil
+ },
+ }
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache)
+ wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
+ `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`,
+ ))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, _, err = clientConn.Read(readCtx)
+ cancelRead()
+ require.Error(t, err)
+ var closeErr coderws.CloseError
+ require.ErrorAs(t, err, &closeErr)
+ require.Equal(t, coderws.StatusTryAgainLater, closeErr.Code)
+ require.Contains(t, strings.ToLower(closeErr.Reason), "too many concurrent requests")
+}
+
+func TestOpenAIResponsesWebSocket_CtxPoolRejectsMessageIDBeforeScheduling(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{RunMode: config.RunModeSimple}
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ cfg.Gateway.OpenAIWS.IngressModeDefault = service.OpenAIWSIngressModeCtxPool
+
+ cache := &concurrencyCacheMock{
+ acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
+ return true, nil
+ },
+ acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
+ return true, nil
+ },
+ }
+
+ h := &OpenAIGatewayHandler{
+ gatewayService: &service.OpenAIGatewayService{},
+ billingCacheService: &service.BillingCacheService{},
+ apiKeyService: &service.APIKeyService{},
+ concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
+ cfg: cfg,
+ }
+
+ var scheduleCalls atomic.Int32
+ var stickyBindCalls atomic.Int32
+ h.wsSelectAccountWithSchedulerFn = func(
+ ctx context.Context,
+ groupID *int64,
+ previousResponseID string,
+ sessionHash string,
+ requestedModel string,
+ excludedIDs map[int64]struct{},
+ requiredTransport service.OpenAIUpstreamTransport,
+ ) (*service.AccountSelectionResult, service.OpenAIAccountScheduleDecision, error) {
+ scheduleCalls.Add(1)
+ return nil, service.OpenAIAccountScheduleDecision{}, errors.New("should not be called")
+ }
+ h.wsBindStickySessionFn = func(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
+ stickyBindCalls.Add(1)
+ return nil
+ }
+
wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
defer wsServer.Close()
@@ -487,10 +593,244 @@ func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testin
_, _, err = clientConn.Read(readCtx)
cancelRead()
require.Error(t, err)
+
var closeErr coderws.CloseError
require.ErrorAs(t, err, &closeErr)
require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
- require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id")
+ require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id must be a response.id")
+ require.Equal(t, int32(0), scheduleCalls.Load())
+ require.Equal(t, int32(0), stickyBindCalls.Load())
+}
+
+func TestOpenAIResponsesWebSocket_RejectsEmptyModelInFirstPayload(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
+ `{"type":"response.create","stream":false,"input":[]}`,
+ ))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, _, err = clientConn.Read(readCtx)
+ cancelRead()
+ require.Error(t, err)
+ var closeErr coderws.CloseError
+ require.ErrorAs(t, err, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
+ require.Contains(t, strings.ToLower(closeErr.Reason), "model is required")
+}
+
+func TestOpenAIResponsesWebSocket_RejectsEmptyModelWithoutCallingProxy(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ var proxyCalls atomic.Int32
+ h.wsProxyResponsesWSFn = func(
+ ctx context.Context,
+ c *gin.Context,
+ clientConn *coderws.Conn,
+ account *service.Account,
+ token string,
+ firstClientMessageType coderws.MessageType,
+ firstClientMessage []byte,
+ hooks *service.OpenAIWSIngressHooks,
+ ) error {
+ proxyCalls.Add(1)
+ return nil
+ }
+
+ wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
+ `{"type":"response.create","stream":false,"input":[]}`,
+ ))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, _, err = clientConn.Read(readCtx)
+ cancelRead()
+ require.Error(t, err)
+ var closeErr coderws.CloseError
+ require.ErrorAs(t, err, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code)
+ require.Contains(t, strings.ToLower(closeErr.Reason), "model is required")
+ require.Equal(t, int32(0), proxyCalls.Load())
+}
+
+func TestOpenAIResponsesWebSocket_PassthroughAndCtxPoolShareSchedulerInputsAndStickyBind(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ type wsScheduleCapture struct {
+ groupID int64
+ previousResponseID string
+ sessionHash string
+ model string
+ transport service.OpenAIUpstreamTransport
+ stickyBindCount int
+ stickyGroupID int64
+ stickySessionHash string
+ stickyAccountID int64
+ }
+ runCase := func(t *testing.T, mode string) wsScheduleCapture {
+ t.Helper()
+
+ cfg := &config.Config{RunMode: config.RunModeSimple}
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ cfg.Gateway.OpenAIWS.IngressModeDefault = mode
+ billingSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg)
+ t.Cleanup(func() {
+ billingSvc.Stop()
+ })
+
+ cache := &concurrencyCacheMock{
+ acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
+ return true, nil
+ },
+ acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
+ return true, nil
+ },
+ }
+ h := &OpenAIGatewayHandler{
+ gatewayService: &service.OpenAIGatewayService{},
+ billingCacheService: billingSvc,
+ apiKeyService: &service.APIKeyService{},
+ concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second),
+ cfg: cfg,
+ }
+
+ account := &service.Account{
+ ID: 901,
+ Name: "ws-mode-test",
+ Platform: service.PlatformOpenAI,
+ Type: service.AccountTypeAPIKey,
+ Status: service.StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{"api_key": "sk-test"},
+ Extra: map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": mode,
+ },
+ }
+
+ var capture wsScheduleCapture
+ h.wsSelectAccountWithSchedulerFn = func(
+ ctx context.Context,
+ groupID *int64,
+ previousResponseID string,
+ sessionHash string,
+ requestedModel string,
+ excludedIDs map[int64]struct{},
+ requiredTransport service.OpenAIUpstreamTransport,
+ ) (*service.AccountSelectionResult, service.OpenAIAccountScheduleDecision, error) {
+ if groupID != nil {
+ capture.groupID = *groupID
+ }
+ capture.previousResponseID = previousResponseID
+ capture.sessionHash = sessionHash
+ capture.model = requestedModel
+ capture.transport = requiredTransport
+ return &service.AccountSelectionResult{
+ Account: account,
+ Acquired: true,
+ ReleaseFunc: func() {},
+ }, service.OpenAIAccountScheduleDecision{Layer: "unit"}, nil
+ }
+ h.wsBindStickySessionFn = func(ctx context.Context, groupID *int64, sessionHash string, accountID int64) error {
+ capture.stickyBindCount++
+ if groupID != nil {
+ capture.stickyGroupID = *groupID
+ }
+ capture.stickySessionHash = sessionHash
+ capture.stickyAccountID = accountID
+ return nil
+ }
+ h.wsGetAccessTokenFn = func(ctx context.Context, account *service.Account) (string, string, error) {
+ return "sk-test", "apikey", nil
+ }
+ h.wsProxyResponsesWSFn = func(
+ ctx context.Context,
+ c *gin.Context,
+ clientConn *coderws.Conn,
+ account *service.Account,
+ token string,
+ firstClientMessageType coderws.MessageType,
+ firstClientMessage []byte,
+ hooks *service.OpenAIWSIngressHooks,
+ ) error {
+ if hooks != nil && hooks.AfterTurn != nil {
+ hooks.AfterTurn(1, nil, nil)
+ }
+ return nil
+ }
+
+ wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1})
+ t.Cleanup(wsServer.Close)
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", &coderws.DialOptions{
+ HTTPHeader: http.Header{
+ "Session_ID": []string{"session-fixed-123"},
+ },
+ })
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(
+ `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_fixed"}`,
+ ))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, _, _ = clientConn.Read(readCtx)
+ cancelRead()
+ return capture
+ }
+
+ passthroughCapture := runCase(t, service.OpenAIWSIngressModePassthrough)
+ ctxPoolCapture := runCase(t, service.OpenAIWSIngressModeCtxPool)
+
+ require.Equal(t, passthroughCapture.groupID, ctxPoolCapture.groupID)
+ require.Equal(t, passthroughCapture.previousResponseID, ctxPoolCapture.previousResponseID)
+ require.Equal(t, passthroughCapture.sessionHash, ctxPoolCapture.sessionHash)
+ require.Equal(t, passthroughCapture.model, ctxPoolCapture.model)
+ require.Equal(t, service.OpenAIUpstreamTransportResponsesWebsocketV2, passthroughCapture.transport)
+ require.Equal(t, passthroughCapture.transport, ctxPoolCapture.transport)
+
+ require.Equal(t, 1, passthroughCapture.stickyBindCount)
+ require.Equal(t, 1, ctxPoolCapture.stickyBindCount)
+ require.Equal(t, passthroughCapture.stickyGroupID, ctxPoolCapture.stickyGroupID)
+ require.Equal(t, passthroughCapture.stickySessionHash, ctxPoolCapture.stickySessionHash)
+ require.Equal(t, passthroughCapture.stickyAccountID, ctxPoolCapture.stickyAccountID)
}
func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) {
diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go
index 2f53d655e..6fbf79527 100644
--- a/backend/internal/handler/ops_error_logger.go
+++ b/backend/internal/handler/ops_error_logger.go
@@ -662,10 +662,8 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
requestID = c.Writer.Header().Get("x-request-id")
}
- normalizedType := normalizeOpsErrorType(parsed.ErrorType, parsed.Code)
-
- phase := classifyOpsPhase(normalizedType, parsed.Message, parsed.Code)
- isBusinessLimited := classifyOpsIsBusinessLimited(normalizedType, phase, parsed.Code, status, parsed.Message)
+ phase := classifyOpsPhase(parsed.ErrorType, parsed.Message, parsed.Code)
+ isBusinessLimited := classifyOpsIsBusinessLimited(parsed.ErrorType, phase, parsed.Code, status, parsed.Message)
errorOwner := classifyOpsErrorOwner(phase, parsed.Message)
errorSource := classifyOpsErrorSource(phase, parsed.Message)
@@ -687,8 +685,8 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
UserAgent: c.GetHeader("User-Agent"),
ErrorPhase: phase,
- ErrorType: normalizedType,
- Severity: classifyOpsSeverity(normalizedType, status),
+ ErrorType: normalizeOpsErrorType(parsed.ErrorType, parsed.Code),
+ Severity: classifyOpsSeverity(parsed.ErrorType, status),
StatusCode: status,
IsBusinessLimited: isBusinessLimited,
IsCountTokens: isCountTokensRequest(c),
@@ -700,7 +698,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
ErrorSource: errorSource,
ErrorOwner: errorOwner,
- IsRetryable: classifyOpsIsRetryable(normalizedType, status),
+ IsRetryable: classifyOpsIsRetryable(parsed.ErrorType, status),
RetryCount: 0,
CreatedAt: time.Now(),
}
@@ -941,29 +939,8 @@ func guessPlatformFromPath(path string) string {
}
}
-// isKnownOpsErrorType returns true if t is a recognized error type used by the
-// ops classification pipeline. Upstream proxies sometimes return garbage values
-// (e.g. the Go-serialized literal "") which would pollute phase/severity
-// classification if accepted blindly.
-func isKnownOpsErrorType(t string) bool {
- switch t {
- case "invalid_request_error",
- "authentication_error",
- "rate_limit_error",
- "billing_error",
- "subscription_error",
- "upstream_error",
- "overloaded_error",
- "api_error",
- "not_found_error",
- "forbidden_error":
- return true
- }
- return false
-}
-
func normalizeOpsErrorType(errType string, code string) string {
- if errType != "" && isKnownOpsErrorType(errType) {
+ if errType != "" {
return errType
}
switch strings.TrimSpace(code) {
diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go
index 679dd4cef..731b36ab9 100644
--- a/backend/internal/handler/ops_error_logger_test.go
+++ b/backend/internal/handler/ops_error_logger_test.go
@@ -214,63 +214,3 @@ func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) {
})
require.Equal(t, http.StatusNoContent, rec.Code)
}
-
-func TestIsKnownOpsErrorType(t *testing.T) {
- known := []string{
- "invalid_request_error",
- "authentication_error",
- "rate_limit_error",
- "billing_error",
- "subscription_error",
- "upstream_error",
- "overloaded_error",
- "api_error",
- "not_found_error",
- "forbidden_error",
- }
- for _, k := range known {
- require.True(t, isKnownOpsErrorType(k), "expected known: %s", k)
- }
-
- unknown := []string{"", "null", "", "random_error", "some_new_type", "\u003e"}
- for _, u := range unknown {
- require.False(t, isKnownOpsErrorType(u), "expected unknown: %q", u)
- }
-}
-
-func TestNormalizeOpsErrorType(t *testing.T) {
- tests := []struct {
- name string
- errType string
- code string
- want string
- }{
- // Known types pass through.
- {"known invalid_request_error", "invalid_request_error", "", "invalid_request_error"},
- {"known rate_limit_error", "rate_limit_error", "", "rate_limit_error"},
- {"known upstream_error", "upstream_error", "", "upstream_error"},
-
- // Unknown/garbage types are rejected and fall through to code-based or default.
- {"nil literal from upstream", "", "", "api_error"},
- {"null string", "null", "", "api_error"},
- {"random string", "something_weird", "", "api_error"},
-
- // Unknown type but known code still maps correctly.
- {"nil with INSUFFICIENT_BALANCE code", "", "INSUFFICIENT_BALANCE", "billing_error"},
- {"nil with USAGE_LIMIT_EXCEEDED code", "", "USAGE_LIMIT_EXCEEDED", "subscription_error"},
-
- // Empty type falls through to code-based mapping.
- {"empty type with balance code", "", "INSUFFICIENT_BALANCE", "billing_error"},
- {"empty type with subscription code", "", "SUBSCRIPTION_NOT_FOUND", "subscription_error"},
- {"empty type no code", "", "", "api_error"},
-
- // Known type overrides conflicting code-based mapping.
- {"known type overrides conflicting code", "rate_limit_error", "INSUFFICIENT_BALANCE", "rate_limit_error"},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := normalizeOpsErrorType(tt.errType, tt.code)
- require.Equal(t, tt.want, got)
- })
- }
-}
diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go
index d933abd7d..890a523db 100644
--- a/backend/internal/handler/sora_client_handler_test.go
+++ b/backend/internal/handler/sora_client_handler_test.go
@@ -942,12 +942,12 @@ func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, e
func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
-func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
-func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
-func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
+func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
+func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
// ==================== NewSoraClientHandler ====================
@@ -2059,13 +2059,19 @@ func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.
return nil, nil, nil
}
func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
- return nil, nil
+ return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) {
return nil, nil
}
-func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) {
- return nil, nil
+func (r *stubAccountRepoForHandler) ListByPlatform(_ context.Context, platform string) ([]service.Account, error) {
+ filtered := make([]service.Account, 0, len(r.accounts))
+ for _, account := range r.accounts {
+ if account.Platform == platform {
+ filtered = append(filtered, account)
+ }
+ }
+ return filtered, nil
}
func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error {
@@ -2199,7 +2205,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil,
- nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
+ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
)
}
@@ -2233,7 +2239,6 @@ func TestProcessGeneration_SelectAccountError(t *testing.T) {
}
func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
@@ -2253,7 +2258,6 @@ func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
}
func TestProcessGeneration_ForwardError(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
@@ -2312,7 +2316,6 @@ func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) {
}
func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
@@ -2374,7 +2377,6 @@ func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) {
}
func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
@@ -2405,7 +2407,6 @@ func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
}
func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
@@ -2453,7 +2454,6 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
}
func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
// 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
@@ -2621,7 +2621,6 @@ func TestDeleteGeneration_DeleteError(t *testing.T) {
// ==================== fetchUpstreamModels 测试 ====================
func TestFetchUpstreamModels_NilGateway(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
h := &SoraClientHandler{}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
@@ -2629,7 +2628,6 @@ func TestFetchUpstreamModels_NilGateway(t *testing.T) {
}
func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{accounts: nil}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
@@ -2639,7 +2637,6 @@ func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
}
func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
@@ -2653,7 +2650,6 @@ func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
}
func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
@@ -2668,7 +2664,6 @@ func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
}
func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
// GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
// 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
accountRepo := &stubAccountRepoForHandler{
@@ -2684,7 +2679,6 @@ func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
}
func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
@@ -2704,7 +2698,6 @@ func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
}
func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not json"))
@@ -2725,7 +2718,6 @@ func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
}
func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[]}`))
@@ -2746,7 +2738,6 @@ func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
}
func TestFetchUpstreamModels_Success(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求头
require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
@@ -2770,7 +2761,6 @@ func TestFetchUpstreamModels_Success(t *testing.T) {
}
func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`))
@@ -2805,7 +2795,6 @@ func TestGetModelFamilies_CachesLocalConfig(t *testing.T) {
}
func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) {
- t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`))
diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go
index 48c1e451b..a0045aa53 100644
--- a/backend/internal/handler/sora_gateway_handler.go
+++ b/backend/internal/handler/sora_gateway_handler.go
@@ -41,6 +41,7 @@ type SoraGatewayHandler struct {
soraTLSEnabled bool
soraMediaSigningKey string
soraMediaRoot string
+ cfg *config.Config
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
@@ -83,6 +84,7 @@ func NewSoraGatewayHandler(
soraTLSEnabled: soraTLSEnabled,
soraMediaSigningKey: signKey,
soraMediaRoot: mediaRoot,
+ cfg: cfg,
}
}
@@ -451,25 +453,12 @@ func generateOpenAISessionHash(c *gin.Context, body []byte) string {
}
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
- if task == nil {
- return
- }
- if h.usageRecordWorkerPool != nil {
- h.usageRecordWorkerPool.Submit(task)
- return
- }
- // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- defer func() {
- if recovered := recover(); recovered != nil {
- logger.L().With(
- zap.String("component", "handler.sora_gateway.chat_completions"),
- zap.Any("panic", recovered),
- ).Error("sora.usage_record_task_panic_recovered")
- }
- }()
- task(ctx)
+ submitUsageRecordTaskWithFallback(
+ "handler.sora_gateway.chat_completions",
+ h.usageRecordWorkerPool,
+ h.cfg,
+ task,
+ )
}
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go
index b76ab67da..66877e6ca 100644
--- a/backend/internal/handler/sora_gateway_handler_test.go
+++ b/backend/internal/handler/sora_gateway_handler_test.go
@@ -326,9 +326,6 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil
}
-func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
- return nil, nil
-}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil
}
@@ -435,8 +432,7 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
deferredService,
nil,
testutil.StubSessionLimitCache{},
- nil, // rpmCache
- nil, // digestStore
+ nil,
)
soraClient := &stubSoraClient{imageURLs: []string{"https://example.com/a.png"}}
diff --git a/backend/internal/handler/usage_record_submit_helper.go b/backend/internal/handler/usage_record_submit_helper.go
new file mode 100644
index 000000000..5c1987b1e
--- /dev/null
+++ b/backend/internal/handler/usage_record_submit_helper.go
@@ -0,0 +1,60 @@
+package handler
+
+import (
+ "context"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "go.uber.org/zap"
+)
+
+func submitUsageRecordTaskWithFallback(
+ component string,
+ pool *service.UsageRecordWorkerPool,
+ cfg *config.Config,
+ task service.UsageRecordTask,
+) {
+ if task == nil {
+ return
+ }
+ if pool != nil {
+ mode := pool.Submit(task)
+ if mode != service.UsageRecordSubmitModeDropped {
+ return
+ }
+ // 队列溢出导致 submit 丢弃时,同步兜底执行,避免 usage 漏记费。
+ logger.L().With(
+ zap.String("component", component),
+ zap.String("submit_mode", mode.String()),
+ ).Warn("usage_record.task_submit_dropped_sync_fallback")
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), usageRecordSyncFallbackTimeout(cfg))
+ defer cancel()
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ logger.L().With(
+ zap.String("component", component),
+ zap.Any("panic", recovered),
+ ).Error("usage_record.task_panic_recovered")
+ }
+ }()
+ task(ctx)
+}
+
+func usageRecordSyncFallbackTimeout(cfg *config.Config) time.Duration {
+ timeout := 10 * time.Second
+ if cfg != nil && cfg.Gateway.UsageRecord.TaskTimeoutSeconds > 0 {
+ timeout = time.Duration(cfg.Gateway.UsageRecord.TaskTimeoutSeconds) * time.Second
+ }
+ // keep a sane bound on synchronous fallback to limit request-path blocking.
+ if timeout < time.Second {
+ return time.Second
+ }
+ if timeout > 10*time.Second {
+ return 10 * time.Second
+ }
+ return timeout
+}
diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go
index c7c48e14b..20e8e87c3 100644
--- a/backend/internal/handler/usage_record_submit_task_test.go
+++ b/backend/internal/handler/usage_record_submit_task_test.go
@@ -6,6 +6,7 @@ import (
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
@@ -54,6 +55,22 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.
require.True(t, called.Load())
}
+func TestGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) {
+ pool := newUsageRecordTestPool(t)
+ pool.Stop()
+ h := &GatewayHandler{usageRecordWorkerPool: pool}
+ var called atomic.Bool
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ if _, ok := ctx.Deadline(); !ok {
+ t.Fatal("expected deadline in dropped sync fallback context")
+ }
+ called.Store(true)
+ })
+
+ require.True(t, called.Load(), "dropped task should run via sync fallback")
+}
+
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &GatewayHandler{}
require.NotPanics(t, func() {
@@ -93,6 +110,40 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
}
}
+func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) {
+ pool := newUsageRecordTestPool(t)
+ pool.Stop()
+ h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
+ var called atomic.Bool
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ if _, ok := ctx.Deadline(); !ok {
+ t.Fatal("expected deadline in dropped sync fallback context")
+ }
+ called.Store(true)
+ })
+
+ require.True(t, called.Load(), "dropped task should run via sync fallback")
+}
+
+func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithConfigFallbackTimeout(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.UsageRecord.TaskTimeoutSeconds = 2
+ h := &OpenAIGatewayHandler{cfg: cfg}
+ var called atomic.Bool
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ deadline, ok := ctx.Deadline()
+ require.True(t, ok, "expected deadline in fallback context")
+ remaining := time.Until(deadline)
+ require.LessOrEqual(t, remaining, 2200*time.Millisecond)
+ require.GreaterOrEqual(t, remaining, 1200*time.Millisecond)
+ called.Store(true)
+ })
+
+ require.True(t, called.Load())
+}
+
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &OpenAIGatewayHandler{}
var called atomic.Bool
@@ -160,6 +211,22 @@ func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *test
require.True(t, called.Load())
}
+func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) {
+ pool := newUsageRecordTestPool(t)
+ pool.Stop()
+ h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
+ var called atomic.Bool
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ if _, ok := ctx.Deadline(); !ok {
+ t.Fatal("expected deadline in dropped sync fallback context")
+ }
+ called.Store(true)
+ })
+
+ require.True(t, called.Load(), "dropped task should run via sync fallback")
+}
+
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &SoraGatewayHandler{}
require.NotPanics(t, func() {
diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go
index d46bbc454..1998221a3 100644
--- a/backend/internal/pkg/antigravity/client.go
+++ b/backend/internal/pkg/antigravity/client.go
@@ -14,9 +14,6 @@ import (
"net/url"
"strings"
"time"
-
- "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
- "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
)
// NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点)
@@ -152,26 +149,22 @@ type Client struct {
httpClient *http.Client
}
-func NewClient(proxyURL string) (*Client, error) {
+func NewClient(proxyURL string) *Client {
client := &http.Client{
Timeout: 30 * time.Second,
}
- _, parsed, err := proxyurl.Parse(proxyURL)
- if err != nil {
- return nil, err
- }
- if parsed != nil {
- transport := &http.Transport{}
- if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
- return nil, fmt.Errorf("configure proxy: %w", err)
+ if strings.TrimSpace(proxyURL) != "" {
+ if proxyURLParsed, err := url.Parse(proxyURL); err == nil {
+ client.Transport = &http.Transport{
+ Proxy: http.ProxyURL(proxyURLParsed),
+ }
}
- client.Transport = transport
}
return &Client{
httpClient: client,
- }, nil
+ }
}
// isConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝)
diff --git a/backend/internal/pkg/antigravity/client_test.go b/backend/internal/pkg/antigravity/client_test.go
index 20b57833a..394b6128c 100644
--- a/backend/internal/pkg/antigravity/client_test.go
+++ b/backend/internal/pkg/antigravity/client_test.go
@@ -228,20 +228,8 @@ func TestGetTier_两者都为nil(t *testing.T) {
// NewClient
// ---------------------------------------------------------------------------
-func mustNewClient(t *testing.T, proxyURL string) *Client {
- t.Helper()
- client, err := NewClient(proxyURL)
- if err != nil {
- t.Fatalf("NewClient(%q) failed: %v", proxyURL, err)
- }
- return client
-}
-
func TestNewClient_无代理(t *testing.T) {
- client, err := NewClient("")
- if err != nil {
- t.Fatalf("NewClient 返回错误: %v", err)
- }
+ client := NewClient("")
if client == nil {
t.Fatal("NewClient 返回 nil")
}
@@ -258,10 +246,7 @@ func TestNewClient_无代理(t *testing.T) {
}
func TestNewClient_有代理(t *testing.T) {
- client, err := NewClient("http://proxy.example.com:8080")
- if err != nil {
- t.Fatalf("NewClient 返回错误: %v", err)
- }
+ client := NewClient("http://proxy.example.com:8080")
if client == nil {
t.Fatal("NewClient 返回 nil")
}
@@ -271,10 +256,7 @@ func TestNewClient_有代理(t *testing.T) {
}
func TestNewClient_空格代理(t *testing.T) {
- client, err := NewClient(" ")
- if err != nil {
- t.Fatalf("NewClient 返回错误: %v", err)
- }
+ client := NewClient(" ")
if client == nil {
t.Fatal("NewClient 返回 nil")
}
@@ -285,13 +267,15 @@ func TestNewClient_空格代理(t *testing.T) {
}
func TestNewClient_无效代理URL(t *testing.T) {
- // 无效 URL 应返回 error
- _, err := NewClient("://invalid")
- if err == nil {
- t.Fatal("无效代理 URL 应返回错误")
+ // 无效 URL 时 url.Parse 不一定返回错误(Go 的 url.Parse 很宽容),
+ // 但 ://invalid 会导致解析错误
+ client := NewClient("://invalid")
+ if client == nil {
+ t.Fatal("NewClient 返回 nil")
}
- if !strings.Contains(err.Error(), "invalid proxy URL") {
- t.Errorf("错误信息应包含 'invalid proxy URL': got %s", err.Error())
+ // 无效 URL 解析失败时,Transport 应保持 nil
+ if client.httpClient.Transport != nil {
+ t.Error("无效代理 URL 时 Transport 应为 nil")
}
}
@@ -515,7 +499,7 @@ func TestClient_ExchangeCode_无ClientSecret(t *testing.T) {
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
- client := mustNewClient(t, "")
+ client := NewClient("")
_, err := client.ExchangeCode(context.Background(), "code", "verifier")
if err == nil {
t.Fatal("缺少 client_secret 时应返回错误")
@@ -618,7 +602,7 @@ func TestClient_RefreshToken_无ClientSecret(t *testing.T) {
defaultClientSecret = ""
t.Cleanup(func() { defaultClientSecret = old })
- client := mustNewClient(t, "")
+ client := NewClient("")
_, err := client.RefreshToken(context.Background(), "refresh-tok")
if err == nil {
t.Fatal("缺少 client_secret 时应返回错误")
@@ -1258,7 +1242,7 @@ func TestClient_LoadCodeAssist_Success_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
resp, rawResp, err := client.LoadCodeAssist(context.Background(), "test-token")
if err != nil {
t.Fatalf("LoadCodeAssist 失败: %v", err)
@@ -1293,7 +1277,7 @@ func TestClient_LoadCodeAssist_HTTPError_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
_, _, err := client.LoadCodeAssist(context.Background(), "bad-token")
if err == nil {
t.Fatal("服务器返回 403 时应返回错误")
@@ -1316,7 +1300,7 @@ func TestClient_LoadCodeAssist_InvalidJSON_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
_, _, err := client.LoadCodeAssist(context.Background(), "token")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
@@ -1349,7 +1333,7 @@ func TestClient_LoadCodeAssist_URLFallback_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
resp, _, err := client.LoadCodeAssist(context.Background(), "token")
if err != nil {
t.Fatalf("LoadCodeAssist 应在 fallback 后成功: %v", err)
@@ -1377,7 +1361,7 @@ func TestClient_LoadCodeAssist_AllURLsFail_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
_, _, err := client.LoadCodeAssist(context.Background(), "token")
if err == nil {
t.Fatal("所有 URL 都失败时应返回错误")
@@ -1393,7 +1377,7 @@ func TestClient_LoadCodeAssist_ContextCanceled_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
ctx, cancel := context.WithCancel(context.Background())
cancel()
@@ -1457,7 +1441,7 @@ func TestClient_FetchAvailableModels_Success_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
resp, rawResp, err := client.FetchAvailableModels(context.Background(), "test-token", "project-abc")
if err != nil {
t.Fatalf("FetchAvailableModels 失败: %v", err)
@@ -1512,7 +1496,7 @@ func TestClient_FetchAvailableModels_HTTPError_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
_, _, err := client.FetchAvailableModels(context.Background(), "bad-token", "proj")
if err == nil {
t.Fatal("服务器返回 403 时应返回错误")
@@ -1532,7 +1516,7 @@ func TestClient_FetchAvailableModels_InvalidJSON_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
_, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err == nil {
t.Fatal("无效 JSON 响应应返回错误")
@@ -1562,7 +1546,7 @@ func TestClient_FetchAvailableModels_URLFallback_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 应在 fallback 后成功: %v", err)
@@ -1590,7 +1574,7 @@ func TestClient_FetchAvailableModels_AllURLsFail_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
_, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err == nil {
t.Fatal("所有 URL 都失败时应返回错误")
@@ -1606,7 +1590,7 @@ func TestClient_FetchAvailableModels_ContextCanceled_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
ctx, cancel := context.WithCancel(context.Background())
cancel()
@@ -1626,7 +1610,7 @@ func TestClient_FetchAvailableModels_EmptyModels_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
resp, rawResp, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 失败: %v", err)
@@ -1662,7 +1646,7 @@ func TestClient_LoadCodeAssist_408Fallback_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
resp, _, err := client.LoadCodeAssist(context.Background(), "token")
if err != nil {
t.Fatalf("LoadCodeAssist 应在 408 fallback 后成功: %v", err)
@@ -1688,7 +1672,7 @@ func TestClient_FetchAvailableModels_404Fallback_RealCall(t *testing.T) {
withMockBaseURLs(t, []string{server1.URL, server2.URL})
- client := mustNewClient(t, "")
+ client := NewClient("")
resp, _, err := client.FetchAvailableModels(context.Background(), "token", "proj")
if err != nil {
t.Fatalf("FetchAvailableModels 应在 404 fallback 后成功: %v", err)
diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go
index 25782c551..b13d66cb4 100644
--- a/backend/internal/pkg/ctxkey/ctxkey.go
+++ b/backend/internal/pkg/ctxkey/ctxkey.go
@@ -52,7 +52,4 @@ const (
// PrefetchedStickyGroupID 标识上游预取 sticky session 时所使用的分组 ID。
// Service 层仅在分组匹配时复用 PrefetchedStickyAccountID,避免分组切换重试误用旧 sticky。
PrefetchedStickyGroupID Key = "ctx_prefetched_sticky_group_id"
-
- // ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22")
- ClaudeCodeVersion Key = "ctx_claude_code_version"
)
diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go
index 32e4bc5b2..6ef3d7141 100644
--- a/backend/internal/pkg/httpclient/pool.go
+++ b/backend/internal/pkg/httpclient/pool.go
@@ -18,11 +18,11 @@ package httpclient
import (
"fmt"
"net/http"
+ "net/url"
"strings"
"sync"
"time"
- "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
)
@@ -41,6 +41,7 @@ type Options struct {
Timeout time.Duration // 请求总超时时间
ResponseHeaderTimeout time.Duration // 等待响应头超时时间
InsecureSkipVerify bool // 是否跳过 TLS 证书验证(已禁用,不允许设置为 true)
+ ProxyStrict bool // 严格代理模式:代理失败时返回错误而非回退
ValidateResolvedIP bool // 是否校验解析后的 IP(防止 DNS Rebinding)
AllowPrivateHosts bool // 允许私有地址解析(与 ValidateResolvedIP 一起使用)
@@ -119,13 +120,15 @@ func buildTransport(opts Options) (*http.Transport, error) {
return nil, fmt.Errorf("insecure_skip_verify is not allowed; install a trusted certificate instead")
}
- _, parsed, err := proxyurl.Parse(opts.ProxyURL)
+ proxyURL := strings.TrimSpace(opts.ProxyURL)
+ if proxyURL == "" {
+ return transport, nil
+ }
+
+ parsed, err := url.Parse(proxyURL)
if err != nil {
return nil, err
}
- if parsed == nil {
- return transport, nil
- }
if err := proxyutil.ConfigureTransportProxy(transport, parsed); err != nil {
return nil, err
@@ -135,11 +138,12 @@ func buildTransport(opts Options) (*http.Transport, error) {
}
func buildClientKey(opts Options) string {
- return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%d|%d|%d",
+ return fmt.Sprintf("%s|%s|%s|%t|%t|%t|%t|%d|%d|%d",
strings.TrimSpace(opts.ProxyURL),
opts.Timeout.String(),
opts.ResponseHeaderTimeout.String(),
opts.InsecureSkipVerify,
+ opts.ProxyStrict,
opts.ValidateResolvedIP,
opts.AllowPrivateHosts,
opts.MaxIdleConns,
diff --git a/backend/internal/pkg/proxyutil/dialer.go b/backend/internal/pkg/proxyutil/dialer.go
index e437cae34..91b224a28 100644
--- a/backend/internal/pkg/proxyutil/dialer.go
+++ b/backend/internal/pkg/proxyutil/dialer.go
@@ -2,11 +2,7 @@
//
// 支持的代理协议:
// - HTTP/HTTPS: 通过 Transport.Proxy 设置
-// - SOCKS5: 通过 Transport.DialContext 设置(客户端本地解析 DNS)
-// - SOCKS5H: 通过 Transport.DialContext 设置(代理端远程解析 DNS,推荐)
-//
-// 注意:proxyurl.Parse() 会自动将 socks5:// 升级为 socks5h://,
-// 确保 DNS 也由代理端解析,防止 DNS 泄漏。
+// - SOCKS5/SOCKS5H: 通过 Transport.DialContext 设置(服务端解析 DNS)
package proxyutil
import (
@@ -24,8 +20,7 @@ import (
//
// 支持的协议:
// - http/https: 设置 transport.Proxy
-// - socks5: 设置 transport.DialContext(客户端本地解析 DNS)
-// - socks5h: 设置 transport.DialContext(代理端远程解析 DNS,推荐)
+// - socks5/socks5h: 设置 transport.DialContext(由代理服务端解析 DNS)
//
// 参数:
// - transport: 需要配置的 http.Transport
diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go
index 0519c2cc1..f09bee8dd 100644
--- a/backend/internal/pkg/response/response.go
+++ b/backend/internal/pkg/response/response.go
@@ -2,6 +2,8 @@
package response
import (
+ "context"
+ "errors"
"log"
"math"
"net/http"
@@ -75,17 +77,45 @@ func ErrorFrom(c *gin.Context, err error) bool {
return false
}
- statusCode, status := infraerrors.ToHTTP(err)
+ normalizedErr := normalizeHTTPError(c, err)
+ statusCode, status := infraerrors.ToHTTP(normalizedErr)
// Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil {
- log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error()))
+ log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(normalizedErr.Error()))
}
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}
+func normalizeHTTPError(c *gin.Context, err error) error {
+ if err == nil {
+ return nil
+ }
+ if c == nil || c.Request == nil {
+ return err
+ }
+ if isClientCanceledError(c.Request.Context(), err) {
+ return infraerrors.ClientClosed("CLIENT_CLOSED", "client closed request").WithCause(err)
+ }
+ return err
+}
+
+func isClientCanceledError(reqCtx context.Context, err error) bool {
+ if reqCtx == nil {
+ return false
+ }
+ // 只有请求上下文本身被取消时,才认为是客户端断开;
+ // 避免将服务端主动 cancel 导致的 context.Canceled 误归为 499。
+ if errors.Is(err, context.Canceled) && errors.Is(reqCtx.Err(), context.Canceled) {
+ return true
+ }
+
+ // Some drivers can surface deadline errors after the request context was already canceled.
+ return errors.Is(err, context.DeadlineExceeded) && errors.Is(reqCtx.Err(), context.Canceled)
+}
+
// BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message)
diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go
index 0debce5fd..64918ca8b 100644
--- a/backend/internal/pkg/response/response_test.go
+++ b/backend/internal/pkg/response/response_test.go
@@ -3,8 +3,10 @@
package response
import (
+ "context"
"encoding/json"
"errors"
+ "fmt"
"net/http"
"net/http/httptest"
"testing"
@@ -107,11 +109,12 @@ func TestErrorFrom(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
- name string
- err error
- wantWritten bool
- wantHTTPCode int
- wantBody Response
+ name string
+ err error
+ cancelRequestContext bool
+ wantWritten bool
+ wantHTTPCode int
+ wantBody Response
}{
{
name: "nil_error",
@@ -184,12 +187,75 @@ func TestErrorFrom(t *testing.T) {
Message: errors2.UnknownMessage,
},
},
+ {
+ name: "context_canceled_without_request_cancel_remains_500",
+ err: context.Canceled,
+ wantWritten: true,
+ wantHTTPCode: http.StatusInternalServerError,
+ wantBody: Response{
+ Code: http.StatusInternalServerError,
+ Message: errors2.UnknownMessage,
+ },
+ },
+ {
+ name: "context_canceled_maps_to_499",
+ err: context.Canceled,
+ cancelRequestContext: true,
+ wantWritten: true,
+ wantHTTPCode: 499,
+ wantBody: Response{
+ Code: 499,
+ Message: "client closed request",
+ Reason: "CLIENT_CLOSED",
+ },
+ },
+ {
+ name: "wrapped_context_canceled_maps_to_499",
+ err: fmt.Errorf("query aborted: %w", context.Canceled),
+ cancelRequestContext: true,
+ wantWritten: true,
+ wantHTTPCode: 499,
+ wantBody: Response{
+ Code: 499,
+ Message: "client closed request",
+ Reason: "CLIENT_CLOSED",
+ },
+ },
+ {
+ name: "deadline_exceeded_without_request_cancel_remains_500",
+ err: context.DeadlineExceeded,
+ wantWritten: true,
+ wantHTTPCode: http.StatusInternalServerError,
+ wantBody: Response{
+ Code: http.StatusInternalServerError,
+ Message: errors2.UnknownMessage,
+ },
+ },
+ {
+ name: "deadline_exceeded_with_request_canceled_maps_to_499",
+ err: context.DeadlineExceeded,
+ cancelRequestContext: true,
+ wantWritten: true,
+ wantHTTPCode: 499,
+ wantBody: Response{
+ Code: 499,
+ Message: "client closed request",
+ Reason: "CLIENT_CLOSED",
+ },
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ if tt.cancelRequestContext {
+ ctx, cancel := context.WithCancel(req.Context())
+ cancel()
+ req = req.WithContext(ctx)
+ }
+ c.Request = req
written := ErrorFrom(c, tt.err)
require.Equal(t, tt.wantWritten, written)
diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go
index 314a6d3c9..5f4e13f54 100644
--- a/backend/internal/pkg/usagestats/usage_log_types.go
+++ b/backend/internal/pkg/usagestats/usage_log_types.go
@@ -78,16 +78,6 @@ type ModelStat struct {
ActualCost float64 `json:"actual_cost"` // 实际扣除
}
-// GroupStat represents usage statistics for a single group
-type GroupStat struct {
- GroupID int64 `json:"group_id"`
- GroupName string `json:"group_name"`
- Requests int64 `json:"requests"`
- TotalTokens int64 `json:"total_tokens"`
- Cost float64 `json:"cost"` // 标准计费
- ActualCost float64 `json:"actual_cost"` // 实际扣除
-}
-
// UserUsageTrendPoint represents user usage trend data point
type UserUsageTrendPoint struct {
Date string `json:"date"`
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index 0669cbbdd..3e922e7c3 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -1215,6 +1215,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Schedulable)
idx++
}
+ if updates.AutoPauseOnExpired != nil {
+ setClauses = append(setClauses, "auto_pause_on_expired = $"+itoa(idx))
+ args = append(args, *updates.AutoPauseOnExpired)
+ idx++
+ }
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if len(updates.Credentials) > 0 {
payload, err := json.Marshal(updates.Credentials)
diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go
index 4fbdae14f..447415ace 100644
--- a/backend/internal/repository/billing_cache.go
+++ b/backend/internal/repository/billing_cache.go
@@ -69,9 +69,20 @@ var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
+ return 2
+ end
+ local cur = tonumber(current)
+ local delta = tonumber(ARGV[1])
+ if cur == nil or delta == nil then
+ return -1
+ end
+ if delta < 0 then
+ return -2
+ end
+ if cur < delta then
return 0
end
- local newVal = tonumber(current) - tonumber(ARGV[1])
+ local newVal = cur - delta
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
@@ -130,12 +141,26 @@ func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := billingBalanceKey(userID)
- _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
+ result, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Int64()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
return err
}
- return nil
+ switch result {
+ case 1:
+ return nil
+ case 2:
+ // 缓存 key 不存在(已过期),返回特定错误让调用方区分处理
+ return service.ErrBalanceCacheNotFound
+ case 0:
+ return service.ErrInsufficientBalance
+ case -1:
+ return fmt.Errorf("invalid cached balance for user %d", userID)
+ case -2:
+ return fmt.Errorf("invalid deduct amount for user %d", userID)
+ default:
+ return fmt.Errorf("unexpected deduct balance cache result for user %d: %d", userID, result)
+ }
}
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
diff --git a/backend/internal/repository/billing_cache_integration_test.go b/backend/internal/repository/billing_cache_integration_test.go
index 4b7377b12..dffe0b6e4 100644
--- a/backend/internal/repository/billing_cache_integration_test.go
+++ b/backend/internal/repository/billing_cache_integration_test.go
@@ -31,14 +31,15 @@ func (s *BillingCacheSuite) TestUserBalance() {
},
},
{
- name: "deduct_on_nonexistent_is_noop",
+ name: "deduct_on_nonexistent_returns_not_found",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(1)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
- require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
+ err := cache.DeductUserBalance(ctx, userID, 1)
+ require.ErrorIs(s.T(), err, service.ErrBalanceCacheNotFound, "DeductUserBalance on non-existent key should return ErrBalanceCacheNotFound")
- _, err := rdb.Get(ctx, balanceKey).Result()
+ _, err = rdb.Get(ctx, balanceKey).Result()
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
},
},
@@ -278,8 +279,8 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
}
}
-// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
-// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
+// TestDeductUserBalance_ErrorPropagation 验证修复:
+// Redis 真实错误应传播,key 不存在应返回 ErrBalanceCacheNotFound。
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
tests := []struct {
name string
@@ -287,11 +288,11 @@ func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
expectErr bool
}{
{
- name: "key_not_exists_returns_nil",
+ name: "key_not_exists_returns_ErrBalanceCacheNotFound",
fn: func(ctx context.Context, cache service.BillingCache) {
- // key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
+ // key 不存在时,Lua 脚本返回 2,应返回 ErrBalanceCacheNotFound
err := cache.DeductUserBalance(ctx, 99999, 1.0)
- require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
+ require.ErrorIs(s.T(), err, service.ErrBalanceCacheNotFound, "DeductUserBalance on non-existent key should return ErrBalanceCacheNotFound")
},
},
{
diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go
index b754bd55e..77764881e 100644
--- a/backend/internal/repository/claude_oauth_service.go
+++ b/backend/internal/repository/claude_oauth_service.go
@@ -11,7 +11,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
- "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
@@ -29,14 +28,11 @@ func NewClaudeOAuthClient() service.ClaudeOAuthClient {
type claudeOAuthService struct {
baseURL string
tokenURL string
- clientFactory func(proxyURL string) (*req.Client, error)
+ clientFactory func(proxyURL string) *req.Client
}
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
- client, err := s.clientFactory(proxyURL)
- if err != nil {
- return "", fmt.Errorf("create HTTP client: %w", err)
- }
+ client := s.clientFactory(proxyURL)
var orgs []struct {
UUID string `json:"uuid"`
@@ -92,10 +88,7 @@ func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey
}
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
- client, err := s.clientFactory(proxyURL)
- if err != nil {
- return "", fmt.Errorf("create HTTP client: %w", err)
- }
+ client := s.clientFactory(proxyURL)
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
@@ -172,10 +165,7 @@ func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKe
}
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
- client, err := s.clientFactory(proxyURL)
- if err != nil {
- return nil, fmt.Errorf("create HTTP client: %w", err)
- }
+ client := s.clientFactory(proxyURL)
// Parse code which may contain state in format "authCode#state"
authCode := code
@@ -233,10 +223,7 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod
}
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
- client, err := s.clientFactory(proxyURL)
- if err != nil {
- return nil, fmt.Errorf("create HTTP client: %w", err)
- }
+ client := s.clientFactory(proxyURL)
reqBody := map[string]any{
"grant_type": "refresh_token",
@@ -266,20 +253,16 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro
return &tokenResp, nil
}
-func createReqClient(proxyURL string) (*req.Client, error) {
+func createReqClient(proxyURL string) *req.Client {
// 禁用 CookieJar,确保每次授权都是干净的会话
client := req.C().
SetTimeout(60 * time.Second).
ImpersonateChrome().
SetCookieJar(nil) // 禁用 CookieJar
- trimmed, _, err := proxyurl.Parse(proxyURL)
- if err != nil {
- return nil, err
- }
- if trimmed != "" {
- client.SetProxyURL(trimmed)
+ if strings.TrimSpace(proxyURL) != "" {
+ client.SetProxyURL(strings.TrimSpace(proxyURL))
}
- return client, nil
+ return client
}
diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go
index c63830338..7395c6d82 100644
--- a/backend/internal/repository/claude_oauth_service_test.go
+++ b/backend/internal/repository/claude_oauth_service_test.go
@@ -91,7 +91,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = "http://in-process"
- s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
+ s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
@@ -169,7 +169,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.baseURL = "http://in-process"
- s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
+ s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
@@ -276,7 +276,7 @@ func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = "http://in-process/token"
- s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
+ s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
@@ -372,7 +372,7 @@ func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
require.True(s.T(), ok, "type assertion failed")
s.client = client
s.client.tokenURL = "http://in-process/token"
- s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
+ s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) }
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go
index f6054828c..1198f4725 100644
--- a/backend/internal/repository/claude_usage_service.go
+++ b/backend/internal/repository/claude_usage_service.go
@@ -83,7 +83,7 @@ func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *se
AllowPrivateHosts: s.allowPrivateHosts,
})
if err != nil {
- return nil, fmt.Errorf("create http client failed: %w", err)
+ client = &http.Client{Timeout: 30 * time.Second}
}
resp, err = client.Do(req)
diff --git a/backend/internal/repository/claude_usage_service_test.go b/backend/internal/repository/claude_usage_service_test.go
index cbd0b6d3e..2e10f3e5b 100644
--- a/backend/internal/repository/claude_usage_service_test.go
+++ b/backend/internal/repository/claude_usage_service_test.go
@@ -50,7 +50,7 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
allowPrivateHosts: true,
}
- resp, err := s.fetcher.FetchUsage(context.Background(), "at", "")
+ resp, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
require.NoError(s.T(), err, "FetchUsage")
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
@@ -112,17 +112,6 @@ func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
require.Error(s.T(), err, "expected error for cancelled context")
}
-func (s *ClaudeUsageServiceSuite) TestFetchUsage_InvalidProxyReturnsError() {
- s.fetcher = &claudeUsageService{
- usageURL: "http://example.com",
- allowPrivateHosts: true,
- }
-
- _, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
- require.Error(s.T(), err)
- require.ErrorContains(s.T(), err, "create http client failed")
-}
-
func TestClaudeUsageServiceSuite(t *testing.T) {
suite.Run(t, new(ClaudeUsageServiceSuite))
}
diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go
index 58291b665..0f193c7dc 100644
--- a/backend/internal/repository/gateway_cache.go
+++ b/backend/internal/repository/gateway_cache.go
@@ -2,14 +2,20 @@ package repository
import (
"context"
+ "encoding/json"
"fmt"
+ "strconv"
+ "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/cespare/xxhash/v2"
"github.com/redis/go-redis/v9"
)
const stickySessionPrefix = "sticky_session:"
+const openAIWSSessionLastResponsePrefix = "openai_ws_session_last_response:"
+const openAIWSResponsePendingToolCallsPrefix = "openai_ws_response_pending_tool_calls:"
type gatewayCache struct {
rdb *redis.Client
@@ -25,6 +31,20 @@ func buildSessionKey(groupID int64, sessionHash string) string {
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
}
+func buildOpenAIWSSessionLastResponseKey(groupID int64, sessionHash string) string {
+ return fmt.Sprintf("%s%d:%s", openAIWSSessionLastResponsePrefix, groupID, sessionHash)
+}
+
+func buildOpenAIWSResponsePendingToolCallsKey(groupID int64, responseID string) string {
+ id := strings.TrimSpace(responseID)
+ if id == "" {
+ return ""
+ }
+ return openAIWSResponsePendingToolCallsPrefix +
+ strconv.FormatInt(groupID, 10) + ":" +
+ strconv.FormatUint(xxhash.Sum64String(id), 16)
+}
+
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Get(ctx, key).Int64()
@@ -51,3 +71,78 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}
+
+func (c *gatewayCache) SetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash, responseID string, ttl time.Duration) error {
+ key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash)
+ return c.rdb.Set(ctx, key, responseID, ttl).Err()
+}
+
+func (c *gatewayCache) GetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) (string, error) {
+ key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash)
+ return c.rdb.Get(ctx, key).Result()
+}
+
+func (c *gatewayCache) DeleteOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) error {
+ key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash)
+ return c.rdb.Del(ctx, key).Err()
+}
+
+func (c *gatewayCache) SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error {
+ key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID)
+ if key == "" {
+ return nil
+ }
+ normalizedCallIDs := normalizeOpenAIWSResponsePendingToolCallIDs(callIDs)
+ if len(normalizedCallIDs) == 0 {
+ return c.rdb.Del(ctx, key).Err()
+ }
+ raw, err := json.Marshal(normalizedCallIDs)
+ if err != nil {
+ return err
+ }
+ return c.rdb.Set(ctx, key, raw, ttl).Err()
+}
+
+func (c *gatewayCache) GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error) {
+ key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID)
+ if key == "" {
+ return nil, nil
+ }
+ raw, err := c.rdb.Get(ctx, key).Bytes()
+ if err != nil {
+ return nil, err
+ }
+ var callIDs []string
+ if err := json.Unmarshal(raw, &callIDs); err != nil {
+ return nil, err
+ }
+ return normalizeOpenAIWSResponsePendingToolCallIDs(callIDs), nil
+}
+
+func (c *gatewayCache) DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error {
+ key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID)
+ if key == "" {
+ return nil
+ }
+ return c.rdb.Del(ctx, key).Err()
+}
+
+func normalizeOpenAIWSResponsePendingToolCallIDs(callIDs []string) []string {
+ if len(callIDs) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(callIDs))
+ normalized := make([]string, 0, len(callIDs))
+ for _, callID := range callIDs {
+ id := strings.TrimSpace(callID)
+ if id == "" {
+ continue
+ }
+ if _, ok := seen[id]; ok {
+ continue
+ }
+ seen[id] = struct{}{}
+ normalized = append(normalized, id)
+ }
+ return normalized
+}
diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go
index 0eebc33f6..093b45ca4 100644
--- a/backend/internal/repository/gateway_cache_integration_test.go
+++ b/backend/internal/repository/gateway_cache_integration_test.go
@@ -3,6 +3,7 @@
package repository
import (
+ "context"
"errors"
"testing"
"time"
@@ -104,6 +105,49 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
+func (s *GatewayCacheSuite) TestSetAndGetOpenAIWSResponsePendingToolCalls() {
+ type responsePendingToolCallsCache interface {
+ SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error
+ GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error)
+ }
+ cache, ok := s.cache.(responsePendingToolCallsCache)
+ require.True(s.T(), ok, "gateway cache should implement pending tool calls cache")
+
+ responseID := "resp_pending_integration_1"
+ groupID := int64(1)
+ ttl := 2 * time.Minute
+ require.NoError(s.T(), cache.SetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID, []string{"call_1", "call_2", "call_1", " "}, ttl))
+
+ callIDs, err := cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID)
+ require.NoError(s.T(), err)
+ require.ElementsMatch(s.T(), []string{"call_1", "call_2"}, callIDs)
+ _, err = cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID+1, responseID)
+ require.True(s.T(), errors.Is(err, redis.Nil), "pending tool calls should be isolated by group")
+
+ key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID)
+ remainingTTL, ttlErr := s.rdb.TTL(s.ctx, key).Result()
+ require.NoError(s.T(), ttlErr)
+ s.AssertTTLWithin(remainingTTL, 1*time.Second, ttl)
+}
+
+func (s *GatewayCacheSuite) TestDeleteOpenAIWSResponsePendingToolCalls() {
+ type responsePendingToolCallsCache interface {
+ SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error
+ GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error)
+ DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error
+ }
+ cache, ok := s.cache.(responsePendingToolCallsCache)
+ require.True(s.T(), ok, "gateway cache should implement pending tool calls cache")
+
+ responseID := "resp_pending_integration_2"
+ groupID := int64(1)
+ require.NoError(s.T(), cache.SetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID, []string{"call_3"}, time.Minute))
+ require.NoError(s.T(), cache.DeleteOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID))
+
+ _, err := cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID)
+ require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
+}
+
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))
}
diff --git a/backend/internal/repository/gemini_oauth_client.go b/backend/internal/repository/gemini_oauth_client.go
index eb14f3134..8b7fe625c 100644
--- a/backend/internal/repository/gemini_oauth_client.go
+++ b/backend/internal/repository/gemini_oauth_client.go
@@ -26,10 +26,7 @@ func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient {
}
func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) {
- client, err := createGeminiReqClient(proxyURL)
- if err != nil {
- return nil, fmt.Errorf("create HTTP client: %w", err)
- }
+ client := createGeminiReqClient(proxyURL)
// Use different OAuth clients based on oauthType:
// - code_assist: always use built-in Gemini CLI OAuth client (public)
@@ -75,10 +72,7 @@ func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, c
}
func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) {
- client, err := createGeminiReqClient(proxyURL)
- if err != nil {
- return nil, fmt.Errorf("create HTTP client: %w", err)
- }
+ client := createGeminiReqClient(proxyURL)
oauthCfgInput := geminicli.OAuthConfig{
ClientID: c.cfg.Gemini.OAuth.ClientID,
@@ -117,7 +111,7 @@ func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refresh
return &tokenResp, nil
}
-func createGeminiReqClient(proxyURL string) (*req.Client, error) {
+func createGeminiReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 60 * time.Second,
diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go
index b5bc64972..4f63280d5 100644
--- a/backend/internal/repository/geminicli_codeassist_client.go
+++ b/backend/internal/repository/geminicli_codeassist_client.go
@@ -26,11 +26,7 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo
}
var out geminicli.LoadCodeAssistResponse
- client, err := createGeminiCliReqClient(proxyURL)
- if err != nil {
- return nil, fmt.Errorf("create HTTP client: %w", err)
- }
- resp, err := client.R().
+ resp, err := createGeminiCliReqClient(proxyURL).R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
@@ -70,11 +66,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody)
var out geminicli.OnboardUserResponse
- client, err := createGeminiCliReqClient(proxyURL)
- if err != nil {
- return nil, fmt.Errorf("create HTTP client: %w", err)
- }
- resp, err := client.R().
+ resp, err := createGeminiCliReqClient(proxyURL).R().
SetContext(ctx).
SetHeader("Authorization", "Bearer "+accessToken).
SetHeader("Content-Type", "application/json").
@@ -106,7 +98,7 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken
return &out, nil
}
-func createGeminiCliReqClient(proxyURL string) (*req.Client, error) {
+func createGeminiCliReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 30 * time.Second,
diff --git a/backend/internal/repository/github_release_service.go b/backend/internal/repository/github_release_service.go
index ad1f22e39..28efe914a 100644
--- a/backend/internal/repository/github_release_service.go
+++ b/backend/internal/repository/github_release_service.go
@@ -5,10 +5,8 @@ import (
"encoding/json"
"fmt"
"io"
- "log/slog"
"net/http"
"os"
- "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
@@ -26,19 +24,13 @@ type githubReleaseClientError struct {
// NewGitHubReleaseClient 创建 GitHub Release 客户端
// proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议
-// 代理配置失败时行为由 allowDirectOnProxyError 控制:
-// - false(默认):返回错误占位客户端,禁止回退到直连
-// - true:回退到直连(仅限管理员显式开启)
func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient {
- // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据,
- // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
- if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
- slog.Warn("proxy client init failed, all requests will fail", "service", "github_release", "error", err)
+ if proxyURL != "" && !allowDirectOnProxyError {
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
sharedClient = &http.Client{Timeout: 30 * time.Second}
@@ -50,8 +42,7 @@ func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) servi
ProxyURL: proxyURL,
})
if err != nil {
- if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
- slog.Warn("proxy download client init failed, all requests will fail", "service", "github_release", "error", err)
+ if proxyURL != "" && !allowDirectOnProxyError {
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
}
downloadClient = &http.Client{Timeout: 10 * time.Minute}
diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go
index 4edc85340..e9b4902ac 100644
--- a/backend/internal/repository/group_repo.go
+++ b/backend/internal/repository/group_repo.go
@@ -127,38 +127,6 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
- // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
- if groupIn.DailyLimitUSD != nil {
- builder = builder.SetDailyLimitUsd(*groupIn.DailyLimitUSD)
- } else {
- builder = builder.ClearDailyLimitUsd()
- }
- if groupIn.WeeklyLimitUSD != nil {
- builder = builder.SetWeeklyLimitUsd(*groupIn.WeeklyLimitUSD)
- } else {
- builder = builder.ClearWeeklyLimitUsd()
- }
- if groupIn.MonthlyLimitUSD != nil {
- builder = builder.SetMonthlyLimitUsd(*groupIn.MonthlyLimitUSD)
- } else {
- builder = builder.ClearMonthlyLimitUsd()
- }
- if groupIn.ImagePrice1K != nil {
- builder = builder.SetImagePrice1k(*groupIn.ImagePrice1K)
- } else {
- builder = builder.ClearImagePrice1k()
- }
- if groupIn.ImagePrice2K != nil {
- builder = builder.SetImagePrice2k(*groupIn.ImagePrice2K)
- } else {
- builder = builder.ClearImagePrice2k()
- }
- if groupIn.ImagePrice4K != nil {
- builder = builder.SetImagePrice4k(*groupIn.ImagePrice4K)
- } else {
- builder = builder.ClearImagePrice4k()
- }
-
// 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go
index a4674c1a3..a9df13229 100644
--- a/backend/internal/repository/http_upstream.go
+++ b/backend/internal/repository/http_upstream.go
@@ -1,6 +1,7 @@
package repository
import (
+ "crypto/tls"
"errors"
"fmt"
"io"
@@ -14,7 +15,6 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -45,6 +45,16 @@ const (
defaultMaxUpstreamClients = 5000
// defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟)
defaultClientIdleTTLSeconds = 900
+ // OpenAI HTTP/2 代理回退策略默认值
+ defaultOpenAIHTTP2FallbackErrorThreshold = 2
+ defaultOpenAIHTTP2FallbackWindow = 60 * time.Second
+ defaultOpenAIHTTP2FallbackTTL = 10 * time.Minute
+)
+
+const (
+ upstreamProtocolModeDefault = "default"
+ upstreamProtocolModeOpenAIH2 = "openai_h2"
+ upstreamProtocolModeOpenAIH1Fallback = "openai_h1_fallback"
)
var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached")
@@ -59,14 +69,30 @@ type poolSettings struct {
responseHeaderTimeout time.Duration // 等待响应头超时时间
}
+type openAIHTTP2Settings struct {
+ enabled bool
+ allowProxyFallbackToHTTP1 bool
+ fallbackErrorThreshold int
+ fallbackWindow time.Duration
+ fallbackTTL time.Duration
+}
+
// upstreamClientEntry 上游客户端缓存条目
// 记录客户端实例及其元数据,用于连接池管理和淘汰策略
type upstreamClientEntry struct {
- client *http.Client // HTTP 客户端实例
- proxyKey string // 代理标识(用于检测代理变更)
- poolKey string // 连接池配置标识(用于检测配置变更)
- lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
- inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
+ client *http.Client // HTTP 客户端实例
+ proxyKey string // 代理标识(用于检测代理变更)
+ poolKey string // 连接池配置标识(用于检测配置变更)
+ protocolMode string // 协议模式(default/openai_h2/openai_h1_fallback)
+ lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
+ inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
+}
+
+type openAIHTTP2FallbackState struct {
+ mu sync.Mutex
+ windowStart time.Time
+ errorCount int
+ fallbackUntil time.Time
}
// httpUpstreamService 通用 HTTP 上游服务
@@ -90,6 +116,8 @@ type httpUpstreamService struct {
cfg *config.Config // 全局配置
mu sync.RWMutex // 保护 clients map 的读写锁
clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定
+ // OpenAI 走 HTTP 代理时的 H2->H1 回退状态(key=标准化 proxyKey)
+ openAIHTTP2Fallbacks sync.Map
}
// NewHTTPUpstream 创建通用 HTTP 上游服务
@@ -127,9 +155,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
+ profile := service.HTTPUpstreamProfileDefault
+ if req != nil {
+ profile = service.HTTPUpstreamProfileFromContext(req.Context())
+ }
// 获取或创建对应的客户端,并标记请求占用
- entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
+ entry, err := s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, profile)
if err != nil {
return nil, err
}
@@ -137,11 +169,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
// 执行请求
resp, err := entry.client.Do(req)
if err != nil {
+ s.recordOpenAIHTTP2Failure(profile, entry.protocolMode, entry.proxyKey, err)
// 请求失败,立即减少计数
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
return nil, err
}
+ s.recordOpenAIHTTP2Success(profile, entry.protocolMode, entry.proxyKey)
// 包装响应体,在关闭时自动减少计数并更新时间戳
// 这确保了流式响应(如 SSE)在完全读取前不会被淘汰
@@ -236,13 +270,10 @@ func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID in
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
isolation := s.getIsolationMode()
- proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL)
- if err != nil {
- return nil, err
- }
+ proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
- cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
- poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
+ cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID, upstreamProtocolModeDefault)
+ poolKey := s.buildPoolKey(isolation, accountConcurrency, upstreamProtocolModeDefault) + ":tls"
now := time.Now()
nowUnix := now.UnixNano()
@@ -359,7 +390,12 @@ func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Req
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
- return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true)
+ return s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault)
+}
+
+// acquireClientWithProfile 获取或创建客户端,并按请求 profile 选择协议策略。
+func (s *httpUpstreamService) acquireClientWithProfile(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile) (*upstreamClientEntry, error) {
+ return s.getClientEntry(proxyURL, accountID, accountConcurrency, profile, true, true)
}
// getOrCreateClient 获取或创建客户端
@@ -377,25 +413,25 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac
// - proxy: 按代理地址隔离,同一代理共享客户端
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
// - account_proxy: 按账户+代理组合隔离,最细粒度
-func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
- return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
+func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry {
+ entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault, false, false)
+ return entry
}
// getClientEntry 获取或创建客户端条目
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
-func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
+func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
// 获取隔离模式
isolation := s.getIsolationMode()
// 标准化代理 URL 并解析
- proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL)
- if err != nil {
- return nil, err
- }
+ proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
+ // 根据请求 profile(例如 OpenAI)选择协议模式
+ protocolMode := s.resolveProtocolMode(profile, proxyKey, parsedProxy)
// 构建缓存键(根据隔离策略不同)
- cacheKey := buildCacheKey(isolation, proxyKey, accountID)
+ cacheKey := buildCacheKey(isolation, proxyKey, accountID, protocolMode)
// 构建连接池配置键(用于检测配置变更)
- poolKey := s.buildPoolKey(isolation, accountConcurrency)
+ poolKey := s.buildPoolKey(isolation, accountConcurrency, protocolMode)
now := time.Now()
nowUnix := now.UnixNano()
@@ -439,7 +475,7 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
// 缓存未命中或需要重建,创建新客户端
settings := s.resolvePoolSettings(isolation, accountConcurrency)
- transport, err := buildUpstreamTransport(settings, parsedProxy)
+ transport, err := buildUpstreamTransport(settings, parsedProxy, protocolMode)
if err != nil {
s.mu.Unlock()
return nil, fmt.Errorf("build transport: %w", err)
@@ -449,9 +485,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{
- client: client,
- proxyKey: proxyKey,
- poolKey: poolKey,
+ client: client,
+ proxyKey: proxyKey,
+ poolKey: poolKey,
+ protocolMode: protocolMode,
}
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
@@ -644,13 +681,17 @@ func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcu
//
// 返回:
// - string: 配置键
-func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string {
+func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int, protocolMode string) string {
+ base := "default"
if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy {
if accountConcurrency > 0 {
- return fmt.Sprintf("account:%d", accountConcurrency)
+ base = fmt.Sprintf("account:%d", accountConcurrency)
}
}
- return "default"
+ if protocolMode == "" || protocolMode == upstreamProtocolModeDefault {
+ return base
+ }
+ return base + "|proto:" + protocolMode
}
// buildCacheKey 构建客户端缓存键
@@ -668,15 +709,20 @@ func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency
// - proxy 模式: "proxy:{proxyKey}"
// - account 模式: "account:{accountID}"
// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}"
-func buildCacheKey(isolation, proxyKey string, accountID int64) string {
+func buildCacheKey(isolation, proxyKey string, accountID int64, protocolMode string) string {
+ var base string
switch isolation {
case config.ConnectionPoolIsolationAccount:
- return fmt.Sprintf("account:%d", accountID)
+ base = fmt.Sprintf("account:%d", accountID)
case config.ConnectionPoolIsolationAccountProxy:
- return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
+ base = fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
default:
- return fmt.Sprintf("proxy:%s", proxyKey)
+ base = fmt.Sprintf("proxy:%s", proxyKey)
+ }
+ if protocolMode != "" && protocolMode != upstreamProtocolModeDefault {
+ base += "|proto:" + protocolMode
}
+ return base
}
// normalizeProxyURL 标准化代理 URL
@@ -686,18 +732,17 @@ func buildCacheKey(isolation, proxyKey string, accountID int64) string {
// - raw: 原始代理 URL 字符串
//
// 返回:
-// - string: 标准化的代理键(空返回 "direct")
-// - *url.URL: 解析后的 URL(空返回 nil)
-// - error: 非空代理 URL 解析失败时返回错误(禁止回退到直连)
-func normalizeProxyURL(raw string) (string, *url.URL, error) {
- _, parsed, err := proxyurl.Parse(raw)
+// - string: 标准化的代理键(空或解析失败返回 "direct")
+// - *url.URL: 解析后的 URL(空或解析失败返回 nil)
+func normalizeProxyURL(raw string) (string, *url.URL) {
+ proxyURL := strings.TrimSpace(raw)
+ if proxyURL == "" {
+ return directProxyKey, nil
+ }
+ parsed, err := url.Parse(proxyURL)
if err != nil {
- return "", nil, err
- }
- if parsed == nil {
- return directProxyKey, nil, nil
+ return directProxyKey, nil
}
- // 规范化:小写 scheme/host,去除路径和查询参数
parsed.Scheme = strings.ToLower(parsed.Scheme)
parsed.Host = strings.ToLower(parsed.Host)
parsed.Path = ""
@@ -717,7 +762,200 @@ func normalizeProxyURL(raw string) (string, *url.URL, error) {
parsed.Host = hostname
}
}
- return parsed.String(), parsed, nil
+ return parsed.String(), parsed
+}
+
+func (s *httpUpstreamService) resolveOpenAIHTTP2Settings() openAIHTTP2Settings {
+ settings := openAIHTTP2Settings{
+ enabled: true,
+ allowProxyFallbackToHTTP1: true,
+ fallbackErrorThreshold: defaultOpenAIHTTP2FallbackErrorThreshold,
+ fallbackWindow: defaultOpenAIHTTP2FallbackWindow,
+ fallbackTTL: defaultOpenAIHTTP2FallbackTTL,
+ }
+ if s == nil || s.cfg == nil {
+ return settings
+ }
+ cfg := s.cfg.Gateway.OpenAIHTTP2
+ settings.enabled = cfg.Enabled
+ settings.allowProxyFallbackToHTTP1 = cfg.AllowProxyFallbackToHTTP1
+ if cfg.FallbackErrorThreshold > 0 {
+ settings.fallbackErrorThreshold = cfg.FallbackErrorThreshold
+ }
+ if cfg.FallbackWindowSeconds > 0 {
+ settings.fallbackWindow = time.Duration(cfg.FallbackWindowSeconds) * time.Second
+ }
+ if cfg.FallbackTTLSeconds > 0 {
+ settings.fallbackTTL = time.Duration(cfg.FallbackTTLSeconds) * time.Second
+ }
+ return settings
+}
+
+func (s *httpUpstreamService) resolveProtocolMode(profile service.HTTPUpstreamProfile, proxyKey string, parsedProxy *url.URL) string {
+ if profile != service.HTTPUpstreamProfileOpenAI {
+ return upstreamProtocolModeDefault
+ }
+ settings := s.resolveOpenAIHTTP2Settings()
+ if !settings.enabled {
+ return upstreamProtocolModeDefault
+ }
+ if parsedProxy == nil {
+ return upstreamProtocolModeOpenAIH2
+ }
+ scheme := strings.ToLower(parsedProxy.Scheme)
+ if scheme != "http" && scheme != "https" {
+ return upstreamProtocolModeOpenAIH2
+ }
+ if settings.allowProxyFallbackToHTTP1 && s.isOpenAIHTTP2FallbackActive(proxyKey) {
+ return upstreamProtocolModeOpenAIH1Fallback
+ }
+ return upstreamProtocolModeOpenAIH2
+}
+
+func (s *httpUpstreamService) isOpenAIHTTP2FallbackActive(proxyKey string) bool {
+ raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey)
+ if !ok {
+ return false
+ }
+ state, ok := raw.(*openAIHTTP2FallbackState)
+ if !ok || state == nil {
+ return false
+ }
+ return state.isFallbackActive(time.Now())
+}
+
+func (s *httpUpstreamService) getOrCreateOpenAIHTTP2FallbackState(proxyKey string) *openAIHTTP2FallbackState {
+ state := &openAIHTTP2FallbackState{}
+ actual, _ := s.openAIHTTP2Fallbacks.LoadOrStore(proxyKey, state)
+ cached, ok := actual.(*openAIHTTP2FallbackState)
+ if !ok || cached == nil {
+ return state
+ }
+ return cached
+}
+
+func isHTTPProxyKey(proxyKey string) bool {
+ return strings.HasPrefix(proxyKey, "http://") || strings.HasPrefix(proxyKey, "https://")
+}
+
+func isOpenAIHTTP2CompatibilityError(err error) bool {
+ if err == nil {
+ return false
+ }
+ msg := strings.ToLower(err.Error())
+ if msg == "" {
+ return false
+ }
+ markers := []string{
+ "http2",
+ "alpn",
+ "no application protocol",
+ "protocol error",
+ "stream error",
+ "goaway",
+ "refused_stream",
+ "frame too large",
+ }
+ for _, marker := range markers {
+ if strings.Contains(msg, marker) {
+ return true
+ }
+ }
+ return false
+}
+
+func (s *httpUpstreamService) recordOpenAIHTTP2Failure(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string, err error) {
+ if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 {
+ return
+ }
+ settings := s.resolveOpenAIHTTP2Settings()
+ if !settings.enabled || !settings.allowProxyFallbackToHTTP1 {
+ return
+ }
+ if !isHTTPProxyKey(proxyKey) || !isOpenAIHTTP2CompatibilityError(err) {
+ return
+ }
+ state := s.getOrCreateOpenAIHTTP2FallbackState(proxyKey)
+ activated, until := state.recordFailure(time.Now(), settings.fallbackErrorThreshold, settings.fallbackWindow, settings.fallbackTTL)
+ if activated {
+ slog.Warn("openai_http2_proxy_fallback_activated",
+ "proxy", proxyKey,
+ "fallback_until", until.Format(time.RFC3339))
+ }
+}
+
+func (s *httpUpstreamService) recordOpenAIHTTP2Success(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string) {
+ if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 {
+ return
+ }
+ if !isHTTPProxyKey(proxyKey) {
+ return
+ }
+ raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey)
+ if !ok {
+ return
+ }
+ state, ok := raw.(*openAIHTTP2FallbackState)
+ if !ok || state == nil {
+ return
+ }
+ state.resetErrorWindow()
+}
+
+func (s *openAIHTTP2FallbackState) isFallbackActive(now time.Time) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.fallbackUntil.IsZero() {
+ return false
+ }
+ if now.Before(s.fallbackUntil) {
+ return true
+ }
+ s.fallbackUntil = time.Time{}
+ return false
+}
+
+func (s *openAIHTTP2FallbackState) resetErrorWindow() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.windowStart = time.Time{}
+ s.errorCount = 0
+}
+
+func (s *openAIHTTP2FallbackState) recordFailure(now time.Time, threshold int, window, ttl time.Duration) (bool, time.Time) {
+ if threshold <= 0 {
+ threshold = defaultOpenAIHTTP2FallbackErrorThreshold
+ }
+ if window <= 0 {
+ window = defaultOpenAIHTTP2FallbackWindow
+ }
+ if ttl <= 0 {
+ ttl = defaultOpenAIHTTP2FallbackTTL
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if !s.fallbackUntil.IsZero() && now.Before(s.fallbackUntil) {
+ return false, s.fallbackUntil
+ }
+ if !s.fallbackUntil.IsZero() && !now.Before(s.fallbackUntil) {
+ s.fallbackUntil = time.Time{}
+ }
+
+ if s.windowStart.IsZero() || now.Sub(s.windowStart) > window {
+ s.windowStart = now
+ s.errorCount = 0
+ }
+ s.errorCount++
+ if s.errorCount < threshold {
+ return false, time.Time{}
+ }
+
+ s.fallbackUntil = now.Add(ttl)
+ s.windowStart = time.Time{}
+ s.errorCount = 0
+ return true, s.fallbackUntil
}
// defaultPoolSettings 获取默认连接池配置
@@ -779,7 +1017,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
-func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) {
+func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL, protocolMode string) (*http.Transport, error) {
transport := &http.Transport{
MaxIdleConns: settings.maxIdleConns,
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
@@ -787,6 +1025,14 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
IdleConnTimeout: settings.idleConnTimeout,
ResponseHeaderTimeout: settings.responseHeaderTimeout,
}
+ switch protocolMode {
+ case upstreamProtocolModeOpenAIH2:
+ transport.ForceAttemptHTTP2 = true
+ case upstreamProtocolModeOpenAIH1Fallback:
+ // 显式禁用 HTTP/2,确保代理不兼容场景回退到 HTTP/1.1。
+ transport.ForceAttemptHTTP2 = false
+ transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
+ }
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
return nil, err
}
diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go
index 89892b3b6..ebeee640e 100644
--- a/backend/internal/repository/http_upstream_benchmark_test.go
+++ b/backend/internal/repository/http_upstream_benchmark_test.go
@@ -45,7 +45,7 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
settings := defaultPoolSettings(cfg)
for i := 0; i < b.N; i++ {
// 每次迭代都创建新客户端,包含 Transport 分配
- transport, err := buildUpstreamTransport(settings, parsedProxy)
+ transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeDefault)
if err != nil {
b.Fatalf("创建 Transport 失败: %v", err)
}
@@ -59,10 +59,7 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
// 模拟优化后的行为,从缓存获取客户端
b.Run("复用", func(b *testing.B) {
// 预热:确保客户端已缓存
- entry, err := svc.getOrCreateClient(proxyURL, 1, 1)
- if err != nil {
- b.Fatalf("getOrCreateClient: %v", err)
- }
+ entry := svc.getOrCreateClient(proxyURL, 1, 1)
client := entry.client
b.ResetTimer() // 重置计时器,排除预热时间
for i := 0; i < b.N; i++ {
diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go
index b3268463a..35ee9adde 100644
--- a/backend/internal/repository/http_upstream_test.go
+++ b/backend/internal/repository/http_upstream_test.go
@@ -1,13 +1,19 @@
package repository
import (
+ "context"
+ "errors"
"io"
"net/http"
+ "net/url"
+ "strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
+ "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -44,7 +50,7 @@ func (s *HTTPUpstreamSuite) newService() *httpUpstreamService {
// 验证未配置时使用 300 秒默认值
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
svc := s.newService()
- entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0)
+ entry := svc.getOrCreateClient("", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
@@ -55,27 +61,25 @@ func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
svc := s.newService()
- entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0)
+ entry := svc.getOrCreateClient("", 0, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
}
-// TestGetOrCreateClient_InvalidURLReturnsError 测试无效代理 URL 返回错误
-// 验证解析失败时拒绝回退到直连模式
-func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLReturnsError() {
+// TestGetOrCreateClient_InvalidURLFallsBackToDirect 测试无效代理 URL 回退
+// 验证解析失败时回退到直连模式
+func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLFallsBackToDirect() {
svc := s.newService()
- _, err := svc.getClientEntry("://bad-proxy-url", 1, 1, false, false)
- require.Error(s.T(), err, "expected error for invalid proxy URL")
+ entry := svc.getOrCreateClient("://bad-proxy-url", 1, 1)
+ require.Equal(s.T(), directProxyKey, entry.proxyKey, "expected direct proxy fallback")
}
// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化
// 验证等价地址能够映射到同一缓存键
func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() {
- key1, _, err1 := normalizeProxyURL("http://proxy.local:8080")
- require.NoError(s.T(), err1)
- key2, _, err2 := normalizeProxyURL("http://proxy.local:8080/")
- require.NoError(s.T(), err2)
+ key1, _ := normalizeProxyURL("http://proxy.local:8080")
+ key2, _ := normalizeProxyURL("http://proxy.local:8080/")
require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match")
}
@@ -116,6 +120,16 @@ func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
require.Equal(s.T(), "direct", string(b), "unexpected body")
}
+func (s *HTTPUpstreamSuite) TestDo_RequestErrorPath() {
+ svc := s.newService()
+ req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:1/unreachable", nil)
+ require.NoError(s.T(), err)
+
+ resp, doErr := svc.Do(req, "", 1, 1)
+ require.Nil(s.T(), resp)
+ require.Error(s.T(), doErr)
+}
+
// TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能
// 验证请求通过代理服务器转发,使用绝对 URI 格式
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
@@ -173,8 +187,8 @@ func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一代理,不同账户
- entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 1, 3)
- entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 2, 3)
+ entry1 := svc.getOrCreateClient("http://proxy.local:8080", 1, 3)
+ entry2 := svc.getOrCreateClient("http://proxy.local:8080", 2, 3)
require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池")
require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端")
}
@@ -185,8 +199,8 @@ func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy}
svc := s.newService()
// 同一账户,不同代理
- entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3)
- entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3)
+ entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
+ entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理")
require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端")
}
@@ -197,8 +211,8 @@ func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 同一账户,先后使用不同代理
- entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3)
- entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3)
+ entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 3)
+ entry2 := svc.getOrCreateClient("http://proxy-b:8080", 1, 3)
require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池")
require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池")
require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理")
@@ -210,7 +224,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() {
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
svc := s.newService()
// 账户并发数为 12
- entry := mustGetOrCreateClient(s.T(), svc, "", 1, 12)
+ entry := svc.getOrCreateClient("", 1, 12)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
// 连接池参数应与并发数一致
@@ -230,7 +244,7 @@ func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() {
}
svc := s.newService()
// 账户并发数为 0,应使用全局配置
- entry := mustGetOrCreateClient(s.T(), svc, "", 1, 0)
+ entry := svc.getOrCreateClient("", 1, 0)
transport, ok := entry.client.Transport.(*http.Transport)
require.True(s.T(), ok, "expected *http.Transport")
require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch")
@@ -247,12 +261,12 @@ func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() {
}
svc := s.newService()
// 创建两个客户端,设置不同的最后使用时间
- entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 1)
- entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 2, 1)
+ entry1 := svc.getOrCreateClient("http://proxy-a:8080", 1, 1)
+ entry2 := svc.getOrCreateClient("http://proxy-b:8080", 2, 1)
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久
atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano())
// 创建第三个客户端,触发淘汰
- _ = mustGetOrCreateClient(s.T(), svc, "http://proxy-c:8080", 3, 1)
+ _ = svc.getOrCreateClient("http://proxy-c:8080", 3, 1)
require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内")
require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理")
@@ -266,29 +280,446 @@ func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() {
ClientIdleTTLSeconds: 1, // 1 秒空闲超时
}
svc := s.newService()
- entry1 := mustGetOrCreateClient(s.T(), svc, "", 1, 1)
+ entry1 := svc.getOrCreateClient("", 1, 1)
// 设置为很久之前使用,但有活跃请求
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano())
atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求
// 创建新客户端,触发淘汰检查
- _, _ = svc.getOrCreateClient("", 2, 1)
+ _ = svc.getOrCreateClient("", 2, 1)
require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收")
}
+func (s *HTTPUpstreamSuite) TestOpenAIProfile_UsesHTTP2TransportForHTTPProxy() {
+ s.cfg.Gateway = config.GatewayConfig{
+ OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 2,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ },
+ }
+ svc := s.newService()
+
+ entry, err := svc.getClientEntry("http://proxy.local:8080", 1, 1, service.HTTPUpstreamProfileOpenAI, false, false)
+ require.NoError(s.T(), err)
+ require.Equal(s.T(), upstreamProtocolModeOpenAIH2, entry.protocolMode)
+
+ transport, ok := entry.client.Transport.(*http.Transport)
+ require.True(s.T(), ok, "expected *http.Transport")
+ require.True(s.T(), transport.ForceAttemptHTTP2, "OpenAI profile should prefer HTTP/2")
+ require.Nil(s.T(), transport.TLSNextProto, "HTTP/2 mode should not force-disable TLSNextProto")
+}
+
+func (s *HTTPUpstreamSuite) TestOpenAIProfile_FallbackToHTTP11WhenProxyMarkedIncompatible() {
+ s.cfg.Gateway = config.GatewayConfig{
+ OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 2,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ },
+ }
+ svc := s.newService()
+ proxyURL := "http://proxy.local:8080"
+
+ state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyURL)
+ state.mu.Lock()
+ state.fallbackUntil = time.Now().Add(3 * time.Minute)
+ state.mu.Unlock()
+
+ entry, err := svc.getClientEntry(proxyURL, 1, 1, service.HTTPUpstreamProfileOpenAI, false, false)
+ require.NoError(s.T(), err)
+ require.Equal(s.T(), upstreamProtocolModeOpenAIH1Fallback, entry.protocolMode)
+
+ transport, ok := entry.client.Transport.(*http.Transport)
+ require.True(s.T(), ok, "expected *http.Transport")
+ require.False(s.T(), transport.ForceAttemptHTTP2, "fallback mode must disable HTTP/2 force-attempt")
+ require.NotNil(s.T(), transport.TLSNextProto, "fallback mode must disable HTTP/2 negotiation")
+}
+
+func (s *HTTPUpstreamSuite) TestOpenAIProfile_RecordHTTP2ErrorActivatesFallback() {
+ s.cfg.Gateway = config.GatewayConfig{
+ OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 2,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ },
+ }
+ svc := s.newService()
+ proxyURL := "http://proxy.local:8080"
+ h2Err := errors.New("http2: stream error")
+
+ svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, h2Err)
+ require.False(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL), "first error should not activate fallback")
+
+ svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, h2Err)
+ require.True(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL), "second error in window should activate fallback")
+}
+
+func (s *HTTPUpstreamSuite) TestOpenAIProfile_RecordNonHTTP2ErrorDoesNotActivateFallback() {
+ s.cfg.Gateway = config.GatewayConfig{
+ OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 1,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ },
+ }
+ svc := s.newService()
+ proxyURL := "http://proxy.local:8080"
+
+ svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, errors.New("dial tcp: i/o timeout"))
+ require.False(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL))
+}
+
+func (s *HTTPUpstreamSuite) TestDoWithTLS_DisabledDelegatesToDo() {
+ upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = io.WriteString(w, "ok")
+ }))
+ s.T().Cleanup(upstream.Close)
+
+ svc := s.newService()
+ req, err := http.NewRequest(http.MethodGet, upstream.URL+"/tls-disabled", nil)
+ require.NoError(s.T(), err)
+
+ resp, err := svc.DoWithTLS(req, "", 1, 1, false)
+ require.NoError(s.T(), err)
+ defer func() { _ = resp.Body.Close() }()
+ body, readErr := io.ReadAll(resp.Body)
+ require.NoError(s.T(), readErr)
+ require.Equal(s.T(), "ok", string(body))
+}
+
+func (s *HTTPUpstreamSuite) TestDoWithTLS_EnabledHTTPRequestSuccess() {
+ upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = io.WriteString(w, "tls-enabled")
+ }))
+ s.T().Cleanup(upstream.Close)
+
+ svc := s.newService()
+ req, err := http.NewRequest(http.MethodGet, upstream.URL+"/tls-enabled", nil)
+ require.NoError(s.T(), err)
+
+ resp, err := svc.DoWithTLS(req, "", 9, 1, true)
+ require.NoError(s.T(), err)
+ defer func() { _ = resp.Body.Close() }()
+ body, readErr := io.ReadAll(resp.Body)
+ require.NoError(s.T(), readErr)
+ require.Equal(s.T(), "tls-enabled", string(body))
+}
+
+func (s *HTTPUpstreamSuite) TestDoWithTLS_EnabledRequestError() {
+ svc := s.newService()
+ req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:1/tls-error", nil)
+ require.NoError(s.T(), err)
+
+ resp, doErr := svc.DoWithTLS(req, "", 9, 1, true)
+ require.Nil(s.T(), resp)
+ require.Error(s.T(), doErr)
+}
+
+func (s *HTTPUpstreamSuite) TestDoWithTLS_ValidateRequestHostFailure() {
+ s.cfg.Security.URLAllowlist.Enabled = true
+ s.cfg.Security.URLAllowlist.AllowPrivateHosts = false
+ svc := s.newService()
+
+ req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1/test", nil)
+ require.NoError(s.T(), err)
+
+ resp, doErr := svc.DoWithTLS(req, "", 1, 1, true)
+ require.Nil(s.T(), resp)
+ require.Error(s.T(), doErr)
+}
+
+func (s *HTTPUpstreamSuite) TestShouldValidateResolvedIPAndValidateRequestHost() {
+ svc := s.newService()
+ require.False(s.T(), svc.shouldValidateResolvedIP())
+ require.NoError(s.T(), svc.validateRequestHost(nil))
+
+ s.cfg.Security.URLAllowlist.Enabled = true
+ s.cfg.Security.URLAllowlist.AllowPrivateHosts = false
+ require.True(s.T(), svc.shouldValidateResolvedIP())
+ require.Error(s.T(), svc.validateRequestHost(nil))
+
+ req, err := http.NewRequest(http.MethodGet, "http:///nohost", nil)
+ require.NoError(s.T(), err)
+ require.Error(s.T(), svc.validateRequestHost(req))
+}
+
+func (s *HTTPUpstreamSuite) TestRedirectCheckerStopsAfterLimit() {
+ svc := s.newService()
+ req, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
+ require.NoError(s.T(), err)
+
+ via := make([]*http.Request, 10)
+ require.Error(s.T(), svc.redirectChecker(req, via))
+}
+
+func (s *HTTPUpstreamSuite) TestRedirectCheckerValidatesRequestHost() {
+ s.cfg.Security.URLAllowlist.Enabled = true
+ s.cfg.Security.URLAllowlist.AllowPrivateHosts = false
+ svc := s.newService()
+
+ req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
+ require.NoError(s.T(), err)
+ require.Error(s.T(), svc.redirectChecker(req, nil))
+}
+
+func (s *HTTPUpstreamSuite) TestShouldReuseEntryAndEvictBranches() {
+ svc := s.newService()
+ entry := &upstreamClientEntry{
+ proxyKey: "proxy-a",
+ poolKey: "pool-a",
+ }
+ require.False(s.T(), svc.shouldReuseEntry(nil, config.ConnectionPoolIsolationAccount, "proxy-a", "pool-a"))
+ require.False(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationAccount, "proxy-b", "pool-a"))
+ require.False(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationProxy, "proxy-a", "pool-b"))
+ require.True(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationProxy, "proxy-x", "pool-a"))
+
+ s.cfg.Gateway.MaxUpstreamClients = 2
+ svc.clients["k1"] = &upstreamClientEntry{inFlight: 1}
+ svc.clients["k2"] = &upstreamClientEntry{inFlight: 1}
+ require.False(s.T(), svc.evictOldestIdleLocked())
+ require.False(s.T(), svc.evictOverLimitLocked())
+}
+
+func (s *HTTPUpstreamSuite) TestBuildCacheKeyAndIsolationMode() {
+ svc := s.newService()
+ require.Equal(s.T(), "account:1", buildCacheKey(config.ConnectionPoolIsolationAccount, "direct", 1, ""))
+ require.Equal(s.T(), "account:2|proxy:px", buildCacheKey(config.ConnectionPoolIsolationAccountProxy, "px", 2, ""))
+ require.Equal(s.T(), "proxy:direct", buildCacheKey(config.ConnectionPoolIsolationProxy, "direct", 3, ""))
+ require.Equal(s.T(), "account:1|proto:openai_h2", buildCacheKey(config.ConnectionPoolIsolationAccount, "direct", 1, "openai_h2"))
+
+ s.cfg.Gateway.ConnectionPoolIsolation = "invalid"
+ require.Equal(s.T(), config.ConnectionPoolIsolationAccountProxy, svc.getIsolationMode())
+ s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationProxy
+ require.Equal(s.T(), config.ConnectionPoolIsolationProxy, svc.getIsolationMode())
+}
+
+func (s *HTTPUpstreamSuite) TestResolveProtocolModeAndSettingsBranches() {
+ svc := s.newService()
+ s.cfg.Gateway.OpenAIHTTP2 = config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 2,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ }
+ parsedHTTPProxy, err := url.Parse("http://proxy.local:8080")
+ require.NoError(s.T(), err)
+ parsedSOCKSProxy, err := url.Parse("socks5://proxy.local:1080")
+ require.NoError(s.T(), err)
+
+ require.Equal(s.T(), upstreamProtocolModeDefault, svc.resolveProtocolMode(service.HTTPUpstreamProfileDefault, "direct", nil))
+ require.Equal(s.T(), upstreamProtocolModeOpenAIH2, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "direct", nil))
+ require.Equal(s.T(), upstreamProtocolModeOpenAIH2, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "socks5://proxy.local:1080", parsedSOCKSProxy))
+
+ state := svc.getOrCreateOpenAIHTTP2FallbackState("http://proxy.local:8080")
+ state.mu.Lock()
+ state.fallbackUntil = time.Now().Add(10 * time.Second)
+ state.mu.Unlock()
+ require.Equal(s.T(), upstreamProtocolModeOpenAIH1Fallback, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "http://proxy.local:8080", parsedHTTPProxy))
+
+ s.cfg.Gateway.OpenAIHTTP2.Enabled = false
+ require.Equal(s.T(), upstreamProtocolModeDefault, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "http://proxy.local:8080", parsedHTTPProxy))
+}
+
+func (s *HTTPUpstreamSuite) TestGetClientEntryWithTLS_ReusesAndRebuildsOnProxyChange() {
+ s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationAccount
+ svc := s.newService()
+ profile := &tlsfingerprint.Profile{Name: "tls-profile"}
+
+ entry1, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, false, false)
+ require.NoError(s.T(), err)
+ entry2, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, false, false)
+ require.NoError(s.T(), err)
+ require.Same(s.T(), entry1, entry2)
+
+ entry3, err := svc.getClientEntryWithTLS("http://proxy-b.local:8080", 1, 1, profile, false, false)
+ require.NoError(s.T(), err)
+ require.NotSame(s.T(), entry1, entry3)
+}
+
+func (s *HTTPUpstreamSuite) TestGetClientEntryWithTLS_OverLimitReturnsError() {
+ s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationAccountProxy
+ s.cfg.Gateway.MaxUpstreamClients = 1
+ svc := s.newService()
+ profile := &tlsfingerprint.Profile{Name: "tls-profile"}
+
+ entry1, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, true, true)
+ require.NoError(s.T(), err)
+ require.NotNil(s.T(), entry1)
+
+ entry2, err := svc.getClientEntryWithTLS("http://proxy-b.local:8080", 2, 1, profile, true, true)
+ require.ErrorIs(s.T(), err, errUpstreamClientLimitReached)
+ require.Nil(s.T(), entry2)
+}
+
+func (s *HTTPUpstreamSuite) TestOpenAIFallbackStateHelpers() {
+ var state openAIHTTP2FallbackState
+ now := time.Now()
+
+ active, until := state.recordFailure(now, 1, time.Minute, time.Minute)
+ require.True(s.T(), active)
+ require.False(s.T(), until.IsZero())
+ require.True(s.T(), state.isFallbackActive(now))
+ require.False(s.T(), state.isFallbackActive(now.Add(2*time.Minute)))
+
+ state.recordFailure(now, 3, time.Minute, time.Minute)
+ state.recordFailure(now.Add(10*time.Second), 3, time.Minute, time.Minute)
+ state.resetErrorWindow()
+ require.Equal(s.T(), 0, state.errorCount)
+ require.True(s.T(), state.windowStart.IsZero())
+
+ // 在 fallback 活跃期间再次失败,不应重复激活。
+ state.fallbackUntil = now.Add(time.Minute)
+ activated, _ := state.recordFailure(now.Add(5*time.Second), 1, time.Minute, time.Minute)
+ require.False(s.T(), activated)
+}
+
+func (s *HTTPUpstreamSuite) TestRecordOpenAIHTTP2SuccessResetsWindow() {
+ svc := s.newService()
+ proxyURL := "http://proxy.local:8080"
+ state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyURL)
+ state.mu.Lock()
+ state.errorCount = 5
+ state.windowStart = time.Now()
+ state.mu.Unlock()
+
+ svc.recordOpenAIHTTP2Success(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL)
+
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ require.Equal(s.T(), 0, state.errorCount)
+ require.True(s.T(), state.windowStart.IsZero())
+}
+
+func (s *HTTPUpstreamSuite) TestDo_OpenAIProxySuccessResetsHTTP2ErrorWindow() {
+ seen := make(chan struct{}, 1)
+ proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ select {
+ case seen <- struct{}{}:
+ default:
+ }
+ _, _ = io.WriteString(w, "proxied")
+ }))
+ s.T().Cleanup(proxySrv.Close)
+
+ s.cfg.Gateway.OpenAIHTTP2 = config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 2,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ }
+ svc := s.newService()
+ proxyKey, _ := normalizeProxyURL(proxySrv.URL)
+ state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyKey)
+ state.mu.Lock()
+ state.windowStart = time.Now()
+ state.errorCount = 3
+ state.fallbackUntil = time.Time{}
+ state.mu.Unlock()
+
+ req, err := http.NewRequest(http.MethodGet, "http://example.com/reset-window", nil)
+ require.NoError(s.T(), err)
+ req = req.WithContext(service.WithHTTPUpstreamProfile(context.Background(), service.HTTPUpstreamProfileOpenAI))
+ resp, doErr := svc.Do(req, proxySrv.URL, 1, 1)
+ require.NoError(s.T(), doErr)
+ defer func() { _ = resp.Body.Close() }()
+ _, _ = io.ReadAll(resp.Body)
+
+ select {
+ case <-seen:
+ default:
+ require.Fail(s.T(), "expected proxy to receive request")
+ }
+
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ require.Equal(s.T(), 0, state.errorCount)
+ require.True(s.T(), state.windowStart.IsZero())
+}
+
+func (s *HTTPUpstreamSuite) TestOpenAIFallbackStateMapTypeSafety() {
+ svc := s.newService()
+ svc.openAIHTTP2Fallbacks.Store("x", "bad-type")
+ require.False(s.T(), svc.isOpenAIHTTP2FallbackActive("x"))
+ state := svc.getOrCreateOpenAIHTTP2FallbackState("x")
+ require.NotNil(s.T(), state)
+}
+
+func (s *HTTPUpstreamSuite) TestBuildUpstreamTransport_ModeSwitchingAndProxyErrors() {
+ settings := defaultPoolSettings(s.cfg)
+ parsedProxy, err := url.Parse("http://proxy.local:8080")
+ require.NoError(s.T(), err)
+
+ h2Transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeOpenAIH2)
+ require.NoError(s.T(), err)
+ require.True(s.T(), h2Transport.ForceAttemptHTTP2)
+
+ h1Transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeOpenAIH1Fallback)
+ require.NoError(s.T(), err)
+ require.False(s.T(), h1Transport.ForceAttemptHTTP2)
+ require.NotNil(s.T(), h1Transport.TLSNextProto)
+
+ badProxy, err := url.Parse("ftp://proxy.local:21")
+ require.NoError(s.T(), err)
+ _, badErr := buildUpstreamTransport(settings, badProxy, upstreamProtocolModeDefault)
+ require.Error(s.T(), badErr)
+}
+
+func (s *HTTPUpstreamSuite) TestBuildUpstreamTransportWithTLSFingerprintBranches() {
+ settings := defaultPoolSettings(s.cfg)
+ profile := &tlsfingerprint.Profile{Name: "test-profile"}
+
+ transportDirect, err := buildUpstreamTransportWithTLSFingerprint(settings, nil, profile)
+ require.NoError(s.T(), err)
+ require.NotNil(s.T(), transportDirect.DialTLSContext)
+
+ httpProxy, err := url.Parse("http://proxy.local:8080")
+ require.NoError(s.T(), err)
+ transportHTTPProxy, err := buildUpstreamTransportWithTLSFingerprint(settings, httpProxy, profile)
+ require.NoError(s.T(), err)
+ require.NotNil(s.T(), transportHTTPProxy.DialTLSContext)
+
+ socksProxy, err := url.Parse("socks5://proxy.local:1080")
+ require.NoError(s.T(), err)
+ transportSOCKSProxy, err := buildUpstreamTransportWithTLSFingerprint(settings, socksProxy, profile)
+ require.NoError(s.T(), err)
+ require.NotNil(s.T(), transportSOCKSProxy.DialTLSContext)
+
+ unsupportedProxy, err := url.Parse("ftp://proxy.local:21")
+ require.NoError(s.T(), err)
+ _, unsupportedErr := buildUpstreamTransportWithTLSFingerprint(settings, unsupportedProxy, profile)
+ require.Error(s.T(), unsupportedErr)
+}
+
+func (s *HTTPUpstreamSuite) TestWrapTrackedBody_NilAndCloseOnce() {
+ require.Nil(s.T(), wrapTrackedBody(nil, nil))
+
+ closed := int32(0)
+ readCloser := io.NopCloser(strings.NewReader("x"))
+ wrapped := wrapTrackedBody(readCloser, func() {
+ atomic.AddInt32(&closed, 1)
+ })
+ require.NotNil(s.T(), wrapped)
+ _ = wrapped.Close()
+ _ = wrapped.Close()
+ require.Equal(s.T(), int32(1), atomic.LoadInt32(&closed))
+}
+
// TestHTTPUpstreamSuite 运行测试套件
func TestHTTPUpstreamSuite(t *testing.T) {
suite.Run(t, new(HTTPUpstreamSuite))
}
-// mustGetOrCreateClient 测试辅助函数,调用 getOrCreateClient 并断言无错误
-func mustGetOrCreateClient(t *testing.T, svc *httpUpstreamService, proxyURL string, accountID int64, concurrency int) *upstreamClientEntry {
- t.Helper()
- entry, err := svc.getOrCreateClient(proxyURL, accountID, concurrency)
- require.NoError(t, err, "getOrCreateClient(%q, %d, %d)", proxyURL, accountID, concurrency)
- return entry
-}
-
// hasEntry 检查客户端是否存在于缓存中
// 辅助函数,用于验证淘汰逻辑
func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool {
diff --git a/backend/internal/repository/identity_cache.go b/backend/internal/repository/identity_cache.go
index 6152dd7a5..c49865479 100644
--- a/backend/internal/repository/identity_cache.go
+++ b/backend/internal/repository/identity_cache.go
@@ -12,7 +12,7 @@ import (
const (
fingerprintKeyPrefix = "fingerprint:"
- fingerprintTTL = 7 * 24 * time.Hour // 7天,配合每24小时懒续期可保持活跃账号永不过期
+ fingerprintTTL = 24 * time.Hour
maskedSessionKeyPrefix = "masked_session:"
maskedSessionTTL = 15 * time.Minute
)
diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go
index dca0b612f..3e155971b 100644
--- a/backend/internal/repository/openai_oauth_service.go
+++ b/backend/internal/repository/openai_oauth_service.go
@@ -23,10 +23,7 @@ type openaiOAuthService struct {
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
- client, err := createOpenAIReqClient(proxyURL)
- if err != nil {
- return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
- }
+ client := createOpenAIReqClient(proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
@@ -77,10 +74,7 @@ func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
}
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
- client, err := createOpenAIReqClient(proxyURL)
- if err != nil {
- return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
- }
+ client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -108,7 +102,7 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre
return &tokenResp, nil
}
-func createOpenAIReqClient(proxyURL string) (*req.Client, error) {
+func createOpenAIReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
diff --git a/backend/internal/repository/pricing_service.go b/backend/internal/repository/pricing_service.go
index ee8e1749f..07d796b8c 100644
--- a/backend/internal/repository/pricing_service.go
+++ b/backend/internal/repository/pricing_service.go
@@ -4,7 +4,6 @@ import (
"context"
"fmt"
"io"
- "log/slog"
"net/http"
"strings"
"time"
@@ -17,37 +16,14 @@ type pricingRemoteClient struct {
httpClient *http.Client
}
-// pricingRemoteClientError 代理初始化失败时的错误占位客户端
-// 所有请求直接返回初始化错误,禁止回退到直连
-type pricingRemoteClientError struct {
- err error
-}
-
-func (c *pricingRemoteClientError) FetchPricingJSON(_ context.Context, _ string) ([]byte, error) {
- return nil, c.err
-}
-
-func (c *pricingRemoteClientError) FetchHashText(_ context.Context, _ string) (string, error) {
- return "", c.err
-}
-
// NewPricingRemoteClient 创建定价数据远程客户端
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
-// 代理配置失败时行为由 allowDirectOnProxyError 控制:
-// - false(默认):返回错误占位客户端,禁止回退到直连
-// - true:回退到直连(仅限管理员显式开启)
-func NewPricingRemoteClient(proxyURL string, allowDirectOnProxyError bool) service.PricingRemoteClient {
- // 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据,
- // 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。
+func NewPricingRemoteClient(proxyURL string) service.PricingRemoteClient {
sharedClient, err := httpclient.GetClient(httpclient.Options{
Timeout: 30 * time.Second,
ProxyURL: proxyURL,
})
if err != nil {
- if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
- slog.Warn("proxy client init failed, all requests will fail", "service", "pricing", "error", err)
- return &pricingRemoteClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
- }
sharedClient = &http.Client{Timeout: 30 * time.Second}
}
return &pricingRemoteClient{
diff --git a/backend/internal/repository/pricing_service_test.go b/backend/internal/repository/pricing_service_test.go
index ef2f214b0..6ea112117 100644
--- a/backend/internal/repository/pricing_service_test.go
+++ b/backend/internal/repository/pricing_service_test.go
@@ -19,7 +19,7 @@ type PricingServiceSuite struct {
func (s *PricingServiceSuite) SetupTest() {
s.ctx = context.Background()
- client, ok := NewPricingRemoteClient("", false).(*pricingRemoteClient)
+ client, ok := NewPricingRemoteClient("").(*pricingRemoteClient)
require.True(s.T(), ok, "type assertion failed")
s.client = client
}
@@ -140,22 +140,6 @@ func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
require.Error(s.T(), err)
}
-func TestNewPricingRemoteClient_InvalidProxy_NoFallback(t *testing.T) {
- client := NewPricingRemoteClient("://bad", false)
- _, ok := client.(*pricingRemoteClientError)
- require.True(t, ok, "should return error client when proxy is invalid and fallback disabled")
-
- _, err := client.FetchPricingJSON(context.Background(), "http://example.com")
- require.Error(t, err)
- require.Contains(t, err.Error(), "proxy client init failed")
-}
-
-func TestNewPricingRemoteClient_InvalidProxy_WithFallback(t *testing.T) {
- client := NewPricingRemoteClient("://bad", true)
- _, ok := client.(*pricingRemoteClient)
- require.True(t, ok, "should fallback to direct client when allowed")
-}
-
func TestPricingServiceSuite(t *testing.T) {
suite.Run(t, new(PricingServiceSuite))
}
diff --git a/backend/internal/repository/proxy_probe_service.go b/backend/internal/repository/proxy_probe_service.go
index b4aeab718..54de28972 100644
--- a/backend/internal/repository/proxy_probe_service.go
+++ b/backend/internal/repository/proxy_probe_service.go
@@ -66,6 +66,7 @@ func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*s
ProxyURL: proxyURL,
Timeout: defaultProxyProbeTimeout,
InsecureSkipVerify: s.insecureSkipVerify,
+ ProxyStrict: true,
ValidateResolvedIP: s.validateResolvedIP,
AllowPrivateHosts: s.allowPrivateHosts,
})
diff --git a/backend/internal/repository/req_client_pool.go b/backend/internal/repository/req_client_pool.go
index 79b24396d..af71a7ee3 100644
--- a/backend/internal/repository/req_client_pool.go
+++ b/backend/internal/repository/req_client_pool.go
@@ -6,8 +6,6 @@ import (
"sync"
"time"
- "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
-
"github.com/imroc/req/v3"
)
@@ -35,11 +33,11 @@ var sharedReqClients sync.Map
// getSharedReqClient 获取共享的 req 客户端实例
// 性能优化:相同配置复用同一客户端,避免重复创建
-func getSharedReqClient(opts reqClientOptions) (*req.Client, error) {
+func getSharedReqClient(opts reqClientOptions) *req.Client {
key := buildReqClientKey(opts)
if cached, ok := sharedReqClients.Load(key); ok {
if c, ok := cached.(*req.Client); ok {
- return c, nil
+ return c
}
}
@@ -50,19 +48,15 @@ func getSharedReqClient(opts reqClientOptions) (*req.Client, error) {
if opts.Impersonate {
client = client.ImpersonateChrome()
}
- trimmed, _, err := proxyurl.Parse(opts.ProxyURL)
- if err != nil {
- return nil, err
- }
- if trimmed != "" {
- client.SetProxyURL(trimmed)
+ if strings.TrimSpace(opts.ProxyURL) != "" {
+ client.SetProxyURL(strings.TrimSpace(opts.ProxyURL))
}
actual, _ := sharedReqClients.LoadOrStore(key, client)
if c, ok := actual.(*req.Client); ok {
- return c, nil
+ return c
}
- return client, nil
+ return client
}
func buildReqClientKey(opts reqClientOptions) string {
diff --git a/backend/internal/repository/req_client_pool_test.go b/backend/internal/repository/req_client_pool_test.go
index 9067d0129..904ed4d6e 100644
--- a/backend/internal/repository/req_client_pool_test.go
+++ b/backend/internal/repository/req_client_pool_test.go
@@ -26,13 +26,11 @@ func TestGetSharedReqClient_ForceHTTP2SeparatesCache(t *testing.T) {
ProxyURL: "http://proxy.local:8080",
Timeout: time.Second,
}
- clientDefault, err := getSharedReqClient(base)
- require.NoError(t, err)
+ clientDefault := getSharedReqClient(base)
force := base
force.ForceHTTP2 = true
- clientForce, err := getSharedReqClient(force)
- require.NoError(t, err)
+ clientForce := getSharedReqClient(force)
require.NotSame(t, clientDefault, clientForce)
require.NotEqual(t, buildReqClientKey(base), buildReqClientKey(force))
@@ -44,10 +42,8 @@ func TestGetSharedReqClient_ReuseCachedClient(t *testing.T) {
ProxyURL: "http://proxy.local:8080",
Timeout: 2 * time.Second,
}
- first, err := getSharedReqClient(opts)
- require.NoError(t, err)
- second, err := getSharedReqClient(opts)
- require.NoError(t, err)
+ first := getSharedReqClient(opts)
+ second := getSharedReqClient(opts)
require.Same(t, first, second)
}
@@ -60,8 +56,7 @@ func TestGetSharedReqClient_IgnoresNonClientCache(t *testing.T) {
key := buildReqClientKey(opts)
sharedReqClients.Store(key, "invalid")
- client, err := getSharedReqClient(opts)
- require.NoError(t, err)
+ client := getSharedReqClient(opts)
require.NotNil(t, client)
loaded, ok := sharedReqClients.Load(key)
@@ -76,45 +71,20 @@ func TestGetSharedReqClient_ImpersonateAndProxy(t *testing.T) {
Timeout: 4 * time.Second,
Impersonate: true,
}
- client, err := getSharedReqClient(opts)
- require.NoError(t, err)
+ client := getSharedReqClient(opts)
require.NotNil(t, client)
require.Equal(t, "http://proxy.local:8080|4s|true|false", buildReqClientKey(opts))
}
-func TestGetSharedReqClient_InvalidProxyURL(t *testing.T) {
- sharedReqClients = sync.Map{}
- opts := reqClientOptions{
- ProxyURL: "://missing-scheme",
- Timeout: time.Second,
- }
- _, err := getSharedReqClient(opts)
- require.Error(t, err)
- require.Contains(t, err.Error(), "invalid proxy URL")
-}
-
-func TestGetSharedReqClient_ProxyURLMissingHost(t *testing.T) {
- sharedReqClients = sync.Map{}
- opts := reqClientOptions{
- ProxyURL: "http://",
- Timeout: time.Second,
- }
- _, err := getSharedReqClient(opts)
- require.Error(t, err)
- require.Contains(t, err.Error(), "proxy URL missing host")
-}
-
func TestCreateOpenAIReqClient_Timeout120Seconds(t *testing.T) {
sharedReqClients = sync.Map{}
- client, err := createOpenAIReqClient("http://proxy.local:8080")
- require.NoError(t, err)
+ client := createOpenAIReqClient("http://proxy.local:8080")
require.Equal(t, 120*time.Second, client.GetClient().Timeout)
}
func TestCreateGeminiReqClient_ForceHTTP2Disabled(t *testing.T) {
sharedReqClients = sync.Map{}
- client, err := createGeminiReqClient("http://proxy.local:8080")
- require.NoError(t, err)
+ client := createGeminiReqClient("http://proxy.local:8080")
require.Equal(t, "", forceHTTPVersion(t, client))
}
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index ff40e97d6..7306b0672 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -218,6 +218,275 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
return true, nil
}
+func (r *usageLogRepository) usageSQLExecutor(ctx context.Context) sqlExecutor {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return tx.Client()
+ }
+ return r.sql
+}
+
+func (r *usageLogRepository) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ if fn == nil {
+ return nil
+ }
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return fn(ctx)
+ }
+ tx, err := r.client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := fn(txCtx); err != nil {
+ _ = tx.Rollback()
+ return err
+ }
+ return tx.Commit()
+}
+
+func (r *usageLogRepository) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*service.UsageBillingEntry, error) {
+ query := `
+ SELECT
+ id,
+ usage_log_id,
+ user_id,
+ api_key_id,
+ subscription_id,
+ billing_type,
+ applied,
+ delta_usd,
+ status,
+ attempt_count,
+ next_retry_at,
+ updated_at,
+ created_at,
+ last_error
+ FROM billing_usage_entries
+ WHERE usage_log_id = $1
+ `
+ rows, err := r.usageSQLExecutor(ctx).QueryContext(ctx, query, usageLogID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
+ return nil, service.ErrUsageBillingEntryNotFound
+ }
+ entry, err := scanUsageBillingEntry(rows)
+ if err != nil {
+ return nil, err
+ }
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
+ return entry, nil
+}
+
+func (r *usageLogRepository) UpsertUsageBillingEntry(ctx context.Context, entry *service.UsageBillingEntry) (*service.UsageBillingEntry, bool, error) {
+ if entry == nil {
+ return nil, false, nil
+ }
+
+ insertQuery := `
+ INSERT INTO billing_usage_entries (
+ usage_log_id,
+ user_id,
+ api_key_id,
+ subscription_id,
+ billing_type,
+ applied,
+ delta_usd,
+ status,
+ attempt_count,
+ next_retry_at,
+ updated_at
+ ) VALUES (
+ $1, $2, $3, $4, $5, FALSE, $6, $7, 0, NOW(), NOW()
+ )
+ ON CONFLICT (usage_log_id) DO NOTHING
+ RETURNING
+ id,
+ usage_log_id,
+ user_id,
+ api_key_id,
+ subscription_id,
+ billing_type,
+ applied,
+ delta_usd,
+ status,
+ attempt_count,
+ next_retry_at,
+ updated_at,
+ created_at,
+ last_error
+ `
+
+ exec := r.usageSQLExecutor(ctx)
+ rows, err := exec.QueryContext(
+ ctx,
+ insertQuery,
+ entry.UsageLogID,
+ entry.UserID,
+ entry.APIKeyID,
+ nullInt64(entry.SubscriptionID),
+ entry.BillingType,
+ entry.DeltaUSD,
+ service.UsageBillingEntryStatusPending,
+ )
+ if err != nil {
+ return nil, false, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if rows.Next() {
+ created, scanErr := scanUsageBillingEntry(rows)
+ if scanErr != nil {
+ return nil, false, scanErr
+ }
+ return created, true, nil
+ }
+ if err = rows.Err(); err != nil {
+ return nil, false, err
+ }
+
+ existing, err := r.GetUsageBillingEntryByUsageLogID(ctx, entry.UsageLogID)
+ if err != nil {
+ return nil, false, err
+ }
+ return existing, false, nil
+}
+
+func (r *usageLogRepository) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error {
+ query := `
+ UPDATE billing_usage_entries
+ SET
+ applied = TRUE,
+ status = $2,
+ last_error = NULL,
+ next_retry_at = NOW(),
+ updated_at = NOW()
+ WHERE id = $1
+ `
+ res, err := r.usageSQLExecutor(ctx).ExecContext(ctx, query, entryID, service.UsageBillingEntryStatusApplied)
+ if err != nil {
+ return err
+ }
+ affected, err := res.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return service.ErrUsageBillingEntryNotFound
+ }
+ return nil
+}
+
+func (r *usageLogRepository) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error {
+ query := `
+ UPDATE billing_usage_entries
+ SET
+ applied = FALSE,
+ status = $2,
+ next_retry_at = $3,
+ last_error = $4,
+ updated_at = NOW()
+ WHERE id = $1
+ `
+ res, err := r.usageSQLExecutor(ctx).ExecContext(
+ ctx,
+ query,
+ entryID,
+ service.UsageBillingEntryStatusPending,
+ nextRetryAt,
+ nullString(&lastError),
+ )
+ if err != nil {
+ return err
+ }
+ affected, err := res.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return service.ErrUsageBillingEntryNotFound
+ }
+ return nil
+}
+
+func (r *usageLogRepository) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]service.UsageBillingEntry, error) {
+ if limit <= 0 {
+ return nil, nil
+ }
+ staleAt := time.Now().Add(-processingStaleAfter)
+ query := `
+ WITH candidates AS (
+ SELECT id
+ FROM billing_usage_entries
+ WHERE applied = FALSE
+ AND (
+ (status = $1 AND next_retry_at <= NOW())
+ OR (status = $2 AND updated_at <= $3)
+ )
+ ORDER BY id
+ LIMIT $4
+ FOR UPDATE SKIP LOCKED
+ )
+ UPDATE billing_usage_entries b
+ SET
+ status = $2,
+ attempt_count = b.attempt_count + 1,
+ updated_at = NOW(),
+ last_error = NULL
+ FROM candidates c
+ WHERE b.id = c.id
+ RETURNING
+ b.id,
+ b.usage_log_id,
+ b.user_id,
+ b.api_key_id,
+ b.subscription_id,
+ b.billing_type,
+ b.applied,
+ b.delta_usd,
+ b.status,
+ b.attempt_count,
+ b.next_retry_at,
+ b.updated_at,
+ b.created_at,
+ b.last_error
+ `
+
+ rows, err := r.usageSQLExecutor(ctx).QueryContext(
+ ctx,
+ query,
+ service.UsageBillingEntryStatusPending,
+ service.UsageBillingEntryStatusProcessing,
+ staleAt,
+ limit,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ entries := make([]service.UsageBillingEntry, 0, limit)
+ for rows.Next() {
+ item, scanErr := scanUsageBillingEntry(rows)
+ if scanErr != nil {
+ return nil, scanErr
+ }
+ entries = append(entries, *item)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return entries, nil
+}
+
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
rows, err := r.sql.QueryContext(ctx, query, id)
@@ -1863,77 +2132,6 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start
return results, nil
}
-// GetGroupStatsWithFilters returns group usage statistics with optional filters
-func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []usagestats.GroupStat, err error) {
- query := `
- SELECT
- COALESCE(ul.group_id, 0) as group_id,
- COALESCE(g.name, '') as group_name,
- COUNT(*) as requests,
- COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens,
- COALESCE(SUM(ul.total_cost), 0) as cost,
- COALESCE(SUM(ul.actual_cost), 0) as actual_cost
- FROM usage_logs ul
- LEFT JOIN groups g ON g.id = ul.group_id
- WHERE ul.created_at >= $1 AND ul.created_at < $2
- `
-
- args := []any{startTime, endTime}
- if userID > 0 {
- query += fmt.Sprintf(" AND ul.user_id = $%d", len(args)+1)
- args = append(args, userID)
- }
- if apiKeyID > 0 {
- query += fmt.Sprintf(" AND ul.api_key_id = $%d", len(args)+1)
- args = append(args, apiKeyID)
- }
- if accountID > 0 {
- query += fmt.Sprintf(" AND ul.account_id = $%d", len(args)+1)
- args = append(args, accountID)
- }
- if groupID > 0 {
- query += fmt.Sprintf(" AND ul.group_id = $%d", len(args)+1)
- args = append(args, groupID)
- }
- query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream)
- if billingType != nil {
- query += fmt.Sprintf(" AND ul.billing_type = $%d", len(args)+1)
- args = append(args, int16(*billingType))
- }
- query += " GROUP BY ul.group_id, g.name ORDER BY total_tokens DESC"
-
- rows, err := r.sql.QueryContext(ctx, query, args...)
- if err != nil {
- return nil, err
- }
- defer func() {
- if closeErr := rows.Close(); closeErr != nil && err == nil {
- err = closeErr
- results = nil
- }
- }()
-
- results = make([]usagestats.GroupStat, 0)
- for rows.Next() {
- var row usagestats.GroupStat
- if err := rows.Scan(
- &row.GroupID,
- &row.GroupName,
- &row.Requests,
- &row.TotalTokens,
- &row.Cost,
- &row.ActualCost,
- ); err != nil {
- return nil, err
- }
- results = append(results, row)
- }
- if err := rows.Err(); err != nil {
- return nil, err
- }
- return results, nil
-}
-
// GetGlobalStats gets usage statistics for all users within a time range
func (r *usageLogRepository) GetGlobalStats(ctx context.Context, startTime, endTime time.Time) (*UsageStats, error) {
query := `
@@ -2590,6 +2788,52 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
return log, nil
}
+func scanUsageBillingEntry(scanner interface{ Scan(...any) error }) (*service.UsageBillingEntry, error) {
+ var (
+ subscriptionID sql.NullInt64
+ nextRetryAt time.Time
+ updatedAt time.Time
+ createdAt time.Time
+ lastError sql.NullString
+ status int16
+ entry service.UsageBillingEntry
+ )
+
+ if err := scanner.Scan(
+ &entry.ID,
+ &entry.UsageLogID,
+ &entry.UserID,
+ &entry.APIKeyID,
+ &subscriptionID,
+ &entry.BillingType,
+ &entry.Applied,
+ &entry.DeltaUSD,
+ &status,
+ &entry.AttemptCount,
+ &nextRetryAt,
+ &updatedAt,
+ &createdAt,
+ &lastError,
+ ); err != nil {
+ return nil, err
+ }
+
+ if subscriptionID.Valid {
+ v := subscriptionID.Int64
+ entry.SubscriptionID = &v
+ }
+ entry.Status = service.UsageBillingEntryStatus(status)
+ entry.NextRetryAt = nextRetryAt
+ entry.UpdatedAt = updatedAt
+ entry.CreatedAt = createdAt
+ if lastError.Valid {
+ msg := lastError.String
+ entry.LastError = &msg
+ }
+
+ return &entry, nil
+}
+
func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
results := make([]TrendDataPoint, 0)
for rows.Next() {
diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go
index 2e35e0a00..eb8ce3fba 100644
--- a/backend/internal/repository/wire.go
+++ b/backend/internal/repository/wire.go
@@ -34,7 +34,7 @@ func ProvideGitHubReleaseClient(cfg *config.Config) service.GitHubReleaseClient
// ProvidePricingRemoteClient 创建定价数据远程客户端
// 从配置中读取代理设置,支持国内服务器通过代理访问 GitHub 上的定价数据
func ProvidePricingRemoteClient(cfg *config.Config) service.PricingRemoteClient {
- return NewPricingRemoteClient(cfg.Update.ProxyURL, cfg.Security.ProxyFallback.AllowDirectOnError)
+ return NewPricingRemoteClient(cfg.Update.ProxyURL)
}
// ProvideSessionLimitCache 创建会话限制缓存
@@ -79,8 +79,6 @@ var ProviderSet = wire.NewSet(
NewTimeoutCounterCache,
ProvideConcurrencyCache,
ProvideSessionLimitCache,
- NewRPMCache,
- NewUserMsgQueueCache,
NewDashboardCache,
NewEmailCache,
NewIdentityCache,
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index 63b6cf282..f5efd97fa 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -86,15 +86,6 @@ func TestAPIContracts(t *testing.T) {
"last_used_at": null,
"quota": 0,
"quota_used": 0,
- "rate_limit_5h": 0,
- "rate_limit_1d": 0,
- "rate_limit_7d": 0,
- "usage_5h": 0,
- "usage_1d": 0,
- "usage_7d": 0,
- "window_5h_start": null,
- "window_1d_start": null,
- "window_7d_start": null,
"expires_at": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
@@ -135,15 +126,6 @@ func TestAPIContracts(t *testing.T) {
"last_used_at": null,
"quota": 0,
"quota_used": 0,
- "rate_limit_5h": 0,
- "rate_limit_1d": 0,
- "rate_limit_7d": 0,
- "usage_5h": 0,
- "usage_1d": 0,
- "usage_7d": 0,
- "window_5h_start": null,
- "window_1d_start": null,
- "window_7d_start": null,
"expires_at": null,
"created_at": "2025-01-02T03:04:05Z",
"updated_at": "2025-01-02T03:04:05Z"
@@ -517,7 +499,6 @@ func TestAPIContracts(t *testing.T) {
"doc_url": "https://docs.example.com",
"default_concurrency": 5,
"default_balance": 1.25,
- "default_subscriptions": [],
"enable_model_fallback": false,
"fallback_model_anthropic": "claude-3-5-sonnet-20241022",
"fallback_model_antigravity": "gemini-2.5-pro",
@@ -531,8 +512,7 @@ func TestAPIContracts(t *testing.T) {
"hide_ccs_import_button": false,
"purchase_subscription_enabled": false,
"purchase_subscription_url": "",
- "min_claude_code_version": "",
- "custom_menu_items": []
+ "custom_menu_items": null
}
}`,
},
@@ -640,12 +620,12 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
- adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil)
+ adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
- adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
@@ -1510,6 +1490,18 @@ func (r *stubApiKeyRepo) IncrementQuotaUsed(ctx context.Context, id int64, amoun
return 0, errors.New("not implemented")
}
+func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
+ return errors.New("not implemented")
+}
+
+func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
+ return errors.New("not implemented")
+}
+
+func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
+ return nil, errors.New("not implemented")
+}
+
func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
key, ok := r.byID[id]
if !ok {
@@ -1524,16 +1516,6 @@ func (r *stubApiKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt ti
return nil
}
-func (r *stubApiKeyRepo) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
- return nil
-}
-func (r *stubApiKeyRepo) ResetRateLimitWindows(ctx context.Context, id int64) error {
- return nil
-}
-func (r *stubApiKeyRepo) GetRateLimitData(ctx context.Context, id int64) (*service.APIKeyRateLimitData, error) {
- return nil, nil
-}
-
type stubUsageLogRepo struct {
userLogs map[int64][]service.UsageLog
}
@@ -1610,10 +1592,6 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented")
}
-func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
- return nil, errors.New("not implemented")
-}
-
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go
index 033a5b778..7640ab2ae 100644
--- a/backend/internal/server/middleware/admin_auth_test.go
+++ b/backend/internal/server/middleware/admin_auth_test.go
@@ -19,7 +19,7 @@ func TestAdminAuthJWTValidatesTokenVersion(t *testing.T) {
gin.SetMode(gin.TestMode)
cfg := &config.Config{JWT: config.JWTConfig{Secret: "test-secret", ExpireHour: 1}}
- authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
+ authService := service.NewAuthService(nil, nil, nil, cfg, nil, nil, nil, nil, nil)
admin := &service.User{
ID: 1,
diff --git a/backend/internal/server/middleware/jwt_auth_test.go b/backend/internal/server/middleware/jwt_auth_test.go
index f8839cfe8..bc3209584 100644
--- a/backend/internal/server/middleware/jwt_auth_test.go
+++ b/backend/internal/server/middleware/jwt_auth_test.go
@@ -40,7 +40,7 @@ func newJWTTestEnv(users map[int64]*service.User) (*gin.Engine, *service.AuthSer
cfg.JWT.AccessTokenExpireMinutes = 60
userRepo := &stubJWTUserRepo{users: users}
- authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
+ authSvc := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil)
userSvc := service.NewUserService(userRepo, nil, nil)
mw := NewJWTAuthMiddleware(authSvc, userSvc)
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index c36c36a0a..4d991ea47 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -26,6 +26,9 @@ func RegisterAdminRoutes(
// 分组管理
registerGroupRoutes(admin, h)
+ // API Key 管理
+ registerAdminAPIKeyRoutes(admin, h)
+
// 账号管理
registerAccountRoutes(admin, h)
@@ -75,16 +78,6 @@ func RegisterAdminRoutes(
// 错误透传规则管理
registerErrorPassthroughRoutes(admin, h)
-
- // API Key 管理
- registerAdminAPIKeyRoutes(admin, h)
- }
-}
-
-func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
- apiKeys := admin.Group("/api-keys")
- {
- apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup)
}
}
@@ -96,6 +89,7 @@ func registerOpsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
ops.GET("/user-concurrency", h.Admin.Ops.GetUserConcurrencyStats)
ops.GET("/account-availability", h.Admin.Ops.GetAccountAvailability)
ops.GET("/realtime-traffic", h.Admin.Ops.GetRealtimeTrafficSummary)
+ ops.GET("/openai-ws-v2/passthrough-metrics", h.Admin.Ops.GetOpenAIWSV2PassthroughMetrics)
// Alerts (rules + events)
ops.GET("/alert-rules", h.Admin.Ops.ListAlertRules)
@@ -227,6 +221,13 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
+func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ apiKeys := admin.Group("/api-keys")
+ {
+ apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup)
+ }
+}
+
func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts := admin.Group("/accounts")
{
@@ -386,6 +387,18 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
+ // 批量编辑模板库(服务端共享)
+ adminSettings.GET("/bulk-edit-templates", h.Admin.Setting.ListBulkEditTemplates)
+ adminSettings.POST("/bulk-edit-templates", h.Admin.Setting.UpsertBulkEditTemplate)
+ adminSettings.DELETE("/bulk-edit-templates/:template_id", h.Admin.Setting.DeleteBulkEditTemplate)
+ adminSettings.GET(
+ "/bulk-edit-templates/:template_id/versions",
+ h.Admin.Setting.ListBulkEditTemplateVersions,
+ )
+ adminSettings.POST(
+ "/bulk-edit-templates/:template_id/rollback",
+ h.Admin.Setting.RollbackBulkEditTemplate,
+ )
// Sora S3 存储配置
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 81e91aeb0..8d5b281b0 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -853,19 +853,24 @@ func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool {
}
const (
- OpenAIWSIngressModeOff = "off"
- OpenAIWSIngressModeShared = "shared"
- OpenAIWSIngressModeDedicated = "dedicated"
+ OpenAIWSIngressModeOff = "off"
+ OpenAIWSIngressModeShared = "shared"
+ OpenAIWSIngressModeDedicated = "dedicated"
+ OpenAIWSIngressModeCtxPool = "ctx_pool"
+ OpenAIWSIngressModePassthrough = "passthrough"
)
func normalizeOpenAIWSIngressMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAIWSIngressModeOff:
return OpenAIWSIngressModeOff
- case OpenAIWSIngressModeShared:
- return OpenAIWSIngressModeShared
- case OpenAIWSIngressModeDedicated:
- return OpenAIWSIngressModeDedicated
+ case OpenAIWSIngressModeCtxPool:
+ return OpenAIWSIngressModeCtxPool
+ case OpenAIWSIngressModePassthrough:
+ return OpenAIWSIngressModePassthrough
+ case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
+ // Deprecated: shared/dedicated 已废弃,平滑迁移到 ctx_pool
+ return OpenAIWSIngressModeCtxPool
default:
return ""
}
@@ -875,16 +880,16 @@ func normalizeOpenAIWSIngressDefaultMode(mode string) string {
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
return normalized
}
- return OpenAIWSIngressModeShared
+ return OpenAIWSIngressModeOff
}
-// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
+// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool/passthrough)。
//
// 优先级:
// 1. 分类型 mode 新字段(string)
// 2. 分类型 enabled 旧字段(bool)
// 3. 兼容 enabled 旧字段(bool)
-// 4. defaultMode(非法时回退 shared)
+// 4. defaultMode(非法时回退 off)
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
if a == nil || !a.IsOpenAI() {
@@ -919,7 +924,8 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return "", false
}
if enabled {
- return OpenAIWSIngressModeShared, true
+ // 兼容旧 enabled 字段:开启时至少落到 ctx_pool。
+ return OpenAIWSIngressModeCtxPool, true
}
return OpenAIWSIngressModeOff, true
}
@@ -1295,12 +1301,6 @@ func parseExtraFloat64(value any) float64 {
}
// parseExtraInt 从 extra 字段解析 int 值
-// ParseExtraInt 从 extra 字段的 any 值解析为 int。
-// 支持 int, int64, float64, json.Number, string 类型,无法解析时返回 0。
-func ParseExtraInt(value any) int {
- return parseExtraInt(value)
-}
-
func parseExtraInt(value any) int {
switch v := value.(type) {
case int:
diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go
index a85c68ec5..2fe62483c 100644
--- a/backend/internal/service/account_openai_passthrough_test.go
+++ b/backend/internal/service/account_openai_passthrough_test.go
@@ -206,30 +206,63 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
}
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
- t.Run("default fallback to shared", func(t *testing.T) {
+ t.Run("default fallback to off", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
- require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
- require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
+ require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
+ require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
})
- t.Run("oauth mode field has highest priority", func(t *testing.T) {
+ t.Run("unsupported mode field falls back to enabled flag", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
- "openai_oauth_responses_websockets_v2_enabled": false,
+ "openai_oauth_responses_websockets_v2_enabled": true,
"responses_websockets_v2_enabled": false,
},
}
- require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
+ require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
+ })
+
+ t.Run("ctx_pool mode field is recognized", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
+ },
+ }
+ require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
+ })
+
+ t.Run("passthrough mode field is recognized", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ },
+ }
+ require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
+ })
+
+ t.Run("legacy enabled maps to ctx_pool when default is off", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+ require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
})
- t.Run("legacy enabled maps to shared", func(t *testing.T) {
+ t.Run("legacy enabled ignores unsupported default and maps to ctx_pool", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
@@ -237,7 +270,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true,
},
}
- require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
+ require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
})
t.Run("legacy disabled maps to off", func(t *testing.T) {
@@ -249,7 +282,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true,
},
}
- require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
+ require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
t.Run("non openai always off", func(t *testing.T) {
@@ -260,7 +293,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
- require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
+ require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
}
diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go
index 18a70c5cc..964261ddc 100644
--- a/backend/internal/service/account_service.go
+++ b/backend/internal/service/account_service.go
@@ -73,15 +73,16 @@ type AccountRepository interface {
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
// Nil pointers mean "do not change".
type AccountBulkUpdate struct {
- Name *string
- ProxyID *int64
- Concurrency *int
- Priority *int
- RateMultiplier *float64
- Status *string
- Schedulable *bool
- Credentials map[string]any
- Extra map[string]any
+ Name *string
+ ProxyID *int64
+ Concurrency *int
+ Priority *int
+ RateMultiplier *float64
+ Status *string
+ Schedulable *bool
+ AutoPauseOnExpired *bool
+ Credentials map[string]any
+ Extra map[string]any
}
// CreateAccountRequest 创建账号请求
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index 6dee6c133..13a138567 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -37,7 +37,6 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
- GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 7e6982d36..6f7c3c05f 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -43,8 +43,6 @@ type AdminService interface {
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
-
- // API Key management (admin)
AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
// Account management
@@ -224,17 +222,18 @@ type UpdateAccountInput struct {
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
- AccountIDs []int64
- Name string
- ProxyID *int64
- Concurrency *int
- Priority *int
- RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
- Status string
- Schedulable *bool
- GroupIDs *[]int64
- Credentials map[string]any
- Extra map[string]any
+ AccountIDs []int64
+ Name string
+ ProxyID *int64
+ Concurrency *int
+ Priority *int
+ RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
+ Status string
+ Schedulable *bool
+ AutoPauseOnExpired *bool
+ GroupIDs *[]int64
+ Credentials map[string]any
+ Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
@@ -250,9 +249,9 @@ type BulkUpdateAccountResult struct {
// AdminUpdateAPIKeyGroupIDResult is the result of AdminUpdateAPIKeyGroupID.
type AdminUpdateAPIKeyGroupIDResult struct {
APIKey *APIKey
- AutoGrantedGroupAccess bool // true if a new exclusive group permission was auto-added
- GrantedGroupID *int64 // the group ID that was auto-granted
- GrantedGroupName string // the group name that was auto-granted
+ AutoGrantedGroupAccess bool
+ GrantedGroupID *int64
+ GrantedGroupName string
}
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
@@ -420,8 +419,6 @@ type adminServiceImpl struct {
proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator
entClient *dbent.Client // 用于开启数据库事务
- settingService *SettingService
- defaultSubAssigner DefaultSubscriptionAssigner
}
type userGroupRateBatchReader interface {
@@ -447,8 +444,6 @@ func NewAdminService(
proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator,
entClient *dbent.Client,
- settingService *SettingService,
- defaultSubAssigner DefaultSubscriptionAssigner,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@@ -464,8 +459,6 @@ func NewAdminService(
proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator,
entClient: entClient,
- settingService: settingService,
- defaultSubAssigner: defaultSubAssigner,
}
}
@@ -550,27 +543,9 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
if err := s.userRepo.Create(ctx, user); err != nil {
return nil, err
}
- s.assignDefaultSubscriptions(ctx, user.ID)
return user, nil
}
-func (s *adminServiceImpl) assignDefaultSubscriptions(ctx context.Context, userID int64) {
- if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
- return
- }
- items := s.settingService.GetDefaultSubscriptions(ctx)
- for _, item := range items {
- if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
- UserID: userID,
- GroupID: item.GroupID,
- ValidityDays: item.ValidityDays,
- Notes: "auto assigned by default user subscriptions setting",
- }); err != nil {
- logger.LegacyPrintf("service.admin", "failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
- }
- }
-}
-
func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error) {
user, err := s.userRepo.GetByID(ctx, id)
if err != nil {
@@ -1225,8 +1200,8 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
-// AdminUpdateAPIKeyGroupID 管理员修改 API Key 分组绑定
-// groupID: nil=不修改, 指向0=解绑, 指向正整数=绑定到目标分组
+// AdminUpdateAPIKeyGroupID allows admins to update API key-group binding.
+// groupID: nil=unchanged, 0=unbind, >0=bind to target group.
func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
if err != nil {
@@ -1234,22 +1209,17 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
}
if groupID == nil {
- // nil 表示不修改,直接返回
return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil
}
-
if *groupID < 0 {
return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative")
}
result := &AdminUpdateAPIKeyGroupIDResult{}
-
if *groupID == 0 {
- // 0 表示解绑分组(不修改 user_allowed_groups,避免影响用户其他 Key)
apiKey.GroupID = nil
apiKey.Group = nil
} else {
- // 验证目标分组存在且状态为 active
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return nil, err
@@ -1257,7 +1227,6 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
if group.Status != StatusActive {
return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
}
- // 订阅类型分组:不允许通过此 API 直接绑定,需通过订阅管理流程
if group.IsSubscriptionType() {
return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow")
}
@@ -1266,7 +1235,6 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
apiKey.GroupID = &gid
apiKey.Group = group
- // 专属标准分组:使用事务保证「添加分组权限」与「更新 API Key」的原子性
if group.IsExclusive {
opCtx := ctx
var tx *dbent.Tx
@@ -1297,8 +1265,6 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
result.AutoGrantedGroupAccess = true
result.GrantedGroupID = &gid
result.GrantedGroupName = group.Name
-
- // 失效认证缓存(在事务提交后执行)
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
@@ -1308,12 +1274,9 @@ func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID i
}
}
- // 非专属分组 / 解绑:无需事务,单步更新即可
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
-
- // 失效认证缓存
if s.authCacheInvalidator != nil {
s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
@@ -1384,6 +1347,11 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
return nil, errors.New("base_url 必须以 http:// 或 https:// 开头")
}
}
+ if len(input.Extra) > 0 {
+ if err := validateOpenAIWSModeExtraValues(input.Extra); err != nil {
+ return nil, err
+ }
+ }
account := &Account{
Name: input.Name,
@@ -1458,6 +1426,9 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
account.Credentials = input.Credentials
}
if len(input.Extra) > 0 {
+ if err := validateOpenAIWSModeExtraValues(input.Extra); err != nil {
+ return nil, err
+ }
account.Extra = input.Extra
}
if input.ProxyID != nil {
@@ -1543,6 +1514,104 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
return updated, nil
}
+var openAIBulkScopedExtraKeys = map[string]struct{}{
+ "openai_passthrough": {},
+ "openai_oauth_passthrough": {},
+ "openai_oauth_responses_websockets_v2_mode": {},
+ "openai_oauth_responses_websockets_v2_enabled": {},
+ "openai_apikey_responses_websockets_v2_mode": {},
+ "openai_apikey_responses_websockets_v2_enabled": {},
+ "codex_cli_only": {},
+}
+
+func hasOpenAIBulkScopedExtraField(extra map[string]any) bool {
+ if len(extra) == 0 {
+ return false
+ }
+ for key := range extra {
+ if _, ok := openAIBulkScopedExtraKeys[key]; ok {
+ return true
+ }
+ }
+ return false
+}
+
+// openaiWSModeExtraKeys 列出所有控制 OpenAI WS v2 mode 的 extra 键名。
+var openaiWSModeExtraKeys = []string{
+ "openai_oauth_responses_websockets_v2_mode",
+ "openai_apikey_responses_websockets_v2_mode",
+}
+
+// validateOpenAIWSModeExtraValues 校验 extra 中 WS v2 mode 字段的值域,
+// 只允许 off / ctx_pool / passthrough。
+func validateOpenAIWSModeExtraValues(extra map[string]any) error {
+ for _, key := range openaiWSModeExtraKeys {
+ raw, exists := extra[key]
+ if !exists {
+ continue
+ }
+ val, ok := raw.(string)
+ if !ok {
+ return infraerrors.BadRequest(
+ "INVALID_OPENAI_WS_MODE",
+ fmt.Sprintf("%s must be a string, got %T", key, raw),
+ )
+ }
+ switch strings.TrimSpace(strings.ToLower(val)) {
+ case OpenAIWSIngressModeOff, OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough:
+ // valid
+ default:
+ return infraerrors.BadRequest(
+ "INVALID_OPENAI_WS_MODE",
+ fmt.Sprintf("%s must be one of off, ctx_pool, passthrough; got %q", key, val),
+ )
+ }
+ }
+ return nil
+}
+
+func validateOpenAIBulkScopedAccounts(accountsByID map[int64]*Account, accountIDs []int64) error {
+ var expectedType string
+
+ for _, accountID := range accountIDs {
+ account := accountsByID[accountID]
+ if account == nil {
+ return infraerrors.BadRequest(
+ "BULK_OPENAI_SCOPE_ACCOUNT_MISSING",
+ fmt.Sprintf("account %d not found for OpenAI scoped bulk update", accountID),
+ )
+ }
+
+ if account.Platform != PlatformOpenAI {
+ return infraerrors.BadRequest(
+ "BULK_OPENAI_SCOPE_PLATFORM_MISMATCH",
+ "OpenAI scoped bulk fields require all selected accounts to be OpenAI",
+ )
+ }
+
+ if account.Type != AccountTypeOAuth && account.Type != AccountTypeAPIKey {
+ return infraerrors.BadRequest(
+ "BULK_OPENAI_SCOPE_TYPE_UNSUPPORTED",
+ "OpenAI scoped bulk fields only support oauth or apikey account types",
+ )
+ }
+
+ if expectedType == "" {
+ expectedType = account.Type
+ continue
+ }
+
+ if account.Type != expectedType {
+ return infraerrors.BadRequest(
+ "BULK_OPENAI_SCOPE_TYPE_MISMATCH",
+ "OpenAI scoped bulk fields require all selected accounts to have the same type",
+ )
+ }
+ }
+
+ return nil
+}
+
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
@@ -1562,32 +1631,48 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
+ needOpenAIScopeCheck := hasOpenAIBulkScopedExtraField(input.Extra)
+ needAccountSnapshot := needMixedChannelCheck || needOpenAIScopeCheck
- // 预加载账号平台信息(混合渠道检查需要)。
+ // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。
platformByID := map[int64]string{}
- if needMixedChannelCheck {
+ accountByID := map[int64]*Account{}
+ groupAccountsByID := map[int64][]Account{}
+ groupNameByID := map[int64]string{}
+ if needAccountSnapshot {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
return nil, err
- }
- for _, account := range accounts {
- if account != nil {
- platformByID[account.ID] = account.Platform
+ } else {
+ for _, account := range accounts {
+ if account != nil {
+ accountByID[account.ID] = account
+ platformByID[account.ID] = account.Platform
+ }
}
}
}
- // 预检查混合渠道风险:在任何写操作之前,若发现风险立即返回错误。
+ if needOpenAIScopeCheck {
+ if err := validateOpenAIBulkScopedAccounts(accountByID, input.AccountIDs); err != nil {
+ return nil, err
+ }
+ }
+
+ // 校验 WS v2 mode 字段值域
+ if len(input.Extra) > 0 {
+ if err := validateOpenAIWSModeExtraValues(input.Extra); err != nil {
+ return nil, err
+ }
+ }
+
if needMixedChannelCheck {
- for _, accountID := range input.AccountIDs {
- platform := platformByID[accountID]
- if platform == "" {
- continue
- }
- if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
- return nil, err
- }
+ loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs)
+ if err != nil {
+ return nil, err
}
+ groupAccountsByID = loadedAccounts
+ groupNameByID = loadedNames
}
if input.RateMultiplier != nil {
@@ -1622,6 +1707,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Schedulable != nil {
repoUpdates.Schedulable = input.Schedulable
}
+ if input.AutoPauseOnExpired != nil {
+ repoUpdates.AutoPauseOnExpired = input.AutoPauseOnExpired
+ }
// Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
@@ -1631,8 +1719,34 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
+ platform := ""
if input.GroupIDs != nil {
+ // 检查混合渠道风险(除非用户已确认)
+ if !input.SkipMixedChannelCheck {
+ platform = platformByID[accountID]
+ if platform == "" {
+ account, err := s.accountRepo.GetByID(ctx, accountID)
+ if err != nil {
+ entry.Success = false
+ entry.Error = err.Error()
+ result.Failed++
+ result.FailedIDs = append(result.FailedIDs, accountID)
+ result.Results = append(result.Results, entry)
+ continue
+ }
+ platform = account.Platform
+ }
+ if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil {
+ entry.Success = false
+ entry.Error = err.Error()
+ result.Failed++
+ result.FailedIDs = append(result.FailedIDs, accountID)
+ result.Results = append(result.Results, entry)
+ continue
+ }
+ }
+
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
@@ -1641,6 +1755,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result.Results = append(result.Results, entry)
continue
}
+ if !input.SkipMixedChannelCheck && platform != "" {
+ updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform)
+ }
}
entry.Success = true
@@ -2028,6 +2145,7 @@ func (s *adminServiceImpl) CheckProxyQuality(ctx context.Context, id int64) (*Pr
ProxyURL: proxyURL,
Timeout: proxyQualityRequestTimeout,
ResponseHeaderTimeout: proxyQualityResponseHeaderTimeout,
+ ProxyStrict: true,
})
if err != nil {
result.Items = append(result.Items, ProxyQualityCheckItem{
@@ -2311,6 +2429,41 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
+func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) {
+ accountsByGroup := make(map[int64][]Account)
+ groupNameByID := make(map[int64]string)
+ if len(groupIDs) == 0 {
+ return accountsByGroup, groupNameByID, nil
+ }
+
+ seen := make(map[int64]struct{}, len(groupIDs))
+ for _, groupID := range groupIDs {
+ if groupID <= 0 {
+ continue
+ }
+ if _, ok := seen[groupID]; ok {
+ continue
+ }
+ seen[groupID] = struct{}{}
+
+ accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
+ if err != nil {
+ return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err)
+ }
+ accountsByGroup[groupID] = accounts
+
+ group, err := s.groupRepo.GetByID(ctx, groupID)
+ if err != nil {
+ continue
+ }
+ if group != nil {
+ groupNameByID[groupID] = group.Name
+ }
+ }
+
+ return accountsByGroup, groupNameByID, nil
+}
+
func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
@@ -2340,6 +2493,71 @@ func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs [
return nil
}
+func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error {
+ currentPlatform := getAccountPlatform(currentAccountPlatform)
+ if currentPlatform == "" {
+ return nil
+ }
+
+ for _, groupID := range groupIDs {
+ accounts := accountsByGroup[groupID]
+ for _, account := range accounts {
+ if currentAccountID > 0 && account.ID == currentAccountID {
+ continue
+ }
+
+ otherPlatform := getAccountPlatform(account.Platform)
+ if otherPlatform == "" {
+ continue
+ }
+
+ if currentPlatform != otherPlatform {
+ groupName := fmt.Sprintf("Group %d", groupID)
+ if name := strings.TrimSpace(groupNameByID[groupID]); name != "" {
+ groupName = name
+ }
+
+ return &MixedChannelError{
+ GroupID: groupID,
+ GroupName: groupName,
+ CurrentPlatform: currentPlatform,
+ OtherPlatform: otherPlatform,
+ }
+ }
+ }
+ }
+
+ return nil
+}
+
+func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) {
+ if len(groupIDs) == 0 || accountID <= 0 || platform == "" {
+ return
+ }
+ for _, groupID := range groupIDs {
+ if groupID <= 0 {
+ continue
+ }
+ accounts := accountsByGroup[groupID]
+ found := false
+ for i := range accounts {
+ if accounts[i].ID != accountID {
+ continue
+ }
+ accounts[i].Platform = platform
+ found = true
+ break
+ }
+ if !found {
+ accounts = append(accounts, Account{
+ ID: accountID,
+ Platform: platform,
+ })
+ }
+ accountsByGroup[groupID] = accounts
+ }
+}
+
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go
index 4845d87c1..3fb14cae7 100644
--- a/backend/internal/service/admin_service_bulk_update_test.go
+++ b/backend/internal/service/admin_service_bulk_update_test.go
@@ -12,23 +12,25 @@ import (
type accountRepoStubForBulkUpdate struct {
accountRepoStub
- bulkUpdateErr error
- bulkUpdateIDs []int64
- bindGroupErrByID map[int64]error
- bindGroupsCalls []int64
- getByIDsAccounts []*Account
- getByIDsErr error
- getByIDsCalled bool
- getByIDsIDs []int64
- getByIDAccounts map[int64]*Account
- getByIDErrByID map[int64]error
- getByIDCalled []int64
- listByGroupData map[int64][]Account
- listByGroupErr map[int64]error
+ bulkUpdateErr error
+ bulkUpdateIDs []int64
+ bulkUpdatePayload AccountBulkUpdate
+ bindGroupErrByID map[int64]error
+ bindGroupsCalls []int64
+ getByIDsAccounts []*Account
+ getByIDsErr error
+ getByIDsCalled bool
+ getByIDsIDs []int64
+ getByIDAccounts map[int64]*Account
+ getByIDErrByID map[int64]error
+ getByIDCalled []int64
+ listByGroupData map[int64][]Account
+ listByGroupErr map[int64]error
}
-func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
+func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
s.bulkUpdateIDs = append([]int64{}, ids...)
+ s.bulkUpdatePayload = updates
if s.bulkUpdateErr != nil {
return 0, s.bulkUpdateErr
}
@@ -139,34 +141,134 @@ func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T)
require.Contains(t, err.Error(), "group repository not configured")
}
-// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies
-// that the global pre-check detects a conflict with existing group members and returns an
-// error before any DB write is performed.
-func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict(t *testing.T) {
+func TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
getByIDsAccounts: []*Account{
- {ID: 1, Platform: PlatformAntigravity},
+ {ID: 1, Platform: PlatformAnthropic},
+ {ID: 2, Platform: PlatformAntigravity},
},
- // Group 10 already contains an Anthropic account.
listByGroupData: map[int64][]Account{
- 10: {{ID: 99, Platform: PlatformAnthropic}},
+ 10: {},
},
}
svc := &adminServiceImpl{
accountRepo: repo,
- groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "target-group"}},
+ groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "目标分组"}},
}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
- AccountIDs: []int64{1},
+ AccountIDs: []int64{1, 2},
GroupIDs: &groupIDs,
}
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.NoError(t, err)
+ require.Equal(t, 1, result.Success)
+ require.Equal(t, 1, result.Failed)
+ require.ElementsMatch(t, []int64{1}, result.SuccessIDs)
+ require.ElementsMatch(t, []int64{2}, result.FailedIDs)
+ require.Len(t, result.Results, 2)
+ require.Contains(t, result.Results[1].Error, "mixed channel")
+ require.Equal(t, []int64{1}, repo.bindGroupsCalls)
+}
+
+func TestAdminService_BulkUpdateAccounts_ForwardsAutoPauseOnExpired(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{}
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ autoPause := true
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{101},
+ AutoPauseOnExpired: &autoPause,
+ }
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.NoError(t, err)
+ require.Equal(t, 1, result.Success)
+ require.NotNil(t, repo.bulkUpdatePayload.AutoPauseOnExpired)
+ require.True(t, *repo.bulkUpdatePayload.AutoPauseOnExpired)
+}
+
+func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsMixedTypes(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{
+ getByIDsAccounts: []*Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
+ {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey},
+ },
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{1, 2},
+ Extra: map[string]any{"openai_passthrough": true},
+ }
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.Nil(t, result)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "same type")
+ require.Empty(t, repo.bulkUpdateIDs)
+ require.True(t, repo.getByIDsCalled)
+}
+
+func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsNonOpenAIPlatform(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{
+ getByIDsAccounts: []*Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
+ {ID: 2, Platform: PlatformAnthropic, Type: AccountTypeOAuth},
+ },
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{1, 2},
+ Extra: map[string]any{"openai_oauth_responses_websockets_v2_mode": "shared"},
+ }
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.Nil(t, result)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "OpenAI")
+ require.Empty(t, repo.bulkUpdateIDs)
+}
+
+func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraAllowsSameTypeOpenAI(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{
+ getByIDsAccounts: []*Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
+ {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
+ },
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{1, 2},
+ Extra: map[string]any{"codex_cli_only": true},
+ }
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.NoError(t, err)
+ require.Equal(t, 2, result.Success)
+ require.ElementsMatch(t, []int64{1, 2}, repo.bulkUpdateIDs)
+}
+
+func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsMissingAccount(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{
+ getByIDsAccounts: []*Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
+ },
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{1, 2},
+ Extra: map[string]any{"openai_passthrough": true},
+ }
+
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.Nil(t, result)
require.Error(t, err)
- require.Contains(t, err.Error(), "mixed channel")
- // No BindGroups should have been called since the check runs before any write.
- require.Empty(t, repo.bindGroupsCalls)
+ require.Contains(t, err.Error(), "not found")
+ require.Empty(t, repo.bulkUpdateIDs)
}
diff --git a/backend/internal/service/admin_service_create_openai_ws_mode_test.go b/backend/internal/service/admin_service_create_openai_ws_mode_test.go
new file mode 100644
index 000000000..d8d170998
--- /dev/null
+++ b/backend/internal/service/admin_service_create_openai_ws_mode_test.go
@@ -0,0 +1,61 @@
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+type accountRepoStubForCreateModeValidation struct {
+ AccountRepository
+ createCalled bool
+ createErr error
+}
+
+func (s *accountRepoStubForCreateModeValidation) Create(_ context.Context, account *Account) error {
+ s.createCalled = true
+ if s.createErr != nil {
+ return s.createErr
+ }
+ if account != nil && account.ID == 0 {
+ account.ID = 1
+ }
+ return nil
+}
+
+func TestAdminService_CreateAccount_RejectsInvalidOpenAIWSMode(t *testing.T) {
+ repo := &accountRepoStubForCreateModeValidation{}
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ account, err := svc.CreateAccount(context.Background(), &CreateAccountInput{
+ Name: "ws-mode-invalid",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{"api_key": "sk-test"},
+ Extra: map[string]any{"openai_apikey_responses_websockets_v2_mode": "shared"},
+ SkipDefaultGroupBind: true,
+ })
+ require.Nil(t, account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "INVALID_OPENAI_WS_MODE")
+ require.False(t, repo.createCalled)
+}
+
+func TestAdminService_CreateAccount_AcceptsValidOpenAIWSMode(t *testing.T) {
+ repo := &accountRepoStubForCreateModeValidation{}
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ account, err := svc.CreateAccount(context.Background(), &CreateAccountInput{
+ Name: "ws-mode-valid",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{"api_key": "sk-test"},
+ Extra: map[string]any{"openai_apikey_responses_websockets_v2_mode": "passthrough"},
+ SkipDefaultGroupBind: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, account)
+ require.True(t, repo.createCalled)
+ require.Equal(t, int64(1), account.ID)
+}
diff --git a/backend/internal/service/admin_service_create_user_test.go b/backend/internal/service/admin_service_create_user_test.go
index c5b1e38d3..a0fe4d87b 100644
--- a/backend/internal/service/admin_service_create_user_test.go
+++ b/backend/internal/service/admin_service_create_user_test.go
@@ -7,7 +7,6 @@ import (
"errors"
"testing"
- "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
@@ -66,32 +65,3 @@ func TestAdminService_CreateUser_CreateError(t *testing.T) {
require.ErrorIs(t, err, createErr)
require.Empty(t, repo.created)
}
-
-func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) {
- repo := &userRepoStub{nextID: 21}
- assigner := &defaultSubscriptionAssignerStub{}
- cfg := &config.Config{
- Default: config.DefaultConfig{
- UserBalance: 0,
- UserConcurrency: 1,
- },
- }
- settingService := NewSettingService(&settingRepoStub{values: map[string]string{
- SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`,
- }}, cfg)
- svc := &adminServiceImpl{
- userRepo: repo,
- settingService: settingService,
- defaultSubAssigner: assigner,
- }
-
- _, err := svc.CreateUser(context.Background(), &CreateUserInput{
- Email: "new-user@test.com",
- Password: "password",
- })
- require.NoError(t, err)
- require.Len(t, assigner.calls, 1)
- require.Equal(t, int64(21), assigner.calls[0].UserID)
- require.Equal(t, int64(5), assigner.calls[0].GroupID)
- require.Equal(t, 30, assigner.calls[0].ValidityDays)
-}
diff --git a/backend/internal/service/admin_service_openai_ws_mode_validation_test.go b/backend/internal/service/admin_service_openai_ws_mode_validation_test.go
new file mode 100644
index 000000000..820342f3b
--- /dev/null
+++ b/backend/internal/service/admin_service_openai_ws_mode_validation_test.go
@@ -0,0 +1,52 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestValidateOpenAIWSModeExtraValues(t *testing.T) {
+ t.Parallel()
+
+ t.Run("accepts supported mode values", func(t *testing.T) {
+ t.Parallel()
+
+ err := validateOpenAIWSModeExtraValues(map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": " passthrough ",
+ "openai_apikey_responses_websockets_v2_mode": "CTX_POOL",
+ })
+ require.NoError(t, err)
+ })
+
+ t.Run("accepts missing mode keys", func(t *testing.T) {
+ t.Parallel()
+
+ err := validateOpenAIWSModeExtraValues(map[string]any{
+ "codex_cli_only": true,
+ })
+ require.NoError(t, err)
+ })
+
+ t.Run("rejects invalid mode value", func(t *testing.T) {
+ t.Parallel()
+
+ err := validateOpenAIWSModeExtraValues(map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": "shared",
+ })
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "INVALID_OPENAI_WS_MODE")
+ require.Contains(t, err.Error(), "off, ctx_pool, passthrough")
+ })
+
+ t.Run("rejects non-string mode value", func(t *testing.T) {
+ t.Parallel()
+
+ err := validateOpenAIWSModeExtraValues(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": true,
+ })
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "INVALID_OPENAI_WS_MODE")
+ require.Contains(t, err.Error(), "must be a string")
+ })
+}
diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go
index 5f6691be2..b67c7fafa 100644
--- a/backend/internal/service/antigravity_oauth_service.go
+++ b/backend/internal/service/antigravity_oauth_service.go
@@ -112,10 +112,7 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig
}
}
- client, err := antigravity.NewClient(proxyURL)
- if err != nil {
- return nil, fmt.Errorf("create antigravity client failed: %w", err)
- }
+ client := antigravity.NewClient(proxyURL)
// 交换 token
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
@@ -170,10 +167,7 @@ func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken
time.Sleep(backoff)
}
- client, err := antigravity.NewClient(proxyURL)
- if err != nil {
- return nil, fmt.Errorf("create antigravity client failed: %w", err)
- }
+ client := antigravity.NewClient(proxyURL)
tokenResp, err := client.RefreshToken(ctx, refreshToken)
if err == nil {
now := time.Now()
@@ -215,10 +209,7 @@ func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refr
}
// 获取用户信息(email)
- client, err := antigravity.NewClient(proxyURL)
- if err != nil {
- return nil, fmt.Errorf("create antigravity client failed: %w", err)
- }
+ client := antigravity.NewClient(proxyURL)
userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
@@ -318,10 +309,7 @@ func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, ac
time.Sleep(backoff)
}
- client, err := antigravity.NewClient(proxyURL)
- if err != nil {
- return "", fmt.Errorf("create antigravity client failed: %w", err)
- }
+ client := antigravity.NewClient(proxyURL)
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go
index e950ec1d9..07eb563d0 100644
--- a/backend/internal/service/antigravity_quota_fetcher.go
+++ b/backend/internal/service/antigravity_quota_fetcher.go
@@ -2,7 +2,6 @@ package service
import (
"context"
- "fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
@@ -32,10 +31,7 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
- client, err := antigravity.NewClient(proxyURL)
- if err != nil {
- return nil, fmt.Errorf("create antigravity client failed: %w", err)
- }
+ client := antigravity.NewClient(proxyURL)
// 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go
index fe3a0f258..eae7bd539 100644
--- a/backend/internal/service/auth_service.go
+++ b/backend/internal/service/auth_service.go
@@ -56,20 +56,15 @@ type JWTClaims struct {
// AuthService 认证服务
type AuthService struct {
- userRepo UserRepository
- redeemRepo RedeemCodeRepository
- refreshTokenCache RefreshTokenCache
- cfg *config.Config
- settingService *SettingService
- emailService *EmailService
- turnstileService *TurnstileService
- emailQueueService *EmailQueueService
- promoService *PromoService
- defaultSubAssigner DefaultSubscriptionAssigner
-}
-
-type DefaultSubscriptionAssigner interface {
- AssignOrExtendSubscription(ctx context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error)
+ userRepo UserRepository
+ redeemRepo RedeemCodeRepository
+ refreshTokenCache RefreshTokenCache
+ cfg *config.Config
+ settingService *SettingService
+ emailService *EmailService
+ turnstileService *TurnstileService
+ emailQueueService *EmailQueueService
+ promoService *PromoService
}
// NewAuthService 创建认证服务实例
@@ -83,19 +78,17 @@ func NewAuthService(
turnstileService *TurnstileService,
emailQueueService *EmailQueueService,
promoService *PromoService,
- defaultSubAssigner DefaultSubscriptionAssigner,
) *AuthService {
return &AuthService{
- userRepo: userRepo,
- redeemRepo: redeemRepo,
- refreshTokenCache: refreshTokenCache,
- cfg: cfg,
- settingService: settingService,
- emailService: emailService,
- turnstileService: turnstileService,
- emailQueueService: emailQueueService,
- promoService: promoService,
- defaultSubAssigner: defaultSubAssigner,
+ userRepo: userRepo,
+ redeemRepo: redeemRepo,
+ refreshTokenCache: refreshTokenCache,
+ cfg: cfg,
+ settingService: settingService,
+ emailService: emailService,
+ turnstileService: turnstileService,
+ emailQueueService: emailQueueService,
+ promoService: promoService,
}
}
@@ -195,7 +188,6 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw
logger.LegacyPrintf("service.auth", "[Auth] Database error creating user: %v", err)
return "", nil, ErrServiceUnavailable
}
- s.assignDefaultSubscriptions(ctx, user.ID)
// 标记邀请码为已使用(如果使用了邀请码)
if invitationRedeemCode != nil {
@@ -485,7 +477,6 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username
}
} else {
user = newUser
- s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -581,7 +572,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
}
} else {
user = newUser
- s.assignDefaultSubscriptions(ctx, user.ID)
}
} else {
logger.LegacyPrintf("service.auth", "[Auth] Database error during oauth login: %v", err)
@@ -607,23 +597,6 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema
return tokenPair, user, nil
}
-func (s *AuthService) assignDefaultSubscriptions(ctx context.Context, userID int64) {
- if s.settingService == nil || s.defaultSubAssigner == nil || userID <= 0 {
- return
- }
- items := s.settingService.GetDefaultSubscriptions(ctx)
- for _, item := range items {
- if _, _, err := s.defaultSubAssigner.AssignOrExtendSubscription(ctx, &AssignSubscriptionInput{
- UserID: userID,
- GroupID: item.GroupID,
- ValidityDays: item.ValidityDays,
- Notes: "auto assigned by default user subscriptions setting",
- }); err != nil {
- logger.LegacyPrintf("service.auth", "[Auth] Failed to assign default subscription: user_id=%d group_id=%d err=%v", userID, item.GroupID, err)
- }
- }
-}
-
// ValidateToken 验证JWT token并返回用户声明
func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) {
// 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。
diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go
index 1999e759e..93659743f 100644
--- a/backend/internal/service/auth_service_register_test.go
+++ b/backend/internal/service/auth_service_register_test.go
@@ -56,21 +56,6 @@ type emailCacheStub struct {
err error
}
-type defaultSubscriptionAssignerStub struct {
- calls []AssignSubscriptionInput
- err error
-}
-
-func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
- if input != nil {
- s.calls = append(s.calls, *input)
- }
- if s.err != nil {
- return nil, false, s.err
- }
- return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
-}
-
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
@@ -138,7 +123,6 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E
nil,
nil,
nil, // promoService
- nil, // defaultSubAssigner
)
}
@@ -397,23 +381,3 @@ func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) {
require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second)
}
-
-func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
- repo := &userRepoStub{nextID: 42}
- assigner := &defaultSubscriptionAssignerStub{}
- service := newAuthService(repo, map[string]string{
- SettingKeyRegistrationEnabled: "true",
- SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
- }, nil)
- service.defaultSubAssigner = assigner
-
- _, user, err := service.Register(context.Background(), "default-sub@test.com", "password")
- require.NoError(t, err)
- require.NotNil(t, user)
- require.Len(t, assigner.calls, 2)
- require.Equal(t, int64(42), assigner.calls[0].UserID)
- require.Equal(t, int64(11), assigner.calls[0].GroupID)
- require.Equal(t, 30, assigner.calls[0].ValidityDays)
- require.Equal(t, int64(12), assigner.calls[1].GroupID)
- require.Equal(t, 7, assigner.calls[1].ValidityDays)
-}
diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go
index 36cb1e065..7dd9edca8 100644
--- a/backend/internal/service/auth_service_turnstile_register_test.go
+++ b/backend/internal/service/auth_service_turnstile_register_test.go
@@ -52,7 +52,6 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
turnstileService,
nil, // emailQueueService
nil, // promoService
- nil, // defaultSubAssigner
)
}
diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go
index e055c0f78..34038f5b2 100644
--- a/backend/internal/service/billing_cache_service.go
+++ b/backend/internal/service/billing_cache_service.go
@@ -2,6 +2,7 @@ package service
import (
"context"
+ "errors"
"fmt"
"strconv"
"sync"
@@ -193,7 +194,7 @@ func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
}
case cacheWriteDeductBalance:
if s.cache != nil {
- if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
+ if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil && !errors.Is(err, ErrBalanceCacheNotFound) {
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err)
}
}
@@ -335,7 +336,13 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
if s.cache == nil {
return nil
}
- return s.cache.DeductUserBalance(ctx, userID, amount)
+ err := s.cache.DeductUserBalance(ctx, userID, amount)
+ if errors.Is(err, ErrBalanceCacheNotFound) {
+ // 缓存 key 不存在(已过期),无法原子扣减,不阻塞主流程。
+ // 下次 GetUserBalance 将从数据库回源重建缓存。
+ return nil
+ }
+ return err
}
// QueueDeductBalance 异步扣减余额缓存
diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go
index f71098b16..d3a4d119b 100644
--- a/backend/internal/service/claude_code_validator.go
+++ b/backend/internal/service/claude_code_validator.go
@@ -4,7 +4,6 @@ import (
"context"
"net/http"
"regexp"
- "strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
@@ -18,9 +17,6 @@ var (
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感)
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
- // 带捕获组的版本提取正则
- claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
-
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
@@ -274,55 +270,3 @@ func IsClaudeCodeClient(ctx context.Context) bool {
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
}
-
-// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
-// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
-func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
- matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
- if len(matches) >= 2 {
- return matches[1]
- }
- return ""
-}
-
-// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
-func SetClaudeCodeVersion(ctx context.Context, version string) context.Context {
- return context.WithValue(ctx, ctxkey.ClaudeCodeVersion, version)
-}
-
-// GetClaudeCodeVersion 从 context 中获取 Claude Code 版本号
-func GetClaudeCodeVersion(ctx context.Context) string {
- if v, ok := ctx.Value(ctxkey.ClaudeCodeVersion).(string); ok {
- return v
- }
- return ""
-}
-
-// CompareVersions 比较两个 semver 版本号
-// 返回: -1 (a < b), 0 (a == b), 1 (a > b)
-func CompareVersions(a, b string) int {
- aParts := parseSemver(a)
- bParts := parseSemver(b)
- for i := 0; i < 3; i++ {
- if aParts[i] < bParts[i] {
- return -1
- }
- if aParts[i] > bParts[i] {
- return 1
- }
- }
- return 0
-}
-
-// parseSemver 解析 semver 版本号为 [major, minor, patch]
-func parseSemver(v string) [3]int {
- v = strings.TrimPrefix(v, "v")
- parts := strings.Split(v, ".")
- result := [3]int{0, 0, 0}
- for i := 0; i < len(parts) && i < 3; i++ {
- if parsed, err := strconv.Atoi(parts[i]); err == nil {
- result[i] = parsed
- }
- }
- return result
-}
diff --git a/backend/internal/service/claude_code_validator_test.go b/backend/internal/service/claude_code_validator_test.go
index f87c56e83..a4cd18866 100644
--- a/backend/internal/service/claude_code_validator_test.go
+++ b/backend/internal/service/claude_code_validator_test.go
@@ -56,51 +56,3 @@ func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) {
ok := validator.Validate(req, nil)
require.True(t, ok)
}
-
-func TestExtractVersion(t *testing.T) {
- v := NewClaudeCodeValidator()
- tests := []struct {
- ua string
- want string
- }{
- {"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"},
- {"claude-cli/1.0.0", "1.0.0"},
- {"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感
- {"curl/8.0.0", ""}, // 非 Claude CLI
- {"", ""}, // 空字符串
- {"claude-cli/", ""}, // 无版本号
- {"claude-cli/2.1.22-beta", "2.1.22"}, // 带后缀仍提取主版本号
- }
- for _, tt := range tests {
- got := v.ExtractVersion(tt.ua)
- require.Equal(t, tt.want, got, "ExtractVersion(%q)", tt.ua)
- }
-}
-
-func TestCompareVersions(t *testing.T) {
- tests := []struct {
- a, b string
- want int
- }{
- {"2.1.0", "2.1.0", 0}, // 相等
- {"2.1.1", "2.1.0", 1}, // patch 更大
- {"2.0.0", "2.1.0", -1}, // minor 更小
- {"3.0.0", "2.99.99", 1}, // major 更大
- {"1.0.0", "2.0.0", -1}, // major 更小
- {"0.0.1", "0.0.0", 1}, // patch 差异
- {"", "1.0.0", -1}, // 空字符串 vs 正常版本
- {"v2.1.0", "2.1.0", 0}, // v 前缀处理
- }
- for _, tt := range tests {
- got := CompareVersions(tt.a, tt.b)
- require.Equal(t, tt.want, got, "CompareVersions(%q, %q)", tt.a, tt.b)
- }
-}
-
-func TestSetGetClaudeCodeVersion(t *testing.T) {
- ctx := context.Background()
- require.Equal(t, "", GetClaudeCodeVersion(ctx), "empty context should return empty string")
-
- ctx = SetClaudeCodeVersion(ctx, "2.1.63")
- require.Equal(t, "2.1.63", GetClaudeCodeVersion(ctx))
-}
diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go
index f6cab204d..996c829cd 100644
--- a/backend/internal/service/claude_token_provider.go
+++ b/backend/internal/service/claude_token_provider.go
@@ -15,6 +15,20 @@ const (
claudeLockWaitTime = 200 * time.Millisecond
)
+func waitClaudeLockRetry(ctx context.Context, wait time.Duration) error {
+ if wait <= 0 {
+ return nil
+ }
+ timer := time.NewTimer(wait)
+ defer timer.Stop()
+ select {
+ case <-timer.C:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type ClaudeTokenCache = GeminiTokenCache
@@ -168,7 +182,9 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
- time.Sleep(claudeLockWaitTime)
+ if waitErr := waitClaudeLockRetry(ctx, claudeLockWaitTime); waitErr != nil {
+ return "", waitErr
+ }
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
diff --git a/backend/internal/service/claude_token_provider_test.go b/backend/internal/service/claude_token_provider_test.go
index 3e21f6f4a..09f3a31e8 100644
--- a/backend/internal/service/claude_token_provider_test.go
+++ b/backend/internal/service/claude_token_provider_test.go
@@ -800,6 +800,34 @@ func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
require.NotEmpty(t, token)
}
+func TestClaudeTokenProvider_Real_LockFailedWait_ContextCanceled(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ cache.lockAcquired = false
+
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 3001,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "fallback-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+ start := time.Now()
+ token, err := provider.GetAccessToken(ctx, account)
+ elapsed := time.Since(start)
+
+ require.ErrorIs(t, err, context.Canceled)
+ require.Empty(t, token)
+ require.Less(t, elapsed, claudeLockWaitTime/2, "context canceled should short-circuit lock wait")
+}
+
func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
diff --git a/backend/internal/service/crs_sync_service.go b/backend/internal/service/crs_sync_service.go
index 6a9167400..040b2357b 100644
--- a/backend/internal/service/crs_sync_service.go
+++ b/backend/internal/service/crs_sync_service.go
@@ -221,7 +221,7 @@ func (s *CRSSyncService) fetchCRSExport(ctx context.Context, baseURL, username,
AllowPrivateHosts: s.cfg.Security.URLAllowlist.AllowPrivateHosts,
})
if err != nil {
- return nil, fmt.Errorf("create http client failed: %w", err)
+ client = &http.Client{Timeout: 20 * time.Second}
}
adminToken, err := crsLogin(ctx, client, normalizedURL, username, password)
diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go
index 2af43386b..4528def3d 100644
--- a/backend/internal/service/dashboard_service.go
+++ b/backend/internal/service/dashboard_service.go
@@ -140,14 +140,6 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
-func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
- stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
- if err != nil {
- return nil, fmt.Errorf("get group stats with filters: %w", err)
- }
- return stats, nil
-}
-
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx)
if err != nil {
diff --git a/backend/internal/service/data_management_service.go b/backend/internal/service/data_management_service.go
index b525c0fae..b0d4d6da8 100644
--- a/backend/internal/service/data_management_service.go
+++ b/backend/internal/service/data_management_service.go
@@ -63,11 +63,11 @@ func NewDataManagementService() *DataManagementService {
}
func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService {
- _ = dialTimeout
path := strings.TrimSpace(socketPath)
if path == "" {
path = DefaultDataManagementAgentSocketPath
}
+ _ = dialTimeout
return &DataManagementService{
socketPath: path,
}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index df2130027..ce5710945 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -118,9 +118,8 @@ const (
SettingKeyCustomMenuItems = "custom_menu_items" // 自定义菜单项(JSON 数组)
// 默认配置
- SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
- SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
- SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON)
+ SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
+ SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
@@ -128,6 +127,9 @@ const (
// Gemini 配额策略(JSON)
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
+ // Bulk edit template library(JSON)
+ SettingKeyBulkEditTemplateLibrary = "bulk_edit_template_library_v1"
+
// Model fallback settings
SettingKeyEnableModelFallback = "enable_model_fallback"
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
@@ -194,13 +196,6 @@ const (
// =========================
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
-
- // =========================
- // Claude Code Version Check
- // =========================
-
- // SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查)
- SettingKeyMinClaudeCodeVersion = "min_claude_code_version"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 02f9a6a3a..0385ab562 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -520,7 +520,6 @@ type GatewayService struct {
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
- rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
modelsListCache *gocache.Cache
@@ -550,7 +549,6 @@ func NewGatewayService(
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
- rpmCache RPMCache,
digestStore *DigestSessionStore,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
@@ -576,7 +574,6 @@ func NewGatewayService(
deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache,
- rpmCache: rpmCache,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
modelsListCache: gocache.New(modelsListTTL, time.Minute),
modelsListCacheTTL: modelsListTTL,
@@ -1157,7 +1154,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts")
}
ctx = s.withWindowCostPrefetch(ctx, accounts)
- ctx = s.withRPMPrefetch(ctx, accounts)
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
@@ -1233,10 +1229,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredWindowCost++
continue
}
- // RPM 检查(非粘性会话路径)
- if !s.isAccountSchedulableForRPM(ctx, account, false) {
- continue
- }
routingCandidates = append(routingCandidates, account)
}
@@ -1260,9 +1252,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
- s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
-
- s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
+ s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
@@ -1416,9 +1406,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
- s.isAccountSchedulableForWindowCost(ctx, account, true) &&
-
- s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
+ s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
@@ -1484,10 +1472,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
- // RPM 检查(非粘性会话路径)
- if !s.isAccountSchedulableForRPM(ctx, acc, false) {
- continue
- }
candidates = append(candidates, acc)
}
@@ -2174,88 +2158,6 @@ checkSchedulability:
return true
}
-// rpmPrefetchContextKey is the context key for prefetched RPM counts.
-type rpmPrefetchContextKeyType struct{}
-
-var rpmPrefetchContextKey = rpmPrefetchContextKeyType{}
-
-func rpmFromPrefetchContext(ctx context.Context, accountID int64) (int, bool) {
- if v, ok := ctx.Value(rpmPrefetchContextKey).(map[int64]int); ok {
- count, found := v[accountID]
- return count, found
- }
- return 0, false
-}
-
-// withRPMPrefetch 批量预取所有候选账号的 RPM 计数
-func (s *GatewayService) withRPMPrefetch(ctx context.Context, accounts []Account) context.Context {
- if s.rpmCache == nil {
- return ctx
- }
-
- var ids []int64
- for i := range accounts {
- if accounts[i].IsAnthropicOAuthOrSetupToken() && accounts[i].GetBaseRPM() > 0 {
- ids = append(ids, accounts[i].ID)
- }
- }
- if len(ids) == 0 {
- return ctx
- }
-
- counts, err := s.rpmCache.GetRPMBatch(ctx, ids)
- if err != nil {
- return ctx // 失败开放
- }
- return context.WithValue(ctx, rpmPrefetchContextKey, counts)
-}
-
-// isAccountSchedulableForRPM 检查账号是否可根据 RPM 进行调度
-// 仅适用于 Anthropic OAuth/SetupToken 账号
-func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account *Account, isSticky bool) bool {
- if !account.IsAnthropicOAuthOrSetupToken() {
- return true
- }
- baseRPM := account.GetBaseRPM()
- if baseRPM <= 0 {
- return true
- }
-
- // 尝试从预取缓存获取
- var currentRPM int
- if count, ok := rpmFromPrefetchContext(ctx, account.ID); ok {
- currentRPM = count
- } else if s.rpmCache != nil {
- if count, err := s.rpmCache.GetRPM(ctx, account.ID); err == nil {
- currentRPM = count
- }
- // 失败开放:GetRPM 错误时允许调度
- }
-
- schedulability := account.CheckRPMSchedulability(currentRPM)
- switch schedulability {
- case WindowCostSchedulable:
- return true
- case WindowCostStickyOnly:
- return isSticky
- case WindowCostNotSchedulable:
- return false
- }
- return true
-}
-
-// IncrementAccountRPM increments the RPM counter for the given account.
-// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口,
-// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit
-// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。
-func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error {
- if s.rpmCache == nil {
- return nil
- }
- _, err := s.rpmCache.IncrementRPM(ctx, accountID)
- return err
-}
-
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash)
@@ -2450,7 +2352,7 @@ func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
//
// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。
-// 因此这里采用"组内分区 + 分区内 shuffle"的方式:
+// 因此这里采用“组内分区 + 分区内 shuffle”的方式:
// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前;
// - 再分别在各段内随机打散,避免热点。
func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
@@ -2590,7 +2492,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -2613,10 +2515,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
accountsLoaded = true
- // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存
- ctx = s.withWindowCostPrefetch(ctx, accounts)
- ctx = s.withRPMPrefetch(ctx, accounts)
-
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
@@ -2644,12 +2542,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
- if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
- continue
- }
- if !s.isAccountSchedulableForRPM(ctx, acc, false) {
- continue
- }
if selected == nil {
selected = acc
continue
@@ -2700,7 +2592,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
return account, nil
}
}
@@ -2721,10 +2613,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
}
- // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1)
- ctx = s.withWindowCostPrefetch(ctx, accounts)
- ctx = s.withRPMPrefetch(ctx, accounts)
-
// 3. 按优先级+最久未用选择(考虑模型支持)
var selected *Account
for i := range accounts {
@@ -2743,12 +2631,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
- if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
- continue
- }
- if !s.isAccountSchedulableForRPM(ctx, acc, false) {
- continue
- }
if selected == nil {
selected = acc
continue
@@ -2818,7 +2700,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
@@ -2839,10 +2721,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
accountsLoaded = true
- // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存
- ctx = s.withWindowCostPrefetch(ctx, accounts)
- ctx = s.withRPMPrefetch(ctx, accounts)
-
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
@@ -2874,12 +2752,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
- if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
- continue
- }
- if !s.isAccountSchedulableForRPM(ctx, acc, false) {
- continue
- }
if selected == nil {
selected = acc
continue
@@ -2930,7 +2802,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil
}
@@ -2949,10 +2821,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
}
- // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1)
- ctx = s.withWindowCostPrefetch(ctx, accounts)
- ctx = s.withRPMPrefetch(ctx, accounts)
-
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
var selected *Account
for i := range accounts {
@@ -2975,12 +2843,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
- if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
- continue
- }
- if !s.isAccountSchedulableForRPM(ctx, acc, false) {
- continue
- }
if selected == nil {
selected = acc
continue
@@ -5332,7 +5194,7 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
- // 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。
+ // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
@@ -6522,7 +6384,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
- logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
+ return fmt.Errorf("create usage log: %w", err)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
@@ -6531,7 +6393,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return nil
}
- shouldBill := inserted || err != nil
+ shouldBill := inserted
// 根据计费类型执行扣费
if isSubscriptionBilling {
@@ -6720,7 +6582,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
- logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
+ return fmt.Errorf("create usage log: %w", err)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
@@ -6729,7 +6591,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
return nil
}
- shouldBill := inserted || err != nil
+ shouldBill := inserted
// 根据计费类型执行扣费
if isSubscriptionBilling {
diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go
index 08a74a372..e866bdc36 100644
--- a/backend/internal/service/gemini_oauth_service.go
+++ b/backend/internal/service/gemini_oauth_service.go
@@ -1045,7 +1045,7 @@ func fetchProjectIDFromResourceManager(ctx context.Context, accessToken, proxyUR
ValidateResolvedIP: true,
})
if err != nil {
- return "", fmt.Errorf("create http client failed: %w", err)
+ client = &http.Client{Timeout: 30 * time.Second}
}
resp, err := client.Do(req)
diff --git a/backend/internal/service/http_upstream_profile.go b/backend/internal/service/http_upstream_profile.go
new file mode 100644
index 000000000..fef1dc191
--- /dev/null
+++ b/backend/internal/service/http_upstream_profile.go
@@ -0,0 +1,41 @@
+package service
+
+import "context"
+
+// HTTPUpstreamProfile 标识上游 HTTP 请求的协议策略分类。
+type HTTPUpstreamProfile string
+
+const (
+ HTTPUpstreamProfileDefault HTTPUpstreamProfile = ""
+ HTTPUpstreamProfileOpenAI HTTPUpstreamProfile = "openai"
+)
+
+type httpUpstreamProfileContextKey struct{}
+
+// WithHTTPUpstreamProfile 在请求上下文中注入上游协议策略分类。
+func WithHTTPUpstreamProfile(ctx context.Context, profile HTTPUpstreamProfile) context.Context {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ if profile == HTTPUpstreamProfileDefault {
+ return ctx
+ }
+ return context.WithValue(ctx, httpUpstreamProfileContextKey{}, profile)
+}
+
+// HTTPUpstreamProfileFromContext 从请求上下文中解析上游协议策略分类。
+func HTTPUpstreamProfileFromContext(ctx context.Context) HTTPUpstreamProfile {
+ if ctx == nil {
+ return HTTPUpstreamProfileDefault
+ }
+ profile, ok := ctx.Value(httpUpstreamProfileContextKey{}).(HTTPUpstreamProfile)
+ if !ok {
+ return HTTPUpstreamProfileDefault
+ }
+ switch profile {
+ case HTTPUpstreamProfileOpenAI:
+ return profile
+ default:
+ return HTTPUpstreamProfileDefault
+ }
+}
diff --git a/backend/internal/service/http_upstream_profile_test.go b/backend/internal/service/http_upstream_profile_test.go
new file mode 100644
index 000000000..0ed14d93d
--- /dev/null
+++ b/backend/internal/service/http_upstream_profile_test.go
@@ -0,0 +1,33 @@
+package service
+
+import (
+ "context"
+ "testing"
+)
+
+func TestWithHTTPUpstreamProfile_DefaultKeepsContext(t *testing.T) {
+ ctx := context.Background()
+ got := WithHTTPUpstreamProfile(ctx, HTTPUpstreamProfileDefault)
+ if got != ctx {
+ t.Fatalf("expected default profile to keep original context")
+ }
+}
+
+func TestWithHTTPUpstreamProfile_TODOContextSetsProfile(t *testing.T) {
+ ctx := WithHTTPUpstreamProfile(context.TODO(), HTTPUpstreamProfileOpenAI)
+ if ctx == nil {
+ t.Fatalf("expected non-nil context")
+ }
+ if profile := HTTPUpstreamProfileFromContext(ctx); profile != HTTPUpstreamProfileOpenAI {
+ t.Fatalf("expected profile %q, got %q", HTTPUpstreamProfileOpenAI, profile)
+ }
+}
+
+func TestHTTPUpstreamProfileFromContext_UnknownValueFallsBackDefault(t *testing.T) {
+ type badKey struct{}
+ ctx := context.WithValue(context.Background(), httpUpstreamProfileContextKey{}, HTTPUpstreamProfile("unknown"))
+ ctx = context.WithValue(ctx, badKey{}, "x")
+ if profile := HTTPUpstreamProfileFromContext(ctx); profile != HTTPUpstreamProfileDefault {
+ t.Fatalf("expected default profile, got %q", profile)
+ }
+}
diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go
index f3130c91c..dc59010d7 100644
--- a/backend/internal/service/identity_service.go
+++ b/backend/internal/service/identity_service.go
@@ -46,7 +46,6 @@ type Fingerprint struct {
StainlessArch string
StainlessRuntime string
StainlessRuntimeVersion string
- UpdatedAt int64 `json:",omitempty"` // Unix timestamp,用于判断是否需要续期TTL
}
// IdentityCache defines cache operations for identity service
@@ -79,26 +78,14 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
// 尝试从缓存获取指纹
cached, err := s.cache.GetFingerprint(ctx, accountID)
if err == nil && cached != nil {
- needWrite := false
-
// 检查客户端的user-agent是否是更新版本
clientUA := headers.Get("User-Agent")
if clientUA != "" && isNewerVersion(clientUA, cached.UserAgent) {
- // 版本升级:merge 语义 — 仅更新请求中实际携带的字段,保留缓存值
- // 避免缺失的头被硬编码默认值覆盖(如新 CLI 版本 + 旧 SDK 默认值的不一致)
- mergeHeadersIntoFingerprint(cached, headers)
- needWrite = true
- logger.LegacyPrintf("service.identity", "Updated fingerprint for account %d: %s (merge update)", accountID, clientUA)
- } else if time.Since(time.Unix(cached.UpdatedAt, 0)) > 24*time.Hour {
- // 距上次写入超过24小时,续期TTL
- needWrite = true
- }
-
- if needWrite {
- cached.UpdatedAt = time.Now().Unix()
- if err := s.cache.SetFingerprint(ctx, accountID, cached); err != nil {
- logger.LegacyPrintf("service.identity", "Warning: failed to refresh fingerprint for account %d: %v", accountID, err)
- }
+ // 更新user-agent
+ cached.UserAgent = clientUA
+ // 保存更新后的指纹
+ _ = s.cache.SetFingerprint(ctx, accountID, cached)
+ logger.LegacyPrintf("service.identity", "Updated fingerprint user-agent for account %d: %s", accountID, clientUA)
}
return cached, nil
}
@@ -108,9 +95,8 @@ func (s *IdentityService) GetOrCreateFingerprint(ctx context.Context, accountID
// 生成随机ClientID
fp.ClientID = generateClientID()
- fp.UpdatedAt = time.Now().Unix()
- // 保存到缓存(7天TTL,每24小时自动续期)
+ // 保存到缓存(永不过期)
if err := s.cache.SetFingerprint(ctx, accountID, fp); err != nil {
logger.LegacyPrintf("service.identity", "Warning: failed to cache fingerprint for account %d: %v", accountID, err)
}
@@ -141,31 +127,6 @@ func (s *IdentityService) createFingerprintFromHeaders(headers http.Header) *Fin
return fp
}
-// mergeHeadersIntoFingerprint 将请求头中实际存在的字段合并到现有指纹中(用于版本升级场景)
-// 关键语义:请求中有的字段 → 用新值覆盖;缺失的头 → 保留缓存中的已有值
-// 与 createFingerprintFromHeaders 的区别:后者用于首次创建,缺失头回退到 defaultFingerprint;
-// 本函数用于升级更新,缺失头保留缓存值,避免将已知的真实值退化为硬编码默认值
-func mergeHeadersIntoFingerprint(fp *Fingerprint, headers http.Header) {
- // User-Agent:版本升级的触发条件,一定存在
- if ua := headers.Get("User-Agent"); ua != "" {
- fp.UserAgent = ua
- }
- // X-Stainless-* 头:仅在请求中实际携带时才更新,否则保留缓存值
- mergeHeader(headers, "X-Stainless-Lang", &fp.StainlessLang)
- mergeHeader(headers, "X-Stainless-Package-Version", &fp.StainlessPackageVersion)
- mergeHeader(headers, "X-Stainless-OS", &fp.StainlessOS)
- mergeHeader(headers, "X-Stainless-Arch", &fp.StainlessArch)
- mergeHeader(headers, "X-Stainless-Runtime", &fp.StainlessRuntime)
- mergeHeader(headers, "X-Stainless-Runtime-Version", &fp.StainlessRuntimeVersion)
-}
-
-// mergeHeader 如果请求头中存在该字段则更新目标值,否则保留原值
-func mergeHeader(headers http.Header, key string, target *string) {
- if v := headers.Get(key); v != "" {
- *target = v
- }
-}
-
// getHeaderOrDefault 获取header值,如果不存在则返回默认值
func getHeaderOrDefault(headers http.Header, key, defaultValue string) string {
if v := headers.Get(key); v != "" {
@@ -410,25 +371,8 @@ func parseUserAgentVersion(ua string) (major, minor, patch int, ok bool) {
return major, minor, patch, true
}
-// extractProduct 提取 User-Agent 中 "/" 前的产品名
-// 例如:claude-cli/2.1.22 (external, cli) -> "claude-cli"
-func extractProduct(ua string) string {
- if idx := strings.Index(ua, "/"); idx > 0 {
- return strings.ToLower(ua[:idx])
- }
- return ""
-}
-
// isNewerVersion 比较版本号,判断newUA是否比cachedUA更新
-// 要求产品名一致(防止浏览器 UA 如 Mozilla/5.0 误判为更新版本)
func isNewerVersion(newUA, cachedUA string) bool {
- // 校验产品名一致性
- newProduct := extractProduct(newUA)
- cachedProduct := extractProduct(cachedUA)
- if newProduct == "" || cachedProduct == "" || newProduct != cachedProduct {
- return false
- }
-
newMajor, newMinor, newPatch, newOk := parseUserAgentVersion(newUA)
cachedMajor, cachedMinor, cachedPatch, cachedOk := parseUserAgentVersion(cachedUA)
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
index 99013ce55..785d66b88 100644
--- a/backend/internal/service/openai_account_scheduler.go
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -43,34 +43,42 @@ type OpenAIAccountScheduleDecision struct {
}
type OpenAIAccountSchedulerMetricsSnapshot struct {
- SelectTotal int64
- StickyPreviousHitTotal int64
- StickySessionHitTotal int64
- LoadBalanceSelectTotal int64
- AccountSwitchTotal int64
- SchedulerLatencyMsTotal int64
- SchedulerLatencyMsAvg float64
- StickyHitRatio float64
- AccountSwitchRate float64
- LoadSkewAvg float64
- RuntimeStatsAccountCount int
+ SelectTotal int64
+ StickyPreviousHitTotal int64
+ StickySessionHitTotal int64
+ LoadBalanceSelectTotal int64
+ AccountSwitchTotal int64
+ SchedulerLatencyMsTotal int64
+ SchedulerLatencyMsAvg float64
+ StickyHitRatio float64
+ AccountSwitchRate float64
+ LoadSkewAvg float64
+ RuntimeStatsAccountCount int
+ CircuitBreakerOpenTotal int64
+ CircuitBreakerRecoverTotal int64
+ StickyReleaseErrorTotal int64
+ StickyReleaseCircuitOpenTotal int64
}
type OpenAIAccountScheduler interface {
Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error)
- ReportResult(accountID int64, success bool, firstTokenMs *int)
+ ReportResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64)
ReportSwitch()
SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot
}
type openAIAccountSchedulerMetrics struct {
- selectTotal atomic.Int64
- stickyPreviousHitTotal atomic.Int64
- stickySessionHitTotal atomic.Int64
- loadBalanceSelectTotal atomic.Int64
- accountSwitchTotal atomic.Int64
- latencyMsTotal atomic.Int64
- loadSkewMilliTotal atomic.Int64
+ selectTotal atomic.Int64
+ stickyPreviousHitTotal atomic.Int64
+ stickySessionHitTotal atomic.Int64
+ loadBalanceSelectTotal atomic.Int64
+ accountSwitchTotal atomic.Int64
+ latencyMsTotal atomic.Int64
+ loadSkewMilliTotal atomic.Int64
+ circuitBreakerOpenTotal atomic.Int64
+ circuitBreakerRecoverTotal atomic.Int64
+ stickyReleaseErrorTotal atomic.Int64
+ stickyReleaseCircuitOpenTotal atomic.Int64
}
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
@@ -99,19 +107,439 @@ func (m *openAIAccountSchedulerMetrics) recordSwitch() {
}
type openAIAccountRuntimeStats struct {
- accounts sync.Map
- accountCount atomic.Int64
+ accounts sync.Map
+ circuitBreakers sync.Map // accountID → *accountCircuitBreaker
+ accountCount atomic.Int64
+ cleanupCounter atomic.Int64 // report call counter for periodic cleanup
+}
+
+// ---------------------------------------------------------------------------
+// Account-level Circuit Breaker (three-state: CLOSED → OPEN → HALF_OPEN)
+// ---------------------------------------------------------------------------
+
+const (
+ circuitBreakerStateClosed int32 = 0
+ circuitBreakerStateOpen int32 = 1
+ circuitBreakerStateHalfOpen int32 = 2
+
+ // Defaults (used when config values are zero/unset)
+ defaultCircuitBreakerFailThreshold = 5
+ defaultCircuitBreakerCooldownSec = 30
+ defaultCircuitBreakerHalfOpenMax = 2
+)
+
+type accountCircuitBreaker struct {
+ state atomic.Int32 // circuitBreakerState*
+ consecutiveFails atomic.Int32
+ lastFailureNano atomic.Int64 // time.Now().UnixNano()
+ halfOpenInFlight atomic.Int32 // current in-flight probes (decremented by release)
+ halfOpenAdmitted atomic.Int32 // total probes admitted this half-open cycle (never decremented by release)
+ halfOpenSuccess atomic.Int32
+}
+
+// allow returns true if the circuit breaker allows a request to pass through.
+func (cb *accountCircuitBreaker) allow(cooldown time.Duration, halfOpenMax int) bool {
+ switch cb.state.Load() {
+ case circuitBreakerStateClosed:
+ return true
+ case circuitBreakerStateOpen:
+ lastFail := time.Unix(0, cb.lastFailureNano.Load())
+ if time.Since(lastFail) <= cooldown {
+ return false
+ }
+ // Cooldown elapsed — attempt transition to HALF_OPEN.
+ // Reset counters before CAS to avoid a window where another goroutine
+ // sees HALF_OPEN but stale counter values.
+ cb.halfOpenInFlight.Store(0)
+ cb.halfOpenAdmitted.Store(0)
+ cb.halfOpenSuccess.Store(0)
+ cb.state.CompareAndSwap(circuitBreakerStateOpen, circuitBreakerStateHalfOpen)
+ // Either we transitioned or another goroutine did; fall through to
+ // HALF_OPEN gate below.
+ return cb.allowHalfOpen(halfOpenMax)
+ case circuitBreakerStateHalfOpen:
+ return cb.allowHalfOpen(halfOpenMax)
+ default:
+ return true
+ }
+}
+
+func (cb *accountCircuitBreaker) isHalfOpen() bool {
+ if cb == nil {
+ return false
+ }
+ return cb.state.Load() == circuitBreakerStateHalfOpen
+}
+
+// releaseHalfOpenPermit releases one HALF_OPEN probe permit when a candidate
+// passed filtering but was not actually selected to execute a request.
+func (cb *accountCircuitBreaker) releaseHalfOpenPermit() {
+ if cb == nil || cb.state.Load() != circuitBreakerStateHalfOpen {
+ return
+ }
+ for {
+ cur := cb.halfOpenInFlight.Load()
+ if cur <= 0 {
+ return
+ }
+ if cb.halfOpenInFlight.CompareAndSwap(cur, cur-1) {
+ return
+ }
+ }
+}
+
+func (cb *accountCircuitBreaker) allowHalfOpen(halfOpenMax int) bool {
+ for {
+ cur := cb.halfOpenInFlight.Load()
+ if int(cur) >= halfOpenMax {
+ return false
+ }
+ if cb.halfOpenInFlight.CompareAndSwap(cur, cur+1) {
+ cb.halfOpenAdmitted.Add(1)
+ return true
+ }
+ }
+}
+
+// recordSuccess is called when a request succeeds.
+func (cb *accountCircuitBreaker) recordSuccess() {
+ cb.consecutiveFails.Store(0)
+ if cb.state.Load() == circuitBreakerStateHalfOpen {
+ newSucc := cb.halfOpenSuccess.Add(1)
+ // Compare against halfOpenAdmitted (total probes ever admitted in
+ // this half-open cycle). Unlike halfOpenInFlight, this is never
+ // decremented by releaseHalfOpenPermit, so the recovery threshold
+ // remains stable regardless of candidate filtering outcomes.
+ admitted := cb.halfOpenAdmitted.Load()
+ if newSucc >= admitted && admitted > 0 {
+ if cb.state.CompareAndSwap(circuitBreakerStateHalfOpen, circuitBreakerStateClosed) {
+ cb.halfOpenInFlight.Store(0)
+ cb.halfOpenAdmitted.Store(0)
+ cb.halfOpenSuccess.Store(0)
+ }
+ }
+ }
+}
+
+// recordFailure is called when a request fails.
+func (cb *accountCircuitBreaker) recordFailure(threshold int) {
+ cb.lastFailureNano.Store(time.Now().UnixNano())
+ newFails := cb.consecutiveFails.Add(1)
+
+ switch cb.state.Load() {
+ case circuitBreakerStateClosed:
+ if int(newFails) >= threshold {
+ cb.state.CompareAndSwap(circuitBreakerStateClosed, circuitBreakerStateOpen)
+ }
+ case circuitBreakerStateHalfOpen:
+ if cb.state.CompareAndSwap(circuitBreakerStateHalfOpen, circuitBreakerStateOpen) {
+ cb.halfOpenInFlight.Store(0)
+ cb.halfOpenAdmitted.Store(0)
+ cb.halfOpenSuccess.Store(0)
+ }
+ }
+}
+
+// isOpen returns true if the circuit breaker is currently in OPEN state.
+func (cb *accountCircuitBreaker) isOpen() bool {
+ return cb.state.Load() == circuitBreakerStateOpen
+}
+
+// stateString returns a human-readable state name.
+func (cb *accountCircuitBreaker) stateString() string {
+ switch cb.state.Load() {
+ case circuitBreakerStateClosed:
+ return "CLOSED"
+ case circuitBreakerStateOpen:
+ return "OPEN"
+ case circuitBreakerStateHalfOpen:
+ return "HALF_OPEN"
+ default:
+ return "UNKNOWN"
+ }
+}
+
+// loadCircuitBreaker returns the CB for accountID if it exists, or nil.
+// Use this on hot paths (e.g. candidate filtering) to avoid allocating CB
+// objects for accounts that have never received a report.
+func (s *openAIAccountRuntimeStats) loadCircuitBreaker(accountID int64) *accountCircuitBreaker {
+ if val, ok := s.circuitBreakers.Load(accountID); ok {
+ if cb, _ := val.(*accountCircuitBreaker); cb != nil {
+ return cb
+ }
+ }
+ return nil
+}
+
+func (s *openAIAccountRuntimeStats) getCircuitBreaker(accountID int64) *accountCircuitBreaker {
+ if val, ok := s.circuitBreakers.Load(accountID); ok {
+ if cb, _ := val.(*accountCircuitBreaker); cb != nil {
+ return cb
+ }
+ }
+ cb := &accountCircuitBreaker{}
+ actual, _ := s.circuitBreakers.LoadOrStore(accountID, cb)
+ if existing, _ := actual.(*accountCircuitBreaker); existing != nil {
+ return existing
+ }
+ return cb
+}
+
+func (s *openAIAccountRuntimeStats) isCircuitOpen(accountID int64) bool {
+ val, ok := s.circuitBreakers.Load(accountID)
+ if !ok {
+ return false
+ }
+ cb, _ := val.(*accountCircuitBreaker)
+ if cb == nil {
+ return false
+ }
+ return cb.isOpen()
+}
+
+// ---------------------------------------------------------------------------
+// Dual-EWMA: fast (α=0.5) reacts quickly to degradation; slow (α=0.1)
+// stabilises over many samples. The pessimistic envelope max(fast,slow) means
+// we *sense* errors fast but only *confirm* recovery slowly.
+// ---------------------------------------------------------------------------
+
+const (
+ dualEWMAAlphaFast = 0.5
+ dualEWMAAlphaSlow = 0.1
+
+ // Per-model TTFT defaults
+ defaultPerModelTTFTMaxModels = 100
+ defaultPerModelTTFTTTL = 30 * time.Minute
+)
+
+// dualEWMA tracks a [0,1] signal (e.g. error rate) with two speeds.
+type dualEWMA struct {
+ fastBits atomic.Uint64 // α = dualEWMAAlphaFast, reacts in ~3 requests
+ slowBits atomic.Uint64 // α = dualEWMAAlphaSlow, stabilises over ~50 requests
+ sampleCount atomic.Int64 // total samples received; used for cold-start guard
+}
+
+// dualEWMAMinSamples is the minimum number of samples required before the
+// EWMA error rate is considered reliable for decision-making (e.g. sticky
+// release). This prevents a single failure on a fresh account from yielding
+// an artificially high error rate.
+const dualEWMAMinSamples = 5
+
+func (d *dualEWMA) update(sample float64) {
+ updateEWMAAtomic(&d.fastBits, sample, dualEWMAAlphaFast)
+ updateEWMAAtomic(&d.slowBits, sample, dualEWMAAlphaSlow)
+ d.sampleCount.Add(1)
+}
+
+// isWarmedUp returns true when enough samples have been collected for the
+// EWMA value to be meaningful.
+func (d *dualEWMA) isWarmedUp() bool {
+ return d.sampleCount.Load() >= dualEWMAMinSamples
+}
+
+// value returns the pessimistic envelope: max(fast, slow).
+func (d *dualEWMA) value() float64 {
+ fast := math.Float64frombits(d.fastBits.Load())
+ slow := math.Float64frombits(d.slowBits.Load())
+ if fast >= slow {
+ return fast
+ }
+ return slow
+}
+
+func (d *dualEWMA) fastValue() float64 {
+ return math.Float64frombits(d.fastBits.Load())
+}
+
+func (d *dualEWMA) slowValue() float64 {
+ return math.Float64frombits(d.slowBits.Load())
+}
+
+// dualEWMATTFT is like dualEWMA but handles the NaN-initialised first-sample
+// case required by TTFT tracking.
+type dualEWMATTFT struct {
+ fastBits atomic.Uint64 // α = dualEWMAAlphaFast
+ slowBits atomic.Uint64 // α = dualEWMAAlphaSlow
+}
+
+// initNaN stores NaN into both channels. Called once at allocation time.
+func (d *dualEWMATTFT) initNaN() {
+ nanBits := math.Float64bits(math.NaN())
+ d.fastBits.Store(nanBits)
+ d.slowBits.Store(nanBits)
+}
+
+func (d *dualEWMATTFT) update(sample float64) {
+ sampleBits := math.Float64bits(sample)
+ // fast channel
+ for {
+ oldBits := d.fastBits.Load()
+ oldValue := math.Float64frombits(oldBits)
+ if math.IsNaN(oldValue) {
+ if d.fastBits.CompareAndSwap(oldBits, sampleBits) {
+ break
+ }
+ continue
+ }
+ newValue := dualEWMAAlphaFast*sample + (1-dualEWMAAlphaFast)*oldValue
+ if d.fastBits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
+ break
+ }
+ }
+ // slow channel
+ for {
+ oldBits := d.slowBits.Load()
+ oldValue := math.Float64frombits(oldBits)
+ if math.IsNaN(oldValue) {
+ if d.slowBits.CompareAndSwap(oldBits, sampleBits) {
+ break
+ }
+ continue
+ }
+ newValue := dualEWMAAlphaSlow*sample + (1-dualEWMAAlphaSlow)*oldValue
+ if d.slowBits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
+ break
+ }
+ }
+}
+
+// value returns (pessimistic TTFT, hasTTFT). If both channels are still NaN
+// the caller gets (0, false).
+func (d *dualEWMATTFT) value() (float64, bool) {
+ fast := math.Float64frombits(d.fastBits.Load())
+ slow := math.Float64frombits(d.slowBits.Load())
+ fastOK := !math.IsNaN(fast)
+ slowOK := !math.IsNaN(slow)
+ switch {
+ case fastOK && slowOK:
+ if fast >= slow {
+ return fast, true
+ }
+ return slow, true
+ case fastOK:
+ return fast, true
+ case slowOK:
+ return slow, true
+ default:
+ return 0, false
+ }
+}
+
+func (d *dualEWMATTFT) fastValue() float64 {
+ return math.Float64frombits(d.fastBits.Load())
+}
+
+func (d *dualEWMATTFT) slowValue() float64 {
+ return math.Float64frombits(d.slowBits.Load())
+}
+
+// ---------------------------------------------------------------------------
+// Load Trend Tracker (ring-buffer linear regression)
+// ---------------------------------------------------------------------------
+
+const loadTrendRingSize = 10
+
+// loadTrendTracker maintains a fixed-size ring buffer of (timestamp, loadRate)
+// samples and computes the ordinary-least-squares slope to predict whether
+// an account's load is rising, falling, or stable.
+type loadTrendTracker struct {
+ mu sync.Mutex
+ samples [loadTrendRingSize]float64 // ring buffer of loadRate values
+ times [loadTrendRingSize]int64 // timestamps in UnixNano
+ head int // next write position
+ count int // number of valid samples (0..loadTrendRingSize)
+}
+
+// record pushes a loadRate sample with the current wall-clock timestamp.
+func (t *loadTrendTracker) record(loadRate float64) {
+ t.recordAt(loadRate, time.Now().UnixNano())
+}
+
+// recordAt pushes a loadRate sample with an explicit timestamp (for testing).
+func (t *loadTrendTracker) recordAt(loadRate float64, tsNano int64) {
+ t.mu.Lock()
+ t.samples[t.head] = loadRate
+ t.times[t.head] = tsNano
+ t.head = (t.head + 1) % loadTrendRingSize
+ if t.count < loadTrendRingSize {
+ t.count++
+ }
+ t.mu.Unlock()
+}
+
+// slope computes the simple linear regression slope of loadRate over time.
+//
+// slope = (N*Sigma(xi*yi) - Sigma(xi)*Sigma(yi)) / (N*Sigma(xi^2) - (Sigma(xi))^2)
+//
+// where xi = seconds elapsed since the oldest sample, yi = loadRate.
+// Returns 0 if fewer than 2 samples are available or if all timestamps are
+// identical (degenerate case).
+func (t *loadTrendTracker) slope() float64 {
+ t.mu.Lock()
+ n := t.count
+ if n < 2 {
+ t.mu.Unlock()
+ return 0
+ }
+
+ // Copy data under lock; computation happens outside.
+ var localSamples [loadTrendRingSize]float64
+ var localTimes [loadTrendRingSize]int64
+ copy(localSamples[:], t.samples[:])
+ copy(localTimes[:], t.times[:])
+ head := t.head
+ t.mu.Unlock()
+
+ // Identify oldest entry index.
+ oldest := head // head points to the next write pos; for a full ring it's the oldest entry.
+ if n < loadTrendRingSize {
+ oldest = 0
+ }
+ t0 := localTimes[oldest]
+
+ var sumX, sumY, sumXY, sumX2 float64
+ for i := 0; i < n; i++ {
+ idx := (oldest + i) % loadTrendRingSize
+ xi := float64(localTimes[idx]-t0) / 1e9 // relative seconds
+ yi := localSamples[idx]
+ sumX += xi
+ sumY += yi
+ sumXY += xi * yi
+ sumX2 += xi * xi
+ }
+
+ nf := float64(n)
+ denom := nf*sumX2 - sumX*sumX
+ if denom == 0 {
+ // All timestamps identical (or single sample) — no meaningful slope.
+ return 0
+ }
+ return (nf*sumXY - sumX*sumY) / denom
}
type openAIAccountRuntimeStat struct {
- errorRateEWMABits atomic.Uint64
- ttftEWMABits atomic.Uint64
+ errorRate dualEWMA
+ ttft dualEWMATTFT
+ modelTTFT sync.Map // key = model name (string), value = *dualEWMATTFT
+ modelTTFTLastUpdate sync.Map // key = model name (string), value = int64 (unix nano)
+ loadTrend loadTrendTracker
+ lastReportNano atomic.Int64 // last report timestamp for GC
}
func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats {
return &openAIAccountRuntimeStats{}
}
+// loadExisting returns the stat for accountID if it exists, or nil.
+// Unlike loadOrCreate, this never allocates a new stat.
+func (s *openAIAccountRuntimeStats) loadExisting(accountID int64) *openAIAccountRuntimeStat {
+ if value, ok := s.accounts.Load(accountID); ok {
+ stat, _ := value.(*openAIAccountRuntimeStat)
+ return stat
+ }
+ return nil
+}
+
func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat {
if value, ok := s.accounts.Load(accountID); ok {
stat, _ := value.(*openAIAccountRuntimeStat)
@@ -121,7 +549,7 @@ func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccount
}
stat := &openAIAccountRuntimeStat{}
- stat.ttftEWMABits.Store(math.Float64bits(math.NaN()))
+ stat.ttft.initNaN()
actual, loaded := s.accounts.LoadOrStore(accountID, stat)
if !loaded {
s.accountCount.Add(1)
@@ -134,6 +562,103 @@ func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccount
return stat
}
+// getOrCreateModelTTFT returns the per-model TTFT tracker, creating it if
+// it does not exist yet. Uses the LoadOrStore pattern for thread safety.
+func (stat *openAIAccountRuntimeStat) getOrCreateModelTTFT(model string) *dualEWMATTFT {
+ if val, ok := stat.modelTTFT.Load(model); ok {
+ if d, _ := val.(*dualEWMATTFT); d != nil {
+ return d
+ }
+ }
+ d := &dualEWMATTFT{}
+ d.initNaN()
+ actual, _ := stat.modelTTFT.LoadOrStore(model, d)
+ if existing, _ := actual.(*dualEWMATTFT); existing != nil {
+ return existing
+ }
+ return d
+}
+
+// reportModelTTFT updates both the per-model and global TTFT trackers.
+func (stat *openAIAccountRuntimeStat) reportModelTTFT(model string, sampleMs float64) {
+ if model == "" || sampleMs <= 0 {
+ return
+ }
+ d := stat.getOrCreateModelTTFT(model)
+ d.update(sampleMs)
+ stat.modelTTFTLastUpdate.Store(model, time.Now().UnixNano())
+ // Also update the global TTFT so that callers without a model still
+ // see a reasonable aggregate.
+ stat.ttft.update(sampleMs)
+}
+
+// modelTTFTValue returns the per-model TTFT value if a tracker exists and has
+// received at least one sample. Otherwise returns (0, false).
+func (stat *openAIAccountRuntimeStat) modelTTFTValue(model string) (float64, bool) {
+ if model == "" {
+ return 0, false
+ }
+ val, ok := stat.modelTTFT.Load(model)
+ if !ok {
+ return 0, false
+ }
+ d, _ := val.(*dualEWMATTFT)
+ if d == nil {
+ return 0, false
+ }
+ return d.value()
+}
+
+// cleanupStaleTTFT removes per-model TTFT entries that have not been updated
+// within ttl, and enforces a hard cap of maxModels entries. Oldest entries
+// are evicted first when the cap is exceeded.
+func (stat *openAIAccountRuntimeStat) cleanupStaleTTFT(ttl time.Duration, maxModels int) {
+ now := time.Now().UnixNano()
+ cutoff := now - int64(ttl)
+
+ // First pass: delete stale entries.
+ stat.modelTTFTLastUpdate.Range(func(key, value any) bool {
+ model, _ := key.(string)
+ ts, _ := value.(int64)
+ if ts < cutoff {
+ stat.modelTTFT.Delete(model)
+ stat.modelTTFTLastUpdate.Delete(model)
+ }
+ return true
+ })
+
+ if maxModels <= 0 {
+ return
+ }
+
+ // Second pass: count remaining entries and evict oldest if over limit.
+ type modelEntry struct {
+ model string
+ ts int64
+ }
+ var entries []modelEntry
+ stat.modelTTFTLastUpdate.Range(func(key, value any) bool {
+ model, _ := key.(string)
+ ts, _ := value.(int64)
+ entries = append(entries, modelEntry{model: model, ts: ts})
+ return true
+ })
+
+ if len(entries) <= maxModels {
+ return
+ }
+
+ // Sort by timestamp ascending (oldest first) and evict surplus.
+ sort.Slice(entries, func(i, j int) bool {
+ return entries[i].ts < entries[j].ts
+ })
+ evictCount := len(entries) - maxModels
+ for i := 0; i < evictCount; i++ {
+ stat.modelTTFT.Delete(entries[i].model)
+ stat.modelTTFTLastUpdate.Delete(entries[i].model)
+ }
+}
+
func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
for {
oldBits := target.Load()
@@ -145,40 +670,77 @@ func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
}
}
-func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) {
+func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) {
+ s.reportWithOptions(
+ accountID,
+ success,
+ firstTokenMs,
+ defaultCircuitBreakerFailThreshold,
+ true,
+ model,
+ ttftMs,
+ true,
+ defaultPerModelTTFTMaxModels,
+ )
+}
+
+func (s *openAIAccountRuntimeStats) reportWithOptions(
+ accountID int64,
+ success bool,
+ firstTokenMs *int,
+ cbThreshold int,
+ updateCircuitBreaker bool,
+ model string,
+ ttftMs float64,
+ perModelTTFTEnabled bool,
+ perModelTTFTMaxModels int,
+) {
if s == nil || accountID <= 0 {
return
}
- const alpha = 0.2
stat := s.loadOrCreate(accountID)
+ stat.lastReportNano.Store(time.Now().UnixNano())
errorSample := 1.0
if success {
errorSample = 0.0
}
- updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha)
+ stat.errorRate.update(errorSample)
- if firstTokenMs != nil && *firstTokenMs > 0 {
- ttft := float64(*firstTokenMs)
- ttftBits := math.Float64bits(ttft)
- for {
- oldBits := stat.ttftEWMABits.Load()
- oldValue := math.Float64frombits(oldBits)
- if math.IsNaN(oldValue) {
- if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) {
- break
- }
- continue
- }
- newValue := alpha*ttft + (1-alpha)*oldValue
- if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
- break
- }
+ // Per-model TTFT tracking: reportModelTTFT updates both per-model and
+ // global TTFT, so skip the separate global update to avoid double-counting.
+ if perModelTTFTEnabled && model != "" && ttftMs > 0 {
+ stat.reportModelTTFT(model, ttftMs)
+ } else if firstTokenMs != nil && *firstTokenMs > 0 {
+ stat.ttft.update(float64(*firstTokenMs))
+ }
+
+ // Update circuit breaker state only when feature is enabled.
+ if updateCircuitBreaker {
+ cb := s.getCircuitBreaker(accountID)
+ if success {
+ cb.recordSuccess()
+ } else {
+ cb.recordFailure(cbThreshold)
+ }
+ }
+
+ // Periodic cleanup: every 100 reports.
+ cnt := s.cleanupCounter.Add(1)
+ if cnt%100 == 0 {
+ maxModels := defaultPerModelTTFTMaxModels
+ if perModelTTFTMaxModels > 0 {
+ maxModels = perModelTTFTMaxModels
}
+ stat.cleanupStaleTTFT(defaultPerModelTTFTTTL, maxModels)
+ }
+ // GC inactive accounts and orphaned circuit breakers: every 1000 reports.
+ if cnt%1000 == 0 {
+ s.gcInactiveAccounts(time.Hour)
}
}
-func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) {
+func (s *openAIAccountRuntimeStats) snapshot(accountID int64, model ...string) (errorRate float64, ttft float64, hasTTFT bool) {
if s == nil || accountID <= 0 {
return 0, 0, false
}
@@ -190,12 +752,17 @@ func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64
if stat == nil {
return 0, 0, false
}
- errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load()))
- ttftValue := math.Float64frombits(stat.ttftEWMABits.Load())
- if math.IsNaN(ttftValue) {
- return errorRate, 0, false
+ errorRate = clamp01(stat.errorRate.value())
+
+ // Try per-model TTFT first; fallback to global.
+ if len(model) > 0 && model[0] != "" {
+ if mTTFT, mOK := stat.modelTTFTValue(model[0]); mOK {
+ return errorRate, mTTFT, true
+ }
}
- return errorRate, ttftValue, true
+
+ ttft, hasTTFT = stat.ttft.value()
+ return errorRate, ttft, hasTTFT
}
func (s *openAIAccountRuntimeStats) size() int {
@@ -205,6 +772,25 @@ func (s *openAIAccountRuntimeStats) size() int {
return int(s.accountCount.Load())
}
+// gcInactiveAccounts removes account stats and circuit breakers that have not
+// received any report for longer than maxIdle. This prevents unbounded growth
+// of the sync.Maps when accounts are created and then deleted/deactivated.
+func (s *openAIAccountRuntimeStats) gcInactiveAccounts(maxIdle time.Duration) {
+ if s == nil {
+ return
+ }
+ cutoff := time.Now().UnixNano() - int64(maxIdle)
+ s.accounts.Range(func(key, value any) bool {
+ stat, _ := value.(*openAIAccountRuntimeStat)
+ if stat == nil || stat.lastReportNano.Load() < cutoff {
+ s.accounts.Delete(key)
+ s.circuitBreakers.Delete(key)
+ s.accountCount.Add(-1)
+ }
+ return true
+ })
+}
+
type defaultOpenAIAccountScheduler struct {
service *OpenAIGatewayService
metrics openAIAccountSchedulerMetrics
@@ -331,6 +917,12 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
return nil, nil
}
+ // Conditional sticky: release binding if account is unhealthy or overloaded.
+ if s.shouldReleaseStickySession(accountID) {
+ _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
+ return nil, nil // Fall through to load balance
+ }
+
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
@@ -557,6 +1149,175 @@ func buildOpenAIWeightedSelectionOrder(
return order
}
+// selectP2COpenAICandidates selects candidates using Power-of-Two-Choices:
+// randomly pick 2 candidates, return the one with the higher score.
+// Repeat to build a full selection order for fallback.
+func selectP2COpenAICandidates(
+ candidates []openAIAccountCandidateScore,
+ req OpenAIAccountScheduleRequest,
+) []openAIAccountCandidateScore {
+ if len(candidates) <= 1 {
+ return append([]openAIAccountCandidateScore(nil), candidates...)
+ }
+
+ rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
+ pool := append([]openAIAccountCandidateScore(nil), candidates...)
+ order := make([]openAIAccountCandidateScore, 0, len(pool))
+
+ for len(pool) > 1 {
+ n := uint64(len(pool))
+ // Pick first random index.
+ idx1 := int(rng.nextUint64() % n)
+ // Pick second random index, distinct from the first.
+ idx2 := int(rng.nextUint64() % (n - 1))
+ if idx2 >= idx1 {
+ idx2++
+ }
+
+ // Compare: take the candidate with the higher score.
+ winner := idx1
+ if isOpenAIAccountCandidateBetter(pool[idx2], pool[idx1]) {
+ winner = idx2
+ }
+
+ order = append(order, pool[winner])
+ // Remove winner from pool (swap with last element for O(1) removal).
+ pool[winner] = pool[len(pool)-1]
+ pool = pool[:len(pool)-1]
+ }
+ // Append the last remaining candidate.
+ order = append(order, pool[0])
+ return order
+}
+
+// ---------------------------------------------------------------------------
+// Softmax Temperature Sampling
+// ---------------------------------------------------------------------------
+
+const defaultSoftmaxTemperature = 0.3
+
+type softmaxConfig struct {
+ enabled bool
+ temperature float64
+}
+
+// softmaxConfigRead reads softmax scheduler config with fallback defaults.
+func (s *defaultOpenAIAccountScheduler) softmaxConfigRead() softmaxConfig {
+ if s == nil || s.service == nil || s.service.cfg == nil {
+ return softmaxConfig{}
+ }
+ wsCfg := s.service.cfg.Gateway.OpenAIWS
+ temp := wsCfg.SchedulerSoftmaxTemperature
+ if temp <= 0 {
+ temp = defaultSoftmaxTemperature
+ }
+ return softmaxConfig{
+ enabled: wsCfg.SchedulerSoftmaxEnabled,
+ temperature: temp,
+ }
+}
+
+// selectSoftmaxOpenAICandidates applies softmax temperature sampling to select
+// one candidate probabilistically, then returns the full list with the selected
+// candidate first and the rest sorted by descending probability.
+//
+// Algorithm (numerically stable):
+//
+// maxScore = max(score[i])
+// weights[i] = exp((score[i] - maxScore) / temperature)
+// probability[i] = weights[i] / sum(weights)
+//
+// A higher temperature yields more uniform selection (exploration); a lower
+// temperature concentrates probability on the highest-scored candidates
+// (exploitation).
+func selectSoftmaxOpenAICandidates(
+ candidates []openAIAccountCandidateScore,
+ temperature float64,
+ rng *openAISelectionRNG,
+) []openAIAccountCandidateScore {
+ if len(candidates) == 0 {
+ return nil
+ }
+ if len(candidates) == 1 {
+ return append([]openAIAccountCandidateScore(nil), candidates...)
+ }
+ if temperature <= 0 {
+ temperature = defaultSoftmaxTemperature
+ }
+
+ // Step 1: find max score for numerical stability.
+ maxScore := candidates[0].score
+ for i := 1; i < len(candidates); i++ {
+ if candidates[i].score > maxScore {
+ maxScore = candidates[i].score
+ }
+ }
+
+ // Step 2: compute softmax weights.
+ type indexedProb struct {
+ index int
+ prob float64
+ }
+ probs := make([]indexedProb, len(candidates))
+ sumWeights := 0.0
+ for i := range candidates {
+ w := math.Exp((candidates[i].score - maxScore) / temperature)
+ // Guard against NaN/Inf from degenerate inputs.
+ if math.IsNaN(w) || math.IsInf(w, 0) {
+ w = 0
+ }
+ probs[i] = indexedProb{index: i, prob: w}
+ sumWeights += w
+ }
+
+ // Normalise to probabilities. If sumWeights is zero (all weights collapsed
+ // to zero, which can happen with extreme negative scores), fall back to
+ // uniform distribution.
+ if sumWeights > 0 {
+ for i := range probs {
+ probs[i].prob /= sumWeights
+ }
+ } else {
+ uniform := 1.0 / float64(len(probs))
+ for i := range probs {
+ probs[i].prob = uniform
+ }
+ }
+
+ // Step 3: sample ONE candidate via CDF.
+ r := rng.nextFloat64()
+ selectedIdx := probs[len(probs)-1].index // default to last if rounding issues
+ cumulative := 0.0
+ for _, ip := range probs {
+ cumulative += ip.prob
+ if cumulative >= r {
+ selectedIdx = ip.index
+ break
+ }
+ }
+
+ // Step 4: build result — selected candidate first, rest sorted by
+ // descending probability.
+ result := make([]openAIAccountCandidateScore, 0, len(candidates))
+ result = append(result, candidates[selectedIdx])
+
+ // Sort remaining by probability descending for fallback order.
+ remaining := make([]indexedProb, 0, len(probs)-1)
+ for _, ip := range probs {
+ if ip.index != selectedIdx {
+ remaining = append(remaining, ip)
+ }
+ }
+ sort.Slice(remaining, func(i, j int) bool {
+ return remaining[i].prob > remaining[j].prob
+ })
+ for _, ip := range remaining {
+ result = append(result, candidates[ip.index])
+ }
+
+ return result
+}
+
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
ctx context.Context,
req OpenAIAccountScheduleRequest,
@@ -597,6 +1358,44 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
+ // Circuit breaker filtering: remove accounts with open CBs, but if that
+ // would empty the candidate pool, keep all accounts (graceful degradation).
+ cbEnabled, _, cbCooldown, cbHalfOpenMax := s.schedulerCircuitBreakerConfig()
+ heldHalfOpenPermits := make(map[int64]*accountCircuitBreaker)
+ releaseHalfOpenPermit := func(accountID int64) {
+ cb, ok := heldHalfOpenPermits[accountID]
+ if !ok || cb == nil {
+ return
+ }
+ cb.releaseHalfOpenPermit()
+ delete(heldHalfOpenPermits, accountID)
+ }
+ defer func() {
+ for accountID := range heldHalfOpenPermits {
+ releaseHalfOpenPermit(accountID)
+ }
+ }()
+ if cbEnabled {
+ healthy := make([]*Account, 0, len(filtered))
+ healthyLoadReq := make([]AccountWithConcurrency, 0, len(loadReq))
+ for i, account := range filtered {
+ cb := s.stats.loadCircuitBreaker(account.ID)
+ if cb == nil || cb.allow(cbCooldown, cbHalfOpenMax) {
+ healthy = append(healthy, account)
+ healthyLoadReq = append(healthyLoadReq, loadReq[i])
+ if cb.isHalfOpen() {
+ heldHalfOpenPermits[account.ID] = cb
+ }
+ }
+ }
+ if len(healthy) > 0 {
+ filtered = healthy
+ loadReq = healthyLoadReq
+ }
+ // else: all accounts are circuit-open; fall through with the
+ // original set to avoid returning "no accounts".
+ }
+
loadMap := map[int64]*AccountLoadInfo{}
if s.service.concurrencyService != nil {
if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil {
@@ -604,8 +1403,16 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
}
+ trendEnabled, trendMaxSlope := s.service.openAIWSSchedulerTrendConfig()
+ perModelTTFTEnabled, _ := s.schedulerPerModelTTFTConfig()
+ requestedModelForStats := ""
+ if perModelTTFTEnabled {
+ requestedModelForStats = req.RequestedModel
+ }
+
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
maxWaiting := 1
+ maxConcurrency := 0
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
@@ -625,7 +1432,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if loadInfo.WaitingCount > maxWaiting {
maxWaiting = loadInfo.WaitingCount
}
- errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
+ if account.Concurrency > maxConcurrency {
+ maxConcurrency = account.Concurrency
+ }
+ errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID, requestedModelForStats)
if hasTTFT && ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = ttft, ttft
@@ -642,6 +1452,13 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
loadRate := float64(loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
+
+ // Record current load rate sample for trend tracking.
+ if trendEnabled {
+ stat := s.stats.loadOrCreate(account.ID)
+ stat.loadTrend.record(loadRate)
+ }
+
candidates = append(candidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
@@ -659,8 +1476,33 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
+ // Base load factor from percentage utilization.
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
+ // Capacity-aware adjustment: accounts with more absolute headroom get a bonus.
+ if maxConcurrency > 0 && item.account.Concurrency > 0 {
+ remainingSlots := float64(item.account.Concurrency) * (1 - float64(item.loadInfo.LoadRate)/100.0)
+ capacityBonus := clamp01(remainingSlots / float64(maxConcurrency))
+ // Blend: 70% relative load + 30% capacity bonus
+ loadFactor = 0.7*loadFactor + 0.3*capacityBonus
+ }
+
+ // Trend adjustment: penalise accounts whose load is rising, reward those declining.
+ // trendAdj ranges [0, 1] where 0 = max rising slope, 1 = max falling/flat slope.
+ // loadFactor is blended: 70% base load + 30% trend influence.
+ if trendEnabled {
+ stat := s.stats.loadOrCreate(item.account.ID)
+ slope := stat.loadTrend.slope()
+ trendAdj := 1.0 - clamp01(slope/trendMaxSlope)
+ loadFactor *= (0.7 + 0.3*trendAdj)
+ }
+
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
+ // Queue depth relative to account's own capacity for capacity-aware blending.
+ if item.account.Concurrency > 0 {
+ relativeQueue := clamp01(float64(item.loadInfo.WaitingCount) / float64(item.account.Concurrency))
+ // Blend: 60% cross-account normalized + 40% self-relative
+ queueFactor = 0.6*queueFactor + 0.4*(1-relativeQueue)
+ }
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
@@ -674,23 +1516,40 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
weights.TTFT*ttftFactor
}
- topK := s.service.openAIWSLBTopK()
- if topK > len(candidates) {
- topK = len(candidates)
- }
- if topK <= 0 {
- topK = 1
+ var selectionOrder []openAIAccountCandidateScore
+ topK := 0
+ rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
+ smCfg := s.softmaxConfigRead()
+ p2cEnabled := s.service.openAIWSSchedulerP2CEnabled()
+ if smCfg.enabled && len(candidates) > 3 {
+ selectionOrder = selectSoftmaxOpenAICandidates(candidates, smCfg.temperature, &rng)
+ // topK = 0 signals softmax mode in metrics / decision struct.
+ } else if p2cEnabled {
+ selectionOrder = selectP2COpenAICandidates(candidates, req)
+ // topK = 0 signals P2C mode in metrics / decision struct.
+ } else {
+ topK = s.service.openAIWSLBTopK()
+ if topK > len(candidates) {
+ topK = len(candidates)
+ }
+ if topK <= 0 {
+ topK = 1
+ }
+ rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
+ selectionOrder = buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
}
- rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
- selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
if acquireErr != nil {
+ releaseHalfOpenPermit(candidate.account.ID)
return nil, len(candidates), topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
+ // Keep HALF_OPEN permit for the selected account; the outcome will be
+ // settled by ReportResult(success/failure) after the request finishes.
+ delete(heldHalfOpenPermits, candidate.account.ID)
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
}
@@ -700,10 +1559,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
ReleaseFunc: result.ReleaseFunc,
}, len(candidates), topK, loadSkew, nil
}
+ releaseHalfOpenPermit(candidate.account.ID)
}
cfg := s.service.schedulingConfig()
candidate := selectionOrder[0]
+ releaseHalfOpenPermit(candidate.account.ID)
return &AccountSelectionResult{
Account: candidate.account,
WaitPlan: &AccountWaitPlan{
@@ -726,11 +1587,54 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
}
-func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
+func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) {
if s == nil || s.stats == nil {
return
}
- s.stats.report(accountID, success, firstTokenMs)
+ perModelTTFTEnabled, perModelTTFTMaxModels := s.schedulerPerModelTTFTConfig()
+ enabled, threshold, _, _ := s.schedulerCircuitBreakerConfig()
+ if !enabled {
+ // Circuit breaker disabled: only update runtime signals (error-rate/TTFT),
+ // do not mutate circuit breaker state.
+ s.stats.reportWithOptions(
+ accountID,
+ success,
+ firstTokenMs,
+ 0,
+ false,
+ model,
+ ttftMs,
+ perModelTTFTEnabled,
+ perModelTTFTMaxModels,
+ )
+ return
+ }
+
+ // Snapshot state before the update for metrics tracking.
+ cb := s.stats.getCircuitBreaker(accountID)
+ stateBefore := cb.state.Load()
+
+ s.stats.reportWithOptions(
+ accountID,
+ success,
+ firstTokenMs,
+ threshold,
+ true,
+ model,
+ ttftMs,
+ perModelTTFTEnabled,
+ perModelTTFTMaxModels,
+ )
+
+ stateAfter := cb.state.Load()
+ // CLOSED/HALF_OPEN → OPEN: circuit tripped.
+ if stateBefore != circuitBreakerStateOpen && stateAfter == circuitBreakerStateOpen {
+ s.metrics.circuitBreakerOpenTotal.Add(1)
+ }
+ // OPEN/HALF_OPEN → CLOSED: circuit recovered.
+ if stateBefore != circuitBreakerStateClosed && stateAfter == circuitBreakerStateClosed {
+ s.metrics.circuitBreakerRecoverTotal.Add(1)
+ }
}
func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
@@ -740,6 +1644,107 @@ func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
s.metrics.recordSwitch()
}
+// schedulerCircuitBreakerConfig reads CB config with fallback defaults.
+func (s *defaultOpenAIAccountScheduler) schedulerCircuitBreakerConfig() (enabled bool, threshold int, cooldown time.Duration, halfOpenMax int) {
+ threshold = defaultCircuitBreakerFailThreshold
+ cooldown = time.Duration(defaultCircuitBreakerCooldownSec) * time.Second
+ halfOpenMax = defaultCircuitBreakerHalfOpenMax
+
+ if s == nil || s.service == nil || s.service.cfg == nil {
+ return false, threshold, cooldown, halfOpenMax
+ }
+ wsCfg := s.service.cfg.Gateway.OpenAIWS
+ enabled = wsCfg.SchedulerCircuitBreakerEnabled
+ if wsCfg.SchedulerCircuitBreakerFailThreshold > 0 {
+ threshold = wsCfg.SchedulerCircuitBreakerFailThreshold
+ }
+ if wsCfg.SchedulerCircuitBreakerCooldownSec > 0 {
+ cooldown = time.Duration(wsCfg.SchedulerCircuitBreakerCooldownSec) * time.Second
+ }
+ if wsCfg.SchedulerCircuitBreakerHalfOpenMax > 0 {
+ halfOpenMax = wsCfg.SchedulerCircuitBreakerHalfOpenMax
+ }
+ return enabled, threshold, cooldown, halfOpenMax
+}
+
+func (s *defaultOpenAIAccountScheduler) schedulerPerModelTTFTConfig() (enabled bool, maxModels int) {
+ maxModels = defaultPerModelTTFTMaxModels
+ if s == nil || s.service == nil || s.service.cfg == nil {
+ return false, maxModels
+ }
+ wsCfg := s.service.cfg.Gateway.OpenAIWS
+ enabled = wsCfg.SchedulerPerModelTTFTEnabled
+ if wsCfg.SchedulerPerModelTTFTMaxModels > 0 {
+ maxModels = wsCfg.SchedulerPerModelTTFTMaxModels
+ }
+ return enabled, maxModels
+}
+
+// ---------------------------------------------------------------------------
+// Conditional Sticky Session Release
+// ---------------------------------------------------------------------------
+
+const defaultStickyReleaseErrorThreshold = 0.3
+
+type stickyReleaseConfig struct {
+ enabled bool
+ errorThreshold float64
+}
+
+// stickyReleaseConfigRead reads conditional sticky release config with defaults.
+func (s *defaultOpenAIAccountScheduler) stickyReleaseConfigRead() stickyReleaseConfig {
+ if s == nil || s.service == nil || s.service.cfg == nil {
+ return stickyReleaseConfig{}
+ }
+ wsCfg := s.service.cfg.Gateway.OpenAIWS
+ threshold := wsCfg.StickyReleaseErrorThreshold
+ if threshold <= 0 {
+ threshold = defaultStickyReleaseErrorThreshold
+ }
+ return stickyReleaseConfig{
+ enabled: wsCfg.StickyReleaseEnabled,
+ errorThreshold: threshold,
+ }
+}
+
+// shouldReleaseStickySession checks whether a sticky binding should be
+// released because the account is unhealthy (circuit breaker open) or has a
+// high error rate. This runs BEFORE slot acquisition to avoid wasting
+// concurrency capacity on degraded accounts.
+func (s *defaultOpenAIAccountScheduler) shouldReleaseStickySession(accountID int64) bool {
+ if s == nil || s.stats == nil || s.service == nil {
+ return false
+ }
+
+ cfg := s.stickyReleaseConfigRead()
+ if !cfg.enabled {
+ return false
+ }
+
+ // Check 1: Circuit breaker is open -> immediate release.
+ // Only check if CB feature is actually enabled, because the default CB
+ // threshold (5) is very aggressive and may trip unexpectedly.
+ cbEnabled, _, _, _ := s.schedulerCircuitBreakerConfig()
+ if cbEnabled && s.stats.isCircuitOpen(accountID) {
+ s.metrics.stickyReleaseCircuitOpenTotal.Add(1)
+ return true
+ }
+
+ // Check 2: Error rate exceeds threshold -> immediate release.
+ // Guard against cold-start: the EWMA error rate is unreliable when
+ // fewer than dualEWMAMinSamples have been collected.
+ stat := s.stats.loadExisting(accountID)
+ if stat != nil && stat.errorRate.isWarmedUp() {
+ errorRate, _, _ := s.stats.snapshot(accountID)
+ if errorRate > cfg.errorThreshold {
+ s.metrics.stickyReleaseErrorTotal.Add(1)
+ return true
+ }
+ }
+
+ return false
+}
+
func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot {
if s == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
@@ -753,13 +1758,17 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
loadSkewTotal := s.metrics.loadSkewMilliTotal.Load()
snapshot := OpenAIAccountSchedulerMetricsSnapshot{
- SelectTotal: selectTotal,
- StickyPreviousHitTotal: prevHit,
- StickySessionHitTotal: sessionHit,
- LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
- AccountSwitchTotal: switchTotal,
- SchedulerLatencyMsTotal: latencyTotal,
- RuntimeStatsAccountCount: s.stats.size(),
+ SelectTotal: selectTotal,
+ StickyPreviousHitTotal: prevHit,
+ StickySessionHitTotal: sessionHit,
+ LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
+ AccountSwitchTotal: switchTotal,
+ SchedulerLatencyMsTotal: latencyTotal,
+ RuntimeStatsAccountCount: s.stats.size(),
+ CircuitBreakerOpenTotal: s.metrics.circuitBreakerOpenTotal.Load(),
+ CircuitBreakerRecoverTotal: s.metrics.circuitBreakerRecoverTotal.Load(),
+ StickyReleaseErrorTotal: s.metrics.stickyReleaseErrorTotal.Load(),
+ StickyReleaseCircuitOpenTotal: s.metrics.stickyReleaseCircuitOpenTotal.Load(),
}
if selectTotal > 0 {
snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal)
@@ -820,12 +1829,12 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
})
}
-func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
+func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return
}
- scheduler.ReportResult(accountID, success, firstTokenMs)
+ scheduler.ReportResult(accountID, success, firstTokenMs, model, ttftMs)
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
@@ -855,7 +1864,14 @@ func (s *OpenAIGatewayService) openAIWSLBTopK() int {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 {
return s.cfg.Gateway.OpenAIWS.LBTopK
}
- return 7
+ return 999
+}
+
+func (s *OpenAIGatewayService) openAIWSSchedulerP2CEnabled() bool {
+ if s != nil && s.cfg != nil {
+ return s.cfg.Gateway.OpenAIWS.SchedulerP2CEnabled
+ }
+ return false
}
func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
@@ -885,6 +1901,24 @@ type GatewayOpenAIWSSchedulerScoreWeightsView struct {
TTFT float64
}
+// defaultSchedulerTrendMaxSlope is the normalization ceiling for the trend
+// slope. A slope of 5.0 means the account's load rate is increasing at 5
+// percentage points per second — a very steep rise.
+const defaultSchedulerTrendMaxSlope = 5.0
+
+// openAIWSSchedulerTrendConfig reads trend-prediction config with defaults.
+func (s *OpenAIGatewayService) openAIWSSchedulerTrendConfig() (enabled bool, maxSlope float64) {
+ maxSlope = defaultSchedulerTrendMaxSlope
+ if s == nil || s.cfg == nil {
+ return false, maxSlope
+ }
+ enabled = s.cfg.Gateway.OpenAIWS.SchedulerTrendEnabled
+ if s.cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope > 0 {
+ maxSlope = s.cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope
+ }
+ return enabled, maxSlope
+}
+
func clamp01(value float64) float64 {
switch {
case value < 0:
diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go
index 7f6f1b66a..6cc3ba8f0 100644
--- a/backend/internal/service/openai_account_scheduler_test.go
+++ b/backend/internal/service/openai_account_scheduler_test.go
@@ -1,6 +1,7 @@
package service
import (
+ "container/heap"
"context"
"fmt"
"math"
@@ -12,6 +13,14 @@ import (
"github.com/stretchr/testify/require"
)
+func mustDefaultOpenAIAccountScheduler(t *testing.T, svc *OpenAIGatewayService, stats *openAIAccountRuntimeStats) *defaultOpenAIAccountScheduler {
+ t.Helper()
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+ return scheduler
+}
+
func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) {
ctx := context.Background()
groupID := int64(9)
@@ -447,7 +456,7 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
- svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
+ svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120), "", 0)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
@@ -465,19 +474,120 @@ func intPtrForTest(v int) *int {
func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) {
stats := newOpenAIAccountRuntimeStats()
- stats.report(1001, true, nil)
+ stats.report(1001, true, nil, "", 0) // error: fast 0→0, slow 0→0
firstTTFT := 100
- stats.report(1001, false, &firstTTFT)
+ stats.report(1001, false, &firstTTFT, "", 0) // error: fast 0→0.5, slow 0→0.1; ttft: NaN→100 (both)
secondTTFT := 200
- stats.report(1001, false, &secondTTFT)
+ stats.report(1001, false, &secondTTFT, "", 0) // error: fast 0.5→0.75, slow 0.1→0.19; ttft: fast 100→150, slow 100→110
errorRate, ttft, hasTTFT := stats.snapshot(1001)
require.True(t, hasTTFT)
- require.InDelta(t, 0.36, errorRate, 1e-9)
- require.InDelta(t, 120.0, ttft, 1e-9)
+ // errorRate = max(fast=0.75, slow=0.19) = 0.75
+ require.InDelta(t, 0.75, errorRate, 1e-9)
+ // ttft = max(fast=150, slow=110) = 150
+ require.InDelta(t, 150.0, ttft, 1e-9)
require.Equal(t, 1, stats.size())
}
+func TestDualEWMA_UpdateAndValue(t *testing.T) {
+ var d dualEWMA
+
+ // Initial state: both channels are 0.
+ require.Equal(t, 0.0, d.fastValue())
+ require.Equal(t, 0.0, d.slowValue())
+ require.Equal(t, 0.0, d.value())
+
+ // First sample = 1.0
+ d.update(1.0)
+ // fast: 0.5*1 + 0.5*0 = 0.5
+ require.InDelta(t, 0.5, d.fastValue(), 1e-12)
+ // slow: 0.1*1 + 0.9*0 = 0.1
+ require.InDelta(t, 0.1, d.slowValue(), 1e-12)
+ // value = max(0.5, 0.1) = 0.5
+ require.InDelta(t, 0.5, d.value(), 1e-12)
+
+ // Second sample = 0.0 (recovery)
+ d.update(0.0)
+ // fast: 0.5*0 + 0.5*0.5 = 0.25
+ require.InDelta(t, 0.25, d.fastValue(), 1e-12)
+ // slow: 0.1*0 + 0.9*0.1 = 0.09
+ require.InDelta(t, 0.09, d.slowValue(), 1e-12)
+ // value = max(0.25, 0.09) = 0.25
+ require.InDelta(t, 0.25, d.value(), 1e-12)
+}
+
+func TestDualEWMA_SlowDominatesAfterRecovery(t *testing.T) {
+ var d dualEWMA
+
+ // Spike: several failures.
+ for i := 0; i < 10; i++ {
+ d.update(1.0)
+ }
+ // Now fast is close to 1, slow is also rising.
+
+ // Recovery: many successes.
+ for i := 0; i < 20; i++ {
+ d.update(0.0)
+ }
+ // Fast should have dropped close to 0, slow should still be > fast.
+ require.Greater(t, d.slowValue(), d.fastValue(),
+ "after recovery, slow channel should dominate the pessimistic envelope")
+ require.Equal(t, d.slowValue(), d.value())
+}
+
+func TestDualEWMATTFT_NaNInitAndFirstSample(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+
+ // Before any sample, value should report no data.
+ v, ok := d.value()
+ require.False(t, ok)
+ require.Equal(t, 0.0, v)
+
+ // First sample seeds both channels.
+ d.update(100.0)
+ require.InDelta(t, 100.0, d.fastValue(), 1e-12)
+ require.InDelta(t, 100.0, d.slowValue(), 1e-12)
+ v, ok = d.value()
+ require.True(t, ok)
+ require.InDelta(t, 100.0, v, 1e-12)
+
+ // Second sample.
+ d.update(200.0)
+ // fast: 0.5*200 + 0.5*100 = 150
+ require.InDelta(t, 150.0, d.fastValue(), 1e-12)
+ // slow: 0.1*200 + 0.9*100 = 110
+ require.InDelta(t, 110.0, d.slowValue(), 1e-12)
+ v, ok = d.value()
+ require.True(t, ok)
+ require.InDelta(t, 150.0, v, 1e-12)
+}
+
+func TestDualEWMATTFT_SlowDominatesWhenLatencyDrops(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+
+ // Warm up with high latency.
+ for i := 0; i < 20; i++ {
+ d.update(500.0)
+ }
+ // Now push many low-latency samples.
+ for i := 0; i < 20; i++ {
+ d.update(100.0)
+ }
+ // Fast should have adapted down quickly; slow should still be higher.
+ require.Greater(t, d.slowValue(), d.fastValue(),
+ "after latency improvement, slow channel should dominate the pessimistic TTFT")
+ v, ok := d.value()
+ require.True(t, ok)
+ require.InDelta(t, d.slowValue(), v, 1e-12)
+}
+
+func TestDualEWMAConstants(t *testing.T) {
+ require.Equal(t, 0.5, dualEWMAAlphaFast)
+ require.Equal(t, 0.1, dualEWMAAlphaSlow)
+}
+
func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) {
stats := newOpenAIAccountRuntimeStats()
@@ -496,7 +606,7 @@ func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) {
accountID := int64(i%accountCount + 1)
success := (i+worker)%3 != 0
ttft := 80 + (i+worker)%40
- stats.report(accountID, success, &ttft)
+ stats.report(accountID, success, &ttft, "", 0)
}
}()
}
@@ -751,7 +861,7 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
require.True(t, ok)
ttft := 100
- scheduler.ReportResult(1001, true, &ttft)
+ scheduler.ReportResult(1001, true, &ttft, "", 0)
scheduler.ReportSwitch()
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
Layer: openAIAccountScheduleLayerLoadBalance,
@@ -780,11 +890,11 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
svc := &OpenAIGatewayService{}
ttft := 120
- svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
+ svc.ReportOpenAIAccountScheduleResult(10, true, &ttft, "", 0)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
- require.Equal(t, 7, svc.openAIWSLBTopK())
+ require.Equal(t, 999, svc.openAIWSLBTopK())
require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL())
defaultWeights := svc.openAIWSSchedulerWeights()
@@ -836,6 +946,3268 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2))
}
+func TestLoadFactorCapacityAwareness(t *testing.T) {
+ // Test that accounts with higher absolute capacity get better scores
+ // when percentage load is equal.
+ //
+ // Setup:
+ // Account A: Concurrency=100, LoadRate=50 (50 free slots)
+ // Account B: Concurrency=10, LoadRate=50 (5 free slots)
+ // Both at 50% load, but A should score higher due to more headroom.
+
+ ctx := context.Background()
+ groupID := int64(20)
+ accounts := []Account{
+ {
+ ID: 6001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 100,
+ Priority: 0,
+ },
+ {
+ ID: 6002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 10,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ // Use only Load weight to isolate the capacity-aware loadFactor effect.
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 6001: {AccountID: 6001, LoadRate: 50, WaitingCount: 0},
+ 6002: {AccountID: 6002, LoadRate: 50, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ // Verify account A (high capacity) is always selected first by score.
+ // Because weighted selection has randomness, we run multiple iterations
+ // and verify A is selected more often than B.
+ countA := 0
+ countB := 0
+ iterations := 100
+ for i := 0; i < iterations; i++ {
+ sessionHash := fmt.Sprintf("cap_aware_test_%d", i)
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ if selection.Account.ID == 6001 {
+ countA++
+ } else {
+ countB++
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ // Account A (100 concurrency) should be selected significantly more often
+ // than Account B (10 concurrency) because A has 50 free slots vs 5 free slots.
+ require.Greater(t, countA, countB,
+ "high-capacity account (50 free slots) should be selected more often than low-capacity (5 free slots) at equal load percentage; got A=%d B=%d", countA, countB)
+
+ // -----------------------------------------------------------------------
+ // Verify score math directly via the capacity-aware loadFactor formula.
+ // -----------------------------------------------------------------------
+ // maxConcurrency = 100 (from account A)
+ //
+ // Account A (Concurrency=100, LoadRate=50):
+ // base loadFactor = 1 - 50/100 = 0.5
+ // remainingSlots = 100 * 0.5 = 50
+ // capacityBonus = 50 / 100 = 0.5
+ // loadFactor = 0.7*0.5 + 0.3*0.5 = 0.5
+ //
+ // Account B (Concurrency=10, LoadRate=50):
+ // base loadFactor = 1 - 50/100 = 0.5
+ // remainingSlots = 10 * 0.5 = 5
+ // capacityBonus = 5 / 100 = 0.05
+ // loadFactor = 0.7*0.5 + 0.3*0.05 = 0.365
+ //
+ // With Load weight = 1.0 and all others 0.0, score = loadFactor.
+ expectedScoreA := 0.7*0.5 + 0.3*0.5 // 0.5
+ expectedScoreB := 0.7*0.5 + 0.3*(5.0/100.0) // 0.365
+ require.Greater(t, expectedScoreA, expectedScoreB, "score sanity check")
+ require.InDelta(t, 0.5, expectedScoreA, 1e-9)
+ require.InDelta(t, 0.365, expectedScoreB, 1e-9)
+}
+
+func TestQueueFactorCapacityAwareness(t *testing.T) {
+ // Test that the capacity-aware queue factor penalises accounts
+ // whose queue depth is high relative to their own concurrency.
+ //
+ // Account A: Concurrency=100, WaitingCount=10 (10% of capacity)
+ // Account B: Concurrency=10, WaitingCount=10 (100% of capacity)
+ // Both have same absolute waiting count, but B should score lower.
+
+ ctx := context.Background()
+ groupID := int64(21)
+ accounts := []Account{
+ {
+ ID: 7001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 100,
+ Priority: 0,
+ },
+ {
+ ID: 7002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 10,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ // Use only Queue weight to isolate the capacity-aware queueFactor effect.
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 7001: {AccountID: 7001, LoadRate: 30, WaitingCount: 10},
+ 7002: {AccountID: 7002, LoadRate: 30, WaitingCount: 10},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ countA := 0
+ countB := 0
+ iterations := 100
+ for i := 0; i < iterations; i++ {
+ sessionHash := fmt.Sprintf("queue_aware_test_%d", i)
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ if selection.Account.ID == 7001 {
+ countA++
+ } else {
+ countB++
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ require.Greater(t, countA, countB,
+ "account with lower relative queue depth should be selected more often; got A=%d B=%d", countA, countB)
+
+ // -----------------------------------------------------------------------
+ // Verify score math for the capacity-aware queueFactor.
+ // -----------------------------------------------------------------------
+ // maxWaiting = 10 (both accounts have WaitingCount=10)
+ //
+ // Account A (Concurrency=100, WaitingCount=10):
+ // base queueFactor = 1 - 10/10 = 0.0
+ // relativeQueue = 10/100 = 0.1
+ // queueFactor = 0.6*0.0 + 0.4*(1-0.1) = 0.36
+ //
+ // Account B (Concurrency=10, WaitingCount=10):
+ // base queueFactor = 1 - 10/10 = 0.0
+ // relativeQueue = clamp01(10/10) = 1.0
+ // queueFactor = 0.6*0.0 + 0.4*(1-1.0) = 0.0
+ expectedQueueA := 0.6*0.0 + 0.4*(1-0.1)
+ expectedQueueB := 0.6*0.0 + 0.4*(1-1.0)
+ require.Greater(t, expectedQueueA, expectedQueueB)
+ require.InDelta(t, 0.36, expectedQueueA, 1e-9)
+ require.InDelta(t, 0.0, expectedQueueB, 1e-9)
+}
+
+func TestLoadFactorCapacityAwareness_ZeroConcurrencyFallback(t *testing.T) {
+ // When Concurrency is 0, the capacity-aware blending should be skipped
+ // and loadFactor should fall back to the simple loadRate/100 formula.
+
+ ctx := context.Background()
+ groupID := int64(22)
+ accounts := []Account{
+ {
+ ID: 8001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 0, // unset / zero
+ Priority: 0,
+ },
+ {
+ ID: 8002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 0,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 8001: {AccountID: 8001, LoadRate: 30, WaitingCount: 0},
+ 8002: {AccountID: 8002, LoadRate: 70, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ // With Concurrency=0, maxConcurrency=0, so the capacity-aware path is skipped.
+ // Account 8001 (LoadRate=30) should have higher loadFactor than 8002 (LoadRate=70).
+ countLow := 0
+ iterations := 60
+ for i := 0; i < iterations; i++ {
+ sessionHash := fmt.Sprintf("zero_conc_%d", i)
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ if selection.Account.ID == 8001 {
+ countLow++
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ // 8001 (lower load) should be picked more often.
+ require.Greater(t, countLow, iterations/2,
+ "account with lower load should be selected more often when concurrency is 0; got %d/%d", countLow, iterations)
+}
+
func int64PtrForTest(v int64) *int64 {
return &v
}
+
+// ---------------------------------------------------------------------------
+// Circuit Breaker Tests
+// ---------------------------------------------------------------------------
+
+func TestAccountCircuitBreaker_ClosedToOpen(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 30 * time.Second
+ halfOpenMax := 2
+
+ // Initially CLOSED — should allow.
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+ require.Equal(t, "CLOSED", cb.stateString())
+ require.False(t, cb.isOpen())
+
+ // Record 4 failures — should still be CLOSED (threshold is 5).
+ for i := 0; i < 4; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "CLOSED", cb.stateString())
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+
+ // 5th failure trips the breaker to OPEN.
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ require.Equal(t, "OPEN", cb.stateString())
+ require.True(t, cb.isOpen())
+ require.False(t, cb.allow(cooldown, halfOpenMax))
+}
+
+func TestAccountCircuitBreaker_OpenToHalfOpen(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 50 * time.Millisecond
+ halfOpenMax := 2
+
+ // Trip the breaker.
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "OPEN", cb.stateString())
+ require.False(t, cb.allow(cooldown, halfOpenMax))
+
+ // Wait for cooldown to elapse.
+ time.Sleep(cooldown + 10*time.Millisecond)
+
+ // Next allow() should transition to HALF_OPEN and admit the request.
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+}
+
+func TestAccountCircuitBreaker_HalfOpenToClose(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 50 * time.Millisecond
+ halfOpenMax := 2
+
+ // Trip the breaker.
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "OPEN", cb.stateString())
+
+ // Wait for cooldown.
+ time.Sleep(cooldown + 10*time.Millisecond)
+
+ // Allow first probe — transitions to HALF_OPEN.
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+
+ // Allow second probe.
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+
+ // Third probe should be rejected (halfOpenMax=2).
+ require.False(t, cb.allow(cooldown, halfOpenMax))
+
+ // Both probes succeed — should close the circuit.
+ cb.recordSuccess()
+ // After first success, still HALF_OPEN (need both to succeed).
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+ cb.recordSuccess()
+ // Both probes succeeded — circuit should be CLOSED now.
+ require.Equal(t, "CLOSED", cb.stateString())
+ require.False(t, cb.isOpen())
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+}
+
+func TestAccountCircuitBreaker_ReleaseHalfOpenPermit(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 10 * time.Millisecond
+ halfOpenMax := 2
+
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "OPEN", cb.stateString())
+
+ time.Sleep(cooldown + 5*time.Millisecond)
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+ require.Equal(t, int32(1), cb.halfOpenInFlight.Load())
+
+ cb.releaseHalfOpenPermit()
+ require.Equal(t, int32(0), cb.halfOpenInFlight.Load())
+
+ // Idempotent release should not underflow.
+ cb.releaseHalfOpenPermit()
+ require.Equal(t, int32(0), cb.halfOpenInFlight.Load())
+}
+
+func TestAccountCircuitBreaker_HalfOpenToOpen(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 50 * time.Millisecond
+ halfOpenMax := 2
+
+ // Trip the breaker.
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "OPEN", cb.stateString())
+
+ // Wait for cooldown.
+ time.Sleep(cooldown + 10*time.Millisecond)
+
+ // Allow a probe — transitions to HALF_OPEN.
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+
+ // Failure in HALF_OPEN should trip back to OPEN.
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ require.Equal(t, "OPEN", cb.stateString())
+ require.True(t, cb.isOpen())
+ require.False(t, cb.allow(cooldown, halfOpenMax))
+}
+
+func TestAccountCircuitBreaker_ResetOnSuccess(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+
+ // 4 failures followed by a success should reset the counter.
+ for i := 0; i < 4; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, int32(4), cb.consecutiveFails.Load())
+
+ cb.recordSuccess()
+ require.Equal(t, int32(0), cb.consecutiveFails.Load())
+ require.Equal(t, "CLOSED", cb.stateString())
+
+ // 4 more failures — still not tripped because counter was reset.
+ for i := 0; i < 4; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "CLOSED", cb.stateString())
+
+ // 5th consecutive failure trips it.
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ require.Equal(t, "OPEN", cb.stateString())
+}
+
+func TestAccountCircuitBreaker_IntegrationWithScheduler(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(30)
+ accounts := []Account{
+ {
+ ID: 9001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ {
+ ID: 9002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
+ // Enable circuit breaker with low threshold for testing.
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 9001: {AccountID: 9001, LoadRate: 10, WaitingCount: 0},
+ 9002: {AccountID: 9002, LoadRate: 10, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ scheduler := svc.getOpenAIAccountScheduler()
+
+ // Report 3 consecutive failures for account 9001 — trips the circuit breaker.
+ for i := 0; i < 3; i++ {
+ scheduler.ReportResult(9001, false, nil, "", 0)
+ }
+
+ // Now all selections should avoid account 9001 and pick 9002.
+ for i := 0; i < 20; i++ {
+ sessionHash := fmt.Sprintf("cb_integration_%d", i)
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(9002), selection.Account.ID,
+ "circuit-open account 9001 should be skipped, got %d on iteration %d", selection.Account.ID, i)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ // Verify metrics tracked the trip.
+ snapshot := scheduler.SnapshotMetrics()
+ require.GreaterOrEqual(t, snapshot.CircuitBreakerOpenTotal, int64(1))
+}
+
+func TestAccountCircuitBreaker_AllOpenFallback(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(31)
+ accounts := []Account{
+ {
+ ID: 9101,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 9101: {AccountID: 9101, LoadRate: 10, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ scheduler := svc.getOpenAIAccountScheduler()
+
+ // Trip the only account.
+ for i := 0; i < 3; i++ {
+ scheduler.ReportResult(9101, false, nil, "", 0)
+ }
+
+ // Even though the only account is circuit-open, the scheduler should
+ // still return it (graceful degradation — never return "no accounts").
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "cb_fallback_test",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(9101), selection.Account.ID)
+}
+
+func TestAccountCircuitBreaker_SelectReleasesUnselectedHalfOpenPermit(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(311)
+ accounts := []Account{
+ {
+ ID: 9111,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ {
+ ID: 9112,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 1
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 1
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 1
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 9111: {AccountID: 9111, LoadRate: 10, WaitingCount: 0},
+ 9112: {AccountID: 9112, LoadRate: 10, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ scheduler, ok := svc.getOpenAIAccountScheduler().(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+
+ // Trip both accounts to OPEN so next select will transition both to HALF_OPEN.
+ scheduler.ReportResult(9111, false, nil, "", 0)
+ scheduler.ReportResult(9112, false, nil, "", 0)
+ scheduler.stats.getCircuitBreaker(9111).lastFailureNano.Store(time.Now().Add(-2 * time.Second).UnixNano())
+ scheduler.stats.getCircuitBreaker(9112).lastFailureNano.Store(time.Now().Add(-2 * time.Second).UnixNano())
+
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "cb_release_unselected",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+
+ selectedID := selection.Account.ID
+ otherID := int64(9111)
+ if selectedID == otherID {
+ otherID = 9112
+ }
+ otherCB := scheduler.stats.getCircuitBreaker(otherID)
+ require.Equal(t, int32(0), otherCB.halfOpenInFlight.Load(),
+ "unselected HALF_OPEN candidate should release probe permit")
+
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestAccountCircuitBreaker_DisabledByConfig(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(32)
+ accounts := []Account{
+ {
+ ID: 9201,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ {
+ ID: 9202,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
+ // Circuit breaker explicitly DISABLED.
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = false
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 9201: {AccountID: 9201, LoadRate: 10, WaitingCount: 0},
+ 9202: {AccountID: 9202, LoadRate: 10, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ scheduler := svc.getOpenAIAccountScheduler()
+ internalScheduler, ok := scheduler.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+
+ // Report many failures — should NOT affect scheduling when disabled.
+ for i := 0; i < 10; i++ {
+ scheduler.ReportResult(9201, false, nil, "", 0)
+ }
+
+ // Both accounts should still be eligible.
+ selected := map[int64]int{}
+ for i := 0; i < 40; i++ {
+ sessionHash := fmt.Sprintf("cb_disabled_%d", i)
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ selected[selection.Account.ID]++
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+ // When disabled, 9201 should still appear as a candidate.
+ require.Greater(t, selected[int64(9201)]+selected[int64(9202)], 0)
+ require.Len(t, selected, 2, "both accounts should be selectable when CB is disabled")
+ cb := internalScheduler.stats.getCircuitBreaker(9201)
+ require.False(t, cb.isOpen(), "circuit breaker should not transition to OPEN when feature is disabled")
+ require.Equal(t, int64(0), internalScheduler.metrics.circuitBreakerOpenTotal.Load())
+}
+
+func TestAccountCircuitBreaker_RecoveryMetrics(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ svc := &OpenAIGatewayService{}
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+
+ // Manually enable CB by setting config on the service.
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 0 // immediate cooldown for test
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 1
+ scheduler.service.cfg = cfg
+
+ // Trip the breaker: 3 consecutive failures.
+ for i := 0; i < 3; i++ {
+ scheduler.ReportResult(5001, false, nil, "", 0)
+ }
+ require.Equal(t, int64(1), scheduler.metrics.circuitBreakerOpenTotal.Load())
+
+ // Let the cooldown expire (0 seconds) and call allow to trigger HALF_OPEN.
+ cb := stats.getCircuitBreaker(5001)
+ require.Equal(t, "OPEN", cb.stateString())
+ allowed := cb.allow(0, 1)
+ require.True(t, allowed)
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+
+ // Report success — should transition HALF_OPEN → CLOSED.
+ scheduler.ReportResult(5001, true, nil, "", 0)
+ require.Equal(t, "CLOSED", cb.stateString())
+ require.Equal(t, int64(1), scheduler.metrics.circuitBreakerRecoverTotal.Load())
+}
+
+func TestAccountCircuitBreaker_StateString(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ require.Equal(t, "CLOSED", cb.stateString())
+
+ cb.state.Store(circuitBreakerStateOpen)
+ require.Equal(t, "OPEN", cb.stateString())
+
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+
+ cb.state.Store(99)
+ require.Equal(t, "UNKNOWN", cb.stateString())
+}
+
+func TestAccountCircuitBreaker_GetAndIsCircuitOpen(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+
+ // isCircuitOpen on a non-existent account should return false.
+ require.False(t, stats.isCircuitOpen(1234))
+
+ // getCircuitBreaker should create on first access.
+ cb := stats.getCircuitBreaker(1234)
+ require.NotNil(t, cb)
+ require.Equal(t, "CLOSED", cb.stateString())
+ require.False(t, stats.isCircuitOpen(1234))
+
+ // Trip it and verify isCircuitOpen returns true.
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.True(t, stats.isCircuitOpen(1234))
+
+ // Second call to getCircuitBreaker should return same instance.
+ cb2 := stats.getCircuitBreaker(1234)
+ require.True(t, cb == cb2, "should return same pointer")
+}
+
+func TestAccountCircuitBreaker_ConcurrentAllowAndRecord(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 50 * time.Millisecond
+ halfOpenMax := 4
+
+ var wg sync.WaitGroup
+ const workers = 16
+ const iterations = 200
+
+ wg.Add(workers)
+ for w := 0; w < workers; w++ {
+ w := w
+ go func() {
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ _ = cb.allow(cooldown, halfOpenMax)
+ if (i+w)%3 == 0 {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ } else {
+ cb.recordSuccess()
+ }
+ }
+ }()
+ }
+ wg.Wait()
+
+ // Just verify it doesn't panic or deadlock, and state is valid.
+ state := cb.state.Load()
+ require.True(t, state == circuitBreakerStateClosed ||
+ state == circuitBreakerStateOpen ||
+ state == circuitBreakerStateHalfOpen,
+ "unexpected state: %d", state)
+}
+
+// ---------------------------------------------------------------------------
+// P2C (Power-of-Two-Choices) Tests
+// ---------------------------------------------------------------------------
+
+func TestSelectP2COpenAICandidates_BasicSelection(t *testing.T) {
+ // P2C should return all candidates in some order, and higher-scored
+ // candidates should tend to appear earlier in the selection order.
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 1}, score: 0.9},
+ {account: &Account{ID: 2, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 2}, score: 0.5},
+ {account: &Account{ID: 3, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 3}, score: 0.1},
+ {account: &Account{ID: 4, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 4}, score: 0.7},
+ {account: &Account{ID: 5, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 5}, score: 0.3},
+ }
+
+ req := OpenAIAccountScheduleRequest{
+ SessionHash: "p2c_basic_test",
+ }
+
+ result := selectP2COpenAICandidates(candidates, req)
+
+ // All candidates must be present exactly once.
+ require.Len(t, result, len(candidates))
+ seen := map[int64]bool{}
+ for _, c := range result {
+ require.False(t, seen[c.account.ID], "duplicate account ID %d", c.account.ID)
+ seen[c.account.ID] = true
+ }
+ for _, c := range candidates {
+ require.True(t, seen[c.account.ID], "missing account ID %d", c.account.ID)
+ }
+
+ // Statistical check: over many runs the highest-scored candidate (ID=1,
+ // score=0.9) should appear in position 0 more often than the lowest-scored
+ // candidate (ID=3, score=0.1).
+ topCount := map[int64]int{}
+ iterations := 500
+ for i := 0; i < iterations; i++ {
+ iterReq := OpenAIAccountScheduleRequest{
+ SessionHash: fmt.Sprintf("p2c_stat_%d", i),
+ }
+ order := selectP2COpenAICandidates(candidates, iterReq)
+ topCount[order[0].account.ID]++
+ }
+ require.Greater(t, topCount[int64(1)], topCount[int64(3)],
+ "highest-scored candidate should appear first more often than lowest-scored; got best=%d worst=%d",
+ topCount[int64(1)], topCount[int64(3)])
+}
+
+func TestSelectP2COpenAICandidates_SingleCandidate(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 42, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 42}, score: 1.0},
+ }
+ req := OpenAIAccountScheduleRequest{SessionHash: "single"}
+
+ result := selectP2COpenAICandidates(candidates, req)
+ require.Len(t, result, 1)
+ require.Equal(t, int64(42), result[0].account.ID)
+}
+
+func TestSelectP2COpenAICandidates_DeterministicWithSameSeed(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 10, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 10}, score: 0.8},
+ {account: &Account{ID: 20, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 20}, score: 0.6},
+ {account: &Account{ID: 30, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 30}, score: 0.4},
+ {account: &Account{ID: 40, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 40}, score: 0.2},
+ }
+ // Use a session hash to ensure the seed is deterministic (no time entropy).
+ req := OpenAIAccountScheduleRequest{
+ SessionHash: "deterministic_p2c_seed",
+ }
+
+ first := selectP2COpenAICandidates(candidates, req)
+ for i := 0; i < 10; i++ {
+ again := selectP2COpenAICandidates(candidates, req)
+ require.Len(t, again, len(first))
+ for j := range first {
+ require.Equal(t, first[j].account.ID, again[j].account.ID,
+ "iteration %d position %d mismatch", i, j)
+ }
+ }
+}
+
+func TestP2CLoadBalanceIntegration(t *testing.T) {
+ // End-to-end test: enable P2C via config, verify it distributes across
+ // accounts and that decision.TopK == 0 (P2C mode indicator).
+ ctx := context.Background()
+ groupID := int64(50)
+ accounts := []Account{
+ {
+ ID: 5001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey,
+ Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0,
+ },
+ {
+ ID: 5002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey,
+ Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0,
+ },
+ {
+ ID: 5003, Platform: PlatformOpenAI, Type: AccountTypeAPIKey,
+ Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.7
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.8
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.5
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 5001: {AccountID: 5001, LoadRate: 20, WaitingCount: 0},
+ 5002: {AccountID: 5002, LoadRate: 30, WaitingCount: 0},
+ 5003: {AccountID: 5003, LoadRate: 40, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selected := map[int64]int{}
+ iterations := 100
+ for i := 0; i < iterations; i++ {
+ sessionHash := fmt.Sprintf("p2c_integration_%d", i)
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ // P2C mode: TopK should be 0.
+ require.Equal(t, 0, decision.TopK, "P2C mode should set TopK=0")
+ selected[selection.Account.ID]++
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ // P2C with 3 candidates: the two better-scored accounts (5001, 5002)
+ // should be selected, while the weakest (5003) may rarely or never win
+ // a P2C tournament. Verify at least 2 distinct accounts are picked and
+ // the lowest-loaded account dominates.
+ require.GreaterOrEqual(t, len(selected), 2,
+ "P2C should distribute across at least 2 accounts; got %v", selected)
+ require.Greater(t, selected[int64(5001)], 0,
+ "lowest-loaded account 5001 should be selected at least once")
+ require.Greater(t, selected[int64(5001)], selected[int64(5003)],
+ "lowest-loaded account should be favored over highest-loaded; got 5001=%d 5003=%d",
+ selected[int64(5001)], selected[int64(5003)])
+}
+
+func TestP2CFallbackToTopK(t *testing.T) {
+ // When P2C is disabled (default), the Top-K path should be used.
+ // Verify topK > 0 in decision.
+ ctx := context.Background()
+ groupID := int64(51)
+ accounts := []Account{
+ {
+ ID: 5101, Platform: PlatformOpenAI, Type: AccountTypeAPIKey,
+ Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0,
+ },
+ {
+ ID: 5102, Platform: PlatformOpenAI, Type: AccountTypeAPIKey,
+ Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ // Explicitly disable P2C (or leave at default false).
+ cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.7
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.8
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.5
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 5101: {AccountID: 5101, LoadRate: 10, WaitingCount: 0},
+ 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "topk_fallback_test",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ // Top-K mode: TopK should be > 0.
+ require.Greater(t, decision.TopK, 0, "Top-K mode should set TopK > 0")
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+
+ // Also verify P2C helper returns false when disabled.
+ require.False(t, svc.openAIWSSchedulerP2CEnabled())
+}
+
+// ---------------------------------------------------------------------------
+// Conditional Sticky Session Release Tests
+// ---------------------------------------------------------------------------
+
+// buildConditionalStickyTestService creates a minimal OpenAIGatewayService and
+// scheduler with injectable runtime stats for conditional sticky tests.
+func buildConditionalStickyTestService(
+ accounts []Account,
+ stickyKey string,
+ stickyAccountID int64,
+ stickyReleaseEnabled bool,
+ cbEnabled bool,
+) (*OpenAIGatewayService, *defaultOpenAIAccountScheduler) {
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{
+ stickyKey: stickyAccountID,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickyReleaseEnabled = stickyReleaseEnabled
+ // Leave StickyReleaseErrorThreshold at 0 to use the default (0.3).
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = cbEnabled
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3
+ cfg.Gateway.Scheduling.StickySessionMaxWaiting = 5
+ cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := &defaultOpenAIAccountScheduler{
+ service: svc,
+ stats: stats,
+ }
+ // Wire the scheduler into the service so that SelectAccountWithScheduler
+ // uses it via getOpenAIAccountScheduler.
+ svc.openaiScheduler = scheduler
+ svc.openaiAccountStats = stats
+ return svc, scheduler
+}
+
+func TestConditionalSticky_ReleaseOnHighErrorRate(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(30001)
+ stickyAccount := Account{
+ ID: 5001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+ fallbackAccount := Account{
+ ID: 5002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+
+ svc, scheduler := buildConditionalStickyTestService(
+ []Account{stickyAccount, fallbackAccount},
+ fmt.Sprintf("openai:sticky_err_%d", groupID),
+ stickyAccount.ID,
+ true, // stickyReleaseEnabled
+ false, // cbEnabled (not needed for error rate test)
+ )
+
+ // Pump the error rate above the 0.3 default threshold.
+ // With alpha=0.5 (fast EWMA), after ~5 consecutive failures the rate
+ // converges well above 0.3.
+ for i := 0; i < 10; i++ {
+ scheduler.stats.report(stickyAccount.ID, false, nil, "", 0)
+ }
+ errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID)
+ require.Greater(t, errRate, 0.3, "error rate should exceed threshold before test")
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "",
+ fmt.Sprintf("sticky_err_%d", groupID),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ // The sticky account should have been released; the scheduler should
+ // have fallen through to load balance and selected one of the accounts.
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer,
+ "should fall through to load balance after sticky release")
+ require.False(t, decision.StickySessionHit, "sticky hit should be false")
+}
+
+func TestConditionalSticky_ReleaseOnCircuitOpen(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(30002)
+ stickyAccount := Account{
+ ID: 5011,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+ fallbackAccount := Account{
+ ID: 5012,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+
+ svc, scheduler := buildConditionalStickyTestService(
+ []Account{stickyAccount, fallbackAccount},
+ fmt.Sprintf("openai:sticky_cb_%d", groupID),
+ stickyAccount.ID,
+ true, // stickyReleaseEnabled
+ true, // cbEnabled
+ )
+
+ // Trip the circuit breaker by reporting consecutive failures beyond
+ // the configured threshold (3).
+ for i := 0; i < 5; i++ {
+ scheduler.ReportResult(stickyAccount.ID, false, nil, "", 0)
+ }
+ require.True(t, scheduler.stats.isCircuitOpen(stickyAccount.ID),
+ "circuit breaker should be OPEN before test")
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "",
+ fmt.Sprintf("sticky_cb_%d", groupID),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer,
+ "should fall through to load balance after sticky release due to CB open")
+ require.False(t, decision.StickySessionHit)
+}
+
+func TestConditionalSticky_KeepsHealthySticky(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(30003)
+ stickyAccount := Account{
+ ID: 5021,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+
+ svc, scheduler := buildConditionalStickyTestService(
+ []Account{stickyAccount},
+ fmt.Sprintf("openai:sticky_ok_%d", groupID),
+ stickyAccount.ID,
+ true, // stickyReleaseEnabled
+ false, // cbEnabled
+ )
+
+ // Report some successes so error rate stays at 0.
+ for i := 0; i < 5; i++ {
+ scheduler.stats.report(stickyAccount.ID, true, nil, "", 0)
+ }
+ errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID)
+ require.Less(t, errRate, 0.3, "error rate should be below threshold")
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "",
+ fmt.Sprintf("sticky_ok_%d", groupID),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, stickyAccount.ID, selection.Account.ID,
+ "healthy sticky account should be kept")
+ require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
+ require.True(t, decision.StickySessionHit)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestConditionalSticky_DisabledByConfig(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(30004)
+ stickyAccount := Account{
+ ID: 5031,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+
+ svc, scheduler := buildConditionalStickyTestService(
+ []Account{stickyAccount},
+ fmt.Sprintf("openai:sticky_off_%d", groupID),
+ stickyAccount.ID,
+ false, // stickyReleaseEnabled = OFF
+ false, // cbEnabled
+ )
+
+ // Pump error rate very high, but since sticky release is disabled,
+ // the sticky binding should still hold.
+ for i := 0; i < 10; i++ {
+ scheduler.stats.report(stickyAccount.ID, false, nil, "", 0)
+ }
+ errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID)
+ require.Greater(t, errRate, 0.3, "error rate should exceed threshold")
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "",
+ fmt.Sprintf("sticky_off_%d", groupID),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, stickyAccount.ID, selection.Account.ID,
+ "sticky should be kept when feature is disabled")
+ require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
+ require.True(t, decision.StickySessionHit)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestConditionalSticky_Metrics(t *testing.T) {
+ groupID := int64(30005)
+ ctx := context.Background()
+
+ // --- Error rate release metric ---
+ stickyAccount := Account{
+ ID: 5041,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+ fallbackAccount := Account{
+ ID: 5042,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+
+ svc, scheduler := buildConditionalStickyTestService(
+ []Account{stickyAccount, fallbackAccount},
+ fmt.Sprintf("openai:sticky_m1_%d", groupID),
+ stickyAccount.ID,
+ true, // stickyReleaseEnabled
+ true, // cbEnabled
+ )
+
+ // Trigger error-rate release. With CB also enabled and threshold=3,
+ // the CB will be OPEN after 3 failures via stats.report (which uses
+ // the default CB threshold of 5). Send enough to ensure both are
+ // triggered.
+ for i := 0; i < 10; i++ {
+ scheduler.stats.report(stickyAccount.ID, false, nil, "", 0)
+ }
+
+ _, _, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "",
+ fmt.Sprintf("sticky_m1_%d", groupID),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+
+ snap := scheduler.SnapshotMetrics()
+ // At least one of the two release metrics should have been incremented.
+ totalReleases := snap.StickyReleaseErrorTotal + snap.StickyReleaseCircuitOpenTotal
+ require.Greater(t, totalReleases, int64(0),
+ "at least one sticky release metric should be incremented")
+
+ // --- Circuit breaker release metric (clean setup) ---
+ groupID2 := int64(30006)
+ svc2, scheduler2 := buildConditionalStickyTestService(
+ []Account{stickyAccount, fallbackAccount},
+ fmt.Sprintf("openai:sticky_m2_%d", groupID2),
+ stickyAccount.ID,
+ true, // stickyReleaseEnabled
+ true, // cbEnabled
+ )
+
+ // Trip CB via ReportResult (which checks the configured threshold=3).
+ for i := 0; i < 5; i++ {
+ scheduler2.ReportResult(stickyAccount.ID, false, nil, "", 0)
+ }
+ require.True(t, scheduler2.stats.isCircuitOpen(stickyAccount.ID))
+
+ _, _, err = svc2.SelectAccountWithScheduler(
+ ctx, &groupID2, "",
+ fmt.Sprintf("sticky_m2_%d", groupID2),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+
+ snap2 := scheduler2.SnapshotMetrics()
+ require.Greater(t, snap2.StickyReleaseCircuitOpenTotal, int64(0),
+ "circuit-open sticky release metric should be incremented")
+}
+
+// ---------------------------------------------------------------------------
+// Softmax Temperature Sampling Tests
+// ---------------------------------------------------------------------------
+
+// makeSoftmaxCandidates builds N candidates with the given scores.
+func makeSoftmaxCandidates(scores ...float64) []openAIAccountCandidateScore {
+ out := make([]openAIAccountCandidateScore, len(scores))
+ for i, s := range scores {
+ out[i] = openAIAccountCandidateScore{
+ account: &Account{ID: int64(i + 1), Priority: 0},
+ loadInfo: &AccountLoadInfo{AccountID: int64(i + 1)},
+ score: s,
+ }
+ }
+ return out
+}
+
+func TestSoftmax_LowTemperatureApproximatesArgmax(t *testing.T) {
+ // With a very low temperature the highest-scored candidate should win
+ // almost every time.
+ candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5)
+
+ winCount := 0
+ trials := 100
+ for i := 0; i < trials; i++ {
+ rng := newOpenAISelectionRNG(uint64(i + 1))
+ result := selectSoftmaxOpenAICandidates(candidates, 0.01, &rng)
+ require.Len(t, result, len(candidates))
+ if result[0].account.ID == 1 { // ID 1 has score 5.0 (the highest)
+ winCount++
+ }
+ }
+
+ require.Greater(t, winCount, 90,
+ "with temperature=0.01 the highest-scored candidate should win >90%% of trials; got %d/%d", winCount, trials)
+}
+
+func TestSoftmax_HighTemperatureApproximatesUniform(t *testing.T) {
+ // With a very high temperature, all candidates should get roughly equal
+ // selection frequency.
+ candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5)
+
+ counts := map[int64]int{}
+ trials := 1000
+ for i := 0; i < trials; i++ {
+ rng := newOpenAISelectionRNG(uint64(i + 1))
+ result := selectSoftmaxOpenAICandidates(candidates, 100.0, &rng)
+ require.Len(t, result, len(candidates))
+ counts[result[0].account.ID]++
+ }
+
+ expected := float64(trials) / float64(len(candidates)) // 250
+ for id, count := range counts {
+ require.InDelta(t, expected, float64(count), float64(trials)*0.10,
+ "candidate ID=%d expected ~%.0f selections, got %d", id, expected, count)
+ }
+}
+
+func TestSoftmax_DefaultTemperature(t *testing.T) {
+ // With the default temperature (0.3), higher-scored candidates should be
+ // picked more often than lower-scored ones.
+ candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5)
+
+ counts := map[int64]int{}
+ trials := 1000
+ for i := 0; i < trials; i++ {
+ rng := newOpenAISelectionRNG(uint64(i + 1))
+ result := selectSoftmaxOpenAICandidates(candidates, defaultSoftmaxTemperature, &rng)
+ counts[result[0].account.ID]++
+ }
+
+ // The candidate with the highest score (ID=1, score=5.0) should be
+ // selected more often than the candidate with the lowest (ID=4, score=0.5).
+ require.Greater(t, counts[int64(1)], counts[int64(4)],
+ "highest-scored candidate should be picked more often; best=%d worst=%d",
+ counts[int64(1)], counts[int64(4)])
+
+ // Also check that the top-scored candidate beats the second-highest.
+ require.Greater(t, counts[int64(1)], counts[int64(2)],
+ "score=5.0 should beat score=3.0; got %d vs %d",
+ counts[int64(1)], counts[int64(2)])
+}
+
+func TestSoftmax_SingleCandidate(t *testing.T) {
+ candidates := makeSoftmaxCandidates(7.5)
+
+ rng := newOpenAISelectionRNG(42)
+ result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng)
+
+ require.Len(t, result, 1)
+ require.Equal(t, int64(1), result[0].account.ID)
+ require.Equal(t, 7.5, result[0].score)
+}
+
+func TestSoftmax_TwoCandidates(t *testing.T) {
+ // Use a moderate score gap (1.0 vs 0.5) with temperature=1.0 so both
+ // candidates have meaningful selection probability.
+ candidates := makeSoftmaxCandidates(1.0, 0.5)
+
+ counts := map[int64]int{}
+ trials := 1000
+ for i := 0; i < trials; i++ {
+ rng := newOpenAISelectionRNG(uint64(i + 1))
+ result := selectSoftmaxOpenAICandidates(candidates, 1.0, &rng)
+ require.Len(t, result, 2)
+ counts[result[0].account.ID]++
+ }
+
+ // Both candidates should be selected at least once (proving no
+ // single-candidate monopoly), and the higher-scored one should dominate.
+ require.Greater(t, counts[int64(1)], 0, "high-scored candidate must be selected at least once")
+ require.Greater(t, counts[int64(2)], 0, "low-scored candidate must be selected at least once")
+ require.Greater(t, counts[int64(1)], counts[int64(2)],
+ "higher-scored candidate should be picked more often; got %d vs %d",
+ counts[int64(1)], counts[int64(2)])
+}
+
+func TestSoftmax_EqualScores(t *testing.T) {
+ // When all scores are equal, selection should be approximately uniform.
+ candidates := makeSoftmaxCandidates(3.0, 3.0, 3.0, 3.0)
+
+ counts := map[int64]int{}
+ trials := 1000
+ for i := 0; i < trials; i++ {
+ rng := newOpenAISelectionRNG(uint64(i + 1))
+ result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng)
+ counts[result[0].account.ID]++
+ }
+
+ expected := float64(trials) / float64(len(candidates)) // 250
+ for id, count := range counts {
+ require.InDelta(t, expected, float64(count), float64(trials)*0.10,
+ "equal scores should yield ~uniform distribution; ID=%d expected ~%.0f got %d",
+ id, expected, count)
+ }
+}
+
+func TestSoftmax_NumericalStability(t *testing.T) {
+ // Large score differences should not cause overflow or NaN.
+ candidates := makeSoftmaxCandidates(100.0, -100.0, 50.0, -50.0)
+
+ rng := newOpenAISelectionRNG(12345)
+ result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng)
+
+ require.Len(t, result, len(candidates))
+ // Verify all scores are finite in the output (no NaN or Inf propagation).
+ for _, c := range result {
+ require.False(t, math.IsNaN(c.score), "score should not be NaN")
+ require.False(t, math.IsInf(c.score, 0), "score should not be Inf")
+ }
+ // All candidates must appear exactly once.
+ seen := map[int64]bool{}
+ for _, c := range result {
+ require.False(t, seen[c.account.ID], "duplicate account ID %d", c.account.ID)
+ seen[c.account.ID] = true
+ }
+ require.Len(t, seen, len(candidates))
+
+ // With such extreme differences at temperature=0.3, the highest scorer (100.0)
+ // should always win because exp((100 - 100)/0.3) = 1 while
+ // exp((-100 - 100)/0.3) ~= 0 (numerically stable via maxScore subtraction).
+ winCount := 0
+ for i := 0; i < 100; i++ {
+ rng2 := newOpenAISelectionRNG(uint64(i + 1))
+ r := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng2)
+ if r[0].account.ID == 1 { // score 100.0
+ winCount++
+ }
+ }
+ require.Greater(t, winCount, 95,
+ "with extreme score gap, the highest scorer should win nearly always; got %d/100", winCount)
+}
+
+func TestSoftmax_DisabledFallsThrough(t *testing.T) {
+ // When softmax is disabled, the scheduler should fall through to P2C or Top-K.
+ ctx := context.Background()
+ groupID := int64(40001)
+ accounts := make([]Account, 5)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(6001 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+ }
+
+ cache := &stubGatewayCache{}
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ // Softmax explicitly disabled (default).
+ cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = false
+ // P2C also disabled.
+ cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "", "softmax_disabled_test",
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ // TopK should be > 0, confirming Top-K path was taken.
+ require.Greater(t, decision.TopK, 0, "should fall through to Top-K when softmax is disabled")
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestSoftmax_FewCandidatesFallsThrough(t *testing.T) {
+ // When softmax is enabled but there are <= 3 candidates, it should fall
+ // through to the next strategy (P2C or Top-K).
+ ctx := context.Background()
+ groupID := int64(40002)
+ // Only 3 accounts — softmax guard requires >3.
+ accounts := make([]Account, 3)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(7001 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+ }
+
+ cache := &stubGatewayCache{}
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ // Softmax enabled but should not activate with only 3 candidates.
+ cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0.5
+ // P2C disabled, so it should fall through to Top-K.
+ cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "", "softmax_few_candidates_test",
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ // TopK should be > 0, confirming Top-K path was taken instead of softmax.
+ require.Greater(t, decision.TopK, 0, "should fall through to Top-K when candidate count <= 3")
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestSoftmax_ConfigDefaults(t *testing.T) {
+ // When config values are zero/unset, defaults should be applied.
+
+ // Case 1: nil service — returns empty config.
+ nilScheduler := &defaultOpenAIAccountScheduler{}
+ cfg0 := nilScheduler.softmaxConfigRead()
+ require.False(t, cfg0.enabled)
+ require.Equal(t, 0.0, cfg0.temperature) // no default when service is nil
+
+ // Case 2: zero temperature (unset) — should default to 0.3.
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{},
+ }
+ svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = true
+ svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0 // unset
+ scheduler := &defaultOpenAIAccountScheduler{
+ service: svc,
+ stats: newOpenAIAccountRuntimeStats(),
+ }
+ cfg1 := scheduler.softmaxConfigRead()
+ require.True(t, cfg1.enabled)
+ require.Equal(t, 0.3, cfg1.temperature, "default temperature should be 0.3")
+
+ // Case 3: explicit temperature — should use the configured value.
+ svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0.7
+ cfg2 := scheduler.softmaxConfigRead()
+ require.True(t, cfg2.enabled)
+ require.Equal(t, 0.7, cfg2.temperature, "should use explicitly configured temperature")
+
+ // Case 4: negative temperature — should fall back to default 0.3.
+ svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = -1.0
+ cfg3 := scheduler.softmaxConfigRead()
+ require.Equal(t, 0.3, cfg3.temperature, "negative temperature should fall back to default 0.3")
+}
+
+// ---------------------------------------------------------------------------
+// Per-Model TTFT Tests
+// ---------------------------------------------------------------------------
+
+func TestPerModelTTFT_IndependentTracking(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8001)
+
+ // Report different TTFT values for two different models on the same account.
+ stats.report(accountID, true, nil, "gpt-4o", 100)
+ stats.report(accountID, true, nil, "gpt-4o", 120)
+ stats.report(accountID, true, nil, "o3-pro", 500)
+ stats.report(accountID, true, nil, "o3-pro", 600)
+
+ // Snapshot for model gpt-4o.
+ _, ttftGPT4o, hasTTFT := stats.snapshot(accountID, "gpt-4o")
+ require.True(t, hasTTFT, "gpt-4o should have TTFT data")
+
+ // Snapshot for model o3-pro.
+ _, ttftO3Pro, hasO3Pro := stats.snapshot(accountID, "o3-pro")
+ require.True(t, hasO3Pro, "o3-pro should have TTFT data")
+
+ // The two models should have different TTFT values because their
+ // sample inputs are very different (100-120 vs 500-600).
+ require.Greater(t, math.Abs(ttftGPT4o-ttftO3Pro), 50.0,
+ "different models should track independent TTFT values")
+
+ // gpt-4o TTFT should be much lower than o3-pro.
+ require.Less(t, ttftGPT4o, ttftO3Pro,
+ "gpt-4o should have lower TTFT than o3-pro")
+}
+
+func TestPerModelTTFT_FallbackToGlobal(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8002)
+
+ // Report TTFT for a specific model.
+ stats.report(accountID, true, nil, "gpt-4o", 200)
+
+ // Snapshot with an unknown model should fall back to global TTFT.
+ _, ttftUnknown, hasUnknown := stats.snapshot(accountID, "unknown-model")
+ require.True(t, hasUnknown, "should fall back to global TTFT")
+
+ // Global TTFT should have been updated by the gpt-4o report.
+ _, ttftGlobal, hasGlobal := stats.snapshot(accountID)
+ require.True(t, hasGlobal, "global TTFT should exist")
+
+ // The unknown-model fallback should equal the global.
+ require.InDelta(t, ttftGlobal, ttftUnknown, 1e-9,
+ "unknown model should return global TTFT as fallback")
+}
+
+func TestPerModelTTFT_GlobalAlsoUpdated(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8003)
+
+ // No data initially.
+ _, _, hasGlobal := stats.snapshot(accountID)
+ require.False(t, hasGlobal, "no global TTFT initially")
+
+ // Report with model — should also update global.
+ stats.report(accountID, true, nil, "gpt-4o", 300)
+
+ _, ttftGlobal, hasGlobal := stats.snapshot(accountID)
+ require.True(t, hasGlobal, "global TTFT should exist after model report")
+ require.InDelta(t, 300.0, ttftGlobal, 1e-9,
+ "global TTFT should be updated by model report")
+}
+
+func TestPerModelTTFT_TTLCleanup(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ stat.ttft.initNaN()
+
+ // Manually insert a model entry with an old timestamp.
+ d := &dualEWMATTFT{}
+ d.initNaN()
+ d.update(100)
+ stat.modelTTFT.Store("old-model", d)
+ stat.modelTTFTLastUpdate.Store("old-model", time.Now().Add(-time.Hour).UnixNano())
+
+ // Insert a recent model entry.
+ d2 := &dualEWMATTFT{}
+ d2.initNaN()
+ d2.update(200)
+ stat.modelTTFT.Store("new-model", d2)
+ stat.modelTTFTLastUpdate.Store("new-model", time.Now().UnixNano())
+
+ // Cleanup with 30-minute TTL — old-model should be removed.
+ stat.cleanupStaleTTFT(30*time.Minute, 100)
+
+ _, hasOld := stat.modelTTFTValue("old-model")
+ require.False(t, hasOld, "old-model should be cleaned up")
+
+ _, hasNew := stat.modelTTFTValue("new-model")
+ require.True(t, hasNew, "new-model should survive cleanup")
+}
+
+func TestPerModelTTFT_MaxModelLimit(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ stat.ttft.initNaN()
+
+ now := time.Now()
+ // Insert 10 models with sequential timestamps.
+ for i := 0; i < 10; i++ {
+ model := fmt.Sprintf("model-%d", i)
+ d := &dualEWMATTFT{}
+ d.initNaN()
+ d.update(float64(100 + i*10))
+ stat.modelTTFT.Store(model, d)
+ stat.modelTTFTLastUpdate.Store(model, now.Add(time.Duration(i)*time.Second).UnixNano())
+ }
+
+ // Enforce limit of 5 models — the 5 oldest should be evicted.
+ stat.cleanupStaleTTFT(time.Hour, 5)
+
+ // Count remaining models.
+ remaining := 0
+ stat.modelTTFT.Range(func(_, _ any) bool {
+ remaining++
+ return true
+ })
+ require.Equal(t, 5, remaining, "should have exactly 5 models after cleanup")
+
+ // The newest 5 (model-5 through model-9) should survive.
+ for i := 5; i < 10; i++ {
+ model := fmt.Sprintf("model-%d", i)
+ _, has := stat.modelTTFTValue(model)
+ require.True(t, has, "%s should survive", model)
+ }
+ // The oldest 5 (model-0 through model-4) should be evicted.
+ for i := 0; i < 5; i++ {
+ model := fmt.Sprintf("model-%d", i)
+ _, has := stat.modelTTFTValue(model)
+ require.False(t, has, "%s should be evicted", model)
+ }
+}
+
+func TestPerModelTTFT_SnapshotUsesModelData(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8006)
+
+ // Report two models with very different TTFT.
+ for i := 0; i < 5; i++ {
+ stats.report(accountID, true, nil, "fast-model", 50)
+ stats.report(accountID, true, nil, "slow-model", 500)
+ }
+
+ // Snapshot with specific model should return that model's TTFT.
+ _, ttftFast, hasFast := stats.snapshot(accountID, "fast-model")
+ require.True(t, hasFast)
+
+ _, ttftSlow, hasSlow := stats.snapshot(accountID, "slow-model")
+ require.True(t, hasSlow)
+
+ // Fast model should have much lower TTFT.
+ require.Less(t, ttftFast, 100.0, "fast-model TTFT should be close to 50")
+ require.Greater(t, ttftSlow, 400.0, "slow-model TTFT should be close to 500")
+
+ // Global TTFT should be a blend of both.
+ _, ttftGlobal, hasGlobal := stats.snapshot(accountID)
+ require.True(t, hasGlobal)
+ require.Greater(t, ttftGlobal, ttftFast, "global TTFT should be higher than fast-model")
+ require.Less(t, ttftGlobal, ttftSlow, "global TTFT should be lower than slow-model")
+}
+
+func TestPerModelTTFT_ConcurrentAccess(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8007)
+
+ const workers = 8
+ const iterations = 200
+ models := []string{"model-a", "model-b", "model-c", "model-d"}
+
+ var wg sync.WaitGroup
+ wg.Add(workers)
+ for w := 0; w < workers; w++ {
+ w := w
+ go func() {
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ model := models[(w+i)%len(models)]
+ ttft := float64(100 + (w*10+i)%200)
+ stats.report(accountID, true, nil, model, ttft)
+
+ // Also read concurrently.
+ stats.snapshot(accountID, model)
+ stats.snapshot(accountID)
+ }
+ }()
+ }
+ wg.Wait()
+
+ // All models should have TTFT data.
+ for _, model := range models {
+ _, ttft, has := stats.snapshot(accountID, model)
+ require.True(t, has, "%s should have TTFT", model)
+ require.Greater(t, ttft, 0.0, "%s TTFT should be positive", model)
+ }
+
+ // Global should also have data.
+ _, ttftGlobal, hasGlobal := stats.snapshot(accountID)
+ require.True(t, hasGlobal)
+ require.Greater(t, ttftGlobal, 0.0)
+}
+
+func TestPerModelTTFT_EmptyModel(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8008)
+
+ // Report with empty model — should only update global TTFT.
+ ttft := 150
+ stats.report(accountID, true, &ttft, "", 0)
+
+ // Global should have data.
+ _, ttftGlobal, hasGlobal := stats.snapshot(accountID)
+ require.True(t, hasGlobal, "global TTFT should exist from firstTokenMs")
+ require.InDelta(t, 150.0, ttftGlobal, 1e-9)
+
+ // Snapshot with empty model returns global.
+ _, ttftEmpty, hasEmpty := stats.snapshot(accountID, "")
+ require.True(t, hasEmpty)
+ require.InDelta(t, ttftGlobal, ttftEmpty, 1e-9,
+ "empty model snapshot should return global TTFT")
+
+ // No per-model entries should exist.
+ stat := stats.loadOrCreate(accountID)
+ count := 0
+ stat.modelTTFT.Range(func(_, _ any) bool {
+ count++
+ return true
+ })
+ require.Equal(t, 0, count, "no per-model entries should exist for empty model")
+}
+
+// ---------------------------------------------------------------------------
+// Load Trend Prediction Tests
+// ---------------------------------------------------------------------------
+
+func TestLoadTrend_RisingLoad(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 10; i++ {
+ tracker.recordAt(float64((i+1)*10), base+int64(i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.Greater(t, slope, 0.0, "rising load should produce positive slope")
+ require.InDelta(t, 10.0, slope, 0.01, "slope should be ~10 per second")
+}
+
+func TestLoadTrend_FallingLoad(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 10; i++ {
+ tracker.recordAt(float64(100-i*10), base+int64(i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.Less(t, slope, 0.0, "falling load should produce negative slope")
+ require.InDelta(t, -10.0, slope, 0.01, "slope should be ~-10 per second")
+}
+
+func TestLoadTrend_StableLoad(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 10; i++ {
+ tracker.recordAt(50.0, base+int64(i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.InDelta(t, 0.0, slope, 1e-9, "constant load should produce zero slope")
+}
+
+func TestLoadTrend_RingBufferFull(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 5; i++ {
+ tracker.recordAt(100.0, base+int64(i)*int64(time.Second))
+ }
+ for i := 0; i < 10; i++ {
+ tracker.recordAt(float64((i+1)*10), base+int64(5+i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.Greater(t, slope, 0.0, "should reflect rising trend from last 10 samples")
+ require.InDelta(t, 10.0, slope, 0.01, "slope should be ~10 per second after ring wraps")
+}
+
+func TestLoadTrend_InsufficientSamples(t *testing.T) {
+ var tracker loadTrendTracker
+ slope := tracker.slope()
+ require.Equal(t, 0.0, slope, "zero samples should return slope 0")
+}
+
+func TestLoadTrend_SingleSample(t *testing.T) {
+ var tracker loadTrendTracker
+ tracker.recordAt(42.0, time.Now().UnixNano())
+ slope := tracker.slope()
+ require.Equal(t, 0.0, slope, "single sample should return slope 0")
+}
+
+func TestLoadTrend_TwoSamples(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ tracker.recordAt(10.0, base)
+ tracker.recordAt(30.0, base+int64(2*time.Second))
+ slope := tracker.slope()
+ require.InDelta(t, 10.0, slope, 0.01, "two-sample slope should be exact delta/time")
+}
+
+func TestLoadTrend_AllSameTimestamp(t *testing.T) {
+ var tracker loadTrendTracker
+ ts := time.Now().UnixNano()
+ for i := 0; i < 5; i++ {
+ tracker.recordAt(float64(i*10), ts)
+ }
+ slope := tracker.slope()
+ require.Equal(t, 0.0, slope, "all-same-timestamp should return slope 0 (degenerate)")
+}
+
+func TestLoadTrend_NegativeSlope(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 5; i++ {
+ tracker.recordAt(50.0-float64(i)*5.0, base+int64(i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.InDelta(t, -5.0, slope, 0.01)
+}
+
+func TestLoadTrend_ScoringIntegration(t *testing.T) {
+ accounts := []Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ }
+
+ loadMap := map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 50},
+ 2: {AccountID: 2, LoadRate: 50},
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 5.0
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{
+ loadMap: loadMap,
+ skipDefaultLoad: true,
+ }),
+ }
+
+ stats := newOpenAIAccountRuntimeStats()
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+
+ base := time.Now().UnixNano() - int64(10*time.Second)
+ stat1 := stats.loadOrCreate(1)
+ for i := 0; i < 9; i++ {
+ stat1.loadTrend.recordAt(float64((i+1)*10), base+int64(i)*int64(time.Second))
+ }
+ stat2 := stats.loadOrCreate(2)
+ for i := 0; i < 9; i++ {
+ stat2.loadTrend.recordAt(50.0, base+int64(i)*int64(time.Second))
+ }
+
+ ctx := context.Background()
+ selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{
+ RequiredTransport: OpenAIUpstreamTransportAny,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+
+ slope1 := stat1.loadTrend.slope()
+ slope2 := stat2.loadTrend.slope()
+ require.Greater(t, slope1, slope2,
+ "rising-trend account should have higher slope than stable account; slope1=%f slope2=%f", slope1, slope2)
+ require.Greater(t, slope1, 0.0, "rising-trend account slope should be positive")
+ require.InDelta(t, 0.0, slope2, 1.0,
+ "stable-trend account slope should be near zero; got %f", slope2)
+
+ trendAdj1 := 1.0 - clamp01(slope1/5.0)
+ trendAdj2 := 1.0 - clamp01(slope2/5.0)
+ require.Less(t, trendAdj1, trendAdj2,
+ "rising-trend trendAdj should be less than stable trendAdj; adj1=%f adj2=%f", trendAdj1, trendAdj2)
+}
+
+func TestLoadTrend_DisabledByDefault(t *testing.T) {
+ accounts := []Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ stats := newOpenAIAccountRuntimeStats()
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+
+ base := time.Now().UnixNano()
+ stat1 := stats.loadOrCreate(1)
+ for i := 0; i < 10; i++ {
+ stat1.loadTrend.recordAt(float64(i*10), base+int64(i)*int64(time.Second))
+ }
+ stat2 := stats.loadOrCreate(2)
+ for i := 0; i < 10; i++ {
+ stat2.loadTrend.recordAt(30.0, base+int64(i)*int64(time.Second))
+ }
+
+ ctx := context.Background()
+ selectedCounts := map[int64]int{}
+ const rounds = 100
+ for r := 0; r < rounds; r++ {
+ selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{
+ RequiredTransport: OpenAIUpstreamTransportAny,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ selectedCounts[selection.Account.ID]++
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ require.Greater(t, selectedCounts[int64(1)], 10,
+ "with trend disabled, account 1 should still get selections; got %d", selectedCounts[int64(1)])
+ require.Greater(t, selectedCounts[int64(2)], 10,
+ "with trend disabled, account 2 should still get selections; got %d", selectedCounts[int64(2)])
+}
+
+func TestLoadTrend_ConcurrentAccess(t *testing.T) {
+ var tracker loadTrendTracker
+ var wg sync.WaitGroup
+ const goroutines = 10
+ const recordsPerGoroutine = 100
+
+ for g := 0; g < goroutines; g++ {
+ wg.Add(1)
+ go func(id int) {
+ defer wg.Done()
+ for i := 0; i < recordsPerGoroutine; i++ {
+ tracker.record(float64(id*100 + i))
+ }
+ }(g)
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < goroutines*recordsPerGoroutine; i++ {
+ _ = tracker.slope()
+ }
+ }()
+
+ wg.Wait()
+
+ slope := tracker.slope()
+ require.False(t, math.IsNaN(slope), "slope should be finite after concurrent access")
+ require.False(t, math.IsInf(slope, 0), "slope should be finite after concurrent access")
+}
+
+func TestLoadTrend_TrendConfigDefaults(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ enabled, maxSlope := svc.openAIWSSchedulerTrendConfig()
+ require.False(t, enabled, "trend should be disabled by default")
+ require.Equal(t, defaultSchedulerTrendMaxSlope, maxSlope, "maxSlope should use default")
+}
+
+func TestLoadTrend_TrendConfigCustom(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 8.0
+ svc := &OpenAIGatewayService{cfg: cfg}
+ enabled, maxSlope := svc.openAIWSSchedulerTrendConfig()
+ require.True(t, enabled)
+ require.Equal(t, 8.0, maxSlope)
+}
+
+func TestLoadTrend_TrendConfigZeroMaxSlope(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 0
+ svc := &OpenAIGatewayService{cfg: cfg}
+ enabled, maxSlope := svc.openAIWSSchedulerTrendConfig()
+ require.True(t, enabled)
+ require.Equal(t, defaultSchedulerTrendMaxSlope, maxSlope)
+}
+
+func TestLoadTrend_RingBufferCountTracking(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+
+ for i := 0; i < loadTrendRingSize; i++ {
+ tracker.recordAt(float64(i), base+int64(i)*int64(time.Second))
+ }
+ require.Equal(t, loadTrendRingSize, tracker.count, "count should equal ring size after filling")
+
+ tracker.recordAt(99.0, base+int64(loadTrendRingSize)*int64(time.Second))
+ require.Equal(t, loadTrendRingSize, tracker.count, "count should remain capped at ring size")
+}
+
+func TestLoadTrend_GentleRise(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 10; i++ {
+ tracker.recordAt(50.0+float64(i)*0.1, base+int64(i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.Greater(t, slope, 0.0, "gentle rise should produce positive slope")
+ require.InDelta(t, 0.1, slope, 0.01)
+}
+
+func TestLoadTrend_RecordUpdatesRuntimeStat(t *testing.T) {
+ accounts := []Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ stats := newOpenAIAccountRuntimeStats()
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+
+ ctx := context.Background()
+ for i := 0; i < 5; i++ {
+ selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{
+ RequiredTransport: OpenAIUpstreamTransportAny,
+ })
+ require.NoError(t, err)
+ if selection != nil && selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ stat := stats.loadOrCreate(1)
+ require.GreaterOrEqual(t, stat.loadTrend.count, 5,
+ "trend tracker should have received samples from scoring loop")
+}
+
+func TestLoadTrend_FallingTrendBoostsScore(t *testing.T) {
+ accounts := []Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 5.0
+
+ loadMap := map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 50},
+ 2: {AccountID: 2, LoadRate: 50},
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{
+ loadMap: loadMap,
+ skipDefaultLoad: true,
+ }),
+ }
+
+ stats := newOpenAIAccountRuntimeStats()
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+
+ base := time.Now().UnixNano()
+ stat1 := stats.loadOrCreate(1)
+ for i := 0; i < 10; i++ {
+ stat1.loadTrend.recordAt(80.0-float64(i)*5.0, base+int64(i)*int64(time.Second))
+ }
+ stat2 := stats.loadOrCreate(2)
+ for i := 0; i < 10; i++ {
+ stat2.loadTrend.recordAt(50.0, base+int64(i)*int64(time.Second))
+ }
+
+ ctx := context.Background()
+ selectedCounts := map[int64]int{}
+ const rounds = 50
+ for r := 0; r < rounds; r++ {
+ selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{
+ RequiredTransport: OpenAIUpstreamTransportAny,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ selectedCounts[selection.Account.ID]++
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ require.Greater(t, selectedCounts[int64(1)]+selectedCounts[int64(2)], 0,
+ "both accounts should receive selections")
+}
+
+// ---------------------------------------------------------------------------
+// Circuit Breaker Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestCircuitBreaker_AllowClosed(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ require.True(t, cb.allow(30*time.Second, 2), "CLOSED state should allow requests")
+}
+
+func TestCircuitBreaker_AllowOpenWithinCooldown(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateOpen)
+ cb.lastFailureNano.Store(time.Now().UnixNano())
+ require.False(t, cb.allow(30*time.Second, 2), "OPEN within cooldown should deny")
+}
+
+func TestCircuitBreaker_AllowOpenAfterCooldown(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateOpen)
+ cb.lastFailureNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
+ require.True(t, cb.allow(30*time.Second, 2), "OPEN after cooldown should transition to HALF_OPEN and allow")
+ require.Equal(t, circuitBreakerStateHalfOpen, cb.state.Load())
+}
+
+func TestCircuitBreaker_AllowHalfOpenLimited(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ require.True(t, cb.allowHalfOpen(2))
+ require.True(t, cb.allowHalfOpen(2))
+ require.False(t, cb.allowHalfOpen(2), "should deny when in-flight reaches max")
+}
+
+func TestCircuitBreaker_AllowHalfOpenViaAllow(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ require.True(t, cb.allow(30*time.Second, 1))
+ require.False(t, cb.allow(30*time.Second, 1), "HALF_OPEN with max=1 should deny second")
+}
+
+func TestCircuitBreaker_AllowDefaultState(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(99) // unknown state
+ require.True(t, cb.allow(30*time.Second, 2), "unknown state should default to allow")
+}
+
+func TestCircuitBreaker_RecordSuccessClosed(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.consecutiveFails.Store(3)
+ cb.recordSuccess()
+ require.Equal(t, int32(0), cb.consecutiveFails.Load())
+ require.Equal(t, circuitBreakerStateClosed, cb.state.Load())
+}
+
+func TestCircuitBreaker_RecordSuccessHalfOpenToClose(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ cb.halfOpenInFlight.Store(1)
+ cb.halfOpenAdmitted.Store(1)
+ cb.halfOpenSuccess.Store(0)
+ cb.recordSuccess()
+ require.Equal(t, circuitBreakerStateClosed, cb.state.Load(), "all probes succeeded should close")
+ require.Equal(t, int32(0), cb.halfOpenInFlight.Load())
+}
+
+func TestCircuitBreaker_RecordSuccessHalfOpenPartial(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ cb.halfOpenInFlight.Store(3)
+ cb.halfOpenAdmitted.Store(3)
+ cb.halfOpenSuccess.Store(0)
+ cb.recordSuccess()
+ require.Equal(t, circuitBreakerStateHalfOpen, cb.state.Load(), "not all probes succeeded yet")
+}
+
+func TestCircuitBreaker_RecordFailureTripsOpen(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ for i := 0; i < 4; i++ {
+ cb.recordFailure(5)
+ }
+ require.Equal(t, circuitBreakerStateClosed, cb.state.Load())
+ cb.recordFailure(5)
+ require.Equal(t, circuitBreakerStateOpen, cb.state.Load(), "5th failure should trip to OPEN")
+}
+
+func TestCircuitBreaker_RecordFailureHalfOpenReverts(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ cb.halfOpenInFlight.Store(2)
+ cb.recordFailure(5)
+ require.Equal(t, circuitBreakerStateOpen, cb.state.Load(), "failure in HALF_OPEN should revert to OPEN")
+ require.Equal(t, int32(0), cb.halfOpenInFlight.Load())
+}
+
+func TestCircuitBreaker_IsHalfOpen(t *testing.T) {
+ var cb *accountCircuitBreaker
+ require.False(t, cb.isHalfOpen(), "nil should return false")
+
+ cb = &accountCircuitBreaker{}
+ require.False(t, cb.isHalfOpen())
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ require.True(t, cb.isHalfOpen())
+}
+
+func TestCircuitBreaker_ReleaseHalfOpenPermit(t *testing.T) {
+ var cb *accountCircuitBreaker
+ cb.releaseHalfOpenPermit() // should not panic
+
+ cb = &accountCircuitBreaker{}
+ cb.releaseHalfOpenPermit() // not in HALF_OPEN, should be no-op
+
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ cb.halfOpenInFlight.Store(2)
+ cb.releaseHalfOpenPermit()
+ require.Equal(t, int32(1), cb.halfOpenInFlight.Load())
+
+ cb.halfOpenInFlight.Store(0)
+ cb.releaseHalfOpenPermit() // already at 0, should be no-op
+ require.Equal(t, int32(0), cb.halfOpenInFlight.Load())
+}
+
+func TestCircuitBreaker_StateString(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ require.Equal(t, "CLOSED", cb.stateString())
+ cb.state.Store(circuitBreakerStateOpen)
+ require.Equal(t, "OPEN", cb.stateString())
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+ cb.state.Store(99)
+ require.Equal(t, "UNKNOWN", cb.stateString())
+}
+
+func TestCircuitBreaker_IsOpen(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ require.False(t, cb.isOpen())
+ cb.state.Store(circuitBreakerStateOpen)
+ require.True(t, cb.isOpen())
+}
+
+func TestCircuitBreaker_FullLifecycle(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ threshold := 3
+ cooldown := 50 * time.Millisecond
+
+ // CLOSED: allow requests
+ require.True(t, cb.allow(cooldown, 2))
+ require.Equal(t, "CLOSED", cb.stateString())
+
+ // Trip to OPEN
+ for i := 0; i < threshold; i++ {
+ cb.recordFailure(threshold)
+ }
+ require.Equal(t, "OPEN", cb.stateString())
+ require.False(t, cb.allow(cooldown, 2), "should deny in OPEN within cooldown")
+
+ // Wait for cooldown
+ time.Sleep(cooldown + 10*time.Millisecond)
+
+ // Should transition to HALF_OPEN
+ require.True(t, cb.allow(cooldown, 2))
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+
+ // Success should close
+ cb.recordSuccess()
+ require.Equal(t, "CLOSED", cb.stateString())
+}
+
+// ---------------------------------------------------------------------------
+// dualEWMATTFT Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestDualEWMATTFT_InitNaN(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+ require.True(t, math.IsNaN(d.fastValue()))
+ require.True(t, math.IsNaN(d.slowValue()))
+ _, hasTTFT := d.value()
+ require.False(t, hasTTFT, "NaN-initialized should return hasTTFT=false")
+}
+
+func TestDualEWMATTFT_UpdateFromNaN(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+ d.update(100.0)
+ v, ok := d.value()
+ require.True(t, ok)
+ require.InDelta(t, 100.0, v, 0.01, "first update should set sample directly")
+}
+
+func TestDualEWMATTFT_UpdateMultiple(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+ for i := 0; i < 20; i++ {
+ d.update(200.0)
+ }
+ v, ok := d.value()
+ require.True(t, ok)
+ require.InDelta(t, 200.0, v, 1.0, "after many updates of same value, should converge")
+}
+
+func TestDualEWMATTFT_ValueFastOnly(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+ // Set fast only, slow stays NaN
+ d.fastBits.Store(math.Float64bits(42.0))
+ v, ok := d.value()
+ require.True(t, ok)
+ require.Equal(t, 42.0, v)
+}
+
+func TestDualEWMATTFT_ValueSlowOnly(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+ // Set slow only, fast stays NaN
+ d.slowBits.Store(math.Float64bits(55.0))
+ v, ok := d.value()
+ require.True(t, ok)
+ require.Equal(t, 55.0, v)
+}
+
+func TestDualEWMATTFT_ValueSlowGreaterThanFast(t *testing.T) {
+ var d dualEWMATTFT
+ d.fastBits.Store(math.Float64bits(30.0))
+ d.slowBits.Store(math.Float64bits(50.0))
+ v, ok := d.value()
+ require.True(t, ok)
+ require.Equal(t, 50.0, v, "pessimistic value should return max(fast, slow)")
+}
+
+func TestDualEWMATTFT_ValueFastGreaterThanSlow(t *testing.T) {
+ var d dualEWMATTFT
+ d.fastBits.Store(math.Float64bits(80.0))
+ d.slowBits.Store(math.Float64bits(50.0))
+ v, ok := d.value()
+ require.True(t, ok)
+ require.Equal(t, 80.0, v, "pessimistic value should return max(fast, slow)")
+}
+
+// ---------------------------------------------------------------------------
+// Softmax Additional Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestSoftmax_EmptyCandidates(t *testing.T) {
+ rng := newOpenAISelectionRNG(42)
+ result := selectSoftmaxOpenAICandidates(nil, 0.3, &rng)
+ require.Nil(t, result)
+}
+
+func TestSoftmax_ZeroTemperatureFallsToDefault(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1}, score: 0.9},
+ {account: &Account{ID: 2}, score: 0.1},
+ {account: &Account{ID: 3}, score: 0.5},
+ {account: &Account{ID: 4}, score: 0.3},
+ }
+ rng := newOpenAISelectionRNG(42)
+ result := selectSoftmaxOpenAICandidates(candidates, 0, &rng)
+ require.Len(t, result, 4)
+}
+
+func TestSoftmax_NaNScoresUniform(t *testing.T) {
+ // Extreme negative scores that cause exp() to return 0 → uniform fallback
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1}, score: -1e308},
+ {account: &Account{ID: 2}, score: -1e308},
+ {account: &Account{ID: 3}, score: -1e308},
+ {account: &Account{ID: 4}, score: -1e308},
+ }
+ rng := newOpenAISelectionRNG(42)
+ result := selectSoftmaxOpenAICandidates(candidates, 0.001, &rng)
+ require.Len(t, result, 4)
+}
+
+// ---------------------------------------------------------------------------
+// Snapshot / Stats Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestSnapshot_NilStats(t *testing.T) {
+ var s *openAIAccountRuntimeStats
+ errorRate, ttft, hasTTFT := s.snapshot(1)
+ require.Equal(t, 0.0, errorRate)
+ require.Equal(t, 0.0, ttft)
+ require.False(t, hasTTFT)
+}
+
+func TestSnapshot_InvalidAccountID(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ errorRate, ttft, hasTTFT := s.snapshot(0)
+ require.Equal(t, 0.0, errorRate)
+ require.Equal(t, 0.0, ttft)
+ require.False(t, hasTTFT)
+}
+
+func TestSnapshot_UnknownAccount(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ errorRate, ttft, hasTTFT := s.snapshot(999)
+ require.Equal(t, 0.0, errorRate)
+ require.Equal(t, 0.0, ttft)
+ require.False(t, hasTTFT)
+}
+
+func TestSnapshot_WithModelFallback(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ stat := s.loadOrCreate(1)
+ // Set global TTFT
+ stat.ttft.update(100.0)
+ // Snapshot with unknown model should fallback to global
+ _, ttft, hasTTFT := s.snapshot(1, "unknown-model")
+ require.True(t, hasTTFT)
+ require.InDelta(t, 100.0, ttft, 0.01)
+}
+
+func TestSnapshot_WithModelSpecific(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ stat := s.loadOrCreate(1)
+ stat.reportModelTTFT("gpt-4", 200.0)
+ stat.ttft.update(50.0) // global is different
+ _, ttft, hasTTFT := s.snapshot(1, "gpt-4")
+ require.True(t, hasTTFT)
+ require.InDelta(t, 200.0, ttft, 0.01, "should use per-model TTFT")
+}
+
+func TestStatsSize_Nil(t *testing.T) {
+ var s *openAIAccountRuntimeStats
+ require.Equal(t, 0, s.size())
+}
+
+func TestStatsSize_Empty(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ require.Equal(t, 0, s.size())
+}
+
+func TestStatsSize_WithAccounts(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ s.loadOrCreate(1)
+ s.loadOrCreate(2)
+ require.Equal(t, 2, s.size())
+}
+
+func TestLoadOrCreate_ConcurrentSameID(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ var wg sync.WaitGroup
+ results := make([]*openAIAccountRuntimeStat, 10)
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ results[idx] = s.loadOrCreate(1)
+ }(i)
+ }
+ wg.Wait()
+ // All should return the same pointer
+ for i := 1; i < 10; i++ {
+ require.Same(t, results[0], results[i], "concurrent loadOrCreate should return same stat")
+ }
+}
+
+// ---------------------------------------------------------------------------
+// modelTTFTValue / reportModelTTFT Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestModelTTFTValue_EmptyModel(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ v, ok := stat.modelTTFTValue("")
+ require.False(t, ok)
+ require.Equal(t, 0.0, v)
+}
+
+func TestModelTTFTValue_UnknownModel(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ v, ok := stat.modelTTFTValue("nonexistent")
+ require.False(t, ok)
+ require.Equal(t, 0.0, v)
+}
+
+func TestReportModelTTFT_EmptyModel(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ stat.ttft.initNaN()
+ stat.reportModelTTFT("", 100.0)
+ // Should be no-op: global TTFT not updated
+ _, hasTTFT := stat.ttft.value()
+ require.False(t, hasTTFT)
+}
+
+func TestReportModelTTFT_ZeroSample(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ stat.ttft.initNaN()
+ stat.reportModelTTFT("gpt-4", 0)
+ _, hasTTFT := stat.ttft.value()
+ require.False(t, hasTTFT)
+}
+
+func TestReportModelTTFT_NegativeSample(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ stat.ttft.initNaN()
+ stat.reportModelTTFT("gpt-4", -10.0)
+ _, hasTTFT := stat.ttft.value()
+ require.False(t, hasTTFT)
+}
+
+func TestGetOrCreateModelTTFT_ConcurrentSameModel(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ var wg sync.WaitGroup
+ results := make([]*dualEWMATTFT, 10)
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ results[idx] = stat.getOrCreateModelTTFT("gpt-4")
+ }(i)
+ }
+ wg.Wait()
+ for i := 1; i < 10; i++ {
+ require.Same(t, results[0], results[i])
+ }
+}
+
+// ---------------------------------------------------------------------------
+// schedulerCircuitBreakerConfig Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestSchedulerCircuitBreakerConfig_Defaults(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ stats := newOpenAIAccountRuntimeStats()
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+ enabled, threshold, cooldown, halfOpenMax := scheduler.schedulerCircuitBreakerConfig()
+ require.False(t, enabled)
+ require.Equal(t, defaultCircuitBreakerFailThreshold, threshold)
+ require.Equal(t, time.Duration(defaultCircuitBreakerCooldownSec)*time.Second, cooldown)
+ require.Equal(t, defaultCircuitBreakerHalfOpenMax, halfOpenMax)
+}
+
+func TestSchedulerCircuitBreakerConfig_Custom(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 10
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 5
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+ enabled, threshold, cooldown, halfOpenMax := scheduler.schedulerCircuitBreakerConfig()
+ require.True(t, enabled)
+ require.Equal(t, 10, threshold)
+ require.Equal(t, 60*time.Second, cooldown)
+ require.Equal(t, 5, halfOpenMax)
+}
+
+func TestSchedulerPerModelTTFTConfig_Defaults(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ stats := newOpenAIAccountRuntimeStats()
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+ enabled, maxModels := scheduler.schedulerPerModelTTFTConfig()
+ require.False(t, enabled)
+ require.Equal(t, defaultPerModelTTFTMaxModels, maxModels)
+}
+
+func TestSchedulerPerModelTTFTConfig_Custom(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTMaxModels = 64
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+ enabled, maxModels := scheduler.schedulerPerModelTTFTConfig()
+ require.True(t, enabled)
+ require.Equal(t, 64, maxModels)
+}
+
+func TestReportResult_PerModelTTFTDisabled_NoPerModelTrackerCreated(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = false
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+ ttft := 120
+
+ scheduler.ReportResult(7001, true, &ttft, "gpt-5.1", 120)
+
+ stat := scheduler.stats.loadExisting(7001)
+ require.NotNil(t, stat)
+ count := 0
+ stat.modelTTFT.Range(func(_, _ any) bool {
+ count++
+ return true
+ })
+ require.Equal(t, 0, count, "per-model ttft should remain disabled")
+ _, globalTTFT, hasTTFT := scheduler.stats.snapshot(7001)
+ require.True(t, hasTTFT)
+ require.InDelta(t, 120.0, globalTTFT, 0.01)
+}
+
+func TestReportResult_PerModelTTFTMaxModels_UsesConfigLimit(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTMaxModels = 2
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats)
+ ttft := 100
+
+ models := []string{"gpt-5.1", "gpt-4o", "o3", "o4-mini"}
+ for i := 0; i < 200; i++ {
+ model := models[i%len(models)]
+ scheduler.ReportResult(7002, true, &ttft, model, float64(ttft+i))
+ }
+
+ stat := scheduler.stats.loadExisting(7002)
+ require.NotNil(t, stat)
+ count := 0
+ stat.modelTTFT.Range(func(_, _ any) bool {
+ count++
+ return true
+ })
+ require.LessOrEqual(t, count, 2, "model tracker count should honor scheduler_per_model_ttft_max_models")
+}
+
+// ---------------------------------------------------------------------------
+// P2C Edge Case Coverage
+// ---------------------------------------------------------------------------
+
+func TestSelectP2C_EmptyCandidates(t *testing.T) {
+ result := selectP2COpenAICandidates(nil, OpenAIAccountScheduleRequest{})
+ require.Nil(t, result)
+}
+
+func TestSelectP2C_SingleCandidate(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1}, score: 0.5},
+ }
+ result := selectP2COpenAICandidates(candidates, OpenAIAccountScheduleRequest{})
+ require.Len(t, result, 1)
+ require.Equal(t, int64(1), result[0].account.ID)
+}
+
+func TestSelectP2C_TwoCandidatesPicksBetter(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1}, score: 0.2},
+ {account: &Account{ID: 2}, score: 0.8},
+ }
+ result := selectP2COpenAICandidates(candidates, OpenAIAccountScheduleRequest{})
+ require.Len(t, result, 2)
+ require.Equal(t, int64(2), result[0].account.ID, "first should be higher-scored")
+}
+
+// ---------------------------------------------------------------------------
+// TopK Selection & Heap Coverage
+// ---------------------------------------------------------------------------
+
+func TestSelectTopK_Empty(t *testing.T) {
+ result := selectTopKOpenAICandidates(nil, 3)
+ require.Nil(t, result)
+}
+
+func TestSelectTopK_TopKZero(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}},
+ }
+ result := selectTopKOpenAICandidates(candidates, 0)
+ require.Len(t, result, 1, "topK=0 should default to 1")
+}
+
+func TestSelectTopK_TopKExceedsCandidates(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1, Priority: 1}, score: 0.3, loadInfo: &AccountLoadInfo{}},
+ {account: &Account{ID: 2, Priority: 1}, score: 0.9, loadInfo: &AccountLoadInfo{}},
+ }
+ result := selectTopKOpenAICandidates(candidates, 10)
+ require.Len(t, result, 2)
+ require.Equal(t, int64(2), result[0].account.ID, "highest score first")
+}
+
+func TestSelectTopK_ProperFiltering(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1, Priority: 1}, score: 0.1, loadInfo: &AccountLoadInfo{}},
+ {account: &Account{ID: 2, Priority: 1}, score: 0.9, loadInfo: &AccountLoadInfo{}},
+ {account: &Account{ID: 3, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}},
+ {account: &Account{ID: 4, Priority: 1}, score: 0.3, loadInfo: &AccountLoadInfo{}},
+ {account: &Account{ID: 5, Priority: 1}, score: 0.7, loadInfo: &AccountLoadInfo{}},
+ }
+ result := selectTopKOpenAICandidates(candidates, 3)
+ require.Len(t, result, 3)
+ require.Equal(t, int64(2), result[0].account.ID)
+ require.Equal(t, int64(5), result[1].account.ID)
+ require.Equal(t, int64(3), result[2].account.ID)
+}
+
+func TestIsOpenAIAccountCandidateBetter_AllTiebreakers(t *testing.T) {
+ // Equal scores, different priority
+ a := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}}
+ b := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 2}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}}
+ require.True(t, isOpenAIAccountCandidateBetter(a, b), "lower priority number = better")
+ require.False(t, isOpenAIAccountCandidateBetter(b, a))
+
+ // Equal scores and priority, different load rate
+ c := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 5}}
+ d := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 60, WaitingCount: 5}}
+ require.True(t, isOpenAIAccountCandidateBetter(c, d), "lower load rate = better")
+
+ // Equal everything except waiting count
+ e := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2}}
+ f := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 8}}
+ require.True(t, isOpenAIAccountCandidateBetter(e, f), "lower waiting count = better")
+
+ // Equal everything except ID
+ g := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}}
+ h := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}}
+ require.True(t, isOpenAIAccountCandidateBetter(g, h), "lower ID = better")
+}
+
+// ---------------------------------------------------------------------------
+// shouldReleaseStickySession Coverage
+// ---------------------------------------------------------------------------
+
+func TestShouldReleaseStickySession_NilScheduler(t *testing.T) {
+ var s *defaultOpenAIAccountScheduler
+ require.False(t, s.shouldReleaseStickySession(1))
+}
+
+func TestShouldReleaseStickySession_Disabled(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickyReleaseEnabled = false
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats)
+ require.False(t, scheduler.shouldReleaseStickySession(1))
+}
+
+func TestShouldReleaseStickySession_CircuitOpen(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ // Trip the circuit breaker
+ cb := stats.getCircuitBreaker(1)
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats)
+ require.True(t, scheduler.shouldReleaseStickySession(1), "should release when circuit is open")
+}
+
+func TestShouldReleaseStickySession_HighErrorRate(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ // Report many failures to push error rate above threshold
+ for i := 0; i < 20; i++ {
+ stats.report(1, false, nil, "", 0)
+ }
+ scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats)
+ require.True(t, scheduler.shouldReleaseStickySession(1), "should release when error rate is high")
+}
+
+func TestShouldReleaseStickySession_Healthy(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ // Report successes
+ for i := 0; i < 20; i++ {
+ stats.report(1, true, nil, "", 0)
+ }
+ scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats)
+ require.False(t, scheduler.shouldReleaseStickySession(1), "should not release when healthy")
+}
+
+// ---------------------------------------------------------------------------
+// stickyReleaseConfigRead Coverage
+// ---------------------------------------------------------------------------
+
+func TestStickyReleaseConfigRead_NilConfig(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats)
+ cfg := scheduler.stickyReleaseConfigRead()
+ require.False(t, cfg.enabled)
+ require.Equal(t, 0.0, cfg.errorThreshold, "nil config returns zero-value struct")
+}
+
+func TestStickyReleaseConfigRead_Defaults(t *testing.T) {
+ c := &config.Config{}
+ // StickyReleaseErrorThreshold defaults to 0 → code uses defaultStickyReleaseErrorThreshold
+ svc := &OpenAIGatewayService{cfg: c}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats)
+ cfg := scheduler.stickyReleaseConfigRead()
+ require.False(t, cfg.enabled)
+ require.Equal(t, defaultStickyReleaseErrorThreshold, cfg.errorThreshold)
+}
+
+func TestStickyReleaseConfigRead_Custom(t *testing.T) {
+ c := &config.Config{}
+ c.Gateway.OpenAIWS.StickyReleaseEnabled = true
+ c.Gateway.OpenAIWS.StickyReleaseErrorThreshold = 0.5
+ svc := &OpenAIGatewayService{cfg: c}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats)
+ cfg := scheduler.stickyReleaseConfigRead()
+ require.True(t, cfg.enabled)
+ require.Equal(t, 0.5, cfg.errorThreshold)
+}
+
+// ---------------------------------------------------------------------------
+// RNG Coverage
+// ---------------------------------------------------------------------------
+
+func TestNewOpenAISelectionRNG_ZeroSeed(t *testing.T) {
+ rng := newOpenAISelectionRNG(0)
+ require.NotEqual(t, uint64(0), rng.state, "zero seed should be replaced with default")
+ v := rng.nextFloat64()
+ require.True(t, v >= 0 && v < 1.0)
+}
+
+// ---------------------------------------------------------------------------
+// isCircuitOpen Coverage
+// ---------------------------------------------------------------------------
+
+func TestIsCircuitOpen_UnknownAccount(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ require.False(t, stats.isCircuitOpen(999), "unknown account should not be circuit-open")
+}
+
+func TestIsCircuitOpen_OpenAccount(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ cb := stats.getCircuitBreaker(1)
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.True(t, stats.isCircuitOpen(1))
+}
+
+// ---------------------------------------------------------------------------
+// openAIWSSchedulerP2CEnabled / openAIWSSchedulerWeights Coverage
+// ---------------------------------------------------------------------------
+
+func TestP2CEnabled_NilConfig(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ require.False(t, svc.openAIWSSchedulerP2CEnabled())
+}
+
+func TestP2CEnabled_Enabled(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = true
+ svc := &OpenAIGatewayService{cfg: cfg}
+ require.True(t, svc.openAIWSSchedulerP2CEnabled())
+}
+
+func TestSchedulerWeights_NilConfig(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ w := svc.openAIWSSchedulerWeights()
+ require.Equal(t, 1.0, w.Priority)
+ require.Equal(t, 1.0, w.Load)
+ require.Equal(t, 0.7, w.Queue)
+ require.Equal(t, 0.8, w.ErrorRate)
+ require.Equal(t, 0.5, w.TTFT)
+}
+
+func TestSchedulerWeights_Custom(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 2.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 3.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1.5
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.8
+ svc := &OpenAIGatewayService{cfg: cfg}
+ w := svc.openAIWSSchedulerWeights()
+ require.Equal(t, 2.0, w.Priority)
+ require.Equal(t, 3.0, w.Load)
+}
+
+// ---------------------------------------------------------------------------
+// Snapshot edge cases
+// ---------------------------------------------------------------------------
+
+func TestSnapshot_WithEmptyModel(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ stat := s.loadOrCreate(1)
+ stat.ttft.update(100.0)
+ _, ttft, hasTTFT := s.snapshot(1, "")
+ require.True(t, hasTTFT)
+ require.InDelta(t, 100.0, ttft, 0.01, "empty model string should fall through to global")
+}
+
+func TestSnapshot_NoModelArg(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ stat := s.loadOrCreate(1)
+ stat.ttft.update(100.0)
+ _, ttft, hasTTFT := s.snapshot(1)
+ require.True(t, hasTTFT)
+ require.InDelta(t, 100.0, ttft, 0.01)
+}
+
+// ---------------------------------------------------------------------------
+// deriveOpenAISelectionSeed coverage
+// ---------------------------------------------------------------------------
+
+func TestDeriveOpenAISelectionSeed_WithSessionHash(t *testing.T) {
+ seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{SessionHash: "abc123"})
+ require.NotEqual(t, uint64(0), seed)
+}
+
+func TestDeriveOpenAISelectionSeed_WithPreviousResponseID(t *testing.T) {
+ seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{PreviousResponseID: "resp_123"})
+ require.NotEqual(t, uint64(0), seed)
+}
+
+func TestDeriveOpenAISelectionSeed_WithGroupID(t *testing.T) {
+ gid := int64(42)
+ seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{GroupID: &gid})
+ require.NotEqual(t, uint64(0), seed)
+}
+
+func TestDeriveOpenAISelectionSeed_Empty(t *testing.T) {
+ seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{})
+ require.NotEqual(t, uint64(0), seed, "empty request should use time entropy")
+}
+
+func TestDeriveOpenAISelectionSeed_WithModel(t *testing.T) {
+ seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{RequestedModel: "gpt-4"})
+ require.NotEqual(t, uint64(0), seed)
+}
+
+// ---------------------------------------------------------------------------
+// SnapshotMetrics coverage
+// ---------------------------------------------------------------------------
+
+func TestSnapshotMetrics_NilScheduler(t *testing.T) {
+ var s *defaultOpenAIAccountScheduler
+ metrics := s.SnapshotMetrics()
+ require.Equal(t, int64(0), metrics.SelectTotal)
+}
+
+func TestSnapshotMetrics_Normal(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := mustDefaultOpenAIAccountScheduler(t, svc, stats)
+ metrics := scheduler.SnapshotMetrics()
+ require.Equal(t, int64(0), metrics.SelectTotal)
+}
+
+// ---------------------------------------------------------------------------
+// Heap Pop coverage
+// ---------------------------------------------------------------------------
+
+func TestCandidateHeap_Pop(t *testing.T) {
+ h := &openAIAccountCandidateHeap{}
+ heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}})
+ heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 2}, score: 0.9, loadInfo: &AccountLoadInfo{}})
+ heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 3}, score: 0.3, loadInfo: &AccountLoadInfo{}})
+ require.Equal(t, 3, h.Len())
+
+ // Pop returns the minimum (worst candidate in min-heap)
+ popped, ok := heap.Pop(h).(openAIAccountCandidateScore)
+ require.True(t, ok)
+ require.Equal(t, int64(3), popped.account.ID, "should pop the lowest-scored")
+ require.Equal(t, 2, h.Len())
+}
+
+// ---------------------------------------------------------------------------
+// openAIWSSessionStickyTTL coverage
+// ---------------------------------------------------------------------------
+
+func TestOpenAIWSSessionStickyTTL_DefaultConfig(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ ttl := svc.openAIWSSessionStickyTTL()
+ require.Equal(t, openaiStickySessionTTL, ttl, "nil config should return default TTL")
+}
+
+func TestOpenAIWSSessionStickyTTL_Custom(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
+ svc := &OpenAIGatewayService{cfg: cfg}
+ ttl := svc.openAIWSSessionStickyTTL()
+ require.Equal(t, 1800*time.Second, ttl)
+}
diff --git a/backend/internal/service/openai_client_transport.go b/backend/internal/service/openai_client_transport.go
index c9cf32462..5ed5ff69a 100644
--- a/backend/internal/service/openai_client_transport.go
+++ b/backend/internal/service/openai_client_transport.go
@@ -64,8 +64,25 @@ func resolveOpenAIWSDecisionByClientTransport(
decision OpenAIWSProtocolDecision,
clientTransport OpenAIClientTransport,
) OpenAIWSProtocolDecision {
- if clientTransport == OpenAIClientTransportHTTP {
+ // WSv2 upstream is only allowed for explicit WebSocket ingress.
+ // Unknown/missing transport is treated as HTTP to avoid accidental WS pool usage.
+ if clientTransport != OpenAIClientTransportWS {
return openAIWSHTTPDecision("client_protocol_http")
}
return decision
}
+
+func shouldWarnOpenAIWSUnknownTransportFallback(
+ decision OpenAIWSProtocolDecision,
+ clientTransport OpenAIClientTransport,
+) bool {
+ if clientTransport != OpenAIClientTransportUnknown {
+ return false
+ }
+ switch decision.Transport {
+ case OpenAIUpstreamTransportResponsesWebsocketV2, OpenAIUpstreamTransportResponsesWebsocket:
+ return true
+ default:
+ return false
+ }
+}
diff --git a/backend/internal/service/openai_client_transport_test.go b/backend/internal/service/openai_client_transport_test.go
index ef90e6145..479534c79 100644
--- a/backend/internal/service/openai_client_transport_test.go
+++ b/backend/internal/service/openai_client_transport_test.go
@@ -103,5 +103,40 @@ func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) {
require.Equal(t, base, wsDecision)
unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown)
- require.Equal(t, base, unknownDecision)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, unknownDecision.Transport)
+ require.Equal(t, "client_protocol_http", unknownDecision.Reason)
+}
+
+func TestShouldWarnOpenAIWSUnknownTransportFallback(t *testing.T) {
+ require.True(t, shouldWarnOpenAIWSUnknownTransportFallback(
+ OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
+ Reason: "ws_v2_enabled",
+ },
+ OpenAIClientTransportUnknown,
+ ))
+
+ require.True(t, shouldWarnOpenAIWSUnknownTransportFallback(
+ OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportResponsesWebsocket,
+ Reason: "ws_v1_enabled",
+ },
+ OpenAIClientTransportUnknown,
+ ))
+
+ require.False(t, shouldWarnOpenAIWSUnknownTransportFallback(
+ OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportHTTPSSE,
+ Reason: "http_only",
+ },
+ OpenAIClientTransportUnknown,
+ ))
+
+ require.False(t, shouldWarnOpenAIWSUnknownTransportFallback(
+ OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
+ Reason: "ws_v2_enabled",
+ },
+ OpenAIClientTransportHTTP,
+ ))
}
diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go
new file mode 100644
index 000000000..9cb453ad0
--- /dev/null
+++ b/backend/internal/service/openai_gateway_record_usage_test.go
@@ -0,0 +1,827 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type openAIRecordUsageLogRepoStub struct {
+ UsageLogRepository
+
+ inserted bool
+ err error
+ calls int
+ lastLog *UsageLog
+ nextID int64
+
+ billingEntry *UsageBillingEntry
+ billingEntryErr error
+ upsertCalls int
+ getCalls int
+ markAppliedCalls int
+ markRetryCalls int
+ lastRetryAt time.Time
+ lastRetryErrMessage string
+ txCalls int
+}
+
+func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
+ s.calls++
+ if log != nil {
+ if log.ID == 0 {
+ if s.nextID == 0 {
+ s.nextID = 1000
+ }
+ log.ID = s.nextID
+ s.nextID++
+ }
+ }
+ s.lastLog = log
+ return s.inserted, s.err
+}
+
+func (s *openAIRecordUsageLogRepoStub) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error) {
+ s.getCalls++
+ if s.billingEntryErr != nil {
+ return nil, s.billingEntryErr
+ }
+ if s.billingEntry == nil {
+ return nil, ErrUsageBillingEntryNotFound
+ }
+ if s.billingEntry.UsageLogID != usageLogID {
+ return nil, ErrUsageBillingEntryNotFound
+ }
+ return s.billingEntry, nil
+}
+
+func (s *openAIRecordUsageLogRepoStub) UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error) {
+ s.upsertCalls++
+ if s.billingEntryErr != nil {
+ return nil, false, s.billingEntryErr
+ }
+ if s.billingEntry != nil {
+ return s.billingEntry, false, nil
+ }
+ if entry == nil {
+ return nil, false, nil
+ }
+ copyEntry := *entry
+ copyEntry.ID = 9100 + int64(s.upsertCalls)
+ copyEntry.Status = UsageBillingEntryStatusPending
+ s.billingEntry = ©Entry
+ return s.billingEntry, true, nil
+}
+
+func (s *openAIRecordUsageLogRepoStub) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error {
+ s.markAppliedCalls++
+ if s.billingEntry != nil && s.billingEntry.ID == entryID {
+ s.billingEntry.Applied = true
+ s.billingEntry.Status = UsageBillingEntryStatusApplied
+ }
+ return nil
+}
+
+func (s *openAIRecordUsageLogRepoStub) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error {
+ s.markRetryCalls++
+ s.lastRetryAt = nextRetryAt
+ s.lastRetryErrMessage = lastError
+ if s.billingEntry != nil && s.billingEntry.ID == entryID {
+ s.billingEntry.Applied = false
+ s.billingEntry.Status = UsageBillingEntryStatusPending
+ msg := lastError
+ s.billingEntry.LastError = &msg
+ s.billingEntry.NextRetryAt = nextRetryAt
+ }
+ return nil
+}
+
+func (s *openAIRecordUsageLogRepoStub) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error) {
+ return nil, nil
+}
+
+func (s *openAIRecordUsageLogRepoStub) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ s.txCalls++
+ if fn == nil {
+ return nil
+ }
+ return fn(ctx)
+}
+
+type openAIRecordUsageUserRepoStub struct {
+ UserRepository
+
+ deductCalls int
+ deductErr error
+}
+
+func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
+ s.deductCalls++
+ return s.deductErr
+}
+
+type openAIRecordUsageSubRepoStub struct {
+ UserSubscriptionRepository
+
+ incrementCalls int
+ incrementErr error
+}
+
+func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
+ s.incrementCalls++
+ return s.incrementErr
+}
+
+type openAIRecordUsageBillingCacheStub struct {
+ BillingCache
+
+ deductCalls int
+ deductErr error
+}
+
+func (s *openAIRecordUsageBillingCacheStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
+ s.deductCalls++
+ return s.deductErr
+}
+
+func (s *openAIRecordUsageBillingCacheStub) GetUserBalance(context.Context, int64) (float64, error) {
+ return 0, errors.New("not implemented")
+}
+
+func (s *openAIRecordUsageBillingCacheStub) SetUserBalance(context.Context, int64, float64) error {
+ return errors.New("not implemented")
+}
+
+func (s *openAIRecordUsageBillingCacheStub) InvalidateUserBalance(context.Context, int64) error {
+ return nil
+}
+
+func (s *openAIRecordUsageBillingCacheStub) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openAIRecordUsageBillingCacheStub) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error {
+ return errors.New("not implemented")
+}
+
+func (s *openAIRecordUsageBillingCacheStub) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error {
+ return errors.New("not implemented")
+}
+
+func (s *openAIRecordUsageBillingCacheStub) InvalidateSubscriptionCache(context.Context, int64, int64) error {
+ return nil
+}
+
+type openAIRecordUsageAPIKeyQuotaStub struct {
+ calls int
+ err error
+}
+
+func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
+ s.calls++
+ return s.err
+}
+
+func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
+ return s.err
+}
+
+func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *OpenAIGatewayService {
+ cfg := &config.Config{
+ Default: config.DefaultConfig{
+ RateMultiplier: 1,
+ },
+ }
+ return &OpenAIGatewayService{
+ usageLogRepo: usageRepo,
+ userRepo: userRepo,
+ userSubRepo: subRepo,
+ cfg: cfg,
+ billingService: NewBillingService(cfg, nil),
+ billingCacheService: &BillingCacheService{},
+ deferredService: &DeferredService{},
+ }
+}
+
+func TestOpenAIGatewayServiceRecordUsage_NoBillingWhenCreateUsageLogFails(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ err: errors.New("write usage log failed"),
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_test_create_fail",
+ Usage: OpenAIUsage{
+ InputTokens: 12,
+ OutputTokens: 8,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1001,
+ },
+ User: &User{
+ ID: 2001,
+ },
+ Account: &Account{
+ ID: 3001,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "create usage log")
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_BillingWhenUsageLogInserted(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_test_inserted",
+ Usage: OpenAIUsage{
+ InputTokens: 20,
+ OutputTokens: 10,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1002,
+ },
+ User: &User{
+ ID: 2002,
+ },
+ Account: &Account{
+ ID: 3002,
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_PricingFailureReturnsError(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+ svc.billingService = &BillingService{
+ cfg: &config.Config{},
+ fallbackPrices: map[string]*ModelPricing{},
+ }
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_pricing_fail",
+ Usage: OpenAIUsage{
+ InputTokens: 1,
+ OutputTokens: 1,
+ },
+ Model: "model_pricing_not_found_for_test",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1102,
+ },
+ User: &User{
+ ID: 2102,
+ },
+ Account: &Account{
+ ID: 3102,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "calculate cost")
+ require.Equal(t, 0, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_DeductBalanceFailureReturnsError(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{
+ deductErr: errors.New("db deduct failed"),
+ }
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_deduct_fail",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 5,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1003,
+ },
+ User: &User{
+ ID: 2003,
+ },
+ Account: &Account{
+ ID: 3003,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "deduct balance")
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_DeductBalanceCacheFailureReturnsError(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+ cache := &openAIRecordUsageBillingCacheStub{
+ deductErr: ErrInsufficientBalance,
+ }
+ svc.billingCacheService = &BillingCacheService{cache: cache}
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_cache_deduct_fail",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 5,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1004,
+ },
+ User: &User{
+ ID: 2004,
+ },
+ Account: &Account{
+ ID: 3004,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "deduct balance cache")
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 1, cache.deductCalls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_SubscriptionIncrementFailureReturnsError(t *testing.T) {
+ groupID := int64(12)
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{
+ incrementErr: errors.New("subscription update failed"),
+ }
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+ svc.billingCacheService = &BillingCacheService{cache: &openAIRecordUsageBillingCacheStub{}}
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_sub_increment_fail",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 5,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1005,
+ GroupID: &groupID,
+ Group: &Group{
+ ID: groupID,
+ SubscriptionType: SubscriptionTypeSubscription,
+ },
+ },
+ User: &User{
+ ID: 2005,
+ },
+ Account: &Account{
+ ID: 3005,
+ },
+ Subscription: &UserSubscription{
+ ID: 4005,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "increment subscription usage")
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 1, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: false,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_duplicate",
+ Usage: OpenAIUsage{
+ InputTokens: 8,
+ OutputTokens: 4,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1006,
+ },
+ User: &User{
+ ID: 2006,
+ },
+ Account: &Account{
+ ID: 3006,
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBilling(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+ svc.cfg.RunMode = config.RunModeSimple
+ quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_simple_mode",
+ Usage: OpenAIUsage{
+ InputTokens: 5,
+ OutputTokens: 2,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1007,
+ Quota: 100,
+ },
+ User: &User{
+ ID: 2007,
+ },
+ Account: &Account{
+ ID: 3007,
+ },
+ APIKeyService: quotaSvc,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+ require.Equal(t, 0, quotaSvc.calls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSuccess(t *testing.T) {
+ groupID := int64(13)
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_sub_success",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 8,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1008,
+ GroupID: &groupID,
+ Group: &Group{
+ ID: groupID,
+ SubscriptionType: SubscriptionTypeSubscription,
+ },
+ },
+ User: &User{
+ ID: 2008,
+ },
+ Account: &Account{
+ ID: 3008,
+ },
+ Subscription: &UserSubscription{
+ ID: 4008,
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 1, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+ quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_quota_update",
+ Usage: OpenAIUsage{
+ InputTokens: 1,
+ OutputTokens: 1,
+ CacheReadInputTokens: 3,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1009,
+ Quota: 100,
+ },
+ User: &User{
+ ID: 2009,
+ },
+ Account: &Account{
+ ID: 3009,
+ },
+ APIKeyService: quotaSvc,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.NotNil(t, usageRepo.lastLog)
+ require.Equal(t, 0, usageRepo.lastLog.InputTokens, "input_tokens 小于 cache_read_tokens 时应被钳制为 0")
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.Equal(t, 1, quotaSvc.calls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_DuplicateWithPendingBillingEntryStillBills(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: false,
+ billingEntry: &UsageBillingEntry{
+ ID: 9201,
+ UsageLogID: 1000,
+ Applied: false,
+ BillingType: BillingTypeBalance,
+ DeltaUSD: 1.25,
+ },
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_duplicate_pending_entry",
+ Usage: OpenAIUsage{
+ InputTokens: 8,
+ OutputTokens: 2,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{ID: 1011},
+ User: &User{ID: 2011},
+ Account: &Account{
+ ID: 3011,
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 1, usageRepo.getCalls)
+ require.Equal(t, 1, usageRepo.markAppliedCalls)
+ require.Equal(t, 1, userRepo.deductCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_BillingFailureMarksRetry(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{
+ deductErr: errors.New("deduct failed"),
+ }
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_mark_retry",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 6,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1012,
+ },
+ User: &User{
+ ID: 2012,
+ },
+ Account: &Account{
+ ID: 3012,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "deduct balance")
+ require.Equal(t, 1, usageRepo.markRetryCalls)
+ require.NotZero(t, usageRepo.lastRetryAt)
+ require.NotEmpty(t, usageRepo.lastRetryErrMessage)
+ require.Equal(t, 0, usageRepo.markAppliedCalls)
+}
+
+func TestResolveOpenAIUsageRequestID_FallbackDeterministic(t *testing.T) {
+ reasoning := "medium"
+ input := &OpenAIRecordUsageInput{
+ FallbackRequestID: "req_fallback_seed",
+ APIKey: &APIKey{ID: 11001},
+ Account: &Account{ID: 21001},
+ Result: &OpenAIForwardResult{
+ RequestID: "",
+ Model: "gpt-5.1",
+ Usage: OpenAIUsage{
+ InputTokens: 12,
+ OutputTokens: 8,
+ CacheCreationInputTokens: 2,
+ CacheReadInputTokens: 1,
+ },
+ Duration: 2300 * time.Millisecond,
+ ReasoningEffort: &reasoning,
+ Stream: true,
+ OpenAIWSMode: true,
+ },
+ }
+
+ got1 := resolveOpenAIUsageRequestID(input)
+ got2 := resolveOpenAIUsageRequestID(input)
+
+ require.NotEmpty(t, got1)
+ require.Equal(t, got1, got2, "fallback request id should be deterministic")
+ require.Contains(t, got1, "wsf_")
+}
+
+func TestResolveOpenAIUsageRequestID_FallbackChangesWhenUsageChanges(t *testing.T) {
+ base := &OpenAIRecordUsageInput{
+ FallbackRequestID: "req_fallback_seed",
+ APIKey: &APIKey{ID: 11002},
+ Account: &Account{ID: 21002},
+ Result: &OpenAIForwardResult{
+ Model: "gpt-5.1",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 4,
+ },
+ Duration: 2 * time.Second,
+ },
+ }
+ changed := &OpenAIRecordUsageInput{
+ FallbackRequestID: base.FallbackRequestID,
+ APIKey: base.APIKey,
+ Account: base.Account,
+ Result: &OpenAIForwardResult{
+ Model: "gpt-5.1",
+ Usage: OpenAIUsage{
+ InputTokens: 11,
+ OutputTokens: 4,
+ },
+ Duration: 2 * time.Second,
+ },
+ }
+
+ baseID := resolveOpenAIUsageRequestID(base)
+ changedID := resolveOpenAIUsageRequestID(changed)
+
+ require.NotEqual(t, baseID, changedID, "fallback request id should change when usage fingerprint changes")
+}
+
+func TestResolveOpenAIUsageRequestID_FallbackChangesWhenWSIngressModeChanges(t *testing.T) {
+ base := &OpenAIRecordUsageInput{
+ FallbackRequestID: "req_fallback_seed",
+ APIKey: &APIKey{ID: 11003},
+ Account: &Account{ID: 21003},
+ Result: &OpenAIForwardResult{
+ Model: "gpt-5.3-codex",
+ Stream: true,
+ OpenAIWSMode: true,
+ WSIngressMode: OpenAIWSIngressModeCtxPool,
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 4,
+ },
+ Duration: 2 * time.Second,
+ },
+ }
+ changed := &OpenAIRecordUsageInput{
+ FallbackRequestID: base.FallbackRequestID,
+ APIKey: base.APIKey,
+ Account: base.Account,
+ Result: &OpenAIForwardResult{
+ Model: base.Result.Model,
+ Stream: base.Result.Stream,
+ OpenAIWSMode: true,
+ WSIngressMode: OpenAIWSIngressModePassthrough,
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 4,
+ },
+ Duration: 2 * time.Second,
+ },
+ }
+
+ baseID := resolveOpenAIUsageRequestID(base)
+ changedID := resolveOpenAIUsageRequestID(changed)
+
+ require.NotEqual(t, baseID, changedID, "fallback request id should change when ws ingress mode changes")
+}
+
+func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDWhenMissing(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ input := &OpenAIRecordUsageInput{
+ FallbackRequestID: "req_from_handler",
+ Result: &OpenAIForwardResult{
+ RequestID: "",
+ Usage: OpenAIUsage{
+ InputTokens: 9,
+ OutputTokens: 3,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1013,
+ },
+ User: &User{
+ ID: 2013,
+ },
+ Account: &Account{
+ ID: 3013,
+ },
+ }
+
+ expectedRequestID := resolveOpenAIUsageRequestID(input)
+ require.NotEmpty(t, expectedRequestID)
+
+ err := svc.RecordUsage(context.Background(), input)
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.Equal(t, expectedRequestID, usageRepo.lastLog.RequestID)
+}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 8606708f7..4a8f28a08 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
+ "log/slog"
"math/rand"
"net/http"
"sort"
@@ -42,6 +43,8 @@ const (
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
+ // OpenAIRequestMetaKey 缓存 handler 已提取的请求元数据,供 Service 层复用。
+ OpenAIRequestMetaKey = "openai_request_meta"
// OpenAI WS Mode 失败后的重连次数上限(不含首次尝试)。
// 与 Codex 客户端保持一致:失败后最多重连 5 次。
openAIWSReconnectRetryLimit = 5
@@ -49,6 +52,13 @@ const (
openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond
openAIWSRetryBackoffMaxDefault = 2 * time.Second
openAIWSRetryJitterRatioDefault = 0.2
+ openAICodexUsageUpdateConcurrency = 16
+)
+
+// SSE 热路径包级常量,避免循环内重复分配
+var (
+ sseDataDone = []byte("[DONE]")
+ sseResponseCompletedMark = []byte(`"response.completed"`)
)
// OpenAI allowed headers whitelist (for non-passthrough).
@@ -210,8 +220,16 @@ type OpenAIForwardResult struct {
ReasoningEffort *string
Stream bool
OpenAIWSMode bool
- Duration time.Duration
- FirstTokenMs *int
+ // WSIngressMode records which WS v2 ingress path produced this result.
+ // Expected values: ctx_pool / passthrough.
+ WSIngressMode string
+ Duration time.Duration
+ FirstTokenMs *int
+ // TerminalEventType records the terminal event that ended the WS turn.
+ TerminalEventType string
+ // PendingFunctionCallIDs 表示该 response 中未完成的 function_call call_id 集合。
+ // 仅在 WS ingress 连续对话场景用于续链自愈,不参与外部 API 返回。
+ PendingFunctionCallIDs []string
}
type OpenAIWSRetryMetricsSnapshot struct {
@@ -221,6 +239,17 @@ type OpenAIWSRetryMetricsSnapshot struct {
NonRetryableFastFallbackTotal int64 `json:"non_retryable_fast_fallback_total"`
}
+type OpenAIWSTurnAbortMetricPoint struct {
+ Reason string `json:"reason"`
+ Expected bool `json:"expected"`
+ Total int64 `json:"total"`
+}
+
+type OpenAIWSAbortMetricsSnapshot struct {
+ TurnAbortTotal []OpenAIWSTurnAbortMetricPoint `json:"turn_abort_total"`
+ TurnAbortRecoveredTotal int64 `json:"turn_abort_recovered_total"`
+}
+
type OpenAICompatibilityFallbackMetricsSnapshot struct {
SessionHashLegacyReadFallbackTotal int64 `json:"session_hash_legacy_read_fallback_total"`
SessionHashLegacyReadFallbackHit int64 `json:"session_hash_legacy_read_fallback_hit"`
@@ -243,6 +272,16 @@ type openAIWSRetryMetrics struct {
nonRetryableFastFallback atomic.Int64
}
+type openAIWSTurnAbortMetricKey struct {
+ reason string
+ expected bool
+}
+
+type openAIWSAbortMetrics struct {
+ turnAbortTotal sync.Map // key: openAIWSTurnAbortMetricKey, value: *atomic.Int64
+ turnAbortRecovered atomic.Int64
+}
+
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo AccountRepository
@@ -263,17 +302,28 @@ type OpenAIGatewayService struct {
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
- openaiWSPoolOnce sync.Once
- openaiWSStateStoreOnce sync.Once
- openaiSchedulerOnce sync.Once
- openaiWSPool *openAIWSConnPool
- openaiWSStateStore OpenAIWSStateStore
- openaiScheduler OpenAIAccountScheduler
- openaiAccountStats *openAIAccountRuntimeStats
+ openaiWSIngressCtxOnce sync.Once
+ openaiWSStateStoreOnce sync.Once
+ openaiSchedulerOnce sync.Once
+ openaiWSPassthroughDialerOnce sync.Once
+ openaiWSIngressCtxPool *openAIWSIngressContextPool
+ openaiWSStateStore OpenAIWSStateStore
+ openaiScheduler OpenAIAccountScheduler
+ openaiWSPassthroughDialer openAIWSClientDialer
+ openaiAccountStats *openAIAccountRuntimeStats
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiWSRetryMetrics openAIWSRetryMetrics
+ openaiWSAbortMetrics openAIWSAbortMetrics
responseHeaderFilter *responseheaders.CompiledHeaderFilter
+
+ codexUsageUpdateOnce sync.Once
+ codexUsageUpdateSem chan struct{}
+
+ usageBillingCompensation *UsageBillingCompensationService
+
+ // test hook for deterministic tie-break in account selection.
+ accountTieBreakIntnFn func(n int) int
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
@@ -313,15 +363,25 @@ func NewOpenAIGatewayService(
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
responseHeaderFilter: compileResponseHeaderFilter(cfg),
}
+ svc.usageBillingCompensation = NewUsageBillingCompensationService(usageLogRepo, userRepo, userSubRepo, billingCacheService, cfg)
+ svc.usageBillingCompensation.Start()
svc.logOpenAIWSModeBootstrap()
return svc
}
-// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
+// CloseOpenAIWSCtxPool 关闭 OpenAI WebSocket ctx_pool 的后台 worker 与连接资源。
// 应在应用优雅关闭时调用。
-func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
- if s != nil && s.openaiWSPool != nil {
- s.openaiWSPool.Close()
+func (s *OpenAIGatewayService) CloseOpenAIWSCtxPool() {
+ if s != nil && s.openaiWSIngressCtxPool != nil {
+ s.openaiWSIngressCtxPool.Close()
+ }
+ if s != nil && s.openaiWSStateStore != nil {
+ if closer, ok := s.openaiWSStateStore.(interface{ Close() }); ok {
+ closer.Close()
+ }
+ }
+ if s != nil && s.usageBillingCompensation != nil {
+ s.usageBillingCompensation.Stop()
}
}
@@ -537,6 +597,34 @@ func (s *OpenAIGatewayService) writeOpenAIWSFallbackErrorResponse(c *gin.Context
return true
}
+func (s *OpenAIGatewayService) writeOpenAIWSV1UnsupportedResponse(c *gin.Context, account *Account) error {
+ const (
+ upstreamMessage = "openai ws v1 is temporarily unsupported; use ws v2"
+ clientMessage = "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2."
+ )
+ setOpsUpstreamError(c, http.StatusBadRequest, upstreamMessage, "")
+ if account != nil {
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: http.StatusBadRequest,
+ Kind: "ws_error",
+ Message: upstreamMessage,
+ })
+ }
+ if c != nil {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": gin.H{
+ "type": "invalid_request_error",
+ "message": clientMessage,
+ },
+ })
+ c.Abort()
+ }
+ return errors.New(upstreamMessage)
+}
+
func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration {
if attempt <= 0 {
return 0
@@ -610,6 +698,16 @@ func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration {
return 0
}
+func openAIWSRetryContextError(ctx context.Context) error {
+ if ctx == nil {
+ return nil
+ }
+ if err := ctx.Err(); err != nil {
+ return wrapOpenAIWSFallback("retry_context_canceled", err)
+ }
+ return nil
+}
+
func (s *OpenAIGatewayService) recordOpenAIWSRetryAttempt(backoff time.Duration) {
if s == nil {
return
@@ -646,6 +744,70 @@ func (s *OpenAIGatewayService) SnapshotOpenAIWSRetryMetrics() OpenAIWSRetryMetri
}
}
+func (s *OpenAIGatewayService) recordOpenAIWSTurnAbort(reason openAIWSIngressTurnAbortReason, expected bool) {
+ if s == nil {
+ return
+ }
+ normalizedReason := strings.TrimSpace(string(reason))
+ if normalizedReason == "" {
+ normalizedReason = string(openAIWSIngressTurnAbortReasonUnknown)
+ }
+ key := openAIWSTurnAbortMetricKey{
+ reason: normalizedReason,
+ expected: expected,
+ }
+ counterAny, _ := s.openaiWSAbortMetrics.turnAbortTotal.LoadOrStore(key, &atomic.Int64{})
+ counter, ok := counterAny.(*atomic.Int64)
+ if !ok || counter == nil {
+ return
+ }
+ counter.Add(1)
+}
+
+func (s *OpenAIGatewayService) recordOpenAIWSTurnAbortRecovered() {
+ if s == nil {
+ return
+ }
+ s.openaiWSAbortMetrics.turnAbortRecovered.Add(1)
+}
+
+func (s *OpenAIGatewayService) SnapshotOpenAIWSAbortMetrics() OpenAIWSAbortMetricsSnapshot {
+ if s == nil {
+ return OpenAIWSAbortMetricsSnapshot{}
+ }
+ points := make([]OpenAIWSTurnAbortMetricPoint, 0, 8)
+ s.openaiWSAbortMetrics.turnAbortTotal.Range(func(key, value any) bool {
+ label, ok := key.(openAIWSTurnAbortMetricKey)
+ if !ok {
+ return true
+ }
+ counter, ok := value.(*atomic.Int64)
+ if !ok || counter == nil {
+ return true
+ }
+ total := counter.Load()
+ if total <= 0 {
+ return true
+ }
+ points = append(points, OpenAIWSTurnAbortMetricPoint{
+ Reason: label.reason,
+ Expected: label.expected,
+ Total: total,
+ })
+ return true
+ })
+ sort.Slice(points, func(i, j int) bool {
+ if points[i].Reason == points[j].Reason {
+ return !points[i].Expected && points[j].Expected
+ }
+ return points[i].Reason < points[j].Reason
+ })
+ return OpenAIWSAbortMetricsSnapshot{
+ TurnAbortTotal: points,
+ TurnAbortRecoveredTotal: s.openaiWSAbortMetrics.turnAbortRecovered.Load(),
+ }
+}
+
func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMetricsSnapshot {
legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal := openAIStickyCompatStats()
isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount := RequestMetadataFallbackStats()
@@ -1041,6 +1203,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// Returns nil if no available account.
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
var selected *Account
+ tieCount := 0
for i := range accounts {
acc := &accounts[i]
@@ -1067,17 +1230,41 @@ func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedMo
// Select highest priority and least recently used
if selected == nil {
selected = acc
+ tieCount = 1
continue
}
if s.isBetterAccount(acc, selected) {
selected = acc
+ tieCount = 1
+ continue
+ }
+ if !s.isBetterAccount(selected, acc) {
+ // 完全同分档时进行 reservoir tie-break,避免长期集中命中首个账号。
+ tieCount++
+ if s.accountTieBreakIntn(tieCount) == 0 {
+ selected = acc
+ }
}
}
return selected
}
+func (s *OpenAIGatewayService) accountTieBreakIntn(n int) int {
+ if n <= 1 {
+ return 0
+ }
+ if s != nil && s.accountTieBreakIntnFn != nil {
+ v := s.accountTieBreakIntnFn(n)
+ if v >= 0 && v < n {
+ return v
+ }
+ return 0
+ }
+ return rand.Intn(n)
+}
+
// isBetterAccount 判断 candidate 是否比 current 更优。
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
//
@@ -1381,6 +1568,9 @@ func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig
// GetAccessToken gets the access token for an OpenAI account
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
+ if account == nil {
+ return "", "", errors.New("account is nil")
+ }
switch account.Type {
case AccountTypeOAuth:
// 使用 TokenProvider 获取缓存的 token
@@ -1392,13 +1582,13 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return accessToken, "oauth", nil
}
// 降级:TokenProvider 未配置时直接从账号读取
- accessToken := account.GetOpenAIAccessToken()
+ accessToken := strings.TrimSpace(account.GetOpenAIAccessToken())
if accessToken == "" {
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
case AccountTypeAPIKey:
- apiKey := account.GetOpenAIApiKey()
+ apiKey := strings.TrimSpace(account.GetOpenAIApiKey())
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
}
@@ -1417,8 +1607,19 @@ func (s *OpenAIGatewayService) shouldFailoverUpstreamError(statusCode int) bool
}
}
-func (s *OpenAIGatewayService) handleFailoverSideEffects(ctx context.Context, resp *http.Response, account *Account) {
- body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+func (s *OpenAIGatewayService) handleFailoverSideEffects(
+ ctx context.Context,
+ resp *http.Response,
+ account *Account,
+ respBody []byte,
+) {
+ if s == nil || s.rateLimitService == nil || resp == nil || account == nil {
+ return
+ }
+ body := respBody
+ if len(body) == 0 && resp.Body != nil {
+ body, _ = io.ReadAll(io.LimitReader(resp.Body, 2<<20))
+ }
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
@@ -1440,12 +1641,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
originalBody := body
- reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
+ reqModel, reqStream, promptCacheKey := extractOpenAIRequestMeta(c, body)
originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
clientTransport := GetOpenAIClientTransport(c)
+ if shouldWarnOpenAIWSUnknownTransportFallback(wsDecision, clientTransport) {
+ logOpenAIWSModeInfo(
+ "client_transport_unknown_fallback_http account_id=%d account_type=%s resolved_transport=%s resolved_reason=%s",
+ account.ID,
+ account.Type,
+ normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
+ normalizeOpenAIWSLogValue(wsDecision.Reason),
+ )
+ }
// 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。
wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport)
if c != nil {
@@ -1465,15 +1675,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。
if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket {
- if c != nil {
- c.JSON(http.StatusBadRequest, gin.H{
- "error": gin.H{
- "type": "invalid_request_error",
- "message": "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.",
- },
- })
- }
- return nil, errors.New("openai ws v1 is temporarily unsupported; use ws v2")
+ return nil, s.writeOpenAIWSV1UnsupportedResponse(c, account)
}
passthroughEnabled := account.IsOpenAIPassthroughEnabled()
if passthroughEnabled {
@@ -1762,6 +1964,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
retryStartedAt := time.Now()
wsRetryLoop:
for attempt := 1; attempt <= maxAttempts; attempt++ {
+ if cancelErr := openAIWSRetryContextError(ctx); cancelErr != nil {
+ wsErr = cancelErr
+ break
+ }
wsAttempts = attempt
wsResult, wsErr = s.forwardOpenAIWSV2(
ctx,
@@ -1942,8 +2148,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
Message: upstreamMsg,
Detail: upstreamDetail,
})
-
- s.handleFailoverSideEffects(ctx, resp, account)
+ s.handleFailoverSideEffects(ctx, resp, account, respBody)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
}
return s.handleErrorResponse(ctx, resp, c, account, body)
@@ -2036,7 +2241,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
}
logger.LegacyPrintf("service.openai_gateway",
- "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
+ "[DEBUG] [OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID,
account.Name,
account.Type,
@@ -2202,6 +2407,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
if err != nil {
return nil, err
}
+ req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI))
// 透传客户端请求头(安全白名单)。
allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed()
@@ -2582,6 +2788,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
if err != nil {
return nil, err
}
+ req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI))
// Set authentication header
req.Header.Set("authorization", "Bearer "+token)
@@ -2934,13 +3141,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
- dataBytes := []byte(data)
-
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
- if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
- dataBytes = correctedData
- data = string(correctedData)
- line = "data: " + data
+ // 仅在 toolCorrector 存在时才转换为 []byte,避免热路径无谓分配
+ if s.toolCorrector != nil {
+ dataBytes := []byte(data)
+ if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
+ data = string(correctedData)
+ line = "data: " + data
+ }
}
// 写入客户端(客户端断开后继续 drain 上游)
@@ -2969,7 +3177,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
- s.parseSSEUsageBytes(dataBytes, usage)
+ // 使用 string 版本解析 usage,避免 string→[]byte 转换
+ s.parseSSEUsageString(data, usage)
return
}
@@ -3143,24 +3352,50 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt
}
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
- s.parseSSEUsageBytes([]byte(data), usage)
+ s.parseSSEUsageString(data, usage)
+}
+
+// parseSSEUsageString 使用 gjson.Get(string 版本)解析 usage,避免 string→[]byte 转换
+func (s *OpenAIGatewayService) parseSSEUsageString(data string, usage *OpenAIUsage) {
+ if usage == nil || len(data) == 0 || data == "[DONE]" {
+ return
+ }
+ if len(data) < 80 || !strings.Contains(data, `"response.completed"`) {
+ return
+ }
+ if gjson.Get(data, "type").String() != "response.completed" {
+ return
+ }
+ usageFields := gjson.GetMany(data,
+ "response.usage.input_tokens",
+ "response.usage.output_tokens",
+ "response.usage.input_tokens_details.cached_tokens",
+ )
+ usage.InputTokens = int(usageFields[0].Int())
+ usage.OutputTokens = int(usageFields[1].Int())
+ usage.CacheReadInputTokens = int(usageFields[2].Int())
}
func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) {
- if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
+ if usage == nil || len(data) == 0 || bytes.Equal(data, sseDataDone) {
return
}
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
- if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
+ if len(data) < 80 || !bytes.Contains(data, sseResponseCompletedMark) {
return
}
if gjson.GetBytes(data, "type").String() != "response.completed" {
return
}
-
- usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int())
- usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int())
- usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int())
+ // 使用 GetManyBytes 一次提取 3 个 usage 字段
+ usageFields := gjson.GetManyBytes(data,
+ "response.usage.input_tokens",
+ "response.usage.output_tokens",
+ "response.usage.input_tokens_details.cached_tokens",
+ )
+ usage.InputTokens = int(usageFields[0].Int())
+ usage.OutputTokens = int(usageFields[1].Int())
+ usage.CacheReadInputTokens = int(usageFields[2].Int())
}
func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
@@ -3318,6 +3553,13 @@ func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel st
}
func (s *OpenAIGatewayService) validateUpstreamBaseURL(raw string) (string, error) {
+ if s == nil || s.cfg == nil {
+ normalized, err := urlvalidator.ValidateHTTPSURL(raw, urlvalidator.ValidationOptions{})
+ if err != nil {
+ return "", fmt.Errorf("invalid base_url: %w", err)
+ }
+ return normalized, nil
+ }
if s.cfg != nil && !s.cfg.Security.URLAllowlist.Enabled {
normalized, err := urlvalidator.ValidateURLFormat(raw, s.cfg.Security.URLAllowlist.AllowInsecureHTTP)
if err != nil {
@@ -3365,14 +3607,224 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
- Result *OpenAIForwardResult
- APIKey *APIKey
- User *User
- Account *Account
- Subscription *UserSubscription
- UserAgent string // 请求的 User-Agent
- IPAddress string // 请求的客户端 IP 地址
- APIKeyService APIKeyQuotaUpdater
+ Result *OpenAIForwardResult
+ APIKey *APIKey
+ User *User
+ Account *Account
+ Subscription *UserSubscription
+ UserAgent string // 请求的 User-Agent
+ IPAddress string // 请求的客户端 IP 地址
+ FallbackRequestID string // 当上游 request_id 缺失时,用于生成稳定幂等 request_id
+ APIKeyService APIKeyQuotaUpdater
+}
+
+func (s *OpenAIGatewayService) usageBillingEntryStore() UsageBillingEntryStore {
+ store, ok := s.usageLogRepo.(UsageBillingEntryStore)
+ if !ok {
+ return nil
+ }
+ return store
+}
+
+func (s *OpenAIGatewayService) usageBillingTxRunner() UsageBillingTxRunner {
+ runner, ok := s.usageLogRepo.(UsageBillingTxRunner)
+ if !ok {
+ return nil
+ }
+ return runner
+}
+
+func (s *OpenAIGatewayService) runUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ runner := s.usageBillingTxRunner()
+ if runner == nil {
+ return fn(ctx)
+ }
+ return runner.WithUsageBillingTx(ctx, fn)
+}
+
+func (s *OpenAIGatewayService) prepareUsageBillingEntry(
+ ctx context.Context,
+ usageLog *UsageLog,
+ inserted bool,
+ billingType int8,
+ deltaUSD float64,
+) (*UsageBillingEntry, bool, error) {
+ if deltaUSD <= 0 {
+ return nil, false, nil
+ }
+
+ store := s.usageBillingEntryStore()
+ if store == nil {
+ if inserted {
+ return nil, true, nil
+ }
+ return nil, false, nil
+ }
+
+ if !inserted {
+ entry, err := store.GetUsageBillingEntryByUsageLogID(ctx, usageLog.ID)
+ if err != nil {
+ if errors.Is(err, ErrUsageBillingEntryNotFound) {
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[BillingReconcile] missing billing entry for duplicate usage log, skip immediate billing: usage_log=%d request_id=%s",
+ usageLog.ID,
+ usageLog.RequestID,
+ )
+ return nil, false, nil
+ }
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[BillingReconcile] load billing entry failed for duplicate usage log, skip immediate billing: usage_log=%d request_id=%s err=%v",
+ usageLog.ID,
+ usageLog.RequestID,
+ err,
+ )
+ return nil, false, nil
+ }
+ return entry, !entry.Applied, nil
+ }
+
+ entry, _, err := store.UpsertUsageBillingEntry(ctx, &UsageBillingEntry{
+ UsageLogID: usageLog.ID,
+ UserID: usageLog.UserID,
+ APIKeyID: usageLog.APIKeyID,
+ SubscriptionID: usageLog.SubscriptionID,
+ BillingType: billingType,
+ DeltaUSD: deltaUSD,
+ Status: UsageBillingEntryStatusPending,
+ })
+ if err != nil {
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[BillingReconcile] upsert billing entry failed, fallback to inline billing: usage_log=%d request_id=%s err=%v",
+ usageLog.ID,
+ usageLog.RequestID,
+ err,
+ )
+ return nil, true, nil
+ }
+
+ return entry, !entry.Applied, nil
+}
+
+func (s *OpenAIGatewayService) markUsageBillingRetry(ctx context.Context, entry *UsageBillingEntry, cause error) {
+ if entry == nil || cause == nil {
+ return
+ }
+ store := s.usageBillingEntryStore()
+ if store == nil {
+ return
+ }
+ errMsg := strings.TrimSpace(cause.Error())
+ if len(errMsg) > 500 {
+ errMsg = errMsg[:500]
+ }
+ nextRetryAt := time.Now().Add(usageBillingRetryBackoff(entry.AttemptCount + 1))
+ if err := store.MarkUsageBillingEntryRetry(ctx, entry.ID, nextRetryAt, errMsg); err != nil {
+ logger.LegacyPrintf("service.openai_gateway", "[BillingReconcile] mark retry failed: entry=%d err=%v", entry.ID, err)
+ }
+}
+
+func resolveOpenAIUsageRequestID(input *OpenAIRecordUsageInput) string {
+ if input == nil || input.Result == nil {
+ return ""
+ }
+ if requestID := strings.TrimSpace(input.Result.RequestID); requestID != "" {
+ return requestID
+ }
+ return buildOpenAIUsageFallbackRequestID(input)
+}
+
+func buildOpenAIUsageFallbackRequestID(input *OpenAIRecordUsageInput) string {
+ if input == nil || input.Result == nil {
+ return ""
+ }
+ result := input.Result
+ usage := result.Usage
+
+ seed := strings.Builder{}
+ seed.Grow(192)
+ _, _ = seed.WriteString(strings.TrimSpace(input.FallbackRequestID))
+ _ = seed.WriteByte('|')
+ if input.APIKey != nil {
+ _, _ = seed.WriteString(strconv.FormatInt(input.APIKey.ID, 10))
+ }
+ _ = seed.WriteByte('|')
+ if input.Account != nil {
+ _, _ = seed.WriteString(strconv.FormatInt(input.Account.ID, 10))
+ }
+ _ = seed.WriteByte('|')
+ _, _ = seed.WriteString(strings.TrimSpace(result.Model))
+ _ = seed.WriteByte('|')
+ _, _ = seed.WriteString(strings.TrimSpace(result.TerminalEventType))
+ _ = seed.WriteByte('|')
+ _, _ = seed.WriteString(strconv.FormatBool(result.Stream))
+ _ = seed.WriteByte('|')
+ _, _ = seed.WriteString(strconv.FormatBool(result.OpenAIWSMode))
+ _ = seed.WriteByte('|')
+ _, _ = seed.WriteString(normalizeOpenAIWSIngressMode(result.WSIngressMode))
+ _ = seed.WriteByte('|')
+ _, _ = seed.WriteString(strconv.Itoa(usage.InputTokens))
+ _ = seed.WriteByte('|')
+ _, _ = seed.WriteString(strconv.Itoa(usage.OutputTokens))
+ _ = seed.WriteByte('|')
+ _, _ = seed.WriteString(strconv.Itoa(usage.CacheCreationInputTokens))
+ _ = seed.WriteByte('|')
+ _, _ = seed.WriteString(strconv.Itoa(usage.CacheReadInputTokens))
+ _ = seed.WriteByte('|')
+ _, _ = seed.WriteString(strconv.FormatInt(result.Duration.Milliseconds(), 10))
+ _ = seed.WriteByte('|')
+ firstTokenMs := -1
+ if result.FirstTokenMs != nil {
+ firstTokenMs = *result.FirstTokenMs
+ }
+ _, _ = seed.WriteString(strconv.Itoa(firstTokenMs))
+ _ = seed.WriteByte('|')
+ if result.ReasoningEffort != nil {
+ _, _ = seed.WriteString(strings.TrimSpace(*result.ReasoningEffort))
+ }
+
+ sum := sha256.Sum256([]byte(seed.String()))
+ return "wsf_" + hex.EncodeToString(sum[:16])
+}
+
+func logOpenAIWSUsageRecorded(
+ result *OpenAIForwardResult,
+ usageLog *UsageLog,
+ inserted bool,
+ shouldBill bool,
+ billedAmount float64,
+) {
+ if result == nil || usageLog == nil || !result.OpenAIWSMode {
+ return
+ }
+ wsIngressMode := normalizeOpenAIWSIngressMode(result.WSIngressMode)
+ if wsIngressMode == "" {
+ wsIngressMode = "-"
+ }
+
+ logOpenAIWSModeInfo(
+ "ingress_ws_usage_recorded account_id=%d api_key_id=%d user_id=%d request_id=%s ws_ingress_mode=%s request_type=%s requests=1 model=%s input_tokens=%d output_tokens=%d cache_creation_tokens=%d cache_read_tokens=%d total_tokens=%d total_cost=%.6f actual_cost=%.6f billed_amount=%.6f should_bill=%v inserted=%v terminal_event=%s",
+ usageLog.AccountID,
+ usageLog.APIKeyID,
+ usageLog.UserID,
+ truncateOpenAIWSLogValue(usageLog.RequestID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(wsIngressMode),
+ normalizeOpenAIWSLogValue(usageLog.EffectiveRequestType().String()),
+ truncateOpenAIWSLogValue(usageLog.Model, openAIWSLogValueMaxLen),
+ usageLog.InputTokens,
+ usageLog.OutputTokens,
+ usageLog.CacheCreationTokens,
+ usageLog.CacheReadTokens,
+ usageLog.TotalTokens(),
+ usageLog.TotalCost,
+ usageLog.ActualCost,
+ billedAmount,
+ shouldBill,
+ inserted,
+ truncateOpenAIWSLogValue(result.TerminalEventType, openAIWSLogValueMaxLen),
+ )
}
// RecordUsage records usage and deducts balance
@@ -3406,7 +3858,25 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
if err != nil {
- cost = &CostBreakdown{ActualCost: 0}
+ if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[PricingWarn] calculate cost failed in simple mode, fallback to zero cost: model=%s request_id=%s err=%v",
+ result.Model,
+ result.RequestID,
+ err,
+ )
+ cost = &CostBreakdown{}
+ } else {
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[PricingAlert] calculate cost failed, reject usage record: model=%s request_id=%s err=%v",
+ result.Model,
+ result.RequestID,
+ err,
+ )
+ return fmt.Errorf("calculate cost: %w", err)
+ }
}
// Determine billing type
@@ -3419,11 +3889,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
+ requestID := resolveOpenAIUsageRequestID(input)
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
- RequestID: result.RequestID,
+ RequestID: requestID,
Model: result.Model,
ReasoningEffort: result.ReasoningEffort,
InputTokens: actualInputTokens,
@@ -3464,24 +3935,68 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
}
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
+ if err != nil {
+ return fmt.Errorf("create usage log: %w", err)
+ }
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
+ logOpenAIWSUsageRecorded(result, usageLog, inserted, false, 0)
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
- shouldBill := inserted || err != nil
-
- // Deduct based on billing type
+ billAmount := cost.ActualCost
if isSubscriptionBilling {
- if shouldBill && cost.TotalCost > 0 {
- _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
- s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
+ billAmount = cost.TotalCost
+ }
+ billingEntry, shouldBill, err := s.prepareUsageBillingEntry(ctx, usageLog, inserted, billingType, billAmount)
+ if err != nil {
+ return fmt.Errorf("prepare usage billing entry: %w", err)
+ }
+
+ if shouldBill {
+ cacheDeducted := false
+ if !isSubscriptionBilling && billAmount > 0 && s.billingCacheService != nil {
+ // 同步扣减缓存,避免并发场景下仅靠“先查后扣”产生透支窗口。
+ if err := s.billingCacheService.DeductBalanceCache(ctx, user.ID, billAmount); err != nil {
+ s.markUsageBillingRetry(ctx, billingEntry, err)
+ return fmt.Errorf("deduct balance cache: %w", err)
+ }
+ cacheDeducted = true
}
- } else {
- if shouldBill && cost.ActualCost > 0 {
- _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
- s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
+
+ applyErr := s.runUsageBillingTx(ctx, func(txCtx context.Context) error {
+ if isSubscriptionBilling {
+ if err := s.userSubRepo.IncrementUsage(txCtx, subscription.ID, cost.TotalCost); err != nil {
+ return fmt.Errorf("increment subscription usage: %w", err)
+ }
+ } else if billAmount > 0 {
+ if err := s.userRepo.DeductBalance(txCtx, user.ID, billAmount); err != nil {
+ return fmt.Errorf("deduct balance: %w", err)
+ }
+ }
+ if billingEntry == nil {
+ return nil
+ }
+ store := s.usageBillingEntryStore()
+ if store == nil {
+ return nil
+ }
+ if err := store.MarkUsageBillingEntryApplied(txCtx, billingEntry.ID); err != nil {
+ return fmt.Errorf("mark usage billing entry applied: %w", err)
+ }
+ return nil
+ })
+ if applyErr != nil {
+ if !isSubscriptionBilling && cacheDeducted && s.billingCacheService != nil {
+ _ = s.billingCacheService.InvalidateUserBalance(context.Background(), user.ID)
+ }
+ s.markUsageBillingRetry(ctx, billingEntry, applyErr)
+ return applyErr
+ }
+
+ if isSubscriptionBilling && s.billingCacheService != nil && apiKey.GroupID != nil && cost.TotalCost > 0 {
+ s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
}
@@ -3491,6 +4006,11 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
logger.LegacyPrintf("service.openai_gateway", "Update API key quota failed: %v", err)
}
}
+ billedAmount := 0.0
+ if shouldBill {
+ billedAmount = billAmount
+ }
+ logOpenAIWSUsageRecorded(result, usageLog, inserted, shouldBill, billedAmount)
// Update API Key rate limit usage
if shouldBill && cost.ActualCost > 0 && apiKey.HasRateLimits() && input.APIKeyService != nil {
@@ -3676,15 +4196,48 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
if len(updates) == 0 {
return
}
+ if !s.tryAcquireCodexUsageUpdateSlot() {
+ slog.Warn("openai_gateway.codex_usage_update_dropped",
+ "account_id", accountID,
+ "reason", "concurrency_limit_reached",
+ )
+ return
+ }
// Update account's Extra field asynchronously
go func() {
+ defer s.releaseCodexUsageUpdateSlot()
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}()
}
+func (s *OpenAIGatewayService) tryAcquireCodexUsageUpdateSlot() bool {
+ if s == nil {
+ return false
+ }
+ s.codexUsageUpdateOnce.Do(func() {
+ s.codexUsageUpdateSem = make(chan struct{}, openAICodexUsageUpdateConcurrency)
+ })
+ select {
+ case s.codexUsageUpdateSem <- struct{}{}:
+ return true
+ default:
+ return false
+ }
+}
+
+func (s *OpenAIGatewayService) releaseCodexUsageUpdateSlot() {
+ if s == nil || s.codexUsageUpdateSem == nil {
+ return
+ }
+ select {
+ case <-s.codexUsageUpdateSem:
+ default:
+ }
+}
+
func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) {
if reqBody == nil {
return "", false
@@ -3731,6 +4284,27 @@ func deriveOpenAIReasoningEffortFromModel(model string) string {
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
}
+// OpenAIRequestMeta 缓存已提取的请求元数据,避免重复解析
+type OpenAIRequestMeta struct {
+ Model string
+ Stream bool
+ PromptCacheKey string
+}
+
+// extractOpenAIRequestMeta 优先从 context 读取已缓存的 meta(只读),回退到 body 解析。
+// Handler 层已完成所有字段提取(含 prompt_cache_key),此处不再修改 meta,避免并发竞态。
+func extractOpenAIRequestMeta(c *gin.Context, body []byte) (model string, stream bool, promptCacheKey string) {
+ if c != nil {
+ if cached, ok := c.Get(OpenAIRequestMetaKey); ok {
+ if meta, ok := cached.(*OpenAIRequestMeta); ok && meta != nil {
+ return meta.Model, meta.Stream, meta.PromptCacheKey
+ }
+ }
+ }
+ // 回退到原始解析(WebSocket 等其他入口)
+ return extractOpenAIRequestMetaFromBody(body)
+}
+
func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) {
if len(body) == 0 {
return "", false, ""
diff --git a/backend/internal/service/openai_gateway_service_access_token_test.go b/backend/internal/service/openai_gateway_service_access_token_test.go
new file mode 100644
index 000000000..32b5c8e4e
--- /dev/null
+++ b/backend/internal/service/openai_gateway_service_access_token_test.go
@@ -0,0 +1,90 @@
+package service
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIGatewayServiceGetAccessToken(t *testing.T) {
+ t.Parallel()
+
+ svc := &OpenAIGatewayService{}
+
+ t.Run("nil account", func(t *testing.T) {
+ token, tokenType, err := svc.GetAccessToken(context.Background(), nil)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "account is nil")
+ require.Empty(t, token)
+ require.Empty(t, tokenType)
+ })
+
+ t.Run("oauth account", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "oauth-token",
+ },
+ }
+ token, tokenType, err := svc.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "oauth-token", token)
+ require.Equal(t, "oauth", tokenType)
+ })
+
+ t.Run("oauth account trims token", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": " oauth-token-trim ",
+ },
+ }
+ token, tokenType, err := svc.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "oauth-token-trim", token)
+ require.Equal(t, "oauth", tokenType)
+ })
+
+ t.Run("api key account", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": "sk-live-token",
+ },
+ }
+ token, tokenType, err := svc.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "sk-live-token", token)
+ require.Equal(t, "apikey", tokenType)
+ })
+
+ t.Run("api key account trims token", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "api_key": " sk-live-trim ",
+ },
+ }
+ token, tokenType, err := svc.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "sk-live-trim", token)
+ require.Equal(t, "apikey", tokenType)
+ })
+
+ t.Run("unsupported account type", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: "unknown",
+ }
+ token, tokenType, err := svc.GetAccessToken(context.Background(), account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "unsupported account type")
+ require.Empty(t, token)
+ require.Empty(t, tokenType)
+ })
+}
diff --git a/backend/internal/service/openai_gateway_service_hotpath_test.go b/backend/internal/service/openai_gateway_service_hotpath_test.go
index f73c06c5e..a82e6e2a3 100644
--- a/backend/internal/service/openai_gateway_service_hotpath_test.go
+++ b/backend/internal/service/openai_gateway_service_hotpath_test.go
@@ -139,3 +139,126 @@ func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) {
require.True(t, ok)
require.Equal(t, got, cachedMap)
}
+
+// --- extractOpenAIRequestMeta context 缓存测试 ---
+
+func TestExtractOpenAIRequestMeta_CacheHit(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ // 预设缓存(Handler 层已提取所有字段,包括 PromptCacheKey)
+ c.Set(OpenAIRequestMetaKey, &OpenAIRequestMeta{
+ Model: "gpt-5",
+ Stream: true,
+ PromptCacheKey: "key-1",
+ })
+
+ body := []byte(`{"model":"gpt-4","stream":false,"prompt_cache_key":"key-other"}`)
+ model, stream, promptKey := extractOpenAIRequestMeta(c, body)
+
+ // 应返回缓存值而非 body 中的值
+ require.Equal(t, "gpt-5", model)
+ require.True(t, stream)
+ require.Equal(t, "key-1", promptKey)
+}
+
+func TestExtractOpenAIRequestMeta_CacheHit_PromptCacheKeyFromHandler(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ // Handler 层已提取 PromptCacheKey,meta 设置后只读不写
+ meta := &OpenAIRequestMeta{Model: "gpt-5", Stream: false, PromptCacheKey: "pk-abc"}
+ c.Set(OpenAIRequestMetaKey, meta)
+
+ body := []byte(`{"model":"gpt-4","prompt_cache_key":"pk-other"}`)
+
+ // 应返回缓存中的值(Handler 层提取),而非 body 中的值
+ _, _, promptKey1 := extractOpenAIRequestMeta(c, body)
+ require.Equal(t, "pk-abc", promptKey1)
+
+ // 多次调用结果一致
+ _, _, promptKey2 := extractOpenAIRequestMeta(c, body)
+ require.Equal(t, "pk-abc", promptKey2)
+}
+
+func TestExtractOpenAIRequestMeta_CacheHit_EmptyBody(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ c.Set(OpenAIRequestMetaKey, &OpenAIRequestMeta{
+ Model: "gpt-5",
+ Stream: true,
+ })
+
+ // body 为空时不应 panic,prompt_cache_key 应为空
+ model, stream, promptKey := extractOpenAIRequestMeta(c, nil)
+ require.Equal(t, "gpt-5", model)
+ require.True(t, stream)
+ require.Equal(t, "", promptKey)
+}
+
+func TestExtractOpenAIRequestMeta_FallbackToBody(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ // 不设缓存,应回退到 body 解析
+ body := []byte(`{"model":"gpt-4o","stream":true,"prompt_cache_key":"ses-2"}`)
+ model, stream, promptKey := extractOpenAIRequestMeta(c, body)
+
+ require.Equal(t, "gpt-4o", model)
+ require.True(t, stream)
+ require.Equal(t, "ses-2", promptKey)
+}
+
+func TestExtractOpenAIRequestMeta_NilContext(t *testing.T) {
+ body := []byte(`{"model":"gpt-4","stream":false,"prompt_cache_key":"k"}`)
+ model, stream, promptKey := extractOpenAIRequestMeta(nil, body)
+
+ require.Equal(t, "gpt-4", model)
+ require.False(t, stream)
+ require.Equal(t, "k", promptKey)
+}
+
+func TestExtractOpenAIRequestMeta_InvalidCacheType(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ // 缓存类型错误,应回退到 body 解析
+ c.Set(OpenAIRequestMetaKey, "invalid-type")
+
+ body := []byte(`{"model":"gpt-4o","stream":true}`)
+ model, stream, _ := extractOpenAIRequestMeta(c, body)
+
+ require.Equal(t, "gpt-4o", model)
+ require.True(t, stream)
+}
+
+func TestExtractOpenAIRequestMeta_NilCacheValue(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ c.Set(OpenAIRequestMetaKey, (*OpenAIRequestMeta)(nil))
+
+ body := []byte(`{"model":"gpt-5","stream":false}`)
+ model, stream, _ := extractOpenAIRequestMeta(c, body)
+
+ require.Equal(t, "gpt-5", model)
+ require.False(t, stream)
+}
+
+func TestOpenAIRequestMeta_Fields(t *testing.T) {
+ meta := &OpenAIRequestMeta{
+ Model: "gpt-5",
+ Stream: true,
+ PromptCacheKey: "pk",
+ }
+ require.Equal(t, "gpt-5", meta.Model)
+ require.True(t, meta.Stream)
+ require.Equal(t, "pk", meta.PromptCacheKey)
+}
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index 4f5f7f3c1..335495e9b 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -10,6 +10,8 @@ import (
"net/http"
"net/http/httptest"
"strings"
+ "sync"
+ "sync/atomic"
"testing"
"time"
@@ -61,6 +63,54 @@ func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatform(ctx context.Co
return r.ListSchedulableByPlatform(ctx, platform)
}
+func (r stubOpenAIAccountRepo) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
+ if len(platforms) == 0 {
+ return nil, nil
+ }
+ platformSet := make(map[string]struct{}, len(platforms))
+ for _, p := range platforms {
+ platformSet[p] = struct{}{}
+ }
+ result := make([]Account, 0, len(r.accounts))
+ for _, acc := range r.accounts {
+ if _, ok := platformSet[acc.Platform]; ok {
+ result = append(result, acc)
+ }
+ }
+ return result, nil
+}
+
+func (r stubOpenAIAccountRepo) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
+ return r.ListSchedulableByPlatforms(ctx, platforms)
+}
+
+type codexUsageUpdateAccountRepoStub struct {
+ stubOpenAIAccountRepo
+
+ calls atomic.Int32
+ entered chan struct{}
+ release chan struct{}
+ enterSig sync.Once
+}
+
+func (r *codexUsageUpdateAccountRepoStub) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
+ r.calls.Add(1)
+ r.enterSig.Do(func() {
+ if r.entered != nil {
+ close(r.entered)
+ }
+ })
+ if r.release == nil {
+ return nil
+ }
+ select {
+ case <-r.release:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
type stubConcurrencyCache struct {
ConcurrencyCache
loadBatchErr error
@@ -355,6 +405,37 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
}
}
+func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_BoundedAsyncUpdates(t *testing.T) {
+ repo := &codexUsageUpdateAccountRepoStub{
+ entered: make(chan struct{}),
+ release: make(chan struct{}),
+ }
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ codexUsageUpdateSem: make(chan struct{}, 1),
+ }
+ svc.codexUsageUpdateOnce.Do(func() {})
+
+ snapshot := &OpenAICodexUsageSnapshot{
+ UpdatedAt: time.Now().Format(time.RFC3339),
+ }
+
+ svc.updateCodexUsageSnapshot(context.Background(), 1, snapshot)
+
+ select {
+ case <-repo.entered:
+ case <-time.After(time.Second):
+ t.Fatal("first codex usage snapshot update did not start")
+ }
+
+ // first async update is still holding the single slot
+ svc.updateCodexUsageSnapshot(context.Background(), 1, snapshot)
+ time.Sleep(80 * time.Millisecond)
+ require.Equal(t, int32(1), repo.calls.Load(), "slot full 时应拒绝新异步写入,避免 goroutine 无界增长")
+
+ close(repo.release)
+}
+
func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) {
sessionHash := "session-1"
repo := stubOpenAIAccountRepo{
@@ -832,6 +913,33 @@ func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.
}
}
+func TestOpenAISelectAccountForModelWithExclusions_EqualPriorityTieBreakRandomized(t *testing.T) {
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1},
+ {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1},
+ },
+ }
+ cache := &stubGatewayCache{}
+
+ calls := 0
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ accountTieBreakIntnFn: func(n int) int {
+ calls++
+ require.Equal(t, 2, n)
+ return 0
+ },
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID, "tie-break should allow later equal-priority candidate to be selected")
+ require.Equal(t, 1, calls, "tie-break should be invoked exactly once for two equal candidates")
+}
+
func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) {
groupID := int64(1)
lastUsed := time.Now().Add(-1 * time.Hour)
@@ -1189,6 +1297,53 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
}
}
+func TestOpenAIBuildUpstreamRequestSetsHTTPUpstreamProfile(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Security: config.SecurityConfig{
+ URLAllowlist: config.URLAllowlistConfig{Enabled: false},
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`))
+
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ }
+
+ req, err := svc.buildUpstreamRequest(context.Background(), c, account, []byte(`{}`), "token", false, "", false)
+ require.NoError(t, err)
+ require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(req.Context()))
+}
+
+func TestOpenAIBuildUpstreamPassthroughRequestSetsHTTPUpstreamProfile(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Security: config.SecurityConfig{
+ URLAllowlist: config.URLAllowlistConfig{Enabled: false},
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ }
+
+ req, err := svc.buildUpstreamRequestOpenAIPassthrough(context.Background(), c, account, []byte(`{}`), "token")
+ require.NoError(t, err)
+ require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(req.Context()))
+}
+
func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
@@ -1248,6 +1403,21 @@ func TestOpenAIValidateUpstreamBaseURLEnabledEnforcesAllowlist(t *testing.T) {
}
}
+func TestOpenAIValidateUpstreamBaseURLNilConfigDefaultsToStrictHTTPS(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+
+ if _, err := svc.validateUpstreamBaseURL("http://example.com"); err == nil {
+ t.Fatalf("expected http to be rejected when config is nil")
+ }
+ normalized, err := svc.validateUpstreamBaseURL("https://example.com")
+ if err != nil {
+ t.Fatalf("expected https to pass when config is nil, got %v", err)
+ }
+ if normalized != "https://example.com" {
+ t.Fatalf("expected normalized https url, got %q", normalized)
+ }
+}
+
// ==================== P1-08 修复:model 替换性能优化测试 ====================
func TestReplaceModelInSSELine(t *testing.T) {
diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go
index 72f4bbb09..1e085923e 100644
--- a/backend/internal/service/openai_oauth_service.go
+++ b/backend/internal/service/openai_oauth_service.go
@@ -7,6 +7,7 @@ import (
"io"
"log/slog"
"net/http"
+ "net/url"
"regexp"
"sort"
"strconv"
@@ -14,7 +15,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
- "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
@@ -273,13 +273,7 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
- client, err := httpclient.GetClient(httpclient.Options{
- ProxyURL: proxyURL,
- Timeout: 120 * time.Second,
- })
- if err != nil {
- return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err)
- }
+ client := newOpenAIOAuthHTTPClient(proxyURL)
resp, err := client.Do(req)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
@@ -471,7 +465,7 @@ func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *A
}
var proxyURL string
- if account.ProxyID != nil {
+ if account.ProxyID != nil && s.proxyRepo != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
@@ -536,6 +530,19 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64
return proxy.URL(), nil
}
+func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
+ transport := &http.Transport{}
+ if strings.TrimSpace(proxyURL) != "" {
+ if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
+ transport.Proxy = http.ProxyURL(parsed)
+ }
+ }
+ return &http.Client{
+ Timeout: 120 * time.Second,
+ Transport: transport,
+ }
+}
+
func normalizeOpenAIOAuthPlatform(platform string) string {
switch strings.ToLower(strings.TrimSpace(platform)) {
case PlatformSora:
diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go
index 292523288..c03352c1f 100644
--- a/backend/internal/service/openai_oauth_service_state_test.go
+++ b/backend/internal/service/openai_oauth_service_state_test.go
@@ -34,6 +34,29 @@ func (s *openaiOAuthClientStateStub) RefreshTokenWithClientID(ctx context.Contex
return s.RefreshToken(ctx, refreshToken, proxyURL)
}
+type openaiOAuthClientRefreshStub struct {
+ refreshCalled int32
+ lastClientID string
+}
+
+func (s *openaiOAuthClientRefreshStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openaiOAuthClientRefreshStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
+ return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
+}
+
+func (s *openaiOAuthClientRefreshStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
+ atomic.AddInt32(&s.refreshCalled, 1)
+ s.lastClientID = clientID
+ return &openai.TokenResponse{
+ AccessToken: "new-at",
+ RefreshToken: "new-rt",
+ ExpiresIn: 3600,
+ }, nil
+}
+
func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
@@ -55,6 +78,30 @@ func TestOpenAIOAuthService_ExchangeCode_StateRequired(t *testing.T) {
require.Equal(t, int32(0), atomic.LoadInt32(&client.exchangeCalled))
}
+func TestOpenAIOAuthService_RefreshAccountToken_NilProxyRepoWithProxyID(t *testing.T) {
+ client := &openaiOAuthClientRefreshStub{}
+ svc := NewOpenAIOAuthService(nil, client)
+ defer svc.Stop()
+
+ proxyID := int64(123)
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ ProxyID: &proxyID,
+ Credentials: map[string]any{
+ "refresh_token": "rt",
+ "client_id": "cid",
+ },
+ }
+
+ tokenInfo, err := svc.RefreshAccountToken(context.Background(), account)
+ require.NoError(t, err)
+ require.NotNil(t, tokenInfo)
+ require.Equal(t, "new-at", tokenInfo.AccessToken)
+ require.Equal(t, int32(1), atomic.LoadInt32(&client.refreshCalled))
+ require.Equal(t, "cid", client.lastClientID)
+}
+
func TestOpenAIOAuthService_ExchangeCode_StateMismatch(t *testing.T) {
client := &openaiOAuthClientStateStub{}
svc := NewOpenAIOAuthService(nil, client)
diff --git a/backend/internal/service/openai_sse_zero_alloc_test.go b/backend/internal/service/openai_sse_zero_alloc_test.go
new file mode 100644
index 000000000..fe853d59a
--- /dev/null
+++ b/backend/internal/service/openai_sse_zero_alloc_test.go
@@ -0,0 +1,276 @@
+package service
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// --- 包级常量验证 ---
+
+func TestSSEPackageLevelConstants(t *testing.T) {
+ require.Equal(t, []byte("[DONE]"), sseDataDone)
+ require.Equal(t, []byte(`"response.completed"`), sseResponseCompletedMark)
+}
+
+func TestSSEDataDone_UsedInBytesEqual(t *testing.T) {
+ require.True(t, bytes.Equal([]byte("[DONE]"), sseDataDone))
+ require.False(t, bytes.Equal([]byte("[done]"), sseDataDone))
+ require.False(t, bytes.Equal([]byte(""), sseDataDone))
+}
+
+func TestSSEResponseCompletedMark_UsedInBytesContains(t *testing.T) {
+ data := []byte(`{"type":"response.completed","response":{"usage":{}}}`)
+ require.True(t, bytes.Contains(data, sseResponseCompletedMark))
+
+ unrelated := []byte(`{"type":"response.in_progress"}`)
+ require.False(t, bytes.Contains(unrelated, sseResponseCompletedMark))
+}
+
+// --- parseSSEUsageString 测试 ---
+
+func TestParseSSEUsageString_CompletedEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{}
+
+ data := `{"type":"response.completed","response":{"usage":{"input_tokens":100,"output_tokens":50,"input_tokens_details":{"cached_tokens":20}}}}`
+ svc.parseSSEUsageString(data, usage)
+
+ require.Equal(t, 100, usage.InputTokens)
+ require.Equal(t, 50, usage.OutputTokens)
+ require.Equal(t, 20, usage.CacheReadInputTokens)
+}
+
+func TestParseSSEUsageString_NonCompletedEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 99, OutputTokens: 88}
+
+ data := `{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}}}`
+ svc.parseSSEUsageString(data, usage)
+
+ // 非 completed 事件不应修改 usage
+ require.Equal(t, 99, usage.InputTokens)
+ require.Equal(t, 88, usage.OutputTokens)
+}
+
+func TestParseSSEUsageString_DoneEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 10}
+
+ svc.parseSSEUsageString("[DONE]", usage)
+ require.Equal(t, 10, usage.InputTokens) // 不应修改
+}
+
+func TestParseSSEUsageString_EmptyString(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 5}
+
+ svc.parseSSEUsageString("", usage)
+ require.Equal(t, 5, usage.InputTokens) // 不应修改
+}
+
+func TestParseSSEUsageString_NilUsage(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+
+ // 不应 panic
+ require.NotPanics(t, func() {
+ svc.parseSSEUsageString(`{"type":"response.completed"}`, nil)
+ })
+}
+
+func TestParseSSEUsageString_ShortData(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 7}
+
+ // 短于 80 字节的数据直接跳过
+ svc.parseSSEUsageString(`{"type":"response.completed"}`, usage)
+ require.Equal(t, 7, usage.InputTokens) // 不应修改
+}
+
+func TestParseSSEUsageString_ContainsCompletedButWrongType(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 42}
+
+ // 包含 "response.completed" 子串但 type 字段不匹配
+ data := `{"type":"response.in_progress","description":"not response.completed at all","padding":"aaaaaaaaaaaaaaaaaaaaaaaaaaaa"}`
+ svc.parseSSEUsageString(data, usage)
+ require.Equal(t, 42, usage.InputTokens) // 不应修改
+}
+
+func TestParseSSEUsageString_ZeroUsageValues(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 99}
+
+ data := `{"type":"response.completed","response":{"usage":{"input_tokens":0,"output_tokens":0,"input_tokens_details":{"cached_tokens":0}}}}`
+ svc.parseSSEUsageString(data, usage)
+
+ require.Equal(t, 0, usage.InputTokens)
+ require.Equal(t, 0, usage.OutputTokens)
+ require.Equal(t, 0, usage.CacheReadInputTokens)
+}
+
+func TestParseSSEUsageString_MissingUsageFields(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{}
+
+ // response.usage 存在但缺少某些子字段
+ data := `{"type":"response.completed","response":{"usage":{"input_tokens":10},"padding":"aaaaaaaaaaaaaaaaaaa"}}`
+ svc.parseSSEUsageString(data, usage)
+
+ require.Equal(t, 10, usage.InputTokens)
+ require.Equal(t, 0, usage.OutputTokens)
+ require.Equal(t, 0, usage.CacheReadInputTokens)
+}
+
+// --- parseSSEUsageBytes 与包级常量集成测试 ---
+
+func TestParseSSEUsageBytes_CompletedEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{}
+
+ data := []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":200,"output_tokens":80,"input_tokens_details":{"cached_tokens":30}}}}`)
+ svc.parseSSEUsageBytes(data, usage)
+
+ require.Equal(t, 200, usage.InputTokens)
+ require.Equal(t, 80, usage.OutputTokens)
+ require.Equal(t, 30, usage.CacheReadInputTokens)
+}
+
+func TestParseSSEUsageBytes_DoneEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 10}
+
+ svc.parseSSEUsageBytes([]byte("[DONE]"), usage)
+ require.Equal(t, 10, usage.InputTokens) // 不应修改
+}
+
+func TestParseSSEUsageBytes_EmptyData(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 5}
+
+ svc.parseSSEUsageBytes(nil, usage)
+ require.Equal(t, 5, usage.InputTokens)
+
+ svc.parseSSEUsageBytes([]byte{}, usage)
+ require.Equal(t, 5, usage.InputTokens)
+}
+
+func TestParseSSEUsageBytes_NilUsage(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+
+ require.NotPanics(t, func() {
+ svc.parseSSEUsageBytes([]byte(`{"type":"response.completed"}`), nil)
+ })
+}
+
+func TestParseSSEUsageBytes_ShortData(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 7}
+
+ svc.parseSSEUsageBytes([]byte(`{"type":"response.completed"}`), usage)
+ require.Equal(t, 7, usage.InputTokens)
+}
+
+func TestParseSSEUsageBytes_NonCompletedEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 99}
+
+ data := []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}},"pad":"xxxx"}`)
+ svc.parseSSEUsageBytes(data, usage)
+
+ require.Equal(t, 99, usage.InputTokens)
+}
+
+func TestParseSSEUsageBytes_GetManyBytesExtraction(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{}
+
+ // 验证 GetManyBytes 一次提取 3 个字段的正确性
+ data := []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":111,"output_tokens":222,"input_tokens_details":{"cached_tokens":333}}}}`)
+ svc.parseSSEUsageBytes(data, usage)
+
+ require.Equal(t, 111, usage.InputTokens)
+ require.Equal(t, 222, usage.OutputTokens)
+ require.Equal(t, 333, usage.CacheReadInputTokens)
+}
+
+// --- parseSSEUsage wrapper 测试 ---
+
+func TestParseSSEUsage_DelegatesToString(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{}
+
+ // 验证 parseSSEUsage 最终正确提取 usage
+ data := `{"type":"response.completed","response":{"usage":{"input_tokens":55,"output_tokens":66,"input_tokens_details":{"cached_tokens":77}}}}`
+ svc.parseSSEUsage(data, usage)
+
+ require.Equal(t, 55, usage.InputTokens)
+ require.Equal(t, 66, usage.OutputTokens)
+ require.Equal(t, 77, usage.CacheReadInputTokens)
+}
+
+func TestParseSSEUsage_DoneNotParsed(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 123}
+
+ svc.parseSSEUsage("[DONE]", usage)
+ require.Equal(t, 123, usage.InputTokens)
+}
+
+func TestParseSSEUsage_EmptyNotParsed(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 456}
+
+ svc.parseSSEUsage("", usage)
+ require.Equal(t, 456, usage.InputTokens)
+}
+
+// --- string 和 bytes 一致性测试 ---
+
+func TestParseSSEUsage_StringAndBytesConsistency(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+
+ completedData := `{"type":"response.completed","response":{"usage":{"input_tokens":300,"output_tokens":150,"input_tokens_details":{"cached_tokens":50}}}}`
+
+ usageStr := &OpenAIUsage{}
+ svc.parseSSEUsageString(completedData, usageStr)
+
+ usageBytes := &OpenAIUsage{}
+ svc.parseSSEUsageBytes([]byte(completedData), usageBytes)
+
+ require.Equal(t, usageStr.InputTokens, usageBytes.InputTokens)
+ require.Equal(t, usageStr.OutputTokens, usageBytes.OutputTokens)
+ require.Equal(t, usageStr.CacheReadInputTokens, usageBytes.CacheReadInputTokens)
+}
+
+func TestParseSSEUsage_StringAndBytesConsistency_NonCompleted(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+
+ inProgressData := `{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}},"pad":"xxx"}`
+
+ usageStr := &OpenAIUsage{InputTokens: 10}
+ svc.parseSSEUsageString(inProgressData, usageStr)
+
+ usageBytes := &OpenAIUsage{InputTokens: 10}
+ svc.parseSSEUsageBytes([]byte(inProgressData), usageBytes)
+
+ // 两者都不应修改
+ require.Equal(t, 10, usageStr.InputTokens)
+ require.Equal(t, 10, usageBytes.InputTokens)
+}
+
+func TestParseSSEUsage_StringAndBytesConsistency_LargeTokenCounts(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+
+ data := `{"type":"response.completed","response":{"usage":{"input_tokens":1000000,"output_tokens":500000,"input_tokens_details":{"cached_tokens":200000}}}}`
+
+ usageStr := &OpenAIUsage{}
+ svc.parseSSEUsageString(data, usageStr)
+
+ usageBytes := &OpenAIUsage{}
+ svc.parseSSEUsageBytes([]byte(data), usageBytes)
+
+ require.Equal(t, 1000000, usageStr.InputTokens)
+ require.Equal(t, usageStr, usageBytes)
+}
diff --git a/backend/internal/service/openai_sticky_compat.go b/backend/internal/service/openai_sticky_compat.go
index e897debc2..0f576b664 100644
--- a/backend/internal/service/openai_sticky_compat.go
+++ b/backend/internal/service/openai_sticky_compat.go
@@ -35,12 +35,31 @@ func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash
return "", ""
}
- currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized))
- sum := sha256.Sum256([]byte(normalized))
- legacyHash = hex.EncodeToString(sum[:])
+ currentHash = deriveOpenAISessionHash(normalized)
+ legacyHash = deriveOpenAILegacySessionHash(normalized)
return currentHash, legacyHash
}
+// deriveOpenAISessionHash returns the fast xxhash-based session hash.
+func deriveOpenAISessionHash(sessionID string) string {
+ normalized := strings.TrimSpace(sessionID)
+ if normalized == "" {
+ return ""
+ }
+ return fmt.Sprintf("%016x", xxhash.Sum64String(normalized))
+}
+
+// deriveOpenAILegacySessionHash returns the SHA-256 legacy hash.
+// Only call this when legacy fallback or dual-write is enabled.
+func deriveOpenAILegacySessionHash(sessionID string) string {
+ normalized := strings.TrimSpace(sessionID)
+ if normalized == "" {
+ return ""
+ }
+ sum := sha256.Sum256([]byte(normalized))
+ return hex.EncodeToString(sum[:])
+}
+
func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context {
if ctx == nil {
return nil
diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go
index a8a6b96c5..d12a57638 100644
--- a/backend/internal/service/openai_token_provider.go
+++ b/backend/internal/service/openai_token_provider.go
@@ -151,9 +151,11 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
// 从数据库获取最新账户信息
- fresh, err := p.accountRepo.GetByID(ctx, account.ID)
- if err == nil && fresh != nil {
- account = fresh
+ if p.accountRepo != nil {
+ fresh, err := p.accountRepo.GetByID(ctx, account.ID)
+ if err == nil && fresh != nil {
+ account = fresh
+ }
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
@@ -181,8 +183,10 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
account.Credentials = newCredentials
- if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
- slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
+ if p.accountRepo != nil {
+ if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
+ slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
+ }
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
@@ -233,8 +237,10 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
account.Credentials = newCredentials
- if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
- slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
+ if p.accountRepo != nil {
+ if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
+ slog.Error("openai_token_provider_update_failed", "account_id", account.ID, "error", updateErr)
+ }
}
expiresAt = account.GetCredentialAsTime("expires_at")
}
diff --git a/backend/internal/service/openai_token_provider_nil_repo_test.go b/backend/internal/service/openai_token_provider_nil_repo_test.go
new file mode 100644
index 000000000..f2e7e91e8
--- /dev/null
+++ b/backend/internal/service/openai_token_provider_nil_repo_test.go
@@ -0,0 +1,101 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ "github.com/stretchr/testify/require"
+)
+
+type openAITokenCacheNilRepoStub struct {
+ token string
+ lockCalls atomic.Int32
+ setCalls atomic.Int32
+ getCalls atomic.Int32
+ lockEnabled bool
+}
+
+func (s *openAITokenCacheNilRepoStub) GetAccessToken(context.Context, string) (string, error) {
+ s.getCalls.Add(1)
+ return s.token, nil
+}
+
+func (s *openAITokenCacheNilRepoStub) SetAccessToken(context.Context, string, string, time.Duration) error {
+ s.setCalls.Add(1)
+ return nil
+}
+
+func (s *openAITokenCacheNilRepoStub) DeleteAccessToken(context.Context, string) error {
+ return nil
+}
+
+func (s *openAITokenCacheNilRepoStub) AcquireRefreshLock(context.Context, string, time.Duration) (bool, error) {
+ s.lockCalls.Add(1)
+ return s.lockEnabled, nil
+}
+
+func (s *openAITokenCacheNilRepoStub) ReleaseRefreshLock(context.Context, string) error {
+ return nil
+}
+
+type openAIOAuthClientNilRepoStub struct{}
+
+func (s *openAIOAuthClientNilRepoStub) ExchangeCode(
+ context.Context,
+ string,
+ string,
+ string,
+ string,
+ string,
+) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openAIOAuthClientNilRepoStub) RefreshToken(
+ context.Context,
+ string,
+ string,
+) (*openai.TokenResponse, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openAIOAuthClientNilRepoStub) RefreshTokenWithClientID(
+ context.Context,
+ string,
+ string,
+ string,
+) (*openai.TokenResponse, error) {
+ return &openai.TokenResponse{
+ AccessToken: "fresh-token",
+ RefreshToken: "fresh-refresh-token",
+ ExpiresIn: 3600,
+ }, nil
+}
+
+func TestOpenAITokenProviderRefreshWithNilAccountRepo(t *testing.T) {
+ cache := &openAITokenCacheNilRepoStub{lockEnabled: true}
+ oauthSvc := NewOpenAIOAuthService(nil, &openAIOAuthClientNilRepoStub{})
+ provider := NewOpenAITokenProvider(nil, cache, oauthSvc)
+
+ expiresAt := time.Now().Add(30 * time.Second).Format(time.RFC3339)
+ account := &Account{
+ ID: 3001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "stale-token",
+ "refresh_token": "refresh-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ token, err := provider.GetAccessToken(context.Background(), account)
+ require.NoError(t, err)
+ require.Equal(t, "fresh-token", token)
+ require.Equal(t, int32(1), cache.lockCalls.Load())
+ require.Equal(t, int32(1), cache.setCalls.Load())
+}
diff --git a/backend/internal/service/openai_ws_client.go b/backend/internal/service/openai_ws_client.go
index 9f3c47b7b..f46eda8a5 100644
--- a/backend/internal/service/openai_ws_client.go
+++ b/backend/internal/service/openai_ws_client.go
@@ -11,6 +11,7 @@ import (
"sync/atomic"
"time"
+ openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
coderws "github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
@@ -77,8 +78,9 @@ func (d *coderOpenAIWSClientDialer) Dial(
}
opts := &coderws.DialOptions{
- HTTPHeader: cloneHeader(headers),
- CompressionMode: coderws.CompressionContextTakeover,
+ HTTPHeader: cloneHeader(headers),
+ // 高频长连接场景优先降低内存/CPU 抖动,避免 context takeover 带来的状态累积。
+ CompressionMode: coderws.CompressionNoContextTakeover,
}
if proxy := strings.TrimSpace(proxyURL); proxy != "" {
proxyClient, err := d.proxyHTTPClient(proxy)
@@ -230,6 +232,9 @@ func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransport
}
}
+// 编译期断言:coderOpenAIWSClientConn 必须实现 openai_ws_v2.FrameConn 接口(passthrough 路径依赖)。
+var _ openaiwsv2.FrameConn = (*coderOpenAIWSClientConn)(nil)
+
type coderOpenAIWSClientConn struct {
conn *coderws.Conn
}
@@ -264,6 +269,30 @@ func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, erro
}
}
+func (c *coderOpenAIWSClientConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ if c == nil || c.conn == nil {
+ return coderws.MessageText, nil, errOpenAIWSConnClosed
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ msgType, payload, err := c.conn.Read(ctx)
+ if err != nil {
+ return coderws.MessageText, nil, err
+ }
+ return msgType, payload, nil
+}
+
+func (c *coderOpenAIWSClientConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
+ if c == nil || c.conn == nil {
+ return errOpenAIWSConnClosed
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return c.conn.Write(ctx, msgType, payload)
+}
+
func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error {
if c == nil || c.conn == nil {
return errOpenAIWSConnClosed
diff --git a/backend/internal/service/openai_ws_client_preempt_test.go b/backend/internal/service/openai_ws_client_preempt_test.go
new file mode 100644
index 000000000..cbbc149e9
--- /dev/null
+++ b/backend/internal/service/openai_ws_client_preempt_test.go
@@ -0,0 +1,1073 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "testing"
+
+ coderws "github.com/coder/websocket"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// errOpenAIWSClientPreempted 哨兵错误基础测试
+// ---------------------------------------------------------------------------
+
+func TestErrOpenAIWSClientPreempted_NotNil(t *testing.T) {
+ t.Parallel()
+ require.NotNil(t, errOpenAIWSClientPreempted)
+ require.Contains(t, errOpenAIWSClientPreempted.Error(), "client preempted")
+}
+
+func TestErrOpenAIWSClientPreempted_ErrorsIs(t *testing.T) {
+ t.Parallel()
+
+ // 直接匹配
+ require.True(t, errors.Is(errOpenAIWSClientPreempted, errOpenAIWSClientPreempted))
+
+ // 包裹后仍可匹配
+ wrapped := fmt.Errorf("outer: %w", errOpenAIWSClientPreempted)
+ require.True(t, errors.Is(wrapped, errOpenAIWSClientPreempted))
+
+ // 不同错误不匹配
+ require.False(t, errors.Is(errors.New("other"), errOpenAIWSClientPreempted))
+}
+
+func TestErrOpenAIWSClientPreempted_WrapInTurnError(t *testing.T) {
+ t.Parallel()
+
+ // 用 wrapOpenAIWSIngressTurnErrorWithPartial 包裹后 errors.Is 仍能识别
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ require.Error(t, turnErr)
+ require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted))
+}
+
+func TestErrOpenAIWSClientPreempted_WrapInTurnError_WithPartialResult(t *testing.T) {
+ t.Parallel()
+
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_preempt_partial",
+ Usage: OpenAIUsage{
+ InputTokens: 100,
+ OutputTokens: 50,
+ },
+ }
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true,
+ partial,
+ )
+ require.Error(t, turnErr)
+ require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted))
+
+ // 验证 partial result 可提取
+ got, ok := OpenAIWSIngressTurnPartialResult(turnErr)
+ require.True(t, ok)
+ require.NotNil(t, got)
+ require.Equal(t, partial.RequestID, got.RequestID)
+ require.Equal(t, partial.Usage.InputTokens, got.Usage.InputTokens)
+}
+
+// ---------------------------------------------------------------------------
+// classifyOpenAIWSIngressTurnAbortReason 对 client_preempted 的识别测试
+// ---------------------------------------------------------------------------
+
+func TestClassifyAbortReason_ClientPreempted_Direct(t *testing.T) {
+ t.Parallel()
+
+ // 直接哨兵错误
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(errOpenAIWSClientPreempted)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+}
+
+func TestClassifyAbortReason_ClientPreempted_WrappedInTurnError(t *testing.T) {
+ t.Parallel()
+
+ // 包裹在 turnError 中
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+}
+
+func TestClassifyAbortReason_ClientPreempted_WrappedInTurnError_WroteDownstream(t *testing.T) {
+ t.Parallel()
+
+ // 包裹在 turnError 中,wroteDownstream=true
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true,
+ &OpenAIForwardResult{RequestID: "resp_partial"},
+ )
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+}
+
+func TestClassifyAbortReason_ClientPreempted_DoubleWrapped(t *testing.T) {
+ t.Parallel()
+
+ // 多层 fmt.Errorf 包裹
+ inner := fmt.Errorf("relay failed: %w", errOpenAIWSClientPreempted)
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(inner)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+}
+
+func TestClassifyAbortReason_ClientPreempted_NotConfusedWithOther(t *testing.T) {
+ t.Parallel()
+
+ // 确保其他错误不会被误分类为 client_preempted
+ others := []error{
+ errors.New("client preempted"), // 文本相同但不是同一哨兵
+ context.Canceled, // context 取消
+ io.EOF, // 客户端断连
+ errors.New("random error"), // 随机错误
+ }
+
+ for _, err := range others {
+ reason, _ := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.NotEqual(t, openAIWSIngressTurnAbortReasonClientPreempted, reason,
+ "error %q should not classify as client_preempted", err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// openAIWSIngressTurnAbortDispositionForReason 对 ClientPreempted 的处置测试
+// ---------------------------------------------------------------------------
+
+func TestDisposition_ClientPreempted_IsContinueTurn(t *testing.T) {
+ t.Parallel()
+
+ disposition := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+}
+
+func TestDisposition_ClientPreempted_SameAsPreviousResponse(t *testing.T) {
+ t.Parallel()
+
+ // client_preempted 与 previous_response_not_found 应有相同的处置
+ prevDisp := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonPreviousResponse)
+ preemptDisp := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted)
+ require.Equal(t, prevDisp, preemptDisp)
+}
+
+func TestDisposition_AllContinueTurnReasons(t *testing.T) {
+ t.Parallel()
+
+ // 验证所有应归为 ContinueTurn 的 reason 列表完整且正确
+ continueTurnReasons := []openAIWSIngressTurnAbortReason{
+ openAIWSIngressTurnAbortReasonPreviousResponse,
+ openAIWSIngressTurnAbortReasonToolOutput,
+ openAIWSIngressTurnAbortReasonUpstreamError,
+ openAIWSIngressTurnAbortReasonClientPreempted,
+ }
+
+ for _, reason := range continueTurnReasons {
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition,
+ "reason %q should be ContinueTurn", reason)
+ }
+}
+
+func TestDisposition_ClientPreempted_NotCloseGracefully(t *testing.T) {
+ t.Parallel()
+
+ disposition := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted)
+ require.NotEqual(t, openAIWSIngressTurnAbortDispositionCloseGracefully, disposition)
+ require.NotEqual(t, openAIWSIngressTurnAbortDispositionFailRequest, disposition)
+}
+
+// ---------------------------------------------------------------------------
+// 端到端 classify → disposition 链路测试
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_ClassifyToDisposition_EndToEnd(t *testing.T) {
+ t.Parallel()
+
+ // 模拟 sendAndRelay 返回 client_preempted 错误的完整链路
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+
+ // 1. classify
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+
+ // 2. disposition
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+
+ // 3. wroteDownstream
+ require.False(t, openAIWSIngressTurnWroteDownstream(turnErr))
+}
+
+func TestClientPreempted_ClassifyToDisposition_WroteDownstream(t *testing.T) {
+ t.Parallel()
+
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_half",
+ Usage: OpenAIUsage{
+ InputTokens: 200,
+ OutputTokens: 100,
+ },
+ }
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true,
+ partial,
+ )
+
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+
+ require.True(t, openAIWSIngressTurnWroteDownstream(turnErr))
+
+ got, ok := OpenAIWSIngressTurnPartialResult(turnErr)
+ require.True(t, ok)
+ require.Equal(t, "resp_half", got.RequestID)
+}
+
+// ---------------------------------------------------------------------------
+// ContinueTurn 分支对 client_preempted 的特殊行为验证
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_ShouldNotSendErrorEvent(t *testing.T) {
+ t.Parallel()
+
+ // 核心语义:client_preempted 时客户端已发出新请求,不需要旧 turn 的 error 事件。
+ // 验证 abortReason 为 client_preempted 时不应产生 error 通知。
+ abortReason := openAIWSIngressTurnAbortReasonClientPreempted
+
+ // 模拟 ContinueTurn 分支的判断逻辑
+ shouldSendError := abortReason != openAIWSIngressTurnAbortReasonClientPreempted
+ require.False(t, shouldSendError, "client_preempted 不应发送 error 事件")
+}
+
+func TestClientPreempted_ShouldNotClearLastResponseID(t *testing.T) {
+ t.Parallel()
+
+ // 核心语义:被抢占的 turn 未完成,上一轮 response_id 仍有效供新 turn 续链。
+ // 验证 abortReason 为 client_preempted 时不应调用 clearSessionLastResponseID。
+ abortReason := openAIWSIngressTurnAbortReasonClientPreempted
+
+ shouldClearLastResponseID := abortReason != openAIWSIngressTurnAbortReasonClientPreempted
+ require.False(t, shouldClearLastResponseID,
+ "client_preempted 不应清除 lastResponseID")
+}
+
+func TestNonPreempted_ContinueTurn_ShouldSendErrorAndClearID(t *testing.T) {
+ t.Parallel()
+
+ // 对照测试:非 client_preempted 的 ContinueTurn reason 应正常发送 error 并清除 ID
+ otherReasons := []openAIWSIngressTurnAbortReason{
+ openAIWSIngressTurnAbortReasonPreviousResponse,
+ openAIWSIngressTurnAbortReasonToolOutput,
+ openAIWSIngressTurnAbortReasonUpstreamError,
+ }
+
+ for _, reason := range otherReasons {
+ shouldSendError := reason != openAIWSIngressTurnAbortReasonClientPreempted
+ shouldClearID := reason != openAIWSIngressTurnAbortReasonClientPreempted
+ require.True(t, shouldSendError,
+ "reason %q (non-preempted) should send error event", reason)
+ require.True(t, shouldClearID,
+ "reason %q (non-preempted) should clear lastResponseID", reason)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// ContinueTurn abort 路径中 client_preempted 的 error 事件格式验证
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_ErrorEventNotGenerated(t *testing.T) {
+ t.Parallel()
+
+ // 在实际的 ContinueTurn 分支中,client_preempted 分支根本不会构造 error 事件。
+ // 此测试验证如果误走错误路径(防御性),error 事件格式仍然正确。
+ abortReason := openAIWSIngressTurnAbortReasonClientPreempted
+ abortMessage := "turn failed: " + string(abortReason)
+
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` +
+ string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`)
+
+ var parsed map[string]any
+ err := json.Unmarshal(errorEvent, &parsed)
+ require.NoError(t, err, "hypothetical error event should be valid JSON")
+
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "client_preempted", errorObj["code"])
+ require.Contains(t, errorObj["message"], "client_preempted")
+}
+
+// ---------------------------------------------------------------------------
+// openAIWSIngressTurnAbortReason 常量值验证
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_ReasonStringValue(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, openAIWSIngressTurnAbortReason("client_preempted"),
+ openAIWSIngressTurnAbortReasonClientPreempted)
+}
+
+// ---------------------------------------------------------------------------
+// classifyOpenAIWSIngressTurnAbortReason 完整 table-driven 测试(含 client_preempted)
+// ---------------------------------------------------------------------------
+
+func TestClassifyAbortReason_AllReasons_IncludeClientPreempted(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ wantReason openAIWSIngressTurnAbortReason
+ wantExpected bool
+ }{
+ {
+ name: "client_preempted_sentinel",
+ err: errOpenAIWSClientPreempted,
+ wantReason: openAIWSIngressTurnAbortReasonClientPreempted,
+ wantExpected: true,
+ },
+ {
+ name: "client_preempted_wrapped_in_turn_error",
+ err: wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonClientPreempted,
+ wantExpected: true,
+ },
+ {
+ name: "client_preempted_wrapped_in_turn_error_wrote_downstream",
+ err: wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true,
+ &OpenAIForwardResult{RequestID: "resp_x"},
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonClientPreempted,
+ wantExpected: true,
+ },
+ {
+ name: "client_preempted_double_wrapped",
+ err: fmt.Errorf("relay: %w", errOpenAIWSClientPreempted),
+ wantReason: openAIWSIngressTurnAbortReasonClientPreempted,
+ wantExpected: true,
+ },
+ {
+ name: "previous_response_not_confused_with_preempt",
+ err: wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("not found"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonPreviousResponse,
+ wantExpected: true,
+ },
+ {
+ name: "tool_output_not_confused_with_preempt",
+ err: wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("tool output not found"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonToolOutput,
+ wantExpected: true,
+ },
+ {
+ name: "context_canceled_not_preempted",
+ err: context.Canceled,
+ wantReason: openAIWSIngressTurnAbortReasonContextCanceled,
+ wantExpected: true,
+ },
+ {
+ name: "eof_not_preempted",
+ err: io.EOF,
+ wantReason: openAIWSIngressTurnAbortReasonClientClosed,
+ wantExpected: true,
+ },
+ {
+ name: "ws_normal_closure_not_preempted",
+ err: coderws.CloseError{Code: coderws.StatusNormalClosure},
+ wantReason: openAIWSIngressTurnAbortReasonClientClosed,
+ wantExpected: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(tt.err)
+ require.Equal(t, tt.wantReason, reason)
+ require.Equal(t, tt.wantExpected, expected)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// classify 优先级测试:client_preempted 在 context.Canceled 之前
+// ---------------------------------------------------------------------------
+
+func TestClassifyAbortReason_ClientPreempted_PriorityOverContextCanceled(t *testing.T) {
+ t.Parallel()
+
+ // errOpenAIWSClientPreempted 不会同时匹配 context.Canceled,
+ // 但若将来有包裹 context.Canceled 的情况,client_preempted 检测应在前。
+ reason, _ := classifyOpenAIWSIngressTurnAbortReason(errOpenAIWSClientPreempted)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason,
+ "client_preempted 检测应优先于 context.Canceled")
+}
+
+// ---------------------------------------------------------------------------
+// openAIWSIngressTurnAbortDispositionForReason table-driven 测试(含 client_preempted)
+// ---------------------------------------------------------------------------
+
+func TestDisposition_AllReasons_IncludeClientPreempted(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ reason openAIWSIngressTurnAbortReason
+ wantDisp openAIWSIngressTurnAbortDisposition
+ }{
+ {openAIWSIngressTurnAbortReasonPreviousResponse, openAIWSIngressTurnAbortDispositionContinueTurn},
+ {openAIWSIngressTurnAbortReasonToolOutput, openAIWSIngressTurnAbortDispositionContinueTurn},
+ {openAIWSIngressTurnAbortReasonUpstreamError, openAIWSIngressTurnAbortDispositionContinueTurn},
+ {openAIWSIngressTurnAbortReasonClientPreempted, openAIWSIngressTurnAbortDispositionContinueTurn},
+ {openAIWSIngressTurnAbortReasonContextCanceled, openAIWSIngressTurnAbortDispositionCloseGracefully},
+ {openAIWSIngressTurnAbortReasonClientClosed, openAIWSIngressTurnAbortDispositionCloseGracefully},
+ {openAIWSIngressTurnAbortReasonUnknown, openAIWSIngressTurnAbortDispositionFailRequest},
+ {openAIWSIngressTurnAbortReasonContextDeadline, openAIWSIngressTurnAbortDispositionFailRequest},
+ {openAIWSIngressTurnAbortReasonWriteUpstream, openAIWSIngressTurnAbortDispositionFailRequest},
+ {openAIWSIngressTurnAbortReasonReadUpstream, openAIWSIngressTurnAbortDispositionFailRequest},
+ {openAIWSIngressTurnAbortReasonWriteClient, openAIWSIngressTurnAbortDispositionFailRequest},
+ {openAIWSIngressTurnAbortReasonContinuationUnavailable, openAIWSIngressTurnAbortDispositionFailRequest},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(string(tt.reason), func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.wantDisp, openAIWSIngressTurnAbortDispositionForReason(tt.reason))
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// isOpenAIWSIngressTurnRetryable 与 client_preempted 的交互
+// ---------------------------------------------------------------------------
+
+func TestIsRetryable_ClientPreempted_NotRetryable(t *testing.T) {
+ t.Parallel()
+
+ // client_preempted 有专门的恢复路径(ContinueTurn),不走通用重试
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ require.False(t, isOpenAIWSIngressTurnRetryable(turnErr),
+ "client_preempted 不应被标记为 retryable")
+}
+
+func TestIsRetryable_ClientPreempted_WroteDownstream(t *testing.T) {
+ t.Parallel()
+
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true,
+ nil,
+ )
+ require.False(t, isOpenAIWSIngressTurnRetryable(turnErr),
+ "client_preempted wroteDownstream=true 不应被标记为 retryable")
+}
+
+// ---------------------------------------------------------------------------
+// sendAndRelay 中 clientMsgCh / clientReadErrCh 行为的单元级测试
+// ---------------------------------------------------------------------------
+
+func TestClientMsgCh_BufferedOne(t *testing.T) {
+ t.Parallel()
+
+ // 验证 clientMsgCh(buffered 1) 的语义:goroutine 在 sendAndRelay 返回到
+ // advanceToNextClientTurn 的间隙不阻塞
+ ch := make(chan []byte, 1)
+
+ // 非阻塞写入
+ select {
+ case ch <- []byte(`{"type":"response.create"}`):
+ // ok
+ default:
+ t.Fatal("buffered(1) channel should not block on first write")
+ }
+
+ // 第二次写入应阻塞
+ select {
+ case ch <- []byte(`{"type":"response.create"}`):
+ t.Fatal("buffered(1) channel should block on second write")
+ default:
+ // expected
+ }
+}
+
+func TestClientReadErrCh_BufferedOne(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan error, 1)
+
+ // 非阻塞写入
+ select {
+ case ch <- io.EOF:
+ default:
+ t.Fatal("buffered(1) channel should not block on first write")
+ }
+
+ // 第二次写入应阻塞
+ select {
+ case ch <- io.EOF:
+ t.Fatal("buffered(1) channel should block on second write")
+ default:
+ // expected
+ }
+}
+
+func TestClientMsgCh_CloseSignalsClosed(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan []byte, 1)
+ close(ch)
+
+ msg, ok := <-ch
+ require.False(t, ok, "closed channel should return ok=false")
+ require.Nil(t, msg)
+}
+
+// ---------------------------------------------------------------------------
+// 客户端抢占暂存(nextClientPreemptedPayload)行为测试
+// ---------------------------------------------------------------------------
+
+func TestPreemptedPayload_ConsumedOnce(t *testing.T) {
+ t.Parallel()
+
+ // 模拟 advanceToNextClientTurn 中预存消息的消费行为
+ var nextPreempted []byte
+ nextPreempted = []byte(`{"type":"response.create","model":"gpt-5.1"}`)
+
+ // 第一次消费
+ require.NotNil(t, nextPreempted)
+ msg := nextPreempted
+ nextPreempted = nil
+
+ require.Equal(t, `{"type":"response.create","model":"gpt-5.1"}`, string(msg))
+ require.Nil(t, nextPreempted, "消费后应置空")
+}
+
+func TestPreemptedPayload_NilFallsBackToChannel(t *testing.T) {
+ t.Parallel()
+
+ // 模拟 advanceToNextClientTurn 中无预存消息时走 channel
+ var nextPreempted []byte
+ clientMsgCh := make(chan []byte, 1)
+ clientMsgCh <- []byte(`{"type":"response.create","model":"gpt-5.1"}`)
+
+ var nextClientMessage []byte
+ if nextPreempted != nil {
+ nextClientMessage = nextPreempted
+ } else {
+ msg, ok := <-clientMsgCh
+ require.True(t, ok)
+ nextClientMessage = msg
+ }
+
+ require.Equal(t, `{"type":"response.create","model":"gpt-5.1"}`, string(nextClientMessage))
+}
+
+// ---------------------------------------------------------------------------
+// sendAndRelay select 路径:pumpEventCh 关闭 → goto pumpClosed
+// ---------------------------------------------------------------------------
+
+func TestSelectLoop_PumpClosed_GoToPumpClosed(t *testing.T) {
+ t.Parallel()
+
+ // 模拟 pumpEventCh 关闭时的行为
+ pumpEventCh := make(chan openAIWSUpstreamPumpEvent)
+ close(pumpEventCh)
+
+ evt, ok := <-pumpEventCh
+ require.False(t, ok, "closed pumpEventCh should return ok=false")
+ require.Nil(t, evt.message)
+ require.Nil(t, evt.err)
+}
+
+// ---------------------------------------------------------------------------
+// sendAndRelay select 路径:clientMsgCh 收到消息 → client preempt
+// ---------------------------------------------------------------------------
+
+func TestSelectLoop_ClientPreempt_ReturnsCorrectError(t *testing.T) {
+ t.Parallel()
+
+ // 模拟 select 中收到客户端抢占消息后生成的 turnError
+ preemptPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
+
+ // 模拟 sendAndRelay 返回的错误
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil, // buildPartialResult 在没有 usage 时返回 nil
+ )
+
+ // 验证错误分类
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+
+ // 验证处置
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+
+ // 验证预存消息可供 advanceToNextClientTurn 使用
+ require.NotEmpty(t, preemptPayload)
+}
+
+func TestSelectLoop_ClientPreempt_WithPartialUsage(t *testing.T) {
+ t.Parallel()
+
+ // 模拟上游已发送部分 token 后被客户端抢占
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_interrupted",
+ Usage: OpenAIUsage{
+ InputTokens: 500,
+ OutputTokens: 200,
+ },
+ }
+
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true, // 已写过下游
+ partial,
+ )
+
+ require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted))
+ require.True(t, openAIWSIngressTurnWroteDownstream(turnErr))
+
+ got, ok := OpenAIWSIngressTurnPartialResult(turnErr)
+ require.True(t, ok)
+ require.Equal(t, "resp_interrupted", got.RequestID)
+ require.Equal(t, 500, got.Usage.InputTokens)
+ require.Equal(t, 200, got.Usage.OutputTokens)
+}
+
+// ---------------------------------------------------------------------------
+// sendAndRelay select 路径:clientMsgCh 关闭 → nil channel
+// ---------------------------------------------------------------------------
+
+func TestSelectLoop_ClientMsgChClosed_NilChannelPreventsReselect(t *testing.T) {
+ t.Parallel()
+
+ // clientMsgCh 关闭后被设为 nil,后续 select 不应再选中它
+ var clientMsgCh chan []byte
+ clientMsgCh = make(chan []byte, 1)
+ close(clientMsgCh)
+
+ // 第一次读取:closed
+ _, ok := <-clientMsgCh
+ require.False(t, ok)
+
+ // 设为 nil
+ clientMsgCh = nil
+
+ // nil channel 上的 select 永远不会被选中(不会 panic)
+ select {
+ case <-clientMsgCh:
+ t.Fatal("nil channel should never be selected")
+ default:
+ // expected: nil channel 不参与 select
+ }
+}
+
+// ---------------------------------------------------------------------------
+// sendAndRelay select 路径:clientReadErrCh 客户端断连
+// ---------------------------------------------------------------------------
+
+func TestSelectLoop_ClientReadErr_DisconnectSetsDrain(t *testing.T) {
+ t.Parallel()
+
+ // 模拟客户端断连读取错误的分类行为
+ disconnectErrors := []error{
+ io.EOF,
+ coderws.CloseError{Code: coderws.StatusNormalClosure},
+ coderws.CloseError{Code: coderws.StatusGoingAway},
+ }
+
+ for _, readErr := range disconnectErrors {
+ require.True(t, isOpenAIWSClientDisconnectError(readErr),
+ "error %v should be classified as client disconnect", readErr)
+ }
+}
+
+func TestSelectLoop_ClientReadErr_NonDisconnect(t *testing.T) {
+ t.Parallel()
+
+ // 非断连错误不应触发 drain
+ nonDisconnectErrors := []error{
+ errors.New("tls handshake timeout"),
+ coderws.CloseError{Code: coderws.StatusPolicyViolation},
+ }
+
+ for _, readErr := range nonDisconnectErrors {
+ require.False(t, isOpenAIWSClientDisconnectError(readErr),
+ "error %v should not be classified as client disconnect", readErr)
+ }
+}
+
+func TestSelectLoop_ClientReadErr_NilChannelsAfterError(t *testing.T) {
+ t.Parallel()
+
+ // 模拟收到 clientReadErrCh 后将两个 channel 置 nil
+ clientMsgCh := make(chan []byte, 1)
+ clientReadErrCh := make(chan error, 1)
+
+ clientReadErrCh <- io.EOF
+
+ // 消费错误
+ readErr := <-clientReadErrCh
+ require.Error(t, readErr)
+
+ // 模拟置空(实际代码中 select case 后的操作)
+ var nilMsgCh chan []byte
+ var nilErrCh chan error
+ nilMsgCh = nil
+ nilErrCh = nil
+
+ // 验证 nil channel 行为
+ _ = clientMsgCh // unused in this test
+
+ select {
+ case <-nilMsgCh:
+ t.Fatal("nil channel should never be selected")
+ case <-nilErrCh:
+ t.Fatal("nil channel should never be selected")
+ default:
+ // expected
+ }
+}
+
+func TestAdvanceConsumePendingClientReadErr(t *testing.T) {
+ t.Parallel()
+
+ require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(nil))
+
+ var pendingErr error
+ require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(&pendingErr))
+
+ sourceErr := errors.New("custom read error")
+ pendingErr = sourceErr
+
+ gotErr := openAIWSAdvanceConsumePendingClientReadErr(&pendingErr)
+ require.Error(t, gotErr)
+ require.ErrorIs(t, gotErr, sourceErr)
+ require.Nil(t, pendingErr, "pending error should be consumed once")
+ require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(&pendingErr))
+}
+
+func TestAdvanceClientReadUnavailable(t *testing.T) {
+ t.Parallel()
+
+ var nilMsgCh chan []byte
+ var nilErrCh chan error
+ require.True(t, openAIWSAdvanceClientReadUnavailable(nilMsgCh, nilErrCh))
+
+ msgCh := make(chan []byte, 1)
+ require.False(t, openAIWSAdvanceClientReadUnavailable(msgCh, nilErrCh))
+
+ errCh := make(chan error, 1)
+ require.False(t, openAIWSAdvanceClientReadUnavailable(nilMsgCh, errCh))
+ require.False(t, openAIWSAdvanceClientReadUnavailable(msgCh, errCh))
+}
+
+// ---------------------------------------------------------------------------
+// advanceToNextClientTurn channel 读取路径测试
+// ---------------------------------------------------------------------------
+
+func TestAdvance_ClientMsgCh_ClosedReturnsExit(t *testing.T) {
+ t.Parallel()
+
+ // clientMsgCh 关闭意味着客户端读取 goroutine 已退出,应返回 exit=true
+ ch := make(chan []byte, 1)
+ close(ch)
+
+ _, ok := <-ch
+ require.False(t, ok, "should signal goroutine exit")
+}
+
+func TestAdvance_ClientReadErrCh_DisconnectReturnsExit(t *testing.T) {
+ t.Parallel()
+
+ // 断连错误应返回 exit=true
+ ch := make(chan error, 1)
+ ch <- io.EOF
+
+ readErr := <-ch
+ require.True(t, isOpenAIWSClientDisconnectError(readErr))
+}
+
+func TestAdvance_ClientReadErrCh_NonDisconnectReturnsError(t *testing.T) {
+ t.Parallel()
+
+ // 非断连错误应返回 error
+ ch := make(chan error, 1)
+ errCustom := errors.New("custom read error")
+ ch <- errCustom
+
+ readErr := <-ch
+ require.False(t, isOpenAIWSClientDisconnectError(readErr))
+ require.Equal(t, errCustom, readErr)
+}
+
+// ---------------------------------------------------------------------------
+// 持久客户端读取 goroutine 行为测试
+// ---------------------------------------------------------------------------
+
+func TestPersistentReader_NormalMessage(t *testing.T) {
+ t.Parallel()
+
+ // 模拟正常消息的推送和消费
+ clientMsgCh := make(chan []byte, 1)
+
+ // 模拟 goroutine 写入
+ go func() {
+ clientMsgCh <- []byte(`{"type":"response.create"}`)
+ }()
+
+ msg := <-clientMsgCh
+ require.Equal(t, `{"type":"response.create"}`, string(msg))
+}
+
+func TestPersistentReader_ErrorSendsToErrCh(t *testing.T) {
+ t.Parallel()
+
+ clientReadErrCh := make(chan error, 1)
+
+ // 模拟 goroutine 发送错误
+ go func() {
+ clientReadErrCh <- io.EOF
+ }()
+
+ readErr := <-clientReadErrCh
+ require.Equal(t, io.EOF, readErr)
+}
+
+func TestPersistentReader_ContextCancel(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ clientMsgCh := make(chan []byte, 1)
+
+ // 填满 buffer
+ clientMsgCh <- []byte("first")
+
+ // 模拟 goroutine 尝试写入已满的 channel
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ select {
+ case clientMsgCh <- []byte("second"):
+ // 不应到达
+ case <-ctx.Done():
+ // 正确退出
+ return
+ }
+ }()
+
+ // 取消 context
+ cancel()
+ <-done
+}
+
+func TestPersistentReader_ClosesMsgChOnExit(t *testing.T) {
+ t.Parallel()
+
+ clientMsgCh := make(chan []byte, 1)
+
+ // 模拟 goroutine 退出时关闭 channel
+ go func() {
+ defer close(clientMsgCh)
+ // 模拟读取错误后退出
+ }()
+
+ // 等待 channel 关闭
+ _, ok := <-clientMsgCh
+ require.False(t, ok, "channel should be closed when goroutine exits")
+}
+
+// ---------------------------------------------------------------------------
+// client_preempted 与其他 abort reason 的正交性验证
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_OrthogonalWithPreviousResponseNotFound(t *testing.T) {
+ t.Parallel()
+
+ preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ prevErr := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("not found"),
+ false,
+ )
+
+ // client_preempted 不会被误判为 previous_response_not_found
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(preemptErr))
+ // previous_response_not_found 不会被误判为 client_preempted
+ require.False(t, errors.Is(prevErr, errOpenAIWSClientPreempted))
+}
+
+func TestClientPreempted_OrthogonalWithToolOutputNotFound(t *testing.T) {
+ t.Parallel()
+
+ preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ toolErr := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("tool output not found"),
+ false,
+ )
+
+ require.False(t, isOpenAIWSIngressToolOutputNotFound(preemptErr))
+ require.False(t, errors.Is(toolErr, errOpenAIWSClientPreempted))
+}
+
+func TestClientPreempted_OrthogonalWithUpstreamError(t *testing.T) {
+ t.Parallel()
+
+ preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ upstreamErr := wrapOpenAIWSIngressTurnError(
+ "upstream_error_event",
+ errors.New("upstream error"),
+ false,
+ )
+
+ require.False(t, isOpenAIWSIngressUpstreamErrorEvent(preemptErr))
+ require.False(t, errors.Is(upstreamErr, errOpenAIWSClientPreempted))
+}
+
+func TestClientPreempted_OrthogonalWithContinuationUnavailable(t *testing.T) {
+ t.Parallel()
+
+ preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ require.False(t, isOpenAIWSContinuationUnavailableCloseError(preemptErr))
+}
+
+func TestClientPreempted_NotClientDisconnect(t *testing.T) {
+ t.Parallel()
+
+ require.False(t, isOpenAIWSClientDisconnectError(errOpenAIWSClientPreempted),
+ "client_preempted should not be classified as client disconnect")
+}
+
+// ---------------------------------------------------------------------------
+// recordOpenAIWSTurnAbort 指标兼容性测试
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_RecordAbortArgs(t *testing.T) {
+ t.Parallel()
+
+ // 验证 classify 返回的 (reason, expected) 值与 recordOpenAIWSTurnAbort 兼容
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+
+ // expected=true 表示这是预期行为,不应触发告警
+ assert.True(t, expected, "client_preempted 应标记为 expected,不触发告警")
+}
+
+// ---------------------------------------------------------------------------
+// shouldFlushOpenAIWSBufferedEventsOnError 与 client_preempted 场景
+// ---------------------------------------------------------------------------
+
+func TestShouldFlushBufferedEvents_ClientPreempted(t *testing.T) {
+ t.Parallel()
+
+ // client_preempted 场景下 clientDisconnected=false(客户端仍在),
+ // 是否 flush 取决于 reqStream 和 wroteDownstream
+ tests := []struct {
+ name string
+ reqStream bool
+ wroteDownstream bool
+ wantFlush bool
+ }{
+ {"stream_wrote", true, true, true},
+ {"stream_not_wrote", true, false, false},
+ {"not_stream_wrote", false, true, false},
+ {"not_stream_not_wrote", false, false, false},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := shouldFlushOpenAIWSBufferedEventsOnError(tt.reqStream, tt.wroteDownstream, false)
+ require.Equal(t, tt.wantFlush, got)
+ })
+ }
+}
diff --git a/backend/internal/service/openai_ws_client_test.go b/backend/internal/service/openai_ws_client_test.go
index a88d62665..fd8d8c76f 100644
--- a/backend/internal/service/openai_ws_client_test.go
+++ b/backend/internal/service/openai_ws_client_test.go
@@ -1,11 +1,15 @@
package service
import (
+ "context"
"fmt"
"net/http"
+ "net/http/httptest"
+ "strings"
"testing"
"time"
+ coderws "github.com/coder/websocket"
"github.com/stretchr/testify/require"
)
@@ -110,3 +114,146 @@ func TestCoderOpenAIWSClientDialer_ProxyTransportTLSHandshakeTimeout(t *testing.
require.NotNil(t, transport)
require.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout)
}
+
+func TestCoderOpenAIWSClientDialer_Dial_EmptyURL(t *testing.T) {
+ t.Parallel()
+
+ dialer := newDefaultOpenAIWSClientDialer()
+ impl, ok := dialer.(*coderOpenAIWSClientDialer)
+ require.True(t, ok)
+
+ conn, status, headers, err := impl.Dial(context.Background(), " ", nil, "")
+ require.Error(t, err)
+ require.Nil(t, conn)
+ require.Equal(t, 0, status)
+ require.Nil(t, headers)
+}
+
+func TestCoderOpenAIWSClientDialer_Dial_InvalidProxyURL(t *testing.T) {
+ t.Parallel()
+
+ dialer := newDefaultOpenAIWSClientDialer()
+ impl, ok := dialer.(*coderOpenAIWSClientDialer)
+ require.True(t, ok)
+
+ conn, status, headers, err := impl.Dial(context.Background(), "ws://example.com", nil, "://bad-proxy")
+ require.Error(t, err)
+ require.Nil(t, conn)
+ require.Equal(t, 0, status)
+ require.Nil(t, headers)
+}
+
+func TestCoderOpenAIWSClientConn_NilGuards(t *testing.T) {
+ t.Parallel()
+
+ var nilConn *coderOpenAIWSClientConn
+ require.ErrorIs(t, nilConn.WriteJSON(context.Background(), map[string]any{"a": 1}), errOpenAIWSConnClosed)
+ _, err := nilConn.ReadMessage(context.Background())
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ _, _, err = nilConn.ReadFrame(context.Background())
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ require.ErrorIs(t, nilConn.WriteFrame(context.Background(), coderws.MessageText, []byte("x")), errOpenAIWSConnClosed)
+ require.ErrorIs(t, nilConn.Ping(context.Background()), errOpenAIWSConnClosed)
+ require.NoError(t, nilConn.Close())
+
+ empty := &coderOpenAIWSClientConn{}
+ var nilCtx context.Context
+ require.ErrorIs(t, empty.WriteJSON(nilCtx, map[string]any{"a": 1}), errOpenAIWSConnClosed)
+ _, err = empty.ReadMessage(nilCtx)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ _, _, err = empty.ReadFrame(nilCtx)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ require.ErrorIs(t, empty.WriteFrame(nilCtx, coderws.MessageText, []byte("x")), errOpenAIWSConnClosed)
+ require.ErrorIs(t, empty.Ping(nilCtx), errOpenAIWSConnClosed)
+ require.NoError(t, empty.Close())
+}
+
+func TestCoderOpenAIWSClientDialer_DialAndConnWrappers(t *testing.T) {
+ t.Parallel()
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, nil)
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
+ _, _, err = conn.Read(readCtx)
+ cancelRead()
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+
+ writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
+ if err = conn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.output_text.delta","delta":"ok"}`)); err != nil {
+ cancelWrite()
+ serverErrCh <- err
+ return
+ }
+ if err = conn.Write(writeCtx, coderws.MessageBinary, []byte{0x01, 0x02, 0x03}); err != nil {
+ cancelWrite()
+ serverErrCh <- err
+ return
+ }
+ cancelWrite()
+
+ readCtx2, cancelRead2 := context.WithTimeout(r.Context(), 3*time.Second)
+ _, payload, err := conn.Read(readCtx2)
+ cancelRead2()
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ if len(payload) == 0 {
+ serverErrCh <- fmt.Errorf("expected client payload")
+ return
+ }
+ serverErrCh <- nil
+ }))
+ defer wsServer.Close()
+
+ dialer := newDefaultOpenAIWSClientDialer()
+ impl, ok := dialer.(*coderOpenAIWSClientDialer)
+ require.True(t, ok)
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ conn, status, headers, err := impl.Dial(
+ dialCtx,
+ "ws"+strings.TrimPrefix(wsServer.URL, "http"),
+ http.Header{"User-Agent": []string{"unit-test-agent/1.0"}},
+ "",
+ )
+ cancelDial()
+ require.NoError(t, err)
+ require.NotNil(t, conn)
+ require.Equal(t, 0, status)
+ _ = headers // 成功建连时状态码为 0;headers 仅用于握手失败诊断。
+
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+ require.NoError(t, conn.WriteJSON(ctx, map[string]any{"type": "response.create"}))
+
+ msg, err := conn.ReadMessage(ctx)
+ require.NoError(t, err)
+ require.JSONEq(t, `{"type":"response.output_text.delta","delta":"ok"}`, string(msg))
+
+ typedConn, ok := conn.(*coderOpenAIWSClientConn)
+ require.True(t, ok)
+ msgType, payload, err := typedConn.ReadFrame(ctx)
+ require.NoError(t, err)
+ require.Equal(t, coderws.MessageBinary, msgType)
+ require.Equal(t, []byte{0x01, 0x02, 0x03}, payload)
+
+ require.NoError(t, typedConn.WriteFrame(ctx, coderws.MessageText, []byte(`{"client":"ack"}`)))
+ pingCtx, cancelPing := context.WithTimeout(context.Background(), 100*time.Millisecond)
+ _ = typedConn.Ping(pingCtx)
+ cancelPing()
+ require.NoError(t, typedConn.Close())
+ require.NoError(t, <-serverErrCh)
+}
diff --git a/backend/internal/service/openai_ws_common.go b/backend/internal/service/openai_ws_common.go
new file mode 100644
index 000000000..d3ce904f9
--- /dev/null
+++ b/backend/internal/service/openai_ws_common.go
@@ -0,0 +1,54 @@
+package service
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "time"
+)
+
+var (
+ errOpenAIWSConnClosed = errors.New("openai ws connection closed")
+ errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full")
+ errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable")
+)
+
+const (
+ openAIWSConnHealthCheckTO = 2 * time.Second
+)
+
+type openAIWSDialError struct {
+ StatusCode int
+ ResponseHeaders http.Header
+ Err error
+}
+
+func (e *openAIWSDialError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.StatusCode > 0 {
+ return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err)
+ }
+ return fmt.Sprintf("openai ws dial failed: %v", e.Err)
+}
+
+func (e *openAIWSDialError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.Err
+}
+
+func cloneHeader(h http.Header) http.Header {
+ if h == nil {
+ return nil
+ }
+ cloned := make(http.Header, len(h))
+ for k, values := range h {
+ copied := make([]string, len(values))
+ copy(copied, values)
+ cloned[k] = copied
+ }
+ return cloned
+}
diff --git a/backend/internal/service/openai_ws_common_test.go b/backend/internal/service/openai_ws_common_test.go
new file mode 100644
index 000000000..211883cb1
--- /dev/null
+++ b/backend/internal/service/openai_ws_common_test.go
@@ -0,0 +1,44 @@
+package service
+
+import (
+ "errors"
+ "net/http"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIWSDialError_Behavior(t *testing.T) {
+ t.Parallel()
+
+ var nilErr *openAIWSDialError
+ require.Equal(t, "", nilErr.Error())
+ require.Nil(t, nilErr.Unwrap())
+
+ err := &openAIWSDialError{
+ StatusCode: 429,
+ Err: errors.New("too many requests"),
+ }
+ require.Contains(t, err.Error(), "status=429")
+ require.ErrorIs(t, err.Unwrap(), err.Err)
+}
+
+func TestCloneHeader_DeepCopy(t *testing.T) {
+ t.Parallel()
+
+ require.Nil(t, cloneHeader(nil))
+
+ origin := http.Header{
+ "X-Request-Id": []string{"req-1"},
+ "Set-Cookie": []string{"a=1", "b=2"},
+ }
+ cloned := cloneHeader(origin)
+ require.Equal(t, origin.Get("X-Request-Id"), cloned.Get("X-Request-Id"))
+ require.Equal(t, origin.Values("Set-Cookie"), cloned.Values("Set-Cookie"))
+
+ // 修改拷贝不应污染原 header。
+ cloned.Set("X-Request-Id", "req-2")
+ cloned["Set-Cookie"][0] = "a=9"
+ require.Equal(t, "req-1", origin.Get("X-Request-Id"))
+ require.Equal(t, "a=1", origin.Values("Set-Cookie")[0])
+}
diff --git a/backend/internal/service/openai_ws_fallback_test.go b/backend/internal/service/openai_ws_fallback_test.go
index ce06f6a21..f840034e2 100644
--- a/backend/internal/service/openai_ws_fallback_test.go
+++ b/backend/internal/service/openai_ws_fallback_test.go
@@ -3,12 +3,15 @@ package service
import (
"context"
"errors"
+ "io"
"net/http"
+ "net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -108,6 +111,154 @@ func TestClassifyOpenAIWSErrorEvent(t *testing.T) {
reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`))
require.Equal(t, "previous_response_not_found", reason)
require.True(t, recoverable)
+
+ // tool_output_not_found: 用户按 ESC 取消 function_call 后重新发送消息
+ reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool output found for function call call_zXKPiNecBmIAoKeW9o2pNMvo.","param":"input"}}`))
+ require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason)
+ require.True(t, recoverable)
+
+ reason, recoverable = classifyOpenAIWSErrorEventFromRaw("", "invalid_request_error", "No tool output found for function call call_abc123.")
+ require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason)
+ require.True(t, recoverable)
+
+ reason, recoverable = classifyOpenAIWSErrorEventFromRaw(
+ "",
+ "invalid_request_error",
+ "No tool call found for function call output with call_id call_abc123.",
+ )
+ require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason)
+ require.True(t, recoverable)
+
+ // reasoning orphaned items should reuse tool_output_not_found recovery path.
+ reason, recoverable = classifyOpenAIWSErrorEventFromRaw(
+ "",
+ "invalid_request_error",
+ "Item 'rs_xxx' of type 'reasoning' was provided without its required following item.",
+ )
+ require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason)
+ require.True(t, recoverable)
+
+ reason, recoverable = classifyOpenAIWSErrorEventFromRaw(
+ "",
+ "invalid_request_error",
+ "Item 'rs_xxx' of type 'reasoning' was provided without its required preceding item.",
+ )
+ require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason)
+ require.True(t, recoverable)
+}
+
+func TestClassifyOpenAIWSErrorEventFromRaw_AllBranches(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ codeRaw string
+ errTypeRaw string
+ msgRaw string
+ wantReason string
+ wantRecover bool
+ }{
+ {
+ name: "code_upgrade_required",
+ codeRaw: "upgrade_required",
+ wantReason: "upgrade_required",
+ wantRecover: true,
+ },
+ {
+ name: "code_ws_unsupported",
+ codeRaw: "websocket_not_supported",
+ wantReason: "ws_unsupported",
+ wantRecover: true,
+ },
+ {
+ name: "code_ws_connection_limit",
+ codeRaw: "websocket_connection_limit_reached",
+ wantReason: "ws_connection_limit_reached",
+ wantRecover: true,
+ },
+ {
+ name: "msg_upgrade_required",
+ msgRaw: "status 426 upgrade required",
+ wantReason: "upgrade_required",
+ wantRecover: true,
+ },
+ {
+ name: "err_type_upgrade",
+ errTypeRaw: "gateway_upgrade_error",
+ wantReason: "upgrade_required",
+ wantRecover: true,
+ },
+ {
+ name: "msg_ws_unsupported",
+ msgRaw: "websocket is unsupported in this region",
+ wantReason: "ws_unsupported",
+ wantRecover: true,
+ },
+ {
+ name: "msg_ws_connection_limit",
+ msgRaw: "websocket connection limit exceeded",
+ wantReason: "ws_connection_limit_reached",
+ wantRecover: true,
+ },
+ {
+ name: "msg_previous_response_not_found_variant",
+ msgRaw: "previous response is not found",
+ wantReason: "previous_response_not_found",
+ wantRecover: true,
+ },
+ {
+ name: "msg_no_tool_output",
+ msgRaw: "No tool output found for function call call_abc.",
+ wantReason: openAIWSIngressStageToolOutputNotFound,
+ wantRecover: true,
+ },
+ {
+ name: "msg_no_tool_call_for_function_call_output",
+ msgRaw: "No tool call found for function call output with call_id call_abc.",
+ wantReason: openAIWSIngressStageToolOutputNotFound,
+ wantRecover: true,
+ },
+ {
+ name: "msg_reasoning_missing_following",
+ msgRaw: "Item 'rs_xxx' of type 'reasoning' was provided without its required following item.",
+ wantReason: openAIWSIngressStageToolOutputNotFound,
+ wantRecover: true,
+ },
+ {
+ name: "msg_reasoning_missing_preceding",
+ msgRaw: "Item 'rs_xxx' of type 'reasoning' was provided without its required preceding item.",
+ wantReason: openAIWSIngressStageToolOutputNotFound,
+ wantRecover: true,
+ },
+ {
+ name: "server_error_by_type",
+ errTypeRaw: "server_error",
+ wantReason: "upstream_error_event",
+ wantRecover: true,
+ },
+ {
+ name: "server_error_by_code",
+ codeRaw: "server_error",
+ wantReason: "upstream_error_event",
+ wantRecover: true,
+ },
+ {
+ name: "unknown_event_error",
+ codeRaw: "other",
+ errTypeRaw: "other",
+ msgRaw: "other",
+ wantReason: "event_error",
+ wantRecover: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ reason, recoverable := classifyOpenAIWSErrorEventFromRaw(tt.codeRaw, tt.errTypeRaw, tt.msgRaw)
+ require.Equal(t, tt.wantReason, reason)
+ require.Equal(t, tt.wantRecover, recoverable)
+ })
+ }
}
func TestClassifyOpenAIWSReconnectReason(t *testing.T) {
@@ -197,12 +348,39 @@ func TestOpenAIWSRetryTotalBudget(t *testing.T) {
require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget())
}
+func TestOpenAIWSRetryContextError(t *testing.T) {
+ require.NoError(t, openAIWSRetryContextError(context.TODO()))
+ require.NoError(t, openAIWSRetryContextError(context.Background()))
+
+ canceledCtx, cancel := context.WithCancel(context.Background())
+ cancel()
+ err := openAIWSRetryContextError(canceledCtx)
+ require.Error(t, err)
+ require.ErrorIs(t, err, context.Canceled)
+
+ var fallbackErr *openAIWSFallbackError
+ require.ErrorAs(t, err, &fallbackErr)
+ require.Equal(t, "retry_context_canceled", fallbackErr.Reason)
+}
+
func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) {
+ require.Equal(t, "service_restart", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusServiceRestart}))
+ require.Equal(t, "try_again_later", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusTryAgainLater}))
require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation}))
require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig}))
require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io")))
}
+func TestClassifyOpenAIWSIngressReadErrorClass(t *testing.T) {
+ require.Equal(t, "unknown", classifyOpenAIWSIngressReadErrorClass(nil))
+ require.Equal(t, "context_canceled", classifyOpenAIWSIngressReadErrorClass(context.Canceled))
+ require.Equal(t, "deadline_exceeded", classifyOpenAIWSIngressReadErrorClass(context.DeadlineExceeded))
+ require.Equal(t, "service_restart", classifyOpenAIWSIngressReadErrorClass(coderws.CloseError{Code: coderws.StatusServiceRestart}))
+ require.Equal(t, "try_again_later", classifyOpenAIWSIngressReadErrorClass(coderws.CloseError{Code: coderws.StatusTryAgainLater}))
+ require.Equal(t, "upstream_closed", classifyOpenAIWSIngressReadErrorClass(io.EOF))
+ require.Equal(t, "unknown", classifyOpenAIWSIngressReadErrorClass(errors.New("tls handshake timeout")))
+}
+
func TestOpenAIWSStoreDisabledConnMode(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
@@ -239,6 +417,119 @@ func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) {
require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal)
}
+func TestWriteOpenAIWSV1UnsupportedResponse_TracksOps(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+
+ svc := &OpenAIGatewayService{}
+ account := &Account{
+ ID: 42,
+ Name: "acc-ws-v1",
+ Platform: PlatformOpenAI,
+ }
+
+ err := svc.writeOpenAIWSV1UnsupportedResponse(c, account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "openai ws v1 is temporarily unsupported")
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "invalid_request_error")
+ require.Contains(t, rec.Body.String(), "temporarily unsupported")
+
+ rawStatus, ok := c.Get(OpsUpstreamStatusCodeKey)
+ require.True(t, ok)
+ require.Equal(t, http.StatusBadRequest, rawStatus)
+
+ rawMsg, ok := c.Get(OpsUpstreamErrorMessageKey)
+ require.True(t, ok)
+ require.Equal(t, "openai ws v1 is temporarily unsupported; use ws v2", rawMsg)
+
+ rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
+ require.True(t, ok)
+ events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
+ require.True(t, ok)
+ require.Len(t, events, 1)
+ require.Equal(t, account.ID, events[0].AccountID)
+ require.Equal(t, account.Platform, events[0].Platform)
+ require.Equal(t, http.StatusBadRequest, events[0].UpstreamStatusCode)
+ require.Equal(t, "ws_error", events[0].Kind)
+}
+
+func TestIsOpenAIWSStreamWriteDisconnectError(t *testing.T) {
+ require.False(t, isOpenAIWSStreamWriteDisconnectError(nil, nil))
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ require.True(t, isOpenAIWSStreamWriteDisconnectError(errors.New("writer failed"), ctx))
+
+ require.True(t, isOpenAIWSStreamWriteDisconnectError(errors.New("broken pipe"), context.Background()))
+ require.True(t, isOpenAIWSStreamWriteDisconnectError(io.EOF, context.Background()))
+
+ require.False(t, isOpenAIWSStreamWriteDisconnectError(errors.New("template execute failed"), context.Background()))
+}
+
+func TestShouldFlushOpenAIWSBufferedEventsOnError(t *testing.T) {
+ require.True(t, shouldFlushOpenAIWSBufferedEventsOnError(true, true, false))
+ require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(true, false, false))
+ require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(true, true, true))
+ require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(false, true, false))
+}
+
+func TestCloneOpenAIWSJSONRawString(t *testing.T) {
+ require.Nil(t, cloneOpenAIWSJSONRawString(""))
+ require.Nil(t, cloneOpenAIWSJSONRawString(" "))
+
+ raw := `{"id":"resp_1","type":"response"}`
+ cloned := cloneOpenAIWSJSONRawString(raw)
+ require.Equal(t, raw, string(cloned))
+ require.Equal(t, len(raw), len(cloned))
+}
+
+func TestOpenAIWSAbortMetricsSnapshot(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonUpstreamError, true)
+ svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonUpstreamError, true)
+ svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonWriteUpstream, false)
+ svc.recordOpenAIWSTurnAbortRecovered()
+
+ snapshot := svc.SnapshotOpenAIWSAbortMetrics()
+ require.Equal(t, int64(1), snapshot.TurnAbortRecoveredTotal)
+
+ getTotal := func(reason string, expected bool) int64 {
+ for _, point := range snapshot.TurnAbortTotal {
+ if point.Reason == reason && point.Expected == expected {
+ return point.Total
+ }
+ }
+ return 0
+ }
+ require.Equal(t, int64(2), getTotal(string(openAIWSIngressTurnAbortReasonUpstreamError), true))
+ require.Equal(t, int64(1), getTotal(string(openAIWSIngressTurnAbortReasonWriteUpstream), false))
+}
+
+func TestOpenAIWSPerformanceMetricsSnapshot_ContainsAbortMetrics(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonClientClosed, true)
+ svc.recordOpenAIWSTurnAbortRecovered()
+
+ snapshot := svc.SnapshotOpenAIWSPerformanceMetrics()
+ require.Equal(t, int64(1), snapshot.Abort.TurnAbortRecoveredTotal)
+ require.Equal(t, int64(0), snapshot.Passthrough.SemanticMutationTotal)
+ require.GreaterOrEqual(t, snapshot.Passthrough.UsageParseFailureTotal, int64(0))
+
+ found := false
+ for _, point := range snapshot.Abort.TurnAbortTotal {
+ if point.Reason == string(openAIWSIngressTurnAbortReasonClientClosed) && point.Expected {
+ require.Equal(t, int64(1), point.Total)
+ found = true
+ break
+ }
+ }
+ require.True(t, found)
+}
+
func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go
index 74ba472f4..dd91055ac 100644
--- a/backend/internal/service/openai_ws_forwarder.go
+++ b/backend/internal/service/openai_ws_forwarder.go
@@ -6,17 +6,18 @@ import (
"encoding/json"
"errors"
"fmt"
- "io"
"math/rand"
- "net"
"net/http"
"net/url"
+ "runtime/debug"
"sort"
+ "strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
coderws "github.com/coder/websocket"
"github.com/gin-gonic/gin"
@@ -26,8 +27,10 @@ import (
)
const (
- openAIWSBetaV1Value = "responses_websockets=2026-02-04"
- openAIWSBetaV2Value = "responses_websockets=2026-02-06"
+ openAIWSBetaV1Value = "responses_websockets=2026-02-04"
+ openAIWSBetaV2Value = "responses_websockets=2026-02-06"
+ openAIWSConnIDPrefixLegacy = "oa_ws_"
+ openAIWSConnIDPrefixCtx = "ctxws_"
openAIWSTurnStateHeader = "x-codex-turn-state"
openAIWSTurnMetadataHeader = "x-codex-turn-metadata"
@@ -55,879 +58,100 @@ const (
openAIWSStoreDisabledConnModeOff = "off"
openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found"
+ openAIWSIngressStageToolOutputNotFound = "tool_output_not_found"
openAIWSMaxPrevResponseIDDeletePasses = 8
+ openAIWSIngressReplayInputMaxBytes = 512 * 1024
+ openAIWSContinuationUnavailableReason = "upstream continuation connection is unavailable; please restart the conversation"
+ openAIWSAutoAbortedToolOutputValue = `{"error":"tool call aborted by gateway"}`
+ openAIWSClientReadIdleTimeoutDefault = 30 * time.Minute
+ openAIWSPassthroughIdleTimeoutDefault = time.Hour
+ openAIWSIngressClientDisconnectDrainTimeout = 5 * time.Second
+ openAIWSUpstreamPumpInfoMinAlive = 100 * time.Millisecond
)
-var openAIWSLogValueReplacer = strings.NewReplacer(
- "error", "err",
- "fallback", "fb",
- "warning", "warnx",
- "failed", "fail",
-)
-
-var openAIWSIngressPreflightPingIdle = 20 * time.Second
-
-// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。
-type openAIWSFallbackError struct {
- Reason string
- Err error
-}
-
-func (e *openAIWSFallbackError) Error() string {
- if e == nil {
- return ""
- }
- if e.Err == nil {
- return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason))
- }
- return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err)
-}
-
-func (e *openAIWSFallbackError) Unwrap() error {
- if e == nil {
- return nil
- }
- return e.Err
-}
-
-func wrapOpenAIWSFallback(reason string, err error) error {
- return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err}
-}
-
-// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。
-type OpenAIWSClientCloseError struct {
- statusCode coderws.StatusCode
- reason string
- err error
-}
-
-type openAIWSIngressTurnError struct {
- stage string
- cause error
- wroteDownstream bool
-}
-
-func (e *openAIWSIngressTurnError) Error() string {
- if e == nil {
- return ""
- }
- if e.cause == nil {
- return strings.TrimSpace(e.stage)
- }
- return e.cause.Error()
-}
-
-func (e *openAIWSIngressTurnError) Unwrap() error {
- if e == nil {
- return nil
- }
- return e.cause
-}
-
-func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error {
- if cause == nil {
- return nil
- }
- return &openAIWSIngressTurnError{
- stage: strings.TrimSpace(stage),
- cause: cause,
- wroteDownstream: wroteDownstream,
- }
-}
-
-func isOpenAIWSIngressTurnRetryable(err error) bool {
- var turnErr *openAIWSIngressTurnError
- if !errors.As(err, &turnErr) || turnErr == nil {
- return false
- }
- if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) {
- return false
- }
- if turnErr.wroteDownstream {
- return false
- }
- switch turnErr.stage {
- case "write_upstream", "read_upstream":
- return true
- default:
- return false
- }
-}
-
-func openAIWSIngressTurnRetryReason(err error) string {
- var turnErr *openAIWSIngressTurnError
- if !errors.As(err, &turnErr) || turnErr == nil {
- return "unknown"
- }
- if turnErr.stage == "" {
- return "unknown"
- }
- return turnErr.stage
-}
-
-func isOpenAIWSIngressPreviousResponseNotFound(err error) bool {
- var turnErr *openAIWSIngressTurnError
- if !errors.As(err, &turnErr) || turnErr == nil {
- return false
- }
- if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound {
- return false
- }
- return !turnErr.wroteDownstream
-}
-
-// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。
-func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error {
- return &OpenAIWSClientCloseError{
- statusCode: statusCode,
- reason: strings.TrimSpace(reason),
- err: err,
- }
-}
-
-func (e *OpenAIWSClientCloseError) Error() string {
- if e == nil {
- return ""
- }
- if e.err == nil {
- return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason))
- }
- return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err)
-}
-
-func (e *OpenAIWSClientCloseError) Unwrap() error {
- if e == nil {
- return nil
- }
- return e.err
-}
-
-func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode {
- if e == nil {
- return coderws.StatusInternalError
- }
- return e.statusCode
-}
-
-func (e *OpenAIWSClientCloseError) Reason() string {
- if e == nil {
- return ""
- }
- return strings.TrimSpace(e.reason)
-}
-
-// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
-type OpenAIWSIngressHooks struct {
- BeforeTurn func(turn int) error
- AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
-}
-
-func normalizeOpenAIWSLogValue(value string) string {
- trimmed := strings.TrimSpace(value)
- if trimmed == "" {
- return "-"
- }
- return openAIWSLogValueReplacer.Replace(trimmed)
-}
-
-func truncateOpenAIWSLogValue(value string, maxLen int) string {
- normalized := normalizeOpenAIWSLogValue(value)
- if normalized == "-" || maxLen <= 0 {
- return normalized
- }
- if len(normalized) <= maxLen {
- return normalized
- }
- return normalized[:maxLen] + "..."
-}
-
-func openAIWSHeaderValueForLog(headers http.Header, key string) string {
- if headers == nil {
- return "-"
- }
- return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen)
-}
-
-func hasOpenAIWSHeader(headers http.Header, key string) bool {
- if headers == nil {
- return false
- }
- return strings.TrimSpace(headers.Get(key)) != ""
-}
-
-type openAIWSSessionHeaderResolution struct {
- SessionID string
- ConversationID string
- SessionSource string
- ConversationSource string
-}
-
-func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution {
- resolution := openAIWSSessionHeaderResolution{
- SessionSource: "none",
- ConversationSource: "none",
- }
- if c != nil && c.Request != nil {
- if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" {
- resolution.SessionID = sessionID
- resolution.SessionSource = "header_session_id"
- }
- if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" {
- resolution.ConversationID = conversationID
- resolution.ConversationSource = "header_conversation_id"
- if resolution.SessionID == "" {
- resolution.SessionID = conversationID
- resolution.SessionSource = "header_conversation_id"
- }
- }
- }
-
- cacheKey := strings.TrimSpace(promptCacheKey)
- if cacheKey != "" {
- if resolution.SessionID == "" {
- resolution.SessionID = cacheKey
- resolution.SessionSource = "prompt_cache_key"
- }
- }
- return resolution
-}
-
-func shouldLogOpenAIWSEvent(idx int, eventType string) bool {
- if idx <= openAIWSEventLogHeadLimit {
- return true
- }
- if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 {
- return true
- }
- if eventType == "error" || isOpenAIWSTerminalEvent(eventType) {
- return true
- }
- return false
-}
-
-func shouldLogOpenAIWSBufferedEvent(idx int) bool {
- if idx <= openAIWSBufferLogHeadLimit {
- return true
- }
- if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 {
- return true
- }
- return false
-}
-
-func openAIWSEventMayContainModel(eventType string) bool {
- switch eventType {
- case "response.created",
- "response.in_progress",
- "response.completed",
- "response.done",
- "response.failed",
- "response.incomplete",
- "response.cancelled",
- "response.canceled":
- return true
- default:
- trimmed := strings.TrimSpace(eventType)
- if trimmed == eventType {
- return false
- }
- switch trimmed {
- case "response.created",
- "response.in_progress",
- "response.completed",
- "response.done",
- "response.failed",
- "response.incomplete",
- "response.cancelled",
- "response.canceled":
- return true
- default:
- return false
- }
- }
-}
-
-func openAIWSEventMayContainToolCalls(eventType string) bool {
- eventType = strings.TrimSpace(eventType)
- if eventType == "" {
- return false
- }
- if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") {
- return true
- }
- switch eventType {
- case "response.output_item.added", "response.output_item.done", "response.completed", "response.done":
- return true
- default:
- return false
- }
-}
-
-func openAIWSEventShouldParseUsage(eventType string) bool {
- return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed"
-}
-
-func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
- if len(message) == 0 {
- return "", "", gjson.Result{}
- }
- values := gjson.GetManyBytes(message, "type", "response.id", "id", "response")
- eventType = strings.TrimSpace(values[0].String())
- if id := strings.TrimSpace(values[1].String()); id != "" {
- responseID = id
- } else {
- responseID = strings.TrimSpace(values[2].String())
- }
- return eventType, responseID, values[3]
-}
-
-func openAIWSMessageLikelyContainsToolCalls(message []byte) bool {
- if len(message) == 0 {
- return false
- }
- return bytes.Contains(message, []byte(`"tool_calls"`)) ||
- bytes.Contains(message, []byte(`"tool_call"`)) ||
- bytes.Contains(message, []byte(`"function_call"`))
-}
-
-func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) {
- if usage == nil || len(message) == 0 {
- return
- }
- values := gjson.GetManyBytes(
- message,
- "response.usage.input_tokens",
- "response.usage.output_tokens",
- "response.usage.input_tokens_details.cached_tokens",
- )
- usage.InputTokens = int(values[0].Int())
- usage.OutputTokens = int(values[1].Int())
- usage.CacheReadInputTokens = int(values[2].Int())
-}
-
-func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
- if len(message) == 0 {
- return "", "", ""
- }
- values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message")
- return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String())
-}
-
-func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) {
- code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen)
- errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen)
- errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen)
- return code, errType, errMessage
-}
-
-func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
- if len(message) == 0 {
- return "-", "-", "-"
- }
- return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message))
-}
-
-func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string {
- if len(payload) == 0 {
- return "-"
- }
- type keySize struct {
- Key string
- Size int
- }
- sizes := make([]keySize, 0, len(payload))
- for key, value := range payload {
- size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth)
- sizes = append(sizes, keySize{Key: key, Size: size})
- }
- sort.Slice(sizes, func(i, j int) bool {
- if sizes[i].Size == sizes[j].Size {
- return sizes[i].Key < sizes[j].Key
- }
- return sizes[i].Size > sizes[j].Size
- })
-
- if topN <= 0 || topN > len(sizes) {
- topN = len(sizes)
- }
- parts := make([]string, 0, topN)
- for idx := 0; idx < topN; idx++ {
- item := sizes[idx]
- parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size))
- }
- return strings.Join(parts, ",")
-}
-
-func estimateOpenAIWSPayloadValueSize(value any, depth int) int {
- if depth <= 0 {
- return -1
- }
- switch v := value.(type) {
- case nil:
- return 0
- case string:
- return len(v)
- case []byte:
- return len(v)
- case bool:
- return 1
- case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
- return 8
- case float32, float64:
- return 8
- case map[string]any:
- if len(v) == 0 {
- return 2
- }
- total := 2
- count := 0
- for key, item := range v {
- count++
- if count > openAIWSPayloadSizeEstimateMaxItems {
- return -1
- }
- itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1)
- if itemSize < 0 {
- return -1
- }
- total += len(key) + itemSize + 3
- if total > openAIWSPayloadSizeEstimateMaxBytes {
- return -1
- }
- }
- return total
- case []any:
- if len(v) == 0 {
- return 2
- }
- total := 2
- limit := len(v)
- if limit > openAIWSPayloadSizeEstimateMaxItems {
- return -1
- }
- for i := 0; i < limit; i++ {
- itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1)
- if itemSize < 0 {
- return -1
- }
- total += itemSize + 1
- if total > openAIWSPayloadSizeEstimateMaxBytes {
- return -1
- }
- }
- return total
- default:
- raw, err := json.Marshal(v)
- if err != nil {
- return -1
- }
- if len(raw) > openAIWSPayloadSizeEstimateMaxBytes {
- return -1
- }
- return len(raw)
- }
-}
-
-func openAIWSPayloadString(payload map[string]any, key string) string {
- if len(payload) == 0 {
- return ""
- }
- raw, ok := payload[key]
- if !ok {
- return ""
- }
- switch v := raw.(type) {
- case nil:
- return ""
- case string:
- return strings.TrimSpace(v)
- case []byte:
- return strings.TrimSpace(string(v))
- default:
- return ""
- }
-}
-
-func openAIWSPayloadStringFromRaw(payload []byte, key string) string {
- if len(payload) == 0 || strings.TrimSpace(key) == "" {
- return ""
- }
- return strings.TrimSpace(gjson.GetBytes(payload, key).String())
-}
-
-func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool {
- if len(payload) == 0 || strings.TrimSpace(key) == "" {
- return defaultValue
- }
- value := gjson.GetBytes(payload, key)
- if !value.Exists() {
- return defaultValue
- }
- if value.Type != gjson.True && value.Type != gjson.False {
- return defaultValue
- }
- return value.Bool()
-}
-
-func openAIWSSessionHashesFromID(sessionID string) (string, string) {
- return deriveOpenAISessionHashes(sessionID)
-}
-
-func extractOpenAIWSImageURL(value any) string {
- switch v := value.(type) {
- case string:
- return strings.TrimSpace(v)
- case map[string]any:
- if raw, ok := v["url"].(string); ok {
- return strings.TrimSpace(raw)
- }
- }
- return ""
-}
-
-func summarizeOpenAIWSInput(input any) string {
- items, ok := input.([]any)
- if !ok || len(items) == 0 {
- return "-"
- }
-
- itemCount := len(items)
- textChars := 0
- imageDataURLs := 0
- imageDataURLChars := 0
- imageRemoteURLs := 0
-
- handleContentItem := func(contentItem map[string]any) {
- contentType, _ := contentItem["type"].(string)
- switch strings.TrimSpace(contentType) {
- case "input_text", "output_text", "text":
- if text, ok := contentItem["text"].(string); ok {
- textChars += len(text)
- }
- case "input_image":
- imageURL := extractOpenAIWSImageURL(contentItem["image_url"])
- if imageURL == "" {
- return
- }
- if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
- imageDataURLs++
- imageDataURLChars += len(imageURL)
- return
- }
- imageRemoteURLs++
- }
- }
-
- handleInputItem := func(inputItem map[string]any) {
- if content, ok := inputItem["content"].([]any); ok {
- for _, rawContent := range content {
- contentItem, ok := rawContent.(map[string]any)
- if !ok {
- continue
- }
- handleContentItem(contentItem)
- }
- return
- }
-
- itemType, _ := inputItem["type"].(string)
- switch strings.TrimSpace(itemType) {
- case "input_text", "output_text", "text":
- if text, ok := inputItem["text"].(string); ok {
- textChars += len(text)
- }
- case "input_image":
- imageURL := extractOpenAIWSImageURL(inputItem["image_url"])
- if imageURL == "" {
- return
- }
- if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
- imageDataURLs++
- imageDataURLChars += len(imageURL)
- return
- }
- imageRemoteURLs++
- }
- }
-
- for _, rawItem := range items {
- inputItem, ok := rawItem.(map[string]any)
- if !ok {
- continue
- }
- handleInputItem(inputItem)
- }
-
- return fmt.Sprintf(
- "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d",
- itemCount,
- textChars,
- imageDataURLs,
- imageDataURLChars,
- imageRemoteURLs,
- )
-}
-
-func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) {
- if len(payload) == 0 || strings.TrimSpace(key) == "" {
- return
- }
- if _, exists := payload[key]; !exists {
- return
- }
- delete(payload, key)
- *removed = append(*removed, key)
-}
-
-// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段,
-// 避免重试成功却改变原始请求语义。
-// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。
-func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) {
- if len(payload) == 0 {
- return "empty", nil
- }
- if attempt <= 1 {
- return "full", nil
- }
-
- removed := make([]string, 0, 2)
- if attempt >= 2 {
- dropOpenAIWSPayloadKey(payload, "include", &removed)
- }
-
- if len(removed) == 0 {
- return "full", nil
- }
- sort.Strings(removed)
- return "trim_optional_fields", removed
-}
-
-func logOpenAIWSModeInfo(format string, args ...any) {
- logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
-}
-
-func isOpenAIWSModeDebugEnabled() bool {
- return logger.L().Core().Enabled(zap.DebugLevel)
-}
-
-func logOpenAIWSModeDebug(format string, args ...any) {
- if !isOpenAIWSModeDebugEnabled() {
- return
- }
- logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
-}
-
-func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) {
- if err == nil {
- return
- }
- logger.L().Warn(
- "openai.ws_bind_response_account_failed",
- zap.Int64("group_id", groupID),
- zap.Int64("account_id", accountID),
- zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)),
- zap.Error(err),
- )
-}
-
-func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) {
- if err == nil {
- return "-", "-"
- }
- statusCode := coderws.CloseStatus(err)
- if statusCode == -1 {
- return "-", "-"
- }
- closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
- closeReason := "-"
- var closeErr coderws.CloseError
- if errors.As(err, &closeErr) {
- reasonText := strings.TrimSpace(closeErr.Reason)
- if reasonText != "" {
- closeReason = normalizeOpenAIWSLogValue(reasonText)
- }
- }
- return normalizeOpenAIWSLogValue(closeStatus), closeReason
-}
+type openAIWSIngressTurnAbortReason string
-func unwrapOpenAIWSDialBaseError(err error) error {
- if err == nil {
- return nil
- }
- var dialErr *openAIWSDialError
- if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil {
- return dialErr.Err
- }
- return err
-}
+const (
+ openAIWSIngressTurnAbortReasonUnknown openAIWSIngressTurnAbortReason = "unknown"
+
+ openAIWSIngressTurnAbortReasonClientClosed openAIWSIngressTurnAbortReason = "client_closed"
+ openAIWSIngressTurnAbortReasonContextCanceled openAIWSIngressTurnAbortReason = "ctx_canceled"
+ openAIWSIngressTurnAbortReasonContextDeadline openAIWSIngressTurnAbortReason = "ctx_deadline_exceeded"
+ openAIWSIngressTurnAbortReasonPreviousResponse openAIWSIngressTurnAbortReason = openAIWSIngressStagePreviousResponseNotFound
+ openAIWSIngressTurnAbortReasonToolOutput openAIWSIngressTurnAbortReason = openAIWSIngressStageToolOutputNotFound
+ openAIWSIngressTurnAbortReasonUpstreamError openAIWSIngressTurnAbortReason = "upstream_error_event"
+ openAIWSIngressTurnAbortReasonWriteUpstream openAIWSIngressTurnAbortReason = "write_upstream"
+ openAIWSIngressTurnAbortReasonReadUpstream openAIWSIngressTurnAbortReason = "read_upstream"
+ openAIWSIngressTurnAbortReasonWriteClient openAIWSIngressTurnAbortReason = "write_client"
+ openAIWSIngressTurnAbortReasonContinuationUnavailable openAIWSIngressTurnAbortReason = "continuation_unavailable"
+ openAIWSIngressTurnAbortReasonClientPreempted openAIWSIngressTurnAbortReason = "client_preempted"
+ openAIWSIngressTurnAbortReasonUpstreamRestart openAIWSIngressTurnAbortReason = "upstream_restart"
+)
-func openAIWSDialRespHeaderForLog(err error, key string) string {
- var dialErr *openAIWSDialError
- if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil {
- return "-"
- }
- return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen)
-}
+type openAIWSIngressTurnAbortDisposition string
-func classifyOpenAIWSDialError(err error) string {
- if err == nil {
- return "-"
- }
- baseErr := unwrapOpenAIWSDialBaseError(err)
- if baseErr == nil {
- return "-"
- }
- if errors.Is(baseErr, context.DeadlineExceeded) {
- return "ctx_deadline_exceeded"
- }
- if errors.Is(baseErr, context.Canceled) {
- return "ctx_canceled"
- }
- var netErr net.Error
- if errors.As(baseErr, &netErr) && netErr.Timeout() {
- return "net_timeout"
- }
- if status := coderws.CloseStatus(baseErr); status != -1 {
- return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status)))
- }
- message := strings.ToLower(strings.TrimSpace(baseErr.Error()))
- switch {
- case strings.Contains(message, "handshake not finished"):
- return "handshake_not_finished"
- case strings.Contains(message, "bad handshake"):
- return "bad_handshake"
- case strings.Contains(message, "connection refused"):
- return "connection_refused"
- case strings.Contains(message, "no such host"):
- return "dns_not_found"
- case strings.Contains(message, "tls"):
- return "tls_error"
- case strings.Contains(message, "i/o timeout"):
- return "io_timeout"
- case strings.Contains(message, "context deadline exceeded"):
- return "ctx_deadline_exceeded"
- default:
- return "dial_error"
- }
-}
+const (
+ openAIWSIngressTurnAbortDispositionFailRequest openAIWSIngressTurnAbortDisposition = "fail_request"
+ openAIWSIngressTurnAbortDispositionContinueTurn openAIWSIngressTurnAbortDisposition = "continue_turn"
+ openAIWSIngressTurnAbortDispositionCloseGracefully openAIWSIngressTurnAbortDisposition = "close_gracefully"
+)
-func summarizeOpenAIWSDialError(err error) (
- statusCode int,
- dialClass string,
- closeStatus string,
- closeReason string,
- respServer string,
- respVia string,
- respCFRay string,
- respRequestID string,
-) {
- dialClass = "-"
- closeStatus = "-"
- closeReason = "-"
- respServer = "-"
- respVia = "-"
- respCFRay = "-"
- respRequestID = "-"
- if err == nil {
- return
- }
- var dialErr *openAIWSDialError
- if errors.As(err, &dialErr) && dialErr != nil {
- statusCode = dialErr.StatusCode
- respServer = openAIWSDialRespHeaderForLog(err, "server")
- respVia = openAIWSDialRespHeaderForLog(err, "via")
- respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray")
- respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id")
- }
- dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err))
- closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err))
- return
+// openAIWSUpstreamPumpEvent 是上游事件读取泵传递给主 goroutine 的消息载体。
+type openAIWSUpstreamPumpEvent struct {
+ message []byte
+ err error
}
-func isOpenAIWSClientDisconnectError(err error) bool {
- if err == nil {
- return false
- }
- if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
- return true
- }
- switch coderws.CloseStatus(err) {
- case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
- return true
- }
- message := strings.ToLower(strings.TrimSpace(err.Error()))
- if message == "" {
- return false
- }
- return strings.Contains(message, "failed to read frame header: eof") ||
- strings.Contains(message, "unexpected eof") ||
- strings.Contains(message, "use of closed network connection") ||
- strings.Contains(message, "connection reset by peer") ||
- strings.Contains(message, "broken pipe")
-}
+const (
+ // openAIWSUpstreamPumpBufferSize 是上游事件读取泵的缓冲 channel 大小。
+ // 缓冲允许上游读取和客户端写入并发执行,吸收客户端写入延迟波动。
+ openAIWSUpstreamPumpBufferSize = 16
+)
-func classifyOpenAIWSReadFallbackReason(err error) string {
- if err == nil {
- return "read_event"
- }
- switch coderws.CloseStatus(err) {
- case coderws.StatusPolicyViolation:
- return "policy_violation"
- case coderws.StatusMessageTooBig:
- return "message_too_big"
- default:
- return "read_event"
- }
-}
+var openAIWSLogValueReplacer = strings.NewReplacer(
+ "error", "err",
+ "fallback", "fb",
+ "warning", "warnx",
+ "failed", "fail",
+)
-func sortedKeys(m map[string]any) []string {
- if len(m) == 0 {
- return nil
- }
- keys := make([]string, 0, len(m))
- for k := range m {
- keys = append(keys, k)
- }
- sort.Strings(keys)
- return keys
-}
+var openAIWSIngressPreflightPingIdle = 20 * time.Second
-func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
+func (s *OpenAIGatewayService) getOpenAIWSIngressContextPool() *openAIWSIngressContextPool {
if s == nil {
return nil
}
- s.openaiWSPoolOnce.Do(func() {
- if s.openaiWSPool == nil {
- s.openaiWSPool = newOpenAIWSConnPool(s.cfg)
+ s.openaiWSIngressCtxOnce.Do(func() {
+ if s.openaiWSIngressCtxPool == nil {
+ pool := newOpenAIWSIngressContextPool(s.cfg)
+ // Ensure the scheduler (and its runtime stats) are initialized
+ // before wiring load-aware signals into the context pool.
+ _ = s.getOpenAIAccountScheduler()
+ pool.schedulerStats = s.openaiAccountStats
+ s.openaiWSIngressCtxPool = pool
}
})
- return s.openaiWSPool
-}
-
-func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
- pool := s.getOpenAIWSConnPool()
- if pool == nil {
- return OpenAIWSPoolMetricsSnapshot{}
- }
- return pool.SnapshotMetrics()
+ return s.openaiWSIngressCtxPool
}
type OpenAIWSPerformanceMetricsSnapshot struct {
- Pool OpenAIWSPoolMetricsSnapshot `json:"pool"`
- Retry OpenAIWSRetryMetricsSnapshot `json:"retry"`
- Transport OpenAIWSTransportMetricsSnapshot `json:"transport"`
+ Retry OpenAIWSRetryMetricsSnapshot `json:"retry"`
+ Abort OpenAIWSAbortMetricsSnapshot `json:"abort"`
+ Transport OpenAIWSTransportMetricsSnapshot `json:"transport"`
+ Passthrough openaiwsv2.MetricsSnapshot `json:"passthrough"`
}
func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot {
- pool := s.getOpenAIWSConnPool()
+ ingressPool := s.getOpenAIWSIngressContextPool()
snapshot := OpenAIWSPerformanceMetricsSnapshot{
- Retry: s.SnapshotOpenAIWSRetryMetrics(),
+ Retry: s.SnapshotOpenAIWSRetryMetrics(),
+ Abort: s.SnapshotOpenAIWSAbortMetrics(),
+ Passthrough: openaiwsv2.SnapshotMetrics(),
}
- if pool == nil {
+ if ingressPool == nil {
return snapshot
}
- snapshot.Pool = pool.SnapshotMetrics()
- snapshot.Transport = pool.SnapshotTransportMetrics()
+ snapshot.Transport = ingressPool.SnapshotTransportMetrics()
return snapshot
}
@@ -943,6 +167,18 @@ func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore {
return s.openaiWSStateStore
}
+func (s *OpenAIGatewayService) getOpenAIWSPassthroughDialer() openAIWSClientDialer {
+ if s == nil {
+ return nil
+ }
+ s.openaiWSPassthroughDialerOnce.Do(func() {
+ if s.openaiWSPassthroughDialer == nil {
+ s.openaiWSPassthroughDialer = newDefaultOpenAIWSClientDialer()
+ }
+ })
+ return s.openaiWSPassthroughDialer
+}
+
func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration {
if s != nil && s.cfg != nil {
seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds
@@ -967,6 +203,20 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
return 15 * time.Minute
}
+func (s *OpenAIGatewayService) openAIWSClientReadIdleTimeout() time.Duration {
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds > 0 {
+ return time.Duration(s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds) * time.Second
+ }
+ return openAIWSClientReadIdleTimeoutDefault
+}
+
+func (s *OpenAIGatewayService) openAIWSPassthroughIdleTimeout() time.Duration {
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds > 0 {
+ return time.Duration(s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds) * time.Second
+ }
+ return openAIWSPassthroughIdleTimeoutDefault
+}
+
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
@@ -1147,8 +397,14 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders(
if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
headers.Set("user-agent", codexCLIUserAgent)
}
- if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) {
+ userAgentIsCodexCLI := openai.IsCodexCLIRequest(headers.Get("user-agent"))
+ if account != nil && account.Type == AccountTypeOAuth && !userAgentIsCodexCLI {
headers.Set("user-agent", codexCLIUserAgent)
+ userAgentIsCodexCLI = true
+ }
+ if account != nil && account.Type == AccountTypeOAuth && userAgentIsCodexCLI {
+ // 保持 OAuth 握手头的一致性:Codex 风格 UA 必须搭配 codex_cli_rs originator。
+ headers.Set("originator", "codex_cli_rs")
}
return headers, sessionResolution
@@ -1202,448 +458,85 @@ func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) {
}
}
-func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool {
- if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() {
- return true
- }
- if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery {
- return true
- }
- return false
-}
-
-func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool {
- if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
- return true
- }
- if len(reqBody) == 0 {
- return false
- }
- rawStore, ok := reqBody["store"]
- if !ok {
- return false
- }
- storeEnabled, ok := rawStore.(bool)
- if !ok {
- return false
- }
- return !storeEnabled
-}
-
-func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool {
- if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
- return true
- }
- if len(reqBody) == 0 {
- return false
- }
- storeValue := gjson.GetBytes(reqBody, "store")
- if !storeValue.Exists() {
- return false
- }
- if storeValue.Type != gjson.True && storeValue.Type != gjson.False {
- return false
- }
- return !storeValue.Bool()
-}
-
-func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string {
- if s == nil || s.cfg == nil {
- return openAIWSStoreDisabledConnModeStrict
- }
- mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode))
- switch mode {
- case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff:
- return mode
- case "":
- // 兼容旧配置:仅配置了布尔开关时按旧语义推导。
- if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
- return openAIWSStoreDisabledConnModeStrict
- }
- return openAIWSStoreDisabledConnModeOff
- default:
- return openAIWSStoreDisabledConnModeStrict
- }
-}
-
-func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool {
- switch mode {
- case openAIWSStoreDisabledConnModeOff:
- return false
- case openAIWSStoreDisabledConnModeAdaptive:
- reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_")
- switch reason {
- case "policy_violation", "message_too_big", "auth_failed", "write_request", "write":
- return true
- default:
- return false
- }
- default:
- return true
- }
-}
-
-func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) {
- return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes)
-}
-
-func dropPreviousResponseIDFromRawPayloadWithDeleteFn(
- payload []byte,
- deleteFn func([]byte, string) ([]byte, error),
-) ([]byte, bool, error) {
- if len(payload) == 0 {
- return payload, false, nil
- }
- if !gjson.GetBytes(payload, "previous_response_id").Exists() {
- return payload, false, nil
- }
- if deleteFn == nil {
- deleteFn = sjson.DeleteBytes
- }
-
- updated := payload
- for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses &&
- gjson.GetBytes(updated, "previous_response_id").Exists(); i++ {
- next, err := deleteFn(updated, "previous_response_id")
- if err != nil {
- return payload, false, err
- }
- updated = next
- }
- return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil
-}
-
-func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) {
- normalizedPrevID := strings.TrimSpace(previousResponseID)
- if len(payload) == 0 || normalizedPrevID == "" {
- return payload, nil
- }
- updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID)
- if err == nil {
- return updated, nil
- }
-
- var reqBody map[string]any
- if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil {
- return nil, err
- }
- reqBody["previous_response_id"] = normalizedPrevID
- rebuilt, marshalErr := json.Marshal(reqBody)
- if marshalErr != nil {
- return nil, marshalErr
- }
- return rebuilt, nil
-}
-
-func shouldInferIngressFunctionCallOutputPreviousResponseID(
- storeDisabled bool,
- turn int,
- hasFunctionCallOutput bool,
- currentPreviousResponseID string,
- expectedPreviousResponseID string,
-) bool {
- if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
- return false
- }
- if strings.TrimSpace(currentPreviousResponseID) != "" {
- return false
- }
- return strings.TrimSpace(expectedPreviousResponseID) != ""
-}
-
-func alignStoreDisabledPreviousResponseID(
- payload []byte,
- expectedPreviousResponseID string,
-) ([]byte, bool, error) {
- if len(payload) == 0 {
- return payload, false, nil
- }
- expected := strings.TrimSpace(expectedPreviousResponseID)
- if expected == "" {
- return payload, false, nil
- }
- current := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
- if current == "" || current == expected {
- return payload, false, nil
- }
-
- withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
- if dropErr != nil {
- return payload, false, dropErr
- }
- if !removed {
- return payload, false, nil
- }
- updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected)
- if setErr != nil {
- return payload, false, setErr
- }
- return updated, true, nil
-}
-
-func cloneOpenAIWSPayloadBytes(payload []byte) []byte {
- if len(payload) == 0 {
- return nil
- }
- cloned := make([]byte, len(payload))
- copy(cloned, payload)
- return cloned
-}
-
-func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage {
- if items == nil {
- return nil
- }
- cloned := make([]json.RawMessage, 0, len(items))
- for idx := range items {
- cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx])))
- }
- return cloned
-}
-
-func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) {
- trimmed := bytes.TrimSpace(raw)
- if len(trimmed) == 0 {
- return nil, errors.New("json is empty")
- }
- var decoded any
- if err := json.Unmarshal(trimmed, &decoded); err != nil {
- return nil, err
- }
- return json.Marshal(decoded)
-}
-
-func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte {
- normalized, err := normalizeOpenAIWSJSONForCompare(raw)
- if err != nil {
- return bytes.TrimSpace(raw)
- }
- return normalized
-}
-
-func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) {
- if len(payload) == 0 {
- return nil, errors.New("payload is empty")
- }
- var decoded map[string]any
- if err := json.Unmarshal(payload, &decoded); err != nil {
- return nil, err
- }
- delete(decoded, "input")
- delete(decoded, "previous_response_id")
- return json.Marshal(decoded)
-}
-
-func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) {
- if len(payload) == 0 {
- return nil, false, nil
- }
- inputValue := gjson.GetBytes(payload, "input")
- if !inputValue.Exists() {
- return nil, false, nil
- }
- if inputValue.Type == gjson.JSON {
- raw := strings.TrimSpace(inputValue.Raw)
- if strings.HasPrefix(raw, "[") {
- var items []json.RawMessage
- if err := json.Unmarshal([]byte(raw), &items); err != nil {
- return nil, true, err
- }
- return items, true, nil
- }
- return []json.RawMessage{json.RawMessage(raw)}, true, nil
- }
- if inputValue.Type == gjson.String {
- encoded, _ := json.Marshal(inputValue.String())
- return []json.RawMessage{encoded}, true, nil
- }
- return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil
-}
-
-func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) {
- previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload)
- if prevErr != nil {
- return false, prevErr
- }
- currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
- if currentErr != nil {
- return false, currentErr
- }
- if !previousExists && !currentExists {
- return true, nil
- }
- if !previousExists {
- return len(currentItems) == 0, nil
- }
- if !currentExists {
- return len(previousItems) == 0, nil
- }
- if len(currentItems) < len(previousItems) {
- return false, nil
- }
-
- for idx := range previousItems {
- previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx])
- currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx])
- if !bytes.Equal(previousNormalized, currentNormalized) {
- return false, nil
- }
- }
- return true, nil
-}
-
-func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool {
- if len(prefix) == 0 {
- return true
- }
- if len(items) < len(prefix) {
- return false
- }
- for idx := range prefix {
- previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx])
- currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx])
- if !bytes.Equal(previousNormalized, currentNormalized) {
- return false
- }
- }
- return true
-}
-
-func buildOpenAIWSReplayInputSequence(
- previousFullInput []json.RawMessage,
- previousFullInputExists bool,
- currentPayload []byte,
- hasPreviousResponseID bool,
-) ([]json.RawMessage, bool, error) {
- currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
- if currentErr != nil {
- return nil, false, currentErr
- }
- if !hasPreviousResponseID {
- return cloneOpenAIWSRawMessages(currentItems), currentExists, nil
- }
- if !previousFullInputExists {
- return cloneOpenAIWSRawMessages(currentItems), currentExists, nil
- }
- if !currentExists || len(currentItems) == 0 {
- return cloneOpenAIWSRawMessages(previousFullInput), true, nil
- }
- if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) {
- return cloneOpenAIWSRawMessages(currentItems), true, nil
- }
- merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems))
- merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...)
- merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...)
- return merged, true, nil
-}
-
-func setOpenAIWSPayloadInputSequence(
- payload []byte,
- fullInput []json.RawMessage,
- fullInputExists bool,
-) ([]byte, error) {
- if !fullInputExists {
- return payload, nil
- }
- // Preserve [] vs null semantics when input exists but is empty.
- inputForMarshal := fullInput
- if inputForMarshal == nil {
- inputForMarshal = []json.RawMessage{}
+func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool {
+ if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() {
+ return true
}
- inputRaw, marshalErr := json.Marshal(inputForMarshal)
- if marshalErr != nil {
- return nil, marshalErr
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery {
+ return true
}
- return sjson.SetRawBytes(payload, "input", inputRaw)
+ return false
}
-func shouldKeepIngressPreviousResponseID(
- previousPayload []byte,
- currentPayload []byte,
- lastTurnResponseID string,
- hasFunctionCallOutput bool,
-) (bool, string, error) {
- if hasFunctionCallOutput {
- return true, "has_function_call_output", nil
- }
- currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
- if currentPreviousResponseID == "" {
- return false, "missing_previous_response_id", nil
+func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool {
+ if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
+ return true
}
- expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
- if expectedPreviousResponseID == "" {
- return false, "missing_last_turn_response_id", nil
+ if len(reqBody) == 0 {
+ return false
}
- if currentPreviousResponseID != expectedPreviousResponseID {
- return false, "previous_response_id_mismatch", nil
+ rawStore, ok := reqBody["store"]
+ if !ok {
+ return false
}
- if len(previousPayload) == 0 {
- return false, "missing_previous_turn_payload", nil
+ storeEnabled, ok := rawStore.(bool)
+ if !ok {
+ return false
}
+ return !storeEnabled
+}
- previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload)
- if previousComparableErr != nil {
- return false, "non_input_compare_error", previousComparableErr
- }
- currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
- if currentComparableErr != nil {
- return false, "non_input_compare_error", currentComparableErr
+func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool {
+ if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
+ return true
}
- if !bytes.Equal(previousComparable, currentComparable) {
- return false, "non_input_changed", nil
+ if len(reqBody) == 0 {
+ return false
}
- return true, "strict_incremental_ok", nil
-}
-
-type openAIWSIngressPreviousTurnStrictState struct {
- nonInputComparable []byte
-}
-
-func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) {
- if len(payload) == 0 {
- return nil, nil
+ storeValue := gjson.GetBytes(reqBody, "store")
+ if !storeValue.Exists() {
+ return false
}
- nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload)
- if nonInputErr != nil {
- return nil, nonInputErr
+ if storeValue.Type != gjson.True && storeValue.Type != gjson.False {
+ return false
}
- return &openAIWSIngressPreviousTurnStrictState{
- nonInputComparable: nonInputComparable,
- }, nil
+ return !storeValue.Bool()
}
-func shouldKeepIngressPreviousResponseIDWithStrictState(
- previousState *openAIWSIngressPreviousTurnStrictState,
- currentPayload []byte,
- lastTurnResponseID string,
- hasFunctionCallOutput bool,
-) (bool, string, error) {
- if hasFunctionCallOutput {
- return true, "has_function_call_output", nil
- }
- currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
- if currentPreviousResponseID == "" {
- return false, "missing_previous_response_id", nil
- }
- expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
- if expectedPreviousResponseID == "" {
- return false, "missing_last_turn_response_id", nil
- }
- if currentPreviousResponseID != expectedPreviousResponseID {
- return false, "previous_response_id_mismatch", nil
+func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string {
+ if s == nil || s.cfg == nil {
+ return openAIWSStoreDisabledConnModeStrict
}
- if previousState == nil {
- return false, "missing_previous_turn_payload", nil
+ mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode))
+ switch mode {
+ case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff:
+ return mode
+ case "":
+ // 兼容旧配置:仅配置了布尔开关时按旧语义推导。
+ if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
+ return openAIWSStoreDisabledConnModeStrict
+ }
+ return openAIWSStoreDisabledConnModeOff
+ default:
+ return openAIWSStoreDisabledConnModeStrict
}
+}
- currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
- if currentComparableErr != nil {
- return false, "non_input_compare_error", currentComparableErr
- }
- if !bytes.Equal(previousState.nonInputComparable, currentComparable) {
- return false, "non_input_changed", nil
+func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool {
+ switch mode {
+ case openAIWSStoreDisabledConnModeOff:
+ return false
+ case openAIWSStoreDisabledConnModeAdaptive:
+ reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_")
+ switch reason {
+ case "policy_violation", "message_too_big", "auth_failed", "write_request", "write":
+ return true
+ default:
+ return false
+ }
+ default:
+ return true
}
- return true, "strict_incremental_ok", nil
}
func (s *OpenAIGatewayService) forwardOpenAIWSV2(
@@ -1660,7 +553,20 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
startTime time.Time,
attempt int,
lastFailureReason string,
-) (*OpenAIForwardResult, error) {
+) (result *OpenAIForwardResult, err error) {
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ logger.LegacyPrintf(
+ "service.openai_ws_forwarder",
+ "[OpenAIWS] recovered panic in forwardOpenAIWSV2: panic=%v stack=%s",
+ recovered,
+ string(debug.Stack()),
+ )
+ err = fmt.Errorf("openai ws panic recovered: %v", recovered)
+ result = nil
+ }
+ }()
+
if s == nil || account == nil {
return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil"))
}
@@ -1669,16 +575,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
if err != nil {
return nil, wrapOpenAIWSFallback("build_ws_url", err)
}
- wsHost := "-"
- wsPath := "-"
- if parsed, parseErr := url.Parse(wsURL); parseErr == nil && parsed != nil {
- if h := strings.TrimSpace(parsed.Host); h != "" {
- wsHost = normalizeOpenAIWSLogValue(h)
- }
- if p := strings.TrimSpace(parsed.Path); p != "" {
- wsPath = normalizeOpenAIWSLogValue(p)
- }
- }
+ wsHost, wsPath := openAIWSHostPathForLogFromURL(wsURL)
logOpenAIWSModeDebug(
"dial_target account_id=%d account_type=%s ws_host=%s ws_path=%s",
account.ID,
@@ -1738,7 +635,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
stateStore := s.getOpenAIWSStateStore()
groupID := getOpenAIGroupIDFromContext(c)
- sessionHash := s.GenerateSessionHash(c, nil)
+ sessionHash := s.GenerateSessionHashWithFallback(c, nil, openAIWSIngressFallbackSessionSeedFromContext(c))
if sessionHash == "" {
var legacySessionHash string
sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey)
@@ -1751,9 +648,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
}
preferredConnID := ""
if stateStore != nil && previousResponseID != "" {
- if connID, ok := stateStore.GetResponseConn(previousResponseID); ok {
- preferredConnID = connID
- }
+ preferredConnID = openAIWSPreferredConnIDFromResponse(stateStore, previousResponseID)
}
storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account)
if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" {
@@ -1800,18 +695,35 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout())
defer acquireCancel()
- lease, err := s.getOpenAIWSConnPool().Acquire(acquireCtx, openAIWSAcquireRequest{
- Account: account,
- WSURL: wsURL,
- Headers: wsHeaders,
- PreferredConnID: preferredConnID,
- ForceNewConn: forceNewConn,
+ ingressCtxPool := s.getOpenAIWSIngressContextPool()
+ if ingressCtxPool == nil {
+ return nil, wrapOpenAIWSFallback("ctx_pool_unavailable", errors.New("openai ws ingress context pool is nil"))
+ }
+ sessionHashForCtx := strings.TrimSpace(sessionHash)
+ if sessionHashForCtx == "" {
+ sessionHashForCtx = fmt.Sprintf("httpws:%d:%d", account.ID, startTime.UnixNano())
+ }
+ if forceNewConn {
+ sessionHashForCtx = fmt.Sprintf("%s:retry:%d", sessionHashForCtx, attempt)
+ }
+ ownerID := fmt.Sprintf("httpws_%d_%d", account.ID, attempt)
+ lease, err := ingressCtxPool.Acquire(acquireCtx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: groupID,
+ SessionHash: sessionHashForCtx,
+ OwnerID: ownerID,
+ WSURL: wsURL,
+ Headers: cloneHeader(wsHeaders),
ProxyURL: func() string {
if account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL()
}
return ""
}(),
+ Turn: 1,
+ HasPreviousResponseID: previousResponseID != "",
+ StrictAffinity: previousResponseID != "",
+ StoreDisabled: storeDisabled,
})
if err != nil {
dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err)
@@ -1971,6 +883,11 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
}
clientDisconnected := false
+ var downstreamWriteErr error
+ var requestCtx context.Context
+ if c != nil && c.Request != nil {
+ requestCtx = c.Request.Context()
+ }
flushBatchSize := s.openAIWSEventFlushBatchSize()
flushInterval := s.openAIWSEventFlushInterval()
pendingFlushEvents := 0
@@ -1988,23 +905,41 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
pendingFlushEvents = 0
lastFlushAt = time.Now()
}
+ var sseFrameBuf []byte
emitStreamMessage := func(message []byte, forceFlush bool) {
- if clientDisconnected {
+ if clientDisconnected || downstreamWriteErr != nil {
return
}
- frame := make([]byte, 0, len(message)+8)
- frame = append(frame, "data: "...)
- frame = append(frame, message...)
- frame = append(frame, '\n', '\n')
- _, wErr := c.Writer.Write(frame)
+ sseFrameBuf = sseFrameBuf[:0]
+ sseFrameBuf = append(sseFrameBuf, "data: "...)
+ sseFrameBuf = append(sseFrameBuf, message...)
+ sseFrameBuf = append(sseFrameBuf, '\n', '\n')
+ _, wErr := c.Writer.Write(sseFrameBuf)
if wErr == nil {
wroteDownstream = true
pendingFlushEvents++
flushStreamWriter(forceFlush)
return
}
- clientDisconnected = true
- logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d", account.ID)
+ if isOpenAIWSStreamWriteDisconnectError(wErr, requestCtx) {
+ clientDisconnected = true
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d conn_id=%s",
+ account.ID,
+ connID,
+ )
+ return
+ }
+ downstreamWriteErr = wErr
+ setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(wErr.Error()), "")
+ logOpenAIWSModeInfo(
+ "stream_write_fail account_id=%d conn_id=%s wrote_downstream=%v cause=%s",
+ account.ID,
+ connID,
+ wroteDownstream,
+ truncateOpenAIWSLogValue(wErr.Error(), openAIWSLogValueMaxLen),
+ )
}
flushBufferedStreamEvents := func(reason string) {
if len(bufferedStreamEvents) == 0 {
@@ -2013,6 +948,9 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
flushed := len(bufferedStreamEvents)
for _, buffered := range bufferedStreamEvents {
emitStreamMessage(buffered, false)
+ if downstreamWriteErr != nil {
+ break
+ }
}
bufferedStreamEvents = bufferedStreamEvents[:0]
flushStreamWriter(true)
@@ -2170,8 +1108,16 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw)
setOpsUpstreamError(c, statusCode, errMsg, "")
if reqStream && !clientDisconnected {
- flushBufferedStreamEvents("error_event")
+ if shouldFlushOpenAIWSBufferedEventsOnError(reqStream, wroteDownstream, clientDisconnected) {
+ flushBufferedStreamEvents("error_event")
+ } else {
+ bufferedStreamEvents = bufferedStreamEvents[:0]
+ }
emitStreamMessage(message, true)
+ if downstreamWriteErr != nil {
+ lease.MarkBroken()
+ return nil, fmt.Errorf("openai ws stream write: %w", downstreamWriteErr)
+ }
}
if !reqStream {
c.JSON(statusCode, gin.H{
@@ -2207,10 +1153,14 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
} else {
flushBufferedStreamEvents(eventType)
emitStreamMessage(message, isTerminalEvent)
+ if downstreamWriteErr != nil {
+ lease.MarkBroken()
+ return nil, fmt.Errorf("openai ws stream write: %w", downstreamWriteErr)
+ }
}
} else {
if responseField.Exists() && responseField.Type == gjson.JSON {
- finalResponse = []byte(responseField.Raw)
+ finalResponse = cloneOpenAIWSJSONRawString(responseField.Raw)
}
}
@@ -2253,7 +1203,14 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
if responseID != "" && stateStore != nil {
ttl := s.openAIWSResponseStickyTTL()
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl))
- stateStore.BindResponseConn(responseID, lease.ConnID(), ttl)
+ if connID, ok := normalizeOpenAIWSPreferredConnID(lease.ConnID()); ok {
+ stateStore.BindResponseConn(responseID, connID, ttl)
+ }
+ if sessionHash != "" && shouldPersistOpenAIWSLastResponseID(lastEventType) {
+ stateStore.BindSessionLastResponseID(groupID, sessionHash, responseID, s.openAIWSSessionStickyTTL())
+ } else if sessionHash != "" {
+ stateStore.DeleteSessionLastResponseID(groupID, sessionHash)
+ }
}
if stateStore != nil && storeDisabled && sessionHash != "" {
stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL())
@@ -2282,14 +1239,16 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
)
return &OpenAIForwardResult{
- RequestID: responseID,
- Usage: *usage,
- Model: originalModel,
- ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
- Stream: reqStream,
- OpenAIWSMode: true,
- Duration: time.Since(startTime),
- FirstTokenMs: firstTokenMs,
+ RequestID: responseID,
+ Usage: *usage,
+ Model: originalModel,
+ ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
+ Stream: reqStream,
+ OpenAIWSMode: true,
+ WSIngressMode: OpenAIWSIngressModeCtxPool,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ TerminalEventType: strings.TrimSpace(lastEventType),
}, nil
}
@@ -2301,9 +1260,27 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
clientConn *coderws.Conn,
account *Account,
token string,
+ firstClientMessageType coderws.MessageType,
firstClientMessage []byte,
hooks *OpenAIWSIngressHooks,
-) error {
+) (err error) {
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ const panicCloseReason = "internal websocket proxy panic"
+ logger.LegacyPrintf(
+ "service.openai_ws_forwarder",
+ "[OpenAIWS] recovered panic in ProxyResponsesWebSocketFromClient: panic=%v stack=%s",
+ recovered,
+ string(debug.Stack()),
+ )
+ err = NewOpenAIWSClientCloseError(
+ coderws.StatusInternalError,
+ panicCloseReason,
+ fmt.Errorf("panic recovered: %v", recovered),
+ )
+ }
+ }()
+
if s == nil {
return errors.New("service is nil")
}
@@ -2322,33 +1299,78 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
- ingressMode := OpenAIWSIngressModeShared
- if modeRouterV2Enabled {
- ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
- if ingressMode == OpenAIWSIngressModeOff {
- return NewOpenAIWSClientCloseError(
- coderws.StatusPolicyViolation,
- "websocket mode is disabled for this account",
- nil,
- )
- }
+ if !modeRouterV2Enabled {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "websocket mode requires mode_router_v2 with ctx_pool/passthrough",
+ nil,
+ )
+ }
+ ingressMode := account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
+ logOpenAIWSModeInfo(
+ "ingress_ws_validate account_id=%d ingress_mode=%s transport=%s",
+ account.ID,
+ normalizeOpenAIWSLogValue(string(ingressMode)),
+ normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
+ )
+ if ingressMode == OpenAIWSIngressModeOff {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "websocket mode is disabled for this account",
+ nil,
+ )
+ }
+ if ingressMode != OpenAIWSIngressModeCtxPool && ingressMode != OpenAIWSIngressModePassthrough {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "websocket mode only supports ctx_pool/passthrough",
+ nil,
+ )
}
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
}
- dedicatedMode := modeRouterV2Enabled && ingressMode == OpenAIWSIngressModeDedicated
-
+ firstModel, firstPreviousResponseID, firstPreviousResponseIDKind := ResolveOpenAIWSFirstMessageMeta(c, firstClientMessage)
+ if firstModel == "" {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "model is required in first response.create payload",
+ nil,
+ )
+ }
+ if ingressMode == OpenAIWSIngressModeCtxPool &&
+ firstPreviousResponseID != "" &&
+ firstPreviousResponseIDKind == OpenAIPreviousResponseIDKindMessageID {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "previous_response_id must be a response.id (resp_*), not a message id",
+ nil,
+ )
+ }
+ if ingressMode == OpenAIWSIngressModePassthrough {
+ return s.proxyResponsesWebSocketV2Passthrough(ctx, c, clientConn, account, token, firstClientMessageType, firstClientMessage, hooks, wsDecision)
+ }
+ // Ingress ws_v2 请求天然是 Codex 会话语义,ctx_pool 是否启用仅由账号 mode 决定。
+ ctxPoolMode := ingressMode == OpenAIWSIngressModeCtxPool
+ ctxPoolSessionScope := ""
+ if ctxPoolMode {
+ ctxPoolSessionScope = openAIWSIngressSessionScopeFromContext(c)
+ }
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
}
- wsHost := "-"
- wsPath := "-"
- if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil {
- wsHost = normalizeOpenAIWSLogValue(parsedURL.Host)
- wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
- }
+ wsHost, wsPath := openAIWSHostPathForLogFromURL(wsURL)
debugEnabled := isOpenAIWSModeDebugEnabled()
+ logOpenAIWSModeInfo(
+ "ingress_ws_session_init account_id=%d ws_host=%s ws_path=%s ctx_pool=%v session_scope=%s debug=%v",
+ account.ID,
+ wsHost,
+ wsPath,
+ ctxPoolMode,
+ truncateOpenAIWSLogValue(ctxPoolSessionScope, openAIWSIDValueMaxLen),
+ debugEnabled,
+ )
type openAIWSClientPayload struct {
payloadRaw []byte
@@ -2475,32 +1497,64 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
stateStore := s.getOpenAIWSStateStore()
groupID := getOpenAIGroupIDFromContext(c)
- sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash)
- if turnState == "" && stateStore != nil && sessionHash != "" {
+ fallbackSessionSeed := openAIWSIngressFallbackSessionSeedFromContext(c)
+ legacySessionHash := strings.TrimSpace(s.GenerateSessionHashWithFallback(c, firstPayload.rawForHash, fallbackSessionSeed))
+ sessionHash := legacySessionHash
+ if ctxPoolMode {
+ sessionHash = openAIWSApplySessionScope(legacySessionHash, ctxPoolSessionScope)
+ }
+ resolveSessionTurnState := func() (string, bool) {
+ if stateStore == nil || sessionHash == "" {
+ return "", false
+ }
if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok {
+ return savedTurnState, true
+ }
+ if !ctxPoolMode || legacySessionHash == "" || legacySessionHash == sessionHash {
+ return "", false
+ }
+ return stateStore.GetSessionTurnState(groupID, legacySessionHash)
+ }
+ resolveSessionLastResponseID := func() (string, bool) {
+ if stateStore == nil || sessionHash == "" {
+ return "", false
+ }
+ if savedResponseID, ok := stateStore.GetSessionLastResponseID(groupID, sessionHash); ok {
+ return strings.TrimSpace(savedResponseID), true
+ }
+ if !ctxPoolMode || legacySessionHash == "" || legacySessionHash == sessionHash {
+ return "", false
+ }
+ savedResponseID, ok := stateStore.GetSessionLastResponseID(groupID, legacySessionHash)
+ return strings.TrimSpace(savedResponseID), ok
+ }
+ if turnState == "" && stateStore != nil && sessionHash != "" {
+ if savedTurnState, ok := resolveSessionTurnState(); ok {
turnState = savedTurnState
}
}
+ sessionLastResponseID := ""
+ if stateStore != nil && sessionHash != "" {
+ if savedResponseID, ok := resolveSessionLastResponseID(); ok {
+ sessionLastResponseID = savedResponseID
+ }
+ }
preferredConnID := ""
if stateStore != nil && firstPayload.previousResponseID != "" {
- if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok {
- preferredConnID = connID
- }
+ preferredConnID = openAIWSPreferredConnIDFromResponse(stateStore, firstPayload.previousResponseID)
}
storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account)
storeDisabledConnMode := s.openAIWSStoreDisabledConnMode()
- if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" {
- if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok {
- preferredConnID = connID
- }
- }
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey)
- baseAcquireReq := openAIWSAcquireRequest{
- Account: account,
+ baseAcquireReq := struct {
+ WSURL string
+ Headers http.Header
+ ProxyURL string
+ }{
WSURL: wsURL,
Headers: wsHeaders,
ProxyURL: func() string {
@@ -2509,21 +1563,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
return ""
}(),
- ForceNewConn: false,
}
- pool := s.getOpenAIWSConnPool()
- if pool == nil {
- return errors.New("openai ws conn pool is nil")
+
+ ingressCtxPool := s.getOpenAIWSIngressContextPool()
+ if ingressCtxPool == nil {
+ return errors.New("openai ws ingress context pool is nil")
}
logOpenAIWSModeInfo(
- "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s store_disabled=%v has_session_hash=%v has_previous_response_id=%v",
+ "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_session_hash=%v has_previous_response_id=%v",
account.ID,
account.Type,
normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
wsHost,
wsPath,
normalizeOpenAIWSLogValue(ingressMode),
+ ctxPoolMode,
storeDisabled,
sessionHash != "",
firstPayload.previousResponseID != "",
@@ -2531,7 +1586,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if debugEnabled {
logOpenAIWSModeDebug(
- "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v",
+ "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v ctx_pool_mode=%v",
account.ID,
account.Type,
normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
@@ -2540,6 +1595,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
sessionHash != "",
firstPayload.previousResponseID != "",
storeDisabled,
+ ctxPoolMode,
)
}
if firstPayload.previousResponseID != "" {
@@ -2566,15 +1622,37 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
acquireTimeout = 30 * time.Second
}
- acquireTurnLease := func(turn int, preferred string, forcePreferredConn bool) (*openAIWSConnLease, error) {
- req := cloneOpenAIWSAcquireRequest(baseAcquireReq)
- req.PreferredConnID = strings.TrimSpace(preferred)
- req.ForcePreferredConn = forcePreferredConn
- // dedicated 模式下每次获取均新建连接,避免跨会话复用残留上下文。
- req.ForceNewConn = dedicatedMode
+ ownerID := fmt.Sprintf("cliws_%p", clientConn)
+ acquireTurnLease := func(
+ turn int,
+ preferred string,
+ forcePreferredConn bool,
+ hasPreviousResponseID bool,
+ ) (openAIWSIngressUpstreamLease, error) {
acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout)
- lease, acquireErr := pool.Acquire(acquireCtx, req)
- acquireCancel()
+ defer acquireCancel()
+
+ var (
+ lease openAIWSIngressUpstreamLease
+ acquireErr error
+ )
+ sessionHashForCtx := strings.TrimSpace(sessionHash)
+ if sessionHashForCtx == "" {
+ sessionHashForCtx = fmt.Sprintf("conn:%s", ownerID)
+ }
+ lease, acquireErr = ingressCtxPool.Acquire(acquireCtx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: groupID,
+ SessionHash: sessionHashForCtx,
+ OwnerID: ownerID,
+ WSURL: baseAcquireReq.WSURL,
+ Headers: cloneHeader(baseAcquireReq.Headers),
+ ProxyURL: baseAcquireReq.ProxyURL,
+ Turn: turn,
+ HasPreviousResponseID: hasPreviousResponseID,
+ StrictAffinity: forcePreferredConn,
+ StoreDisabled: storeDisabled,
+ })
if acquireErr != nil {
dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr)
logOpenAIWSModeInfo(
@@ -2597,14 +1675,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
- if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
- return nil, NewOpenAIWSClientCloseError(
- coderws.StatusPolicyViolation,
- "upstream continuation connection is unavailable; please restart the conversation",
- acquireErr,
- )
- }
- if errors.Is(acquireErr, context.DeadlineExceeded) || errors.Is(acquireErr, errOpenAIWSConnQueueFull) {
+ if errors.Is(acquireErr, context.DeadlineExceeded) ||
+ errors.Is(acquireErr, errOpenAIWSConnQueueFull) ||
+ errors.Is(acquireErr, errOpenAIWSIngressContextBusy) {
return nil, NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket is busy, please retry later",
@@ -2627,7 +1700,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
baseAcquireReq.Headers = updatedHeaders
}
logOpenAIWSModeInfo(
- "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s",
+ "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s ctx_pool_mode=%v schedule_layer=%s stickiness_level=%s migration_used=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
@@ -2635,6 +1708,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
lease.ConnPickDuration().Milliseconds(),
lease.QueueWaitDuration().Milliseconds(),
truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen),
+ ctxPoolMode,
+ normalizeOpenAIWSLogValue(lease.ScheduleLayer()),
+ normalizeOpenAIWSLogValue(lease.StickinessLevel()),
+ lease.MigrationUsed(),
)
return lease, nil
}
@@ -2646,8 +1723,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
readClientMessage := func() ([]byte, error) {
- msgType, payload, readErr := clientConn.Read(ctx)
+ readCtx := ctx
+ if idleTimeout := s.openAIWSClientReadIdleTimeout(); idleTimeout > 0 {
+ var cancel context.CancelFunc
+ readCtx, cancel = context.WithTimeout(ctx, idleTimeout)
+ defer cancel()
+ }
+ msgType, payload, readErr := clientConn.Read(readCtx)
if readErr != nil {
+ if readCtx != nil && readCtx.Err() == context.DeadlineExceeded && (ctx == nil || ctx.Err() == nil) {
+ return nil, NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "client websocket idle timeout",
+ readErr,
+ )
+ }
return nil, readErr
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
@@ -2660,12 +1750,73 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return payload, nil
}
- sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) {
+ // 持久客户端读取 goroutine:将客户端消息推送到 channel,
+ // 使 sendAndRelay 可以通过 select 同时监听上游事件和客户端新请求。
+ var nextClientPreemptedPayload []byte
+ var pendingClientReadErr error
+ clientMsgCh := make(chan []byte, 1)
+ clientReadErrCh := make(chan error, 1)
+ go func() {
+ defer close(clientMsgCh)
+ for {
+ msg, err := readClientMessage()
+ if err != nil {
+ select {
+ case clientReadErrCh <- err:
+ case <-ctx.Done():
+ }
+ return
+ }
+ select {
+ case clientMsgCh <- msg:
+ case <-ctx.Done():
+ return
+ }
+ }
+ }()
+
+ sendAndRelay := func(turn int, lease openAIWSIngressUpstreamLease, payload []byte, payloadBytes int, originalModel string, expectedPreviousResponseID string) (*OpenAIForwardResult, error) {
if lease == nil {
return nil, errors.New("upstream websocket lease is nil")
}
turnStart := time.Now()
wroteDownstream := false
+ reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true)
+ turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID)
+ turnExpectedPreviousResponseID := strings.TrimSpace(expectedPreviousResponseID)
+ turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key")
+ turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account)
+ turnFunctionCallOutputCallIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)
+ turnHasFunctionCallOutput := len(turnFunctionCallOutputCallIDs) > 0
+ turnHasToolOutputContext := openAIWSHasToolCallContextInPayload(payload) ||
+ openAIWSHasItemReferenceForAllFunctionCallOutputsInPayload(payload, turnFunctionCallOutputCallIDs)
+
+ // 预防性检测:必须在发往上游之前执行,避免“先写上游再本地失败”造成 turn 状态错位。
+ // 在 store_disabled 模式下,若 function_call_output 既没有 previous_response_id,
+ // 也没有可关联的 tool_call/item_reference 上下文,则必然会触发 tool_output_not_found。
+ // 提前返回可恢复错误,由外层 recoverIngressPrevResponseNotFound 执行 context replay 重试。
+ if shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(
+ turnStoreDisabled,
+ turnHasFunctionCallOutput,
+ turnPreviousResponseID,
+ turnHasToolOutputContext,
+ ) {
+ logOpenAIWSModeInfo(
+ "ingress_ws_proactive_tool_output_reject account_id=%d turn=%d conn_id=%s has_tool_output_context=%v previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ turnHasToolOutputContext,
+ truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen),
+ )
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"),
+ false,
+ nil,
+ )
+ }
if err := lease.WriteJSONWithContextTimeout(ctx, json.RawMessage(payload), s.openAIWSWriteTimeout()); err != nil {
return nil, wrapOpenAIWSIngressTurnError(
"write_upstream",
@@ -2686,12 +1837,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
responseID := ""
usage := OpenAIUsage{}
var firstTokenMs *int
- reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true)
- turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
- turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID)
- turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key")
- turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account)
- turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
+
+ turnPendingFunctionCallIDSet := make(map[string]struct{}, 4)
eventCount := 0
tokenEventCount := 0
terminalEventCount := 0
@@ -2699,8 +1846,31 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
lastEventType := ""
needModelReplace := false
clientDisconnected := false
+ clientDisconnectDrainDeadline := time.Time{}
+ terminateOnErrorEvent := false
+ terminateOnErrorMessage := ""
mappedModel := ""
var mappedModelBytes []byte
+ buildPartialResult := func(terminalEventType string) *OpenAIForwardResult {
+ if usage.InputTokens <= 0 &&
+ usage.OutputTokens <= 0 &&
+ usage.CacheCreationInputTokens <= 0 &&
+ usage.CacheReadInputTokens <= 0 {
+ return nil
+ }
+ return &OpenAIForwardResult{
+ RequestID: responseID,
+ Usage: usage,
+ Model: originalModel,
+ ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
+ Stream: reqStream,
+ OpenAIWSMode: true,
+ WSIngressMode: OpenAIWSIngressModeCtxPool,
+ Duration: time.Since(turnStart),
+ FirstTokenMs: firstTokenMs,
+ TerminalEventType: strings.TrimSpace(terminalEventType),
+ }
+ }
if originalModel != "" {
mappedModel = account.GetMappedModel(originalModel)
if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" {
@@ -2711,18 +1881,190 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
mappedModelBytes = []byte(mappedModel)
}
}
+ // 启动上游事件读取泵:解耦上游读取和客户端写入,允许二者并发执行。
+ // 读取 goroutine 将上游事件推送到缓冲 channel,主 goroutine 从 channel 消费并处理/转发。
+ // 缓冲 channel 允许上游在客户端写入阻塞时继续读取后续事件,降低端到端延迟。
+ pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize)
+ pumpCtx, pumpCancel := context.WithCancel(ctx)
+ defer pumpCancel()
+ pumpStartedAt := time.Now()
+ go func() {
+ defer func() {
+ close(pumpEventCh)
+ if pumpCtx.Err() == nil {
+ return
+ }
+ pumpAlive := time.Since(pumpStartedAt)
+ if pumpAlive >= openAIWSUpstreamPumpInfoMinAlive {
+ logOpenAIWSModeInfo(
+ "ingress_ws_upstream_pump_exit account_id=%d turn=%d conn_id=%s reason=context_cancelled pump_alive_ms=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ pumpAlive.Milliseconds(),
+ )
+ return
+ }
+ logOpenAIWSModeDebug(
+ "ingress_ws_upstream_pump_exit account_id=%d turn=%d conn_id=%s reason=context_cancelled pump_alive_ms=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ pumpAlive.Milliseconds(),
+ )
+ }()
+ for {
+ msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, s.openAIWSReadTimeout())
+ select {
+ case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}:
+ case <-pumpCtx.Done():
+ return
+ }
+ if readErr != nil {
+ return
+ }
+ // 检测终端/错误事件以终止读取泵。
+ evtType, _ := parseOpenAIWSEventType(msg)
+ if isOpenAIWSTerminalEvent(evtType) || evtType == "error" {
+ return
+ }
+ }
+ }()
+ var drainTimer *time.Timer
+ defer func() {
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+ }()
for {
- upstreamMessage, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout())
- if readErr != nil {
+ var evt openAIWSUpstreamPumpEvent
+ var evtOk bool
+ select {
+ case evt, evtOk = <-pumpEventCh:
+ if !evtOk {
+ goto pumpClosed
+ }
+ case preemptMsg, ok := <-clientMsgCh:
+ if !ok {
+ // 客户端读取 goroutine 退出,置空 channel 防止再次 select
+ clientMsgCh = nil
+ continue
+ }
+ // 客户端抢占:暂存新请求,取消上游转发,返回让外层切换到下一 turn
+ nextClientPreemptedPayload = preemptMsg
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_preempt account_id=%d turn=%d conn_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ )
+ pumpCancel()
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted", errOpenAIWSClientPreempted,
+ wroteDownstream, buildPartialResult("client_preempted"))
+ case readErr := <-clientReadErrCh:
+ // 客户端断连:立即取消上游 pump 并释放连接。
+ // Codex CLI 在 ESC 取消后会关闭旧 WebSocket 并新建连接发送下一条消息,
+ // 继续排水只会延迟新连接获取上游 lease,因此这里直接终止。
+ if isOpenAIWSClientDisconnectError(readErr) {
+ closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_disconnected_immediate_cancel account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ closeStatus,
+ truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
+ )
+ pumpCancel()
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_disconnected_immediate",
+ fmt.Errorf("client disconnected (read): %w", readErr),
+ wroteDownstream,
+ buildPartialResult("client_disconnected"),
+ )
+ }
+ pendingClientReadErr = readErr
+ cause := "-"
+ if readErr != nil {
+ cause = truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen)
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_read_error_deferred account_id=%d turn=%d conn_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ cause,
+ )
+ clientMsgCh = nil
+ clientReadErrCh = nil
+ continue
+ }
+ // 排水超时检查:客户端已断连且排水截止时间已过,终止读取。
+ if clientDisconnected && !clientDisconnectDrainDeadline.IsZero() && time.Now().After(clientDisconnectDrainDeadline) {
+ pumpCancel()
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_disconnected_drain_timeout account_id=%d turn=%d conn_id=%s timeout_ms=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ openAIWSIngressClientDisconnectDrainTimeout.Milliseconds(),
+ )
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_disconnected_drain_timeout",
+ openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout),
+ wroteDownstream,
+ buildPartialResult("client_disconnected_drain_timeout"),
+ )
+ }
+ upstreamMessage := evt.message
+ if evt.err != nil {
+ readErr := evt.err
+ if clientDisconnected {
+ // 排水期间读取失败(上游关闭或读取泵被取消),按排水超时处理。
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_disconnected_drain_timeout account_id=%d turn=%d conn_id=%s timeout_ms=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ openAIWSIngressClientDisconnectDrainTimeout.Milliseconds(),
+ )
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_disconnected_drain_timeout",
+ openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout),
+ wroteDownstream,
+ buildPartialResult("client_disconnected_drain_timeout"),
+ )
+ }
+ readErrClass := classifyOpenAIWSIngressReadErrorClass(readErr)
+ closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
+ logOpenAIWSModeInfo(
+ "ingress_ws_upstream_read_error account_id=%d turn=%d conn_id=%s class=%s close_status=%s close_reason=%s events_received=%d wrote_downstream=%v response_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(readErrClass),
+ closeStatus,
+ closeReason,
+ eventCount,
+ wroteDownstream,
+ truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen),
+ )
lease.MarkBroken()
- return nil, wrapOpenAIWSIngressTurnError(
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
"read_upstream",
fmt.Errorf("read upstream websocket event: %w", readErr),
wroteDownstream,
+ buildPartialResult("read_upstream"),
)
}
- eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage)
+ eventType, eventResponseID := parseOpenAIWSEventType(upstreamMessage)
if responseID == "" && eventResponseID != "" {
responseID = eventResponseID
}
@@ -2737,15 +2079,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage)
fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
+ recoveryEnabled := s.openAIWSIngressPreviousResponseRecoveryEnabled()
recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound &&
+ recoveryEnabled &&
+ (turnPreviousResponseID != "" || (turnHasFunctionCallOutput && turnExpectedPreviousResponseID != "")) &&
+ !wroteDownstream
+ // tool_output_not_found: previous_response_id 指向的 response 包含未完成的 function_call
+ // (用户在 Codex CLI 按 ESC 取消后重新发送消息),需要移除 previous_response_id 后重放。
+ recoverableToolOutputNotFound := fallbackReason == openAIWSIngressStageToolOutputNotFound &&
+ recoveryEnabled &&
turnPreviousResponseID != "" &&
- !turnHasFunctionCallOutput &&
- s.openAIWSIngressPreviousResponseRecoveryEnabled() &&
!wroteDownstream
- if recoverablePrevNotFound {
+ recoverableContextMismatch := recoverablePrevNotFound || recoverableToolOutputNotFound
+ if recoverableContextMismatch {
// 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。
logOpenAIWSModeInfo(
- "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v",
+ "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_prompt_cache_key=%v has_function_call_output=%v recovery_enabled=%v wrote_downstream=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
@@ -2757,12 +2106,17 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(turnPreviousResponseIDKind),
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(ingressMode),
+ ctxPoolMode,
turnStoreDisabled,
turnPromptCacheKey != "",
+ turnHasFunctionCallOutput,
+ recoveryEnabled,
+ wroteDownstream,
)
} else {
logOpenAIWSModeInfo(
- "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v",
+ "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_prompt_cache_key=%v has_function_call_output=%v recovery_enabled=%v wrote_downstream=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
@@ -2774,24 +2128,39 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(turnPreviousResponseIDKind),
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(ingressMode),
+ ctxPoolMode,
turnStoreDisabled,
turnPromptCacheKey != "",
+ turnHasFunctionCallOutput,
+ recoveryEnabled,
+ wroteDownstream,
)
}
- // previous_response_not_found 在 ingress 模式支持单次恢复重试:
+ // previous_response_not_found / tool_output_not_found 在 ingress 模式支持单次恢复重试:
// 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。
- if recoverablePrevNotFound {
+ if recoverableContextMismatch {
lease.MarkBroken()
errMsg := strings.TrimSpace(errMsgRaw)
if errMsg == "" {
- errMsg = "previous response not found"
+ if fallbackReason == openAIWSIngressStageToolOutputNotFound {
+ errMsg = "no tool output found for function call"
+ } else {
+ errMsg = "previous response not found"
+ }
}
- return nil, wrapOpenAIWSIngressTurnError(
- openAIWSIngressStagePreviousResponseNotFound,
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ fallbackReason,
errors.New(errMsg),
false,
+ buildPartialResult(fallbackReason),
)
}
+ terminateOnErrorEvent = true
+ terminateOnErrorMessage = strings.TrimSpace(errMsgRaw)
+ if terminateOnErrorMessage == "" {
+ terminateOnErrorMessage = "upstream websocket error"
+ }
}
isTokenEvent := isOpenAIWSTokenEvent(eventType)
if isTokenEvent {
@@ -2808,6 +2177,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if openAIWSEventShouldParseUsage(eventType) {
parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage)
}
+ if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) {
+ for _, callID := range openAIWSExtractPendingFunctionCallIDsFromEvent(upstreamMessage) {
+ turnPendingFunctionCallIDSet[callID] = struct{}{}
+ }
+ }
if !clientDisconnected {
if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) {
@@ -2820,27 +2194,47 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
if err := writeClientMessage(upstreamMessage); err != nil {
if isOpenAIWSClientDisconnectError(err) {
- clientDisconnected = true
+ // 客户端断连:立即取消上游 pump 并释放连接。
+ // 不再排水等待,以便新连接能尽快获取上游 lease。
closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err)
logOpenAIWSModeInfo(
- "ingress_ws_client_disconnected_drain account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s",
+ "ingress_ws_client_disconnected_immediate_cancel account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
closeStatus,
truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
)
+ pumpCancel()
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_disconnected_immediate",
+ fmt.Errorf("client disconnected (write): %w", err),
+ wroteDownstream,
+ buildPartialResult("client_disconnected"),
+ )
} else {
- return nil, wrapOpenAIWSIngressTurnError(
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
"write_client",
fmt.Errorf("write client websocket event: %w", err),
wroteDownstream,
+ buildPartialResult("write_client"),
)
}
} else {
wroteDownstream = true
}
}
+ if terminateOnErrorEvent {
+ // WS ingress 中的 error 事件应立即终止当前 turn,避免继续阻塞在下一次上游 read。
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "upstream_error_event",
+ errors.New(terminateOnErrorMessage),
+ wroteDownstream,
+ buildPartialResult("upstream_error_event"),
+ )
+ }
if isTerminalEvent {
// 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。
if clientDisconnected {
@@ -2852,7 +2246,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
if debugEnabled {
logOpenAIWSModeDebug(
- "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v",
+ "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v has_function_call_output=%v pending_function_call_ids=%d",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
@@ -2865,20 +2259,47 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen),
firstTokenMsValue,
clientDisconnected,
+ turnHasFunctionCallOutput,
+ len(turnPendingFunctionCallIDSet),
)
}
+ pendingFunctionCallIDs := make([]string, 0, len(turnPendingFunctionCallIDSet))
+ for callID := range turnPendingFunctionCallIDSet {
+ pendingFunctionCallIDs = append(pendingFunctionCallIDs, callID)
+ }
+ sort.Strings(pendingFunctionCallIDs)
return &OpenAIForwardResult{
- RequestID: responseID,
- Usage: usage,
- Model: originalModel,
- ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
- Stream: reqStream,
- OpenAIWSMode: true,
- Duration: time.Since(turnStart),
- FirstTokenMs: firstTokenMs,
+ RequestID: responseID,
+ Usage: usage,
+ Model: originalModel,
+ ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
+ Stream: reqStream,
+ OpenAIWSMode: true,
+ WSIngressMode: OpenAIWSIngressModeCtxPool,
+ Duration: time.Since(turnStart),
+ FirstTokenMs: firstTokenMs,
+ TerminalEventType: strings.TrimSpace(eventType),
+ PendingFunctionCallIDs: pendingFunctionCallIDs,
}, nil
}
}
+ pumpClosed:
+ // 读取泵 channel 关闭但未收到终端事件:
+ // - 客户端已断连:按排水超时收尾,避免误判为 read_upstream。
+ // - 其他场景:按上游读取异常处理。
+ lease.MarkBroken()
+ if clientDisconnected {
+ return nil, openAIWSIngressPumpClosedTurnError(
+ true,
+ wroteDownstream,
+ buildPartialResult("client_disconnected_drain_timeout"),
+ )
+ }
+ return nil, openAIWSIngressPumpClosedTurnError(
+ false,
+ wroteDownstream,
+ buildPartialResult("read_upstream"),
+ )
}
currentPayload := firstPayload.payloadRaw
@@ -2890,50 +2311,39 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != ""
}
- var sessionLease *openAIWSConnLease
+ var sessionLease openAIWSIngressUpstreamLease
sessionConnID := ""
- pinnedSessionConnID := ""
- unpinSessionConn := func(connID string) {
- connID = strings.TrimSpace(connID)
- if connID == "" || pinnedSessionConnID != connID {
- return
- }
- pool.UnpinConn(account.ID, connID)
- pinnedSessionConnID = ""
- }
- pinSessionConn := func(connID string) {
- if !storeDisabled {
- return
- }
- connID = strings.TrimSpace(connID)
- if connID == "" || pinnedSessionConnID == connID {
+ unpinSessionConn := func(_ string) {}
+ pinSessionConn := func(_ string) {}
+ releaseSessionLease := func() {
+ if sessionLease == nil {
return
}
- if pinnedSessionConnID != "" {
- pool.UnpinConn(account.ID, pinnedSessionConnID)
- pinnedSessionConnID = ""
- }
- if pool.PinConn(account.ID, connID) {
- pinnedSessionConnID = connID
+ unpinSessionConn(sessionConnID)
+ sessionLease.Release()
+ if debugEnabled {
+ logOpenAIWSModeDebug(
+ "ingress_ws_upstream_released account_id=%d conn_id=%s",
+ account.ID,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ )
}
}
- releaseSessionLease := func() {
+ yieldSessionLease := func() {
if sessionLease == nil {
return
}
- if dedicatedMode {
- // dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。
- sessionLease.MarkBroken()
- }
unpinSessionConn(sessionConnID)
- sessionLease.Release()
+ sessionLease.Yield()
if debugEnabled {
logOpenAIWSModeDebug(
- "ingress_ws_upstream_released account_id=%d conn_id=%s",
+ "ingress_ws_upstream_yielded account_id=%d conn_id=%s",
account.ID,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
)
}
+ sessionLease = nil
+ sessionConnID = ""
}
defer releaseSessionLease()
@@ -2941,7 +2351,17 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
turnRetry := 0
turnPrevRecoveryTried := false
lastTurnFinishedAt := time.Time{}
- lastTurnResponseID := ""
+ lastTurnResponseID := sessionLastResponseID
+ clearSessionLastResponseID := func() {
+ lastTurnResponseID = ""
+ if stateStore == nil || sessionHash == "" {
+ return
+ }
+ stateStore.DeleteSessionLastResponseID(groupID, sessionHash)
+ if ctxPoolMode && legacySessionHash != "" && legacySessionHash != sessionHash {
+ stateStore.DeleteSessionLastResponseID(groupID, legacySessionHash)
+ }
+ }
lastTurnPayload := []byte(nil)
var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState
lastTurnReplayInput := []json.RawMessage(nil)
@@ -2953,6 +2373,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if sessionLease == nil {
return
}
+ resetStart := time.Now()
+ resetConnID := sessionConnID
if markBroken {
sessionLease.MarkBroken()
}
@@ -2960,12 +2382,194 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
sessionLease = nil
sessionConnID = ""
preferredConnID = ""
+ if elapsed := time.Since(resetStart); elapsed > 100*time.Millisecond {
+ logOpenAIWSModeInfo(
+ "ingress_ws_reset_session_lease_slow account_id=%d conn_id=%s mark_broken=%v elapsed_ms=%d",
+ account.ID,
+ truncateOpenAIWSLogValue(resetConnID, openAIWSIDValueMaxLen),
+ markBroken,
+ elapsed.Milliseconds(),
+ )
+ }
}
recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool {
- if !isOpenAIWSIngressPreviousResponseNotFound(relayErr) {
+ isPrevNotFound := isOpenAIWSIngressPreviousResponseNotFound(relayErr)
+ isToolOutputMissing := isOpenAIWSIngressToolOutputNotFound(relayErr)
+ if !isPrevNotFound && !isToolOutputMissing {
return false
}
if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() {
+ skipReason := "already_tried"
+ if !s.openAIWSIngressPreviousResponseRecoveryEnabled() {
+ skipReason = "recovery_disabled"
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skipped account_id=%d turn=%d conn_id=%s reason=%s is_prev_not_found=%v is_tool_output_missing=%v",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(skipReason),
+ isPrevNotFound,
+ isToolOutputMissing,
+ )
+ return false
+ }
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
+ // tool_output_not_found: previous_response_id 指向的 response 包含未完成的 function_call
+ // (用户在 Codex CLI 按 ESC 取消了 function_call 后重新发送消息)。
+ // 对齐/保持 previous_response_id 无法解决问题,直接跳到 drop 分支移除后重放。
+ if isToolOutputMissing {
+ logOpenAIWSModeInfo(
+ "ingress_ws_tool_output_not_found_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ )
+ turnPrevRecoveryTried = true
+ updatedPayload := currentPayload
+ if currentPreviousResponseID != "" {
+ dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
+ if dropErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_tool_output_not_found_recovery_skip account_id=%d turn=%d conn_id=%s reason=drop_error",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ )
+ return false
+ }
+ if removed {
+ updatedPayload = dropped
+ }
+ }
+ // previous_response_id 已不存在或已移除,继续执行 setOpenAIWSPayloadInputSequence
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ currentTurnReplayInput,
+ currentTurnReplayInputExists,
+ )
+ if setInputErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_tool_output_not_found_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
+ )
+ return false
+ }
+ currentPayload = updatedWithInput
+ currentPayloadBytes = len(updatedWithInput)
+ clearSessionLastResponseID()
+ resetSessionLease(true)
+ skipBeforeTurn = true
+ return true
+ }
+ hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
+ if hasFunctionCallOutput {
+ turnPrevRecoveryTried = true
+ expectedPrev := strings.TrimSpace(lastTurnResponseID)
+ if currentPreviousResponseID == "" && expectedPrev != "" {
+ updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev)
+ if setPrevErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ } else {
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ currentTurnReplayInput,
+ currentTurnReplayInputExists,
+ )
+ if setInputErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s previous_response_id=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ } else {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=set_previous_response_id_retry previous_response_id=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ currentPayload = updatedWithInput
+ currentPayloadBytes = len(updatedWithInput)
+ resetSessionLease(true)
+ skipBeforeTurn = true
+ return true
+ }
+ }
+ }
+ alignedPayload, aligned, alignErr := alignStoreDisabledPreviousResponseID(currentPayload, expectedPrev)
+ if alignErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=align_previous_response_id_error cause=%s previous_response_id=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(alignErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ } else if aligned {
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ alignedPayload,
+ currentTurnReplayInput,
+ currentTurnReplayInputExists,
+ )
+ if setInputErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s previous_response_id=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ } else {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=align_previous_response_id_retry previous_response_id=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ currentPayload = updatedWithInput
+ currentPayloadBytes = len(updatedWithInput)
+ resetSessionLease(true)
+ skipBeforeTurn = true
+ return true
+ }
+ }
+ // function_call_output 与 previous_response_id 语义绑定:
+ // function_call_output 引用了前一个 response 中的 call_id,
+ // 移除 previous_response_id 但保留 function_call_output 会导致上游报错
+ // "No tool call found for function call output with call_id ..."。
+ // 此场景在网关层不可恢复,返回 false 走 abort 路径通知客户端,
+ // 客户端收到错误后会重置并发送完整请求(不带 previous_response_id)。
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=abort_function_call_unrecoverable previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ )
return false
}
if isStrictAffinityTurn(currentPayload) {
@@ -3018,12 +2622,26 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
)
currentPayload = updatedWithInput
currentPayloadBytes = len(updatedWithInput)
+ clearSessionLastResponseID()
resetSessionLease(true)
skipBeforeTurn = true
return true
}
retryIngressTurn := func(relayErr error, turn int, connID string) bool {
if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 {
+ retrySkipReason := "not_retryable"
+ if turnRetry >= 1 {
+ retrySkipReason = "retry_exhausted"
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_retry_skipped account_id=%d turn=%d conn_id=%s reason=%s retry_count=%d err_stage=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(retrySkipReason),
+ turnRetry,
+ truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen),
+ )
return false
}
if isStrictAffinityTurn(currentPayload) {
@@ -3048,6 +2666,160 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
skipBeforeTurn = true
return true
}
+ advanceToNextClientTurn := func(turn int, connID string) (bool, error) {
+ logOpenAIWSModeInfo(
+ "ingress_ws_advance_wait_client account_id=%d turn=%d conn_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ )
+ var nextClientMessage []byte
+ if nextClientPreemptedPayload != nil {
+ nextClientMessage = nextClientPreemptedPayload
+ nextClientPreemptedPayload = nil
+ logOpenAIWSModeInfo(
+ "ingress_ws_advance_use_preempted_payload account_id=%d turn=%d conn_id=%s payload_bytes=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ len(nextClientMessage),
+ )
+ } else {
+ if pendingReadErr := openAIWSAdvanceConsumePendingClientReadErr(&pendingClientReadErr); pendingReadErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_advance_read_fail account_id=%d turn=%d conn_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(pendingReadErr.Error(), openAIWSLogValueMaxLen),
+ )
+ return false, pendingReadErr
+ }
+ if openAIWSAdvanceClientReadUnavailable(clientMsgCh, clientReadErrCh) {
+ logOpenAIWSModeInfo(
+ "ingress_ws_advance_read_unavailable account_id=%d turn=%d conn_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ )
+ return false, fmt.Errorf("read client websocket request: %w", errOpenAIWSAdvanceClientReadUnavailable)
+ }
+ select {
+ case msg, ok := <-clientMsgCh:
+ if !ok {
+ // 客户端读取 goroutine 已退出
+ return true, nil
+ }
+ nextClientMessage = msg
+ case readErr := <-clientReadErrCh:
+ if isOpenAIWSClientDisconnectError(readErr) {
+ closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s",
+ account.ID,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ closeStatus,
+ truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
+ )
+ return true, nil
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_advance_read_fail account_id=%d turn=%d conn_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen),
+ )
+ return false, fmt.Errorf("read client websocket request: %w", readErr)
+ }
+ }
+
+ nextPayload, parseErr := parseClientPayload(nextClientMessage)
+ if parseErr != nil {
+ return false, parseErr
+ }
+ nextStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(nextPayload.payloadRaw, account)
+ nextLegacySessionHash := strings.TrimSpace(s.GenerateSessionHashWithFallback(c, nextPayload.rawForHash, fallbackSessionSeed))
+ nextSessionHash := nextLegacySessionHash
+ if ctxPoolMode {
+ nextSessionHash = openAIWSApplySessionScope(nextLegacySessionHash, ctxPoolSessionScope)
+ }
+ if sessionHash == "" && nextSessionHash != "" {
+ sessionHash = nextSessionHash
+ legacySessionHash = nextLegacySessionHash
+ if stateStore != nil {
+ if turnState == "" {
+ if savedTurnState, ok := resolveSessionTurnState(); ok {
+ turnState = savedTurnState
+ }
+ }
+ if lastTurnResponseID == "" {
+ if savedResponseID, ok := resolveSessionLastResponseID(); ok {
+ lastTurnResponseID = savedResponseID
+ }
+ }
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_session_hash_backfill account_id=%d turn=%d next_turn=%d conn_id=%s session_hash=%s has_turn_state=%v has_last_response_id=%v store_disabled=%v",
+ account.ID,
+ turn,
+ turn+1,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(sessionHash, 12),
+ turnState != "",
+ strings.TrimSpace(lastTurnResponseID) != "",
+ nextStoreDisabled,
+ )
+ }
+ if nextPayload.promptCacheKey != "" {
+ // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接;
+ // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。
+ updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey)
+ baseAcquireReq.Headers = updatedHeaders
+ }
+ if nextPayload.previousResponseID != "" {
+ expectedPrev := strings.TrimSpace(lastTurnResponseID)
+ chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev
+ nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID)
+ logOpenAIWSModeInfo(
+ "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v",
+ account.ID,
+ turn,
+ turn+1,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(nextPreviousResponseIDKind),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ chainedFromLast,
+ nextPayload.promptCacheKey != "",
+ storeDisabled,
+ )
+ }
+ if stateStore != nil && nextPayload.previousResponseID != "" {
+ if stickyConnID := openAIWSPreferredConnIDFromResponse(stateStore, nextPayload.previousResponseID); stickyConnID != "" {
+ if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID {
+ logOpenAIWSModeInfo(
+ "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
+ )
+ } else {
+ preferredConnID = stickyConnID
+ }
+ }
+ }
+ currentPayload = nextPayload.payloadRaw
+ currentOriginalModel = nextPayload.originalModel
+ currentPayloadBytes = nextPayload.payloadBytes
+ storeDisabled = nextStoreDisabled
+ if !storeDisabled {
+ unpinSessionConn(sessionConnID)
+ }
+ return false, nil
+ }
for {
if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil {
if err := hooks.BeforeTurn(turn); err != nil {
@@ -3057,39 +2829,78 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
skipBeforeTurn = false
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
expectedPrev := strings.TrimSpace(lastTurnResponseID)
- hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
- // store=false + function_call_output 场景必须有续链锚点。
- // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
- if shouldInferIngressFunctionCallOutputPreviousResponseID(
- storeDisabled,
+ if expectedPrev == "" && stateStore != nil && sessionHash != "" {
+ if savedResponseID, ok := resolveSessionLastResponseID(); ok {
+ expectedPrev = savedResponseID
+ if expectedPrev != "" {
+ lastTurnResponseID = expectedPrev
+ }
+ }
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_begin account_id=%d turn=%d conn_id=%s previous_response_id=%s expected_previous_response_id=%s store_disabled=%v has_session_lease=%v",
+ account.ID,
turn,
- hasFunctionCallOutput,
- currentPreviousResponseID,
- expectedPrev,
- ) {
- updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev)
- if setPrevErr != nil {
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ storeDisabled,
+ sessionLease != nil,
+ )
+ pendingExpectedCallIDs := []string(nil)
+ if storeDisabled && expectedPrev != "" && stateStore != nil {
+ if pendingCallIDs, ok := stateStore.GetResponsePendingToolCalls(groupID, expectedPrev); ok {
+ pendingExpectedCallIDs = openAIWSNormalizeCallIDs(pendingCallIDs)
+ }
+ }
+ normalized := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: account.ID,
+ turn: turn,
+ connID: sessionConnID,
+ currentPayload: currentPayload,
+ currentPayloadBytes: currentPayloadBytes,
+ currentPreviousResponseID: currentPreviousResponseID,
+ expectedPreviousResponse: expectedPrev,
+ pendingExpectedCallIDs: pendingExpectedCallIDs,
+ })
+ currentPayload = normalized.currentPayload
+ currentPayloadBytes = normalized.currentPayloadBytes
+ currentPreviousResponseID = normalized.currentPreviousResponseID
+ expectedPrev = normalized.expectedPreviousResponseID
+ pendingExpectedCallIDs = normalized.pendingExpectedCallIDs
+ currentFunctionCallOutputCallIDs := normalized.functionCallOutputCallIDs
+ hasFunctionCallOutput := normalized.hasFunctionCallOutputCallID
+
+ // 当客户端发送 function_call_output 但未携带 previous_response_id 时,
+ // 主动注入 Gateway 跟踪的 lastTurnResponseID。
+ // 在 store_disabled 模式下,上游需要 previous_response_id 来关联 function_call_output 与 response,
+ // 否则会返回 "No tool call found for function call output" 错误。
+ if shouldInferIngressFunctionCallOutputPreviousResponseID(storeDisabled, turn, hasFunctionCallOutput, currentPreviousResponseID, expectedPrev) {
+ injectedPayload, injectErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev)
+ if injectErr != nil {
logOpenAIWSModeInfo(
- "ingress_ws_function_call_output_prev_infer_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s",
+ "ingress_ws_inject_prev_response_id_fail account_id=%d turn=%d conn_id=%s cause=%s expected_previous_response_id=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(injectErr.Error(), openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
)
} else {
- currentPayload = updatedPayload
- currentPayloadBytes = len(updatedPayload)
- currentPreviousResponseID = expectedPrev
logOpenAIWSModeInfo(
- "ingress_ws_function_call_output_prev_infer account_id=%d turn=%d conn_id=%s action=set_previous_response_id previous_response_id=%s",
+ "ingress_ws_inject_prev_response_id account_id=%d turn=%d conn_id=%s injected_previous_response_id=%s has_function_call_output=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ hasFunctionCallOutput,
)
+ currentPayload = injectedPayload
+ currentPayloadBytes = len(injectedPayload)
+ currentPreviousResponseID = expectedPrev
}
}
+
nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence(
lastTurnReplayInput,
lastTurnReplayInputExists,
@@ -3120,6 +2931,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentPayload,
lastTurnResponseID,
hasFunctionCallOutput,
+ pendingExpectedCallIDs,
+ currentFunctionCallOutputCallIDs,
)
} else {
shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID(
@@ -3127,6 +2940,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentPayload,
lastTurnResponseID,
hasFunctionCallOutput,
+ pendingExpectedCallIDs,
+ currentFunctionCallOutputCallIDs,
)
}
if strictErr != nil {
@@ -3196,9 +3011,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
}
forcePreferredConn := isStrictAffinityTurn(currentPayload)
+ hasPreviousResponseIDForAcquire := currentPreviousResponseID != ""
if sessionLease == nil {
- acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn)
+ acquiredLease, acquireErr := acquireTurnLease(
+ turn,
+ preferredConnID,
+ forcePreferredConn,
+ hasPreviousResponseIDForAcquire,
+ )
if acquireErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_acquire_lease_fail account_id=%d turn=%d conn_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen),
+ )
return fmt.Errorf("acquire upstream websocket: %w", acquireErr)
}
sessionLease = acquiredLease
@@ -3224,64 +3052,14 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen),
)
- if forcePreferredConn {
- if !turnPrevRecoveryTried && currentPreviousResponseID != "" {
- updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
- if dropErr != nil || !removed {
- reason := "not_removed"
- if dropErr != nil {
- reason = "drop_error"
- }
- logOpenAIWSModeInfo(
- "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s previous_response_id=%s",
- account.ID,
- turn,
- truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
- normalizeOpenAIWSLogValue(reason),
- truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
- )
- } else {
- updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
- updatedPayload,
- currentTurnReplayInput,
- currentTurnReplayInputExists,
- )
- if setInputErr != nil {
- logOpenAIWSModeInfo(
- "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error previous_response_id=%s cause=%s",
- account.ID,
- turn,
- truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
- )
- } else {
- logOpenAIWSModeInfo(
- "ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s",
- account.ID,
- turn,
- truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
- )
- turnPrevRecoveryTried = true
- currentPayload = updatedWithInput
- currentPayloadBytes = len(updatedWithInput)
- resetSessionLease(true)
- skipBeforeTurn = true
- continue
- }
- }
- }
- resetSessionLease(true)
- return NewOpenAIWSClientCloseError(
- coderws.StatusPolicyViolation,
- "upstream continuation connection is unavailable; please restart the conversation",
- pingErr,
- )
- }
+ // preflight ping 失败:直接重连,不修改 payload
resetSessionLease(true)
-
- acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn)
+ acquiredLease, acquireErr := acquireTurnLease(
+ turn,
+ preferredConnID,
+ forcePreferredConn,
+ currentPreviousResponseID != "",
+ )
if acquireErr != nil {
return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr)
}
@@ -3315,7 +3093,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
)
}
- result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel)
+ result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, expectedPrev)
if relayErr != nil {
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
continue
@@ -3327,11 +3105,108 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if unwrapped := errors.Unwrap(relayErr); unwrapped != nil {
finalErr = unwrapped
}
+ abortReason, abortExpected := classifyOpenAIWSIngressTurnAbortReason(relayErr)
+ s.recordOpenAIWSTurnAbort(abortReason, abortExpected)
+ logOpenAIWSIngressTurnAbort(account.ID, turn, connID, abortReason, abortExpected, finalErr)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turn, nil, finalErr)
}
- sessionLease.MarkBroken()
- return finalErr
+ switch openAIWSIngressTurnAbortDispositionForReason(abortReason) {
+ case openAIWSIngressTurnAbortDispositionContinueTurn:
+ switch abortReason {
+ case openAIWSIngressTurnAbortReasonClientPreempted:
+ // 客户端抢占:不通知 error(客户端已发出新请求,不需要旧 turn 的错误事件),
+ // 保留上一轮 response_id(被抢占的 turn 未完成,上一轮 response_id 仍有效供新 turn 续链)。
+ preemptRecoverStart := time.Now()
+ resetSessionLease(true)
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_preempt_recover account_id=%d turn=%d conn_id=%s reset_elapsed_ms=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ time.Since(preemptRecoverStart).Milliseconds(),
+ )
+ case openAIWSIngressTurnAbortReasonUpstreamRestart:
+ // 上游重启(1012/1013):连接级关闭,客户端未收到任何终端事件,
+ // 始终补发 error 事件(无论 wroteDownstream 状态),避免客户端永远等待响应。
+ abortMessage := "upstream service restarting, please retry"
+ if finalErr != nil {
+ abortMessage = finalErr.Error()
+ }
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`)
+ if writeErr := writeClientMessage(errorEvent); writeErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_upstream_restart_notify_failed account_id=%d turn=%d conn_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(writeErr.Error(), openAIWSLogValueMaxLen),
+ )
+ } else {
+ logOpenAIWSModeInfo(
+ "ingress_ws_upstream_restart_notified account_id=%d turn=%d conn_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ )
+ }
+ resetSessionLease(true)
+ clearSessionLastResponseID()
+ default:
+ // turn 级终止:当前 turn 失败,但客户端 WS 会话继续。
+ // 这样可与 Codex 客户端语义对齐:后续 turn 允许在新上游连接上继续进行。
+ //
+ // 关键修复:若未向客户端写入过任何数据 (wroteDownstream=false),
+ // 必须补发 error 事件通知客户端本轮失败,否则客户端会一直等待响应,
+ // 而服务端在 advanceToNextClientTurn 中等待客户端下一条消息 → 双向死锁。
+ if !openAIWSIngressTurnWroteDownstream(relayErr) {
+ abortMessage := "turn failed: " + string(abortReason)
+ if finalErr != nil {
+ abortMessage = finalErr.Error()
+ }
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`)
+ if writeErr := writeClientMessage(errorEvent); writeErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_abort_notify_failed account_id=%d turn=%d conn_id=%s reason=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(string(abortReason)),
+ truncateOpenAIWSLogValue(writeErr.Error(), openAIWSLogValueMaxLen),
+ )
+ } else {
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_abort_notified account_id=%d turn=%d conn_id=%s reason=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(string(abortReason)),
+ )
+ }
+ }
+ resetSessionLease(true)
+ clearSessionLastResponseID()
+ }
+ turnRetry = 0
+ turnPrevRecoveryTried = false
+ exit, advanceErr := advanceToNextClientTurn(turn, connID)
+ if advanceErr != nil {
+ return advanceErr
+ }
+ if exit {
+ return nil
+ }
+ s.recordOpenAIWSTurnAbortRecovered()
+ turn++
+ continue
+ case openAIWSIngressTurnAbortDispositionCloseGracefully:
+ resetSessionLease(true)
+ clearSessionLastResponseID()
+ return nil
+ case openAIWSIngressTurnAbortDispositionFailRequest:
+ sessionLease.MarkBroken()
+ return finalErr
+ }
}
turnRetry = 0
turnPrevRecoveryTried = false
@@ -3343,7 +3218,27 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return errors.New("websocket turn result is nil")
}
responseID := strings.TrimSpace(result.RequestID)
- lastTurnResponseID = responseID
+ persistLastResponseID := responseID != "" && shouldPersistOpenAIWSLastResponseID(result.TerminalEventType)
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d persist_response_id=%v has_function_call_output=%v pending_function_calls=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ result.Duration.Milliseconds(),
+ persistLastResponseID,
+ hasFunctionCallOutput,
+ len(result.PendingFunctionCallIDs),
+ )
+ if persistLastResponseID {
+ lastTurnResponseID = responseID
+ } else if responseID != "" && len(result.PendingFunctionCallIDs) > 0 {
+ // response 未 completed/done(如 incomplete/failed/cancelled),但包含未完成的 function_call。
+ // 保留 response_id 以便下一个 turn 的 function_call_output 能够关联。
+ lastTurnResponseID = responseID
+ } else {
+ clearSessionLastResponseID()
+ }
lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload)
lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput)
lastTurnReplayInputExists = currentTurnReplayInputExists
@@ -3361,84 +3256,39 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
lastTurnStrictState = nextStrictState
}
+ if stateStore != nil &&
+ expectedPrev != "" &&
+ currentPreviousResponseID == expectedPrev &&
+ (hasFunctionCallOutput || len(pendingExpectedCallIDs) > 0) {
+ stateStore.DeleteResponsePendingToolCalls(groupID, expectedPrev)
+ }
+
if responseID != "" && stateStore != nil {
ttl := s.openAIWSResponseStickyTTL()
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl))
- stateStore.BindResponseConn(responseID, connID, ttl)
- }
- if stateStore != nil && storeDisabled && sessionHash != "" {
- stateStore.BindSessionConn(groupID, sessionHash, connID, s.openAIWSSessionStickyTTL())
+ if poolConnID, ok := normalizeOpenAIWSPreferredConnID(connID); ok {
+ stateStore.BindResponseConn(responseID, poolConnID, ttl)
+ }
+ if pendingFunctionCallIDs := openAIWSNormalizeCallIDs(result.PendingFunctionCallIDs); len(pendingFunctionCallIDs) > 0 {
+ stateStore.BindResponsePendingToolCalls(groupID, responseID, pendingFunctionCallIDs, ttl)
+ } else {
+ stateStore.DeleteResponsePendingToolCalls(groupID, responseID)
+ }
+ if sessionHash != "" && persistLastResponseID {
+ stateStore.BindSessionLastResponseID(groupID, sessionHash, responseID, s.openAIWSSessionStickyTTL())
+ }
}
if connID != "" {
preferredConnID = connID
}
+ yieldSessionLease()
- nextClientMessage, readErr := readClientMessage()
- if readErr != nil {
- if isOpenAIWSClientDisconnectError(readErr) {
- closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
- logOpenAIWSModeInfo(
- "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s",
- account.ID,
- truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
- closeStatus,
- truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
- )
- return nil
- }
- return fmt.Errorf("read client websocket request: %w", readErr)
- }
-
- nextPayload, parseErr := parseClientPayload(nextClientMessage)
- if parseErr != nil {
- return parseErr
- }
- if nextPayload.promptCacheKey != "" {
- // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接;
- // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。
- updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey)
- baseAcquireReq.Headers = updatedHeaders
- }
- if nextPayload.previousResponseID != "" {
- expectedPrev := strings.TrimSpace(lastTurnResponseID)
- chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev
- nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID)
- logOpenAIWSModeInfo(
- "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v",
- account.ID,
- turn,
- turn+1,
- truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
- normalizeOpenAIWSLogValue(nextPreviousResponseIDKind),
- truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
- chainedFromLast,
- nextPayload.promptCacheKey != "",
- storeDisabled,
- )
- }
- if stateStore != nil && nextPayload.previousResponseID != "" {
- if stickyConnID, ok := stateStore.GetResponseConn(nextPayload.previousResponseID); ok {
- if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID {
- logOpenAIWSModeInfo(
- "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s",
- account.ID,
- turn,
- truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
- )
- } else {
- preferredConnID = stickyConnID
- }
- }
+ exit, advanceErr := advanceToNextClientTurn(turn, connID)
+ if advanceErr != nil {
+ return advanceErr
}
- currentPayload = nextPayload.payloadRaw
- currentOriginalModel = nextPayload.originalModel
- currentPayloadBytes = nextPayload.payloadBytes
- storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account)
- if !storeDisabled {
- unpinSessionConn(sessionConnID)
+ if exit {
+ return nil
}
turn++
}
@@ -3452,7 +3302,7 @@ func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool {
// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。
func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
ctx context.Context,
- lease *openAIWSConnLease,
+ lease openAIWSIngressUpstreamLease,
decision OpenAIWSProtocolDecision,
payload map[string]any,
previousResponseID string,
@@ -3540,7 +3390,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr)
}
- eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message)
+ eventType, eventResponseID := parseOpenAIWSEventType(message)
if eventType == "" {
continue
}
@@ -3595,7 +3445,9 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
if prewarmResponseID != "" && stateStore != nil {
ttl := s.openAIWSResponseStickyTTL()
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl))
- stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl)
+ if connID, ok := normalizeOpenAIWSPreferredConnID(lease.ConnID()); ok {
+ stateStore.BindResponseConn(prewarmResponseID, connID, ttl)
+ }
}
logOpenAIWSModeInfo(
"prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d",
@@ -3609,111 +3461,6 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
return nil
}
-func payloadAsJSON(payload map[string]any) string {
- return string(payloadAsJSONBytes(payload))
-}
-
-func payloadAsJSONBytes(payload map[string]any) []byte {
- if len(payload) == 0 {
- return []byte("{}")
- }
- body, err := json.Marshal(payload)
- if err != nil {
- return []byte("{}")
- }
- return body
-}
-
-func isOpenAIWSTerminalEvent(eventType string) bool {
- switch strings.TrimSpace(eventType) {
- case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
- return true
- default:
- return false
- }
-}
-
-func isOpenAIWSTokenEvent(eventType string) bool {
- eventType = strings.TrimSpace(eventType)
- if eventType == "" {
- return false
- }
- switch eventType {
- case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
- return false
- }
- if strings.Contains(eventType, ".delta") {
- return true
- }
- if strings.HasPrefix(eventType, "response.output_text") {
- return true
- }
- if strings.HasPrefix(eventType, "response.output") {
- return true
- }
- return eventType == "response.completed" || eventType == "response.done"
-}
-
-func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte {
- if len(message) == 0 {
- return message
- }
- if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel {
- return message
- }
- if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) {
- return message
- }
- modelValues := gjson.GetManyBytes(message, "model", "response.model")
- replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel
- replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel
- if !replaceModel && !replaceResponseModel {
- return message
- }
- updated := message
- if replaceModel {
- if next, err := sjson.SetBytes(updated, "model", toModel); err == nil {
- updated = next
- }
- }
- if replaceResponseModel {
- if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil {
- updated = next
- }
- }
- return updated
-}
-
-func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) {
- if usage == nil || len(body) == 0 {
- return
- }
- values := gjson.GetManyBytes(
- body,
- "usage.input_tokens",
- "usage.output_tokens",
- "usage.input_tokens_details.cached_tokens",
- )
- usage.InputTokens = int(values[0].Int())
- usage.OutputTokens = int(values[1].Int())
- usage.CacheReadInputTokens = int(values[2].Int())
-}
-
-func getOpenAIGroupIDFromContext(c *gin.Context) int64 {
- if c == nil {
- return 0
- }
- value, exists := c.Get("api_key")
- if !exists {
- return 0
- }
- apiKey, ok := value.(*APIKey)
- if !ok || apiKey == nil || apiKey.GroupID == nil {
- return 0
- }
- return *apiKey.GroupID
-}
-
// SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。
// 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。
func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
@@ -3792,164 +3539,3 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
}
return nil, nil
}
-
-func classifyOpenAIWSAcquireError(err error) string {
- if err == nil {
- return "acquire_conn"
- }
- var dialErr *openAIWSDialError
- if errors.As(err, &dialErr) {
- switch dialErr.StatusCode {
- case 426:
- return "upgrade_required"
- case 401, 403:
- return "auth_failed"
- case 429:
- return "upstream_rate_limited"
- }
- if dialErr.StatusCode >= 500 {
- return "upstream_5xx"
- }
- return "dial_failed"
- }
- if errors.Is(err, errOpenAIWSConnQueueFull) {
- return "conn_queue_full"
- }
- if errors.Is(err, errOpenAIWSPreferredConnUnavailable) {
- return "preferred_conn_unavailable"
- }
- if errors.Is(err, context.DeadlineExceeded) {
- return "acquire_timeout"
- }
- return "acquire_conn"
-}
-
-func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
- code := strings.ToLower(strings.TrimSpace(codeRaw))
- errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
- msg := strings.ToLower(strings.TrimSpace(msgRaw))
-
- switch code {
- case "upgrade_required":
- return "upgrade_required", true
- case "websocket_not_supported", "websocket_unsupported":
- return "ws_unsupported", true
- case "websocket_connection_limit_reached":
- return "ws_connection_limit_reached", true
- case "previous_response_not_found":
- return "previous_response_not_found", true
- }
- if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
- return "upgrade_required", true
- }
- if strings.Contains(errType, "upgrade") {
- return "upgrade_required", true
- }
- if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") {
- return "ws_unsupported", true
- }
- if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") {
- return "ws_connection_limit_reached", true
- }
- if strings.Contains(msg, "previous_response_not_found") ||
- (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) {
- return "previous_response_not_found", true
- }
- if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") {
- return "upstream_error_event", true
- }
- return "event_error", false
-}
-
-func classifyOpenAIWSErrorEvent(message []byte) (string, bool) {
- if len(message) == 0 {
- return "event_error", false
- }
- return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message))
-}
-
-func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
- code := strings.ToLower(strings.TrimSpace(codeRaw))
- errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
- switch {
- case strings.Contains(errType, "invalid_request"),
- strings.Contains(code, "invalid_request"),
- strings.Contains(code, "bad_request"),
- code == "previous_response_not_found":
- return http.StatusBadRequest
- case strings.Contains(errType, "authentication"),
- strings.Contains(code, "invalid_api_key"),
- strings.Contains(code, "unauthorized"):
- return http.StatusUnauthorized
- case strings.Contains(errType, "permission"),
- strings.Contains(code, "forbidden"):
- return http.StatusForbidden
- case strings.Contains(errType, "rate_limit"),
- strings.Contains(code, "rate_limit"),
- strings.Contains(code, "insufficient_quota"):
- return http.StatusTooManyRequests
- default:
- return http.StatusBadGateway
- }
-}
-
-func openAIWSErrorHTTPStatus(message []byte) int {
- if len(message) == 0 {
- return http.StatusBadGateway
- }
- codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message)
- return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
-}
-
-func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration {
- if s == nil || s.cfg == nil {
- return 30 * time.Second
- }
- seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds
- if seconds <= 0 {
- return 0
- }
- return time.Duration(seconds) * time.Second
-}
-
-func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool {
- if s == nil || accountID <= 0 {
- return false
- }
- cooldown := s.openAIWSFallbackCooldown()
- if cooldown <= 0 {
- return false
- }
- rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID)
- if !ok || rawUntil == nil {
- return false
- }
- until, ok := rawUntil.(time.Time)
- if !ok || until.IsZero() {
- s.openaiWSFallbackUntil.Delete(accountID)
- return false
- }
- if time.Now().Before(until) {
- return true
- }
- s.openaiWSFallbackUntil.Delete(accountID)
- return false
-}
-
-func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) {
- if s == nil || accountID <= 0 {
- return
- }
- cooldown := s.openAIWSFallbackCooldown()
- if cooldown <= 0 {
- return
- }
- s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown))
-}
-
-func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) {
- if s == nil || accountID <= 0 {
- return
- }
- s.openaiWSFallbackUntil.Delete(accountID)
-}
diff --git a/backend/internal/service/openai_ws_forwarder_benchmark_test.go b/backend/internal/service/openai_ws_forwarder_benchmark_test.go
index bd03ab5a6..0bc2114fc 100644
--- a/backend/internal/service/openai_ws_forwarder_benchmark_test.go
+++ b/backend/internal/service/openai_ws_forwarder_benchmark_test.go
@@ -1,8 +1,10 @@
package service
import (
+ "context"
"fmt"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
@@ -125,3 +127,146 @@ func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) {
benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
}
}
+
+// --- Optimization benchmarks ---
+
+var benchmarkOpenAIWSConnSink openAIWSClientConn
+
+func BenchmarkTouchLease_Full(b *testing.B) {
+ ctx := &openAIWSIngressContext{}
+ ttl := 10 * time.Minute
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ctx.touchLease(time.Now(), ttl)
+ }
+}
+
+func BenchmarkMaybeTouchLease_Throttled(b *testing.B) {
+ ctx := &openAIWSIngressContext{}
+ ttl := 10 * time.Minute
+ ctx.touchLease(time.Now(), ttl) // seed the initial touch
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ctx.maybeTouchLease(ttl)
+ }
+}
+
+func BenchmarkActiveConn_CachedPath(b *testing.B) {
+ conn := &benchmarkOpenAIWSNoopConn{}
+ ctx := &openAIWSIngressContext{ownerID: "bench_owner", upstream: conn}
+ lease := &openAIWSIngressContextLease{context: ctx, ownerID: "bench_owner"}
+ // Prime the cache
+ _, _ = lease.activeConn()
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSConnSink, _ = lease.activeConn()
+ }
+}
+
+func BenchmarkActiveConn_MutexPath(b *testing.B) {
+ conn := &benchmarkOpenAIWSNoopConn{}
+ ctx := &openAIWSIngressContext{ownerID: "bench_owner", upstream: conn}
+ lease := &openAIWSIngressContextLease{context: ctx, ownerID: "bench_owner"}
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ lease.cachedConn = nil // force mutex path each iteration
+ benchmarkOpenAIWSConnSink, _ = lease.activeConn()
+ }
+}
+
+func BenchmarkParseOpenAIWSEventType_Lightweight(b *testing.B) {
+ event := []byte(`{"type":"response.output_text.delta","delta":"hello","response":{"id":"resp_1","model":"gpt-5.1"}}`)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ et, rid := parseOpenAIWSEventType(event)
+ benchmarkOpenAIWSStringSink = et
+ benchmarkOpenAIWSStringSink = rid
+ }
+}
+
+func BenchmarkParseOpenAIWSEventEnvelope_Full(b *testing.B) {
+ event := []byte(`{"type":"response.output_text.delta","delta":"hello","response":{"id":"resp_1","model":"gpt-5.1"}}`)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ et, rid, resp := parseOpenAIWSEventEnvelope(event)
+ benchmarkOpenAIWSStringSink = et
+ benchmarkOpenAIWSStringSink = rid
+ benchmarkOpenAIWSBoolSink = resp.Exists()
+ }
+}
+
+func BenchmarkSessionTurnStateKey_Strconv(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSStringSink = openAIWSSessionTurnStateKey(int64(i%1000+1), "session_hash_bench")
+ }
+}
+
+func BenchmarkResponseAccountCacheKey_XXHash(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSStringSink = openAIWSResponseAccountCacheKey(fmt.Sprintf("resp_bench_%d", i%1000))
+ }
+}
+
+func BenchmarkIsOpenAIWSTerminalEvent(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSBoolSink = isOpenAIWSTerminalEvent("response.completed")
+ }
+}
+
+func BenchmarkIsOpenAIWSTokenEvent(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSBoolSink = isOpenAIWSTokenEvent("response.output_text.delta")
+ }
+}
+
+func BenchmarkStateStore_ShardedBindGet(b *testing.B) {
+ storeAny := NewOpenAIWSStateStore(nil)
+ store, ok := storeAny.(*defaultOpenAIWSStateStore)
+ if !ok {
+ b.Fatal("expected *defaultOpenAIWSStateStore")
+ }
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ key := fmt.Sprintf("resp_%d", i%1000)
+ store.BindResponseConn(key, "conn_bench", time.Minute)
+ benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = store.GetResponseConn(key)
+ }
+}
+
+func BenchmarkDeriveOpenAISessionHash(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSStringSink = deriveOpenAISessionHash("session-id-benchmark-value")
+ }
+}
+
+func BenchmarkDeriveOpenAILegacySessionHash(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSStringSink = deriveOpenAILegacySessionHash("session-id-benchmark-value")
+ }
+}
+
+type benchmarkOpenAIWSNoopConn struct{}
+
+func (c *benchmarkOpenAIWSNoopConn) WriteJSON(_ context.Context, _ any) error { return nil }
+func (c *benchmarkOpenAIWSNoopConn) ReadMessage(_ context.Context) ([]byte, error) { return nil, nil }
+func (c *benchmarkOpenAIWSNoopConn) Ping(_ context.Context) error { return nil }
+func (c *benchmarkOpenAIWSNoopConn) Close() error { return nil }
diff --git a/backend/internal/service/openai_ws_forwarder_headers_test.go b/backend/internal/service/openai_ws_forwarder_headers_test.go
new file mode 100644
index 000000000..30b7216fe
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_headers_test.go
@@ -0,0 +1,119 @@
+package service
+
+import (
+ "net/http/httptest"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestBuildOpenAIWSHeaders_OAuthNormalization(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ svc := &OpenAIGatewayService{cfg: &config.Config{}}
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "chatgpt_account_id": "chatgpt_acc_1",
+ },
+ }
+
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/openai/v1/responses", nil)
+ c.Request.Header.Set("accept-language", "zh-CN")
+ c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ c.Request.Header.Set("session_id", "sess_hdr")
+ c.Request.Header.Set("conversation_id", "conv_hdr")
+
+ headers, resolution := svc.buildOpenAIWSHeaders(
+ c,
+ account,
+ "test_token",
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ false,
+ "turn_state_1",
+ "turn_meta_1",
+ "prompt_cache_fallback",
+ )
+
+ require.Equal(t, "Bearer test_token", headers.Get("authorization"))
+ require.Equal(t, "zh-CN", headers.Get("accept-language"))
+ require.Equal(t, "sess_hdr", headers.Get("session_id"))
+ require.Equal(t, "conv_hdr", headers.Get("conversation_id"))
+ require.Equal(t, "turn_state_1", headers.Get(openAIWSTurnStateHeader))
+ require.Equal(t, "turn_meta_1", headers.Get(openAIWSTurnMetadataHeader))
+ require.Equal(t, "chatgpt_acc_1", headers.Get("chatgpt-account-id"))
+ require.Equal(t, openAIWSBetaV2Value, headers.Get("OpenAI-Beta"))
+ require.Equal(t, codexCLIUserAgent, headers.Get("user-agent"))
+ require.Equal(t, "codex_cli_rs", headers.Get("originator"))
+
+ require.Equal(t, "sess_hdr", resolution.SessionID)
+ require.Equal(t, "conv_hdr", resolution.ConversationID)
+ require.Equal(t, "header_session_id", resolution.SessionSource)
+ require.Equal(t, "header_conversation_id", resolution.ConversationSource)
+}
+
+func TestBuildOpenAIWSHeaders_APIKeyForceCodexAndV1(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Gateway.ForceCodexCLI = true
+ svc := &OpenAIGatewayService{cfg: cfg}
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Credentials: map[string]any{
+ "user_agent": "my-custom-ua/1.0",
+ },
+ }
+
+ headers, resolution := svc.buildOpenAIWSHeaders(
+ nil,
+ account,
+ "token_apikey",
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocket},
+ false,
+ "",
+ "",
+ "pcache_1",
+ )
+
+ require.Equal(t, "Bearer token_apikey", headers.Get("authorization"))
+ require.Equal(t, openAIWSBetaV1Value, headers.Get("OpenAI-Beta"))
+ require.Equal(t, codexCLIUserAgent, headers.Get("user-agent"))
+ require.Equal(t, "", headers.Get("originator"))
+ require.Equal(t, "pcache_1", headers.Get("session_id"))
+ require.Equal(t, "pcache_1", resolution.SessionID)
+ require.Equal(t, "prompt_cache_key", resolution.SessionSource)
+}
+
+func TestBuildOpenAIWSHeaders_OAuthCodexCLIInputKeepsOriginator(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ svc := &OpenAIGatewayService{cfg: &config.Config{}}
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ }
+
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ c.Request = httptest.NewRequest("GET", "/openai/v1/responses", nil)
+ c.Request.Header.Set("User-Agent", codexCLIUserAgent)
+
+ headers, _ := svc.buildOpenAIWSHeaders(
+ c,
+ account,
+ "token_oauth",
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ true,
+ "",
+ "",
+ "",
+ )
+
+ require.Equal(t, codexCLIUserAgent, headers.Get("user-agent"))
+ require.Equal(t, "codex_cli_rs", headers.Get("originator"))
+}
diff --git a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go
index 761676038..807e97255 100644
--- a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go
+++ b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go
@@ -2,9 +2,12 @@ package service
import (
"net/http"
+ "net/http/httptest"
"testing"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
)
func TestParseOpenAIWSEventEnvelope(t *testing.T) {
@@ -31,6 +34,15 @@ func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
require.Equal(t, 3, usage.CacheReadInputTokens)
}
+func TestOpenAIWSEventShouldParseUsage_TerminalEvents(t *testing.T) {
+ require.True(t, openAIWSEventShouldParseUsage("response.completed"))
+ require.True(t, openAIWSEventShouldParseUsage("response.done"))
+ require.True(t, openAIWSEventShouldParseUsage("response.failed"))
+ // After removing TrimSpace, callers must provide pre-trimmed input.
+ require.False(t, openAIWSEventShouldParseUsage(" response.done "))
+ require.False(t, openAIWSEventShouldParseUsage("response.in_progress"))
+}
+
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
@@ -53,11 +65,62 @@ func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
}
func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) {
+ require.False(t, openAIWSMessageLikelyContainsToolCalls(nil))
require.False(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)))
require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"tool_calls":[{"id":"tc1"}]}}`)))
+ require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"tool_call"}}`)))
require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`)))
}
+func TestOpenAIWSExtractPendingFunctionCallIDsFromEvent(t *testing.T) {
+ callIDs := openAIWSExtractPendingFunctionCallIDsFromEvent([]byte(`{
+ "type":"response.output_item.added",
+ "response":{"id":"resp_1"},
+ "item":{"type":"function_call","call_id":"call_a"}
+ }`))
+ require.Equal(t, []string{"call_a"}, callIDs)
+
+ callIDs = openAIWSExtractPendingFunctionCallIDsFromEvent([]byte(`{
+ "type":"response.completed",
+ "response":{
+ "id":"resp_2",
+ "output":[
+ {"type":"function_call","call_id":"call_b"},
+ {"type":"message","content":[{"type":"output_text","text":"ok"}]},
+ {"type":"function_call","call_id":"call_c"}
+ ]
+ }
+ }`))
+ require.Equal(t, []string{"call_b", "call_c"}, callIDs)
+}
+
+func TestOpenAIWSExtractFunctionCallOutputCallIDsFromPayload(t *testing.T) {
+ callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload([]byte(`{
+ "input":[
+ {"type":"input_text","text":"hi"},
+ {"type":"function_call_output","call_id":"call_2","output":"ok"},
+ {"type":"function_call_output","call_id":"call_1","output":"ok"},
+ {"type":"function_call_output","call_id":"call_2","output":"dup"}
+ ]
+ }`))
+ require.Equal(t, []string{"call_1", "call_2"}, callIDs)
+}
+
+func TestOpenAIWSInjectFunctionCallOutputItems(t *testing.T) {
+ updatedPayload, injected, err := openAIWSInjectFunctionCallOutputItems(
+ []byte(`{"type":"response.create","input":[{"type":"input_text","text":"hello"}]}`),
+ []string{"call_1", "call_2", "call_1"},
+ openAIWSAutoAbortedToolOutputValue,
+ )
+ require.NoError(t, err)
+ require.Equal(t, 2, injected)
+ require.Equal(t, "input_text", gjson.GetBytes(updatedPayload, "input.0.type").String())
+ require.Equal(t, "function_call_output", gjson.GetBytes(updatedPayload, "input.1.type").String())
+ require.Equal(t, "call_1", gjson.GetBytes(updatedPayload, "input.1.call_id").String())
+ require.Equal(t, openAIWSAutoAbortedToolOutputValue, gjson.GetBytes(updatedPayload, "input.1.output").String())
+ require.Equal(t, "call_2", gjson.GetBytes(updatedPayload, "input.2.call_id").String())
+}
+
func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) {
noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`)
require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model")))
@@ -71,3 +134,45 @@ func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) {
both := []byte(`{"model":"gpt-5.1","response":{"model":"gpt-5.1"}}`)
require.Equal(t, `{"model":"custom-model","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(both, "gpt-5.1", "custom-model")))
}
+
+func TestResolveOpenAIWSFirstMessageMeta_ContextPreferred(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ c, _ := gin.CreateTestContext(httptest.NewRecorder())
+ SetOpenAIWSFirstMessageMeta(c, "ctx_model", "resp_ctx", OpenAIPreviousResponseIDKindResponseID)
+
+ model, prevID, prevKind := ResolveOpenAIWSFirstMessageMeta(
+ c,
+ []byte(`{"model":"payload_model","previous_response_id":"resp_payload"}`),
+ )
+ require.Equal(t, "ctx_model", model)
+ require.Equal(t, "resp_ctx", prevID)
+ require.Equal(t, OpenAIPreviousResponseIDKindResponseID, prevKind)
+
+ require.NotPanics(t, func() {
+ SetOpenAIWSFirstMessageMeta(nil, "m", "resp_x", OpenAIPreviousResponseIDKindResponseID)
+ })
+}
+
+func TestResolveOpenAIWSFirstMessageMeta_FallbackParse(t *testing.T) {
+ model, prevID, prevKind := ResolveOpenAIWSFirstMessageMeta(
+ nil,
+ []byte(`{"model":"payload_model","previous_response_id":"resp_payload"}`),
+ )
+ require.Equal(t, "payload_model", model)
+ require.Equal(t, "resp_payload", prevID)
+ require.Equal(t, OpenAIPreviousResponseIDKindResponseID, prevKind)
+}
+
+func TestOpenAIWSHostPathForLogFromURL(t *testing.T) {
+ host, path := openAIWSHostPathForLogFromURL("wss://api.openai.com/v1/responses?stream=true")
+ require.Equal(t, "api.openai.com", host)
+ require.Equal(t, "/v1/responses", path)
+
+ host, path = openAIWSHostPathForLogFromURL("api.openai.com/v1/responses")
+ require.Equal(t, "api.openai.com", host)
+ require.Equal(t, "/v1/responses", path)
+
+ host, path = openAIWSHostPathForLogFromURL(" ")
+ require.Equal(t, "-", host)
+ require.Equal(t, "-", path)
+}
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go b/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go
new file mode 100644
index 000000000..149c5765b
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go
@@ -0,0 +1,384 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func runIngressProxyWithFirstPayload(
+ t *testing.T,
+ svc *OpenAIGatewayService,
+ account *Account,
+ firstPayload string,
+) error {
+ t.Helper()
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, message, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", msgType, message, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(firstPayload))
+ cancelWrite()
+ require.NoError(t, err)
+
+ select {
+ case serverErr := <-serverErrCh:
+ return serverErr
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ return nil
+ }
+}
+
+func buildIngressPolicyTestConfig() *config.Config {
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
+ return cfg
+}
+
+func buildIngressPolicyTestService(cfg *config.Config) *OpenAIGatewayService {
+ return &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+}
+
+func buildIngressPolicyTestAccount(extra map[string]any) *Account {
+ return &Account{
+ ID: 442,
+ Name: "openai-ingress-policy",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: extra,
+ }
+}
+
+type openAIWSPassthroughProbeStateStore struct {
+ mu sync.Mutex
+ calls []string
+}
+
+func newOpenAIWSPassthroughProbeStateStore() *openAIWSPassthroughProbeStateStore {
+ return &openAIWSPassthroughProbeStateStore{
+ calls: make([]string, 0, 4),
+ }
+}
+
+func (s *openAIWSPassthroughProbeStateStore) record(method string) {
+ s.mu.Lock()
+ s.calls = append(s.calls, method)
+ s.mu.Unlock()
+}
+
+func (s *openAIWSPassthroughProbeStateStore) unexpectedErr(method string) error {
+ s.record(method)
+ return errors.New("passthrough must not call OpenAIWSStateStore." + method)
+}
+
+func (s *openAIWSPassthroughProbeStateStore) Calls() []string {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ out := make([]string, len(s.calls))
+ copy(out, s.calls)
+ return out
+}
+
+func (s *openAIWSPassthroughProbeStateStore) BindResponseAccount(context.Context, int64, string, int64, time.Duration) error {
+ return s.unexpectedErr("BindResponseAccount")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) GetResponseAccount(context.Context, int64, string) (int64, error) {
+ return 0, s.unexpectedErr("GetResponseAccount")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) DeleteResponseAccount(context.Context, int64, string) error {
+ return s.unexpectedErr("DeleteResponseAccount")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) BindResponseConn(string, string, time.Duration) {
+ s.record("BindResponseConn")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) GetResponseConn(string) (string, bool) {
+ s.record("GetResponseConn")
+ return "", false
+}
+
+func (s *openAIWSPassthroughProbeStateStore) DeleteResponseConn(string) {
+ s.record("DeleteResponseConn")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) BindResponsePendingToolCalls(int64, string, []string, time.Duration) {
+ s.record("BindResponsePendingToolCalls")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) GetResponsePendingToolCalls(int64, string) ([]string, bool) {
+ s.record("GetResponsePendingToolCalls")
+ return nil, false
+}
+
+func (s *openAIWSPassthroughProbeStateStore) DeleteResponsePendingToolCalls(int64, string) {
+ s.record("DeleteResponsePendingToolCalls")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) BindSessionTurnState(int64, string, string, time.Duration) {
+ s.record("BindSessionTurnState")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) GetSessionTurnState(int64, string) (string, bool) {
+ s.record("GetSessionTurnState")
+ return "", false
+}
+
+func (s *openAIWSPassthroughProbeStateStore) DeleteSessionTurnState(int64, string) {
+ s.record("DeleteSessionTurnState")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) BindSessionLastResponseID(int64, string, string, time.Duration) {
+ s.record("BindSessionLastResponseID")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) GetSessionLastResponseID(int64, string) (string, bool) {
+ s.record("GetSessionLastResponseID")
+ return "", false
+}
+
+func (s *openAIWSPassthroughProbeStateStore) DeleteSessionLastResponseID(int64, string) {
+ s.record("DeleteSessionLastResponseID")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) BindSessionConn(int64, string, string, time.Duration) {
+ s.record("BindSessionConn")
+}
+
+func (s *openAIWSPassthroughProbeStateStore) GetSessionConn(int64, string) (string, bool) {
+ s.record("GetSessionConn")
+ return "", false
+}
+
+func (s *openAIWSPassthroughProbeStateStore) DeleteSessionConn(int64, string) {
+ s.record("DeleteSessionConn")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ svc := buildIngressPolicyTestService(cfg)
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
+ })
+
+ serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, serverErr, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ require.Equal(t, "websocket mode is disabled for this account", closeErr.Reason())
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeRouterDisabledReturnsPolicyViolation(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = false
+ svc := buildIngressPolicyTestService(cfg)
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "responses_websockets_v2_enabled": true,
+ })
+
+ serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, serverErr, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ require.Equal(t, "websocket mode requires mode_router_v2 with ctx_pool/passthrough", closeErr.Reason())
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_CtxPoolRejectsMessageIDPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ svc := buildIngressPolicyTestService(cfg)
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
+ })
+
+ serverErr := runIngressProxyWithFirstPayload(
+ t,
+ svc,
+ account,
+ `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`,
+ )
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, serverErr, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ require.Contains(t, closeErr.Reason(), "previous_response_id must be a response.id")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughDoesNotRejectMessageIDPreviousResponseID(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ svc := buildIngressPolicyTestService(cfg)
+ dialer := &openAIWSAlwaysFailDialer{}
+ svc.openaiWSPassthroughDialer = dialer
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ serverErr := runIngressProxyWithFirstPayload(
+ t,
+ svc,
+ account,
+ `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`,
+ )
+ require.Error(t, serverErr)
+ require.Contains(t, serverErr.Error(), "openai ws passthrough dial")
+ require.NotContains(t, serverErr.Error(), "previous_response_id must be a response.id")
+ require.Equal(t, 1, dialer.DialCount())
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughEmptyModelFailsBeforeDial(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ svc := buildIngressPolicyTestService(cfg)
+ dialer := &openAIWSAlwaysFailDialer{}
+ svc.openaiWSPassthroughDialer = dialer
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ serverErr := runIngressProxyWithFirstPayload(
+ t,
+ svc,
+ account,
+ `{"type":"response.create","stream":false,"input":[]}`,
+ )
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, serverErr, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ require.Contains(t, closeErr.Reason(), "model is required")
+ require.Equal(t, 0, dialer.DialCount())
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughFunctionCallOutputNoRecoveryReject(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ svc := buildIngressPolicyTestService(cfg)
+ dialer := &openAIWSAlwaysFailDialer{}
+ svc.openaiWSPassthroughDialer = dialer
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ serverErr := runIngressProxyWithFirstPayload(
+ t,
+ svc,
+ account,
+ `{"type":"response.create","model":"gpt-5.1","stream":false,"input":[{"type":"function_call_output","call_id":"call_abc","output":"ok"}]}`,
+ )
+ require.Error(t, serverErr)
+ require.Contains(t, serverErr.Error(), "openai ws passthrough dial")
+ require.NotContains(t, serverErr.Error(), "tool_output_not_found")
+ require.NotContains(t, serverErr.Error(), "previous_response_not_found")
+ require.Equal(t, 1, dialer.DialCount())
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PassthroughDoesNotTouchStateStore(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ svc := buildIngressPolicyTestService(cfg)
+ dialer := &openAIWSAlwaysFailDialer{}
+ svc.openaiWSPassthroughDialer = dialer
+ storeProbe := newOpenAIWSPassthroughProbeStateStore()
+ svc.openaiWSStateStore = storeProbe
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ serverErr := runIngressProxyWithFirstPayload(
+ t,
+ svc,
+ account,
+ `{"type":"response.create","model":"gpt-5.1","stream":false}`,
+ )
+ require.Error(t, serverErr)
+ require.Contains(t, serverErr.Error(), "openai ws passthrough dial")
+ require.Equal(t, 1, dialer.DialCount())
+ require.Empty(t, storeProbe.Calls(), "passthrough 路径不应访问 StateStore")
+}
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
deleted file mode 100644
index 5a3c12c39..000000000
--- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
+++ /dev/null
@@ -1,2483 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "sync"
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- coderws "github.com/coder/websocket"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/require"
- "github.com/tidwall/gjson"
-)
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossTurns(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 114,
- Name: "openai-ingress-session-lease",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- turnWSModeCh := make(chan bool, 2)
- hooks := &OpenAIWSIngressHooks{
- AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
- if turnErr == nil && result != nil {
- turnWSModeCh <- result.OpenAIWSMode
- }
- },
- }
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
- firstTurnEvent := readMessage()
- require.Equal(t, "response.completed", gjson.GetBytes(firstTurnEvent, "type").String())
- require.Equal(t, "resp_ingress_turn_1", gjson.GetBytes(firstTurnEvent, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_ingress_turn_1"}`)
- secondTurnEvent := readMessage()
- require.Equal(t, "response.completed", gjson.GetBytes(secondTurnEvent, "type").String())
- require.Equal(t, "resp_ingress_turn_2", gjson.GetBytes(secondTurnEvent, "response.id").String())
- require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
- require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
-
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- metrics := svc.SnapshotOpenAIWSPoolMetrics()
- require.Equal(t, int64(1), metrics.AcquireTotal, "同一 ingress 会话多 turn 应只获取一次上游 lease")
- require.Equal(t, 1, captureDialer.DialCount(), "同一 ingress 会话应保持同一上游连接")
- require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
- cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- upstreamConn1 := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- upstreamConn2 := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{upstreamConn1, upstreamConn2},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 441,
- Name: "openai-ingress-dedicated",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
- },
- }
-
- serverErrCh := make(chan error, 2)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- runSingleTurnSession := func(expectedResponseID string) {
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
- err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
- cancelWrite()
- require.NoError(t, err)
-
- readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
- msgType, event, readErr := clientConn.Read(readCtx)
- cancelRead()
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- require.Equal(t, expectedResponseID, gjson.GetBytes(event, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
-
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
- }
-
- runSingleTurnSession("resp_dedicated_1")
- runSingleTurnSession("resp_dedicated_2")
-
- require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
- cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: newOpenAIWSConnPool(cfg),
- }
-
- account := &Account{
- ID: 442,
- Name: "openai-ingress-off",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
- err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
- cancelWrite()
- require.NoError(t, err)
-
- select {
- case serverErr := <-serverErrCh:
- var closeErr *OpenAIWSClientCloseError
- require.ErrorAs(t, serverErr, &closeErr)
- require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
- require.Equal(t, "websocket mode is disabled for this account", closeErr.Reason())
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropToFullCreate(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 140,
- Name: "openai-ingress-prev-preflight-rewrite",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_preflight_rewrite_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_preflight_rewrite_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 1, captureDialer.DialCount(), "严格增量不成立时应在同一连接内降级为 full create")
- require.Len(t, captureConn.writes, 2)
- secondWrite := requestToJSONString(captureConn.writes[1])
- require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格增量不成立时应移除 previous_response_id,改为 full create")
- require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "严格降级为 full create 时应重放完整 input 上下文")
- require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
- require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropBeforePreflightPingFailReconnects(t *testing.T) {
- gin.SetMode(gin.TestMode)
- prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
- openAIWSIngressPreflightPingIdle = 0
- defer func() {
- openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
- }()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSPreflightFailConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 142,
- Name: "openai-ingress-prev-strict-drop-before-ping",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_ping_drop_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_ping_drop_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 严格降级后预检换连超时")
- }
-
- require.Equal(t, 2, dialer.DialCount(), "严格降级为 full create 后,预检 ping 失败应允许换连")
- require.Equal(t, 1, firstConn.WriteCount(), "首连接在预检失败后不应继续发送第二轮")
- require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping")
- secondConn.mu.Lock()
- secondWrites := append([]map[string]any(nil), secondConn.writes...)
- secondConn.mu.Unlock()
- require.Len(t, secondWrites, 1)
- secondWrite := requestToJSONString(secondWrites[0])
- require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格降级后重试应移除 previous_response_id")
- require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()))
- require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
- require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreEnabledSkipsStrictPrevResponseEval(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 143,
- Name: "openai-ingress-store-enabled-skip-strict",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_store_enabled_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true,"previous_response_id":"resp_stale_external"}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_store_enabled_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 store=true 场景 websocket 结束超时")
- }
-
- require.Equal(t, 1, captureDialer.DialCount())
- require.Len(t, captureConn.writes, 2)
- require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "store=true 场景不应触发 store-disabled strict 规则")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponsePreflightSkipForFunctionCallOutput(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 141,
- Name: "openai-ingress-prev-preflight-skip-fco",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_preflight_skip_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_preflight_skip_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 1, captureDialer.DialCount())
- require.Len(t, captureConn.writes, 2)
- require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 场景不应预改写 previous_response_id")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachPreviousResponseID(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 143,
- Name: "openai-ingress-fco-auto-prev",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_auto_prev_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_1","output":"ok"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_auto_prev_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 1, captureDialer.DialCount())
- require.Len(t, captureConn.writes, 2)
- require.Equal(t, "resp_auto_prev_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 缺失 previous_response_id 时应回填上一轮响应 ID")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenLastResponseIDMissing(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 144,
- Name: "openai-ingress-fco-auto-prev-skip",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "response.completed", gjson.GetBytes(firstTurn, "type").String())
- require.Empty(t, gjson.GetBytes(firstTurn, "response.id").String(), "首轮响应不返回 response.id,模拟无法推导续链锚点")
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_skip_1","output":"ok"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_auto_prev_skip_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 1, captureDialer.DialCount())
- require.Len(t, captureConn.writes, 2)
- require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
- gin.SetMode(gin.TestMode)
- prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
- openAIWSIngressPreflightPingIdle = 0
- defer func() {
- openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
- }()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSPreflightFailConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 116,
- Name: "openai-ingress-preflight-ping",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_ping_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_ping_1"}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_ping_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
- require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 前 ping 失败应触发换连")
- require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续向旧连接发送第二轮 turn")
- require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应对旧连接执行 preflight ping")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreflightPingFailAutoRecoveryReconnects(t *testing.T) {
- gin.SetMode(gin.TestMode)
- prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
- openAIWSIngressPreflightPingIdle = 0
- defer func() {
- openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
- }()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSPreflightFailConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 121,
- Name: "openai-ingress-preflight-ping-strict-affinity",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_ping_strict_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_strict_1","input":[{"type":"input_text","text":"world"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_ping_strict_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 严格亲和自动恢复后结束超时")
- }
-
- require.Equal(t, 2, dialer.DialCount(), "严格亲和 preflight ping 失败后应自动降级并换连重放")
- require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续在旧连接写第二轮")
- require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping")
- secondConn.mu.Lock()
- secondWrites := append([]map[string]any(nil), secondConn.writes...)
- secondConn.mu.Unlock()
- require.Len(t, secondWrites, 1)
- secondWrite := requestToJSONString(secondWrites[0])
- require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "自动恢复重放应移除 previous_response_id")
- require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "自动恢复重放应使用完整 input 上下文")
- require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
- require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSWriteFailAfterFirstTurnConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 117,
- Name: "openai-ingress-write-retry",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
- var hooksMu sync.Mutex
- beforeTurnCalls := make(map[int]int)
- afterTurnCalls := make(map[int]int)
- hooks := &OpenAIWSIngressHooks{
- BeforeTurn: func(turn int) error {
- hooksMu.Lock()
- beforeTurnCalls[turn]++
- hooksMu.Unlock()
- return nil
- },
- AfterTurn: func(turn int, _ *OpenAIForwardResult, _ error) {
- hooksMu.Lock()
- afterTurnCalls[turn]++
- hooksMu.Unlock()
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_write_retry_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_write_retry_1"}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_write_retry_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
- require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 上游写失败且未写下游时应自动重试并换连")
- hooksMu.Lock()
- beforeTurn1 := beforeTurnCalls[1]
- beforeTurn2 := beforeTurnCalls[2]
- afterTurn1 := afterTurnCalls[1]
- afterTurn2 := afterTurnCalls[2]
- hooksMu.Unlock()
- require.Equal(t, 1, beforeTurn1, "首轮 turn BeforeTurn 应执行一次")
- require.Equal(t, 1, beforeTurn2, "同一 turn 重试不应重复触发 BeforeTurn")
- require.Equal(t, 1, afterTurn1, "首轮 turn AfterTurn 应执行一次")
- require.Equal(t, 1, afterTurn2, "第二轮 turn AfterTurn 应执行一次")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoversByDroppingPrevID(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":""}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 118,
- Name: "openai-ingress-prev-recovery",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_seed_anchor"}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_prev_recover_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_recover_1"}`)
- secondTurn := readMessage()
- require.Equal(t, "response.completed", gjson.GetBytes(secondTurn, "type").String())
- require.Equal(t, "resp_turn_prev_recover_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应触发换连重试")
-
- firstConn.mu.Lock()
- firstWrites := append([]map[string]any(nil), firstConn.writes...)
- firstConn.mu.Unlock()
- require.Len(t, firstWrites, 2, "首个连接应处理首轮与失败的第二轮请求")
- require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists(), "失败轮次首发请求应包含 previous_response_id")
-
- secondConn.mu.Lock()
- secondWrites := append([]map[string]any(nil), secondConn.writes...)
- secondConn.mu.Unlock()
- require.Len(t, secondWrites, 1, "恢复重试应在第二个连接发送一次请求")
- require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreviousResponseNotFoundLayer2Recovery(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"missing strict anchor"}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 122,
- Name: "openai-ingress-prev-strict-layer2",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_prev_strict_recover_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","previous_response_id":"resp_turn_prev_strict_recover_1","input":[{"type":"input_text","text":"world"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_prev_strict_recover_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 严格亲和 Layer2 恢复结束超时")
- }
-
- require.Equal(t, 2, dialer.DialCount(), "严格亲和链路命中 previous_response_not_found 应触发 Layer2 恢复重试")
-
- firstConn.mu.Lock()
- firstWrites := append([]map[string]any(nil), firstConn.writes...)
- firstConn.mu.Unlock()
- require.Len(t, firstWrites, 2, "首连接应收到首轮请求和失败的续链请求")
- require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists())
-
- secondConn.mu.Lock()
- secondWrites := append([]map[string]any(nil), secondConn.writes...)
- secondConn.mu.Unlock()
- require.Len(t, secondWrites, 1, "Layer2 恢复应仅重放一次")
- secondWrite := requestToJSONString(secondWrites[0])
- require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "Layer2 恢复重放应移除 previous_response_id")
- require.True(t, gjson.Get(secondWrite, "store").Exists(), "Layer2 恢复不应改变 store 标志")
- require.False(t, gjson.Get(secondWrite, "store").Bool())
- require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "Layer2 恢复应重放完整 input 上下文")
- require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
- require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoveryRemovesDuplicatePrevID(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"first missing"}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 120,
- Name: "openai-ingress-prev-recovery-once",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_prev_once_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- // duplicate previous_response_id: 恢复重试时应删除所有重复键,避免再次 previous_response_not_found。
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_once_1","input":[],"previous_response_id":"resp_turn_prev_duplicate"}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_prev_once_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应只重试一次")
-
- firstConn.mu.Lock()
- firstWrites := append([]map[string]any(nil), firstConn.writes...)
- firstConn.mu.Unlock()
- require.Len(t, firstWrites, 2)
- require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists())
-
- secondConn.mu.Lock()
- secondWrites := append([]map[string]any(nil), secondConn.writes...)
- secondConn.mu.Unlock()
- require.Len(t, secondWrites, 1)
- require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "重复键场景恢复重试后不应保留 previous_response_id")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 119,
- Name: "openai-ingress-prev-validation",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
- err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456"}`))
- cancelWrite()
- require.NoError(t, err)
-
- select {
- case serverErr := <-serverErrCh:
- require.Error(t, serverErr)
- var closeErr *OpenAIWSClientCloseError
- require.ErrorAs(t, serverErr, &closeErr)
- require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
- require.Contains(t, closeErr.Reason(), "previous_response_id must be a response.id")
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-}
-
-type openAIWSQueueDialer struct {
- mu sync.Mutex
- conns []openAIWSClientConn
- dialCount int
-}
-
-func (d *openAIWSQueueDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = headers
- _ = proxyURL
- d.mu.Lock()
- defer d.mu.Unlock()
- d.dialCount++
- if len(d.conns) == 0 {
- return nil, 503, nil, errors.New("no test conn")
- }
- conn := d.conns[0]
- if len(d.conns) > 1 {
- d.conns = d.conns[1:]
- }
- return conn, 0, nil, nil
-}
-
-func (d *openAIWSQueueDialer) DialCount() int {
- d.mu.Lock()
- defer d.mu.Unlock()
- return d.dialCount
-}
-
-type openAIWSPreflightFailConn struct {
- mu sync.Mutex
- events [][]byte
- pingFails bool
- writeCount int
- pingCount int
-}
-
-func (c *openAIWSPreflightFailConn) WriteJSON(context.Context, any) error {
- c.mu.Lock()
- c.writeCount++
- c.mu.Unlock()
- return nil
-}
-
-func (c *openAIWSPreflightFailConn) ReadMessage(context.Context) ([]byte, error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if len(c.events) == 0 {
- return nil, io.EOF
- }
- event := c.events[0]
- c.events = c.events[1:]
- if len(c.events) == 0 {
- c.pingFails = true
- }
- return event, nil
-}
-
-func (c *openAIWSPreflightFailConn) Ping(context.Context) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.pingCount++
- if c.pingFails {
- return errors.New("preflight ping failed")
- }
- return nil
-}
-
-func (c *openAIWSPreflightFailConn) Close() error {
- return nil
-}
-
-func (c *openAIWSPreflightFailConn) WriteCount() int {
- c.mu.Lock()
- defer c.mu.Unlock()
- return c.writeCount
-}
-
-func (c *openAIWSPreflightFailConn) PingCount() int {
- c.mu.Lock()
- defer c.mu.Unlock()
- return c.pingCount
-}
-
-type openAIWSWriteFailAfterFirstTurnConn struct {
- mu sync.Mutex
- events [][]byte
- failOnWrite bool
-}
-
-func (c *openAIWSWriteFailAfterFirstTurnConn) WriteJSON(context.Context, any) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.failOnWrite {
- return errors.New("write failed on stale conn")
- }
- return nil
-}
-
-func (c *openAIWSWriteFailAfterFirstTurnConn) ReadMessage(context.Context) ([]byte, error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if len(c.events) == 0 {
- return nil, io.EOF
- }
- event := c.events[0]
- c.events = c.events[1:]
- if len(c.events) == 0 {
- c.failOnWrite = true
- }
- return event, nil
-}
-
-func (c *openAIWSWriteFailAfterFirstTurnConn) Ping(context.Context) error {
- return nil
-}
-
-func (c *openAIWSWriteFailAfterFirstTurnConn) Close() error {
- return nil
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnectStillDrainsUpstream(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- // 多个上游事件:前几个为非 terminal 事件,最后一个为 terminal。
- // 第一个事件延迟 250ms 让客户端 RST 有时间传播,使 writeClientMessage 可靠失败。
- captureConn := &openAIWSCaptureConn{
- readDelays: []time.Duration{250 * time.Millisecond, 0, 0},
- events: [][]byte{
- []byte(`{"type":"response.created","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1"}}`),
- []byte(`{"type":"response.output_item.added","response":{"id":"resp_ingress_disconnect"}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 115,
- Name: "openai-ingress-client-disconnect",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "model_mapping": map[string]any{
- "custom-original-model": "gpt-5.1",
- },
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- resultCh := make(chan *OpenAIForwardResult, 1)
- hooks := &OpenAIWSIngressHooks{
- AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
- if turnErr == nil && result != nil {
- resultCh <- result
- }
- },
- }
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
-
- writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
- err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`))
- cancelWrite()
- require.NoError(t, err)
- // 立即关闭客户端,模拟客户端在 relay 期间断连。
- require.NoError(t, clientConn.CloseNow(), "模拟 ingress 客户端提前断连")
-
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr, "客户端断连后应继续 drain 上游直到 terminal 或正常结束")
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- select {
- case result := <-resultCh:
- require.Equal(t, "resp_ingress_disconnect", result.RequestID)
- require.Equal(t, 2, result.Usage.InputTokens)
- require.Equal(t, 1, result.Usage.OutputTokens)
- case <-time.After(2 * time.Second):
- t.Fatal("未收到断连后的 turn 结果回调")
- }
-}
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go
index ff35cb01d..3129dae27 100644
--- a/backend/internal/service/openai_ws_forwarder_ingress_test.go
+++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go
@@ -6,10 +6,13 @@ import (
"errors"
"io"
"net"
+ "strings"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
@@ -34,6 +37,8 @@ func TestIsOpenAIWSClientDisconnectError(t *testing.T) {
{name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true},
{name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true},
{name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true},
+ {name: "blank_message", err: errors.New(" "), want: false},
+ {name: "unmatched_message", err: errors.New("tls handshake timeout"), want: false},
}
for _, tt := range tests {
@@ -45,6 +50,423 @@ func TestIsOpenAIWSClientDisconnectError(t *testing.T) {
}
}
+func TestOpenAIWSIngressFallbackSessionSeedFromContext(t *testing.T) {
+ t.Parallel()
+
+ require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(nil))
+
+ gin.SetMode(gin.TestMode)
+ c, _ := gin.CreateTestContext(nil)
+ require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(c))
+
+ c.Set("api_key", "not_api_key")
+ require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(c))
+
+ groupID := int64(99)
+ c.Set("api_key", &APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &User{ID: 202},
+ })
+ require.Equal(t, "openai_ws_ingress:99:202:101", openAIWSIngressFallbackSessionSeedFromContext(c))
+
+ c.Set("api_key", &APIKey{
+ ID: 303,
+ User: nil,
+ })
+ require.Equal(t, "openai_ws_ingress:0:0:303", openAIWSIngressFallbackSessionSeedFromContext(c))
+}
+
+func TestClassifyOpenAIWSIngressTurnAbortReason(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ wantReason openAIWSIngressTurnAbortReason
+ wantExpected bool
+ }{
+ {
+ name: "nil",
+ err: nil,
+ wantReason: openAIWSIngressTurnAbortReasonUnknown,
+ wantExpected: false,
+ },
+ {
+ name: "context canceled",
+ err: context.Canceled,
+ wantReason: openAIWSIngressTurnAbortReasonContextCanceled,
+ wantExpected: true,
+ },
+ {
+ name: "context deadline",
+ err: context.DeadlineExceeded,
+ wantReason: openAIWSIngressTurnAbortReasonContextDeadline,
+ wantExpected: false,
+ },
+ {
+ name: "client close",
+ err: coderws.CloseError{Code: coderws.StatusNormalClosure},
+ wantReason: openAIWSIngressTurnAbortReasonClientClosed,
+ wantExpected: true,
+ },
+ {
+ name: "client close by eof",
+ err: io.EOF,
+ wantReason: openAIWSIngressTurnAbortReasonClientClosed,
+ wantExpected: true,
+ },
+ {
+ name: "previous response not found",
+ err: wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonPreviousResponse,
+ wantExpected: true,
+ },
+ {
+ name: "tool output not found",
+ err: wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("no tool output found"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonToolOutput,
+ wantExpected: true,
+ },
+ {
+ name: "upstream error event",
+ err: wrapOpenAIWSIngressTurnError(
+ "upstream_error_event",
+ errors.New("upstream error event"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonUpstreamError,
+ wantExpected: true,
+ },
+ {
+ name: "write upstream",
+ err: wrapOpenAIWSIngressTurnError(
+ "write_upstream",
+ errors.New("write upstream fail"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonWriteUpstream,
+ wantExpected: false,
+ },
+ {
+ name: "read upstream",
+ err: wrapOpenAIWSIngressTurnError(
+ "read_upstream",
+ errors.New("read upstream fail"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonReadUpstream,
+ wantExpected: false,
+ },
+ {
+ name: "idle timeout stage",
+ err: wrapOpenAIWSIngressTurnError(
+ "idle_timeout",
+ errors.New("relay idle timeout"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonContextDeadline,
+ wantExpected: false,
+ },
+ {
+ name: "write client",
+ err: wrapOpenAIWSIngressTurnError(
+ "write_client",
+ errors.New("write client fail"),
+ true,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonWriteClient,
+ wantExpected: false,
+ },
+ {
+ name: "unknown turn stage",
+ err: wrapOpenAIWSIngressTurnError(
+ "some_unknown_stage",
+ errors.New("unknown stage fail"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonUnknown,
+ wantExpected: false,
+ },
+ {
+ name: "continuation unavailable close",
+ err: NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ openAIWSContinuationUnavailableReason,
+ nil,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonContinuationUnavailable,
+ wantExpected: true,
+ },
+ {
+ name: "upstream restart 1012",
+ err: wrapOpenAIWSIngressTurnError(
+ "read_upstream",
+ coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"},
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart,
+ wantExpected: true,
+ },
+ {
+ name: "upstream try again later 1013",
+ err: wrapOpenAIWSIngressTurnError(
+ "read_upstream",
+ coderws.CloseError{Code: coderws.StatusTryAgainLater, Reason: "try again later"},
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart,
+ wantExpected: true,
+ },
+ {
+ name: "upstream restart 1012 with wroteDownstream",
+ err: wrapOpenAIWSIngressTurnError(
+ "read_upstream",
+ coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"},
+ true,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart,
+ wantExpected: true,
+ },
+ {
+ name: "1012 on non-read_upstream stage should not match",
+ err: wrapOpenAIWSIngressTurnError(
+ "write_upstream",
+ coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"},
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonWriteUpstream,
+ wantExpected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(tt.err)
+ require.Equal(t, tt.wantReason, reason)
+ require.Equal(t, tt.wantExpected, expected)
+ })
+ }
+}
+
+func TestClassifyOpenAIWSIngressTurnAbortReason_ClientDisconnectedDrainTimeout(t *testing.T) {
+ t.Parallel()
+
+ err := wrapOpenAIWSIngressTurnError(
+ "client_disconnected_drain_timeout",
+ openAIWSIngressClientDisconnectedDrainTimeoutError(2*time.Second),
+ true,
+ )
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.Equal(t, openAIWSIngressTurnAbortReasonContextCanceled, reason)
+ require.True(t, expected)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionCloseGracefully, openAIWSIngressTurnAbortDispositionForReason(reason))
+}
+
+func TestOpenAIWSIngressPumpClosedTurnError_ClientDisconnected(t *testing.T) {
+ t.Parallel()
+
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_partial",
+ Usage: OpenAIUsage{
+ InputTokens: 12,
+ },
+ }
+ err := openAIWSIngressPumpClosedTurnError(true, true, partial)
+ require.Error(t, err)
+ require.ErrorIs(t, err, context.Canceled)
+
+ var turnErr *openAIWSIngressTurnError
+ require.ErrorAs(t, err, &turnErr)
+ require.Equal(t, "client_disconnected_drain_timeout", turnErr.stage)
+ require.True(t, turnErr.wroteDownstream)
+ require.NotNil(t, turnErr.partialResult)
+ require.Equal(t, partial.RequestID, turnErr.partialResult.RequestID)
+}
+
+func TestOpenAIWSIngressPumpClosedTurnError_ReadUpstream(t *testing.T) {
+ t.Parallel()
+
+ err := openAIWSIngressPumpClosedTurnError(false, false, nil)
+ require.Error(t, err)
+
+ var turnErr *openAIWSIngressTurnError
+ require.ErrorAs(t, err, &turnErr)
+ require.Equal(t, "read_upstream", turnErr.stage)
+ require.False(t, turnErr.wroteDownstream)
+ require.Nil(t, turnErr.partialResult)
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.Equal(t, openAIWSIngressTurnAbortReasonReadUpstream, reason)
+ require.False(t, expected)
+}
+
+func TestOpenAIWSIngressPumpClosedTurnError_ClonesPartialResult(t *testing.T) {
+ t.Parallel()
+
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_original",
+ PendingFunctionCallIDs: []string{"call_a"},
+ }
+ err := openAIWSIngressPumpClosedTurnError(true, true, partial)
+ require.Error(t, err)
+
+ partial.RequestID = "resp_mutated"
+ partial.PendingFunctionCallIDs[0] = "call_b"
+
+ var turnErr *openAIWSIngressTurnError
+ require.ErrorAs(t, err, &turnErr)
+ require.NotNil(t, turnErr.partialResult)
+ require.Equal(t, "resp_original", turnErr.partialResult.RequestID)
+ require.Equal(t, []string{"call_a"}, turnErr.partialResult.PendingFunctionCallIDs)
+}
+
+func TestOpenAIWSIngressClientDisconnectedDrainTimeoutError_DefaultTimeout(t *testing.T) {
+ t.Parallel()
+
+ err := openAIWSIngressClientDisconnectedDrainTimeoutError(0)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), openAIWSIngressClientDisconnectDrainTimeout.String())
+ require.ErrorIs(t, err, context.Canceled)
+}
+
+func TestOpenAIWSIngressResolveDrainReadTimeout(t *testing.T) {
+ t.Parallel()
+
+ now := time.Now()
+ tests := []struct {
+ name string
+ base time.Duration
+ deadline time.Time
+ want time.Duration
+ wantExpire bool
+ }{
+ {
+ name: "no_deadline_uses_base",
+ base: 15 * time.Second,
+ deadline: time.Time{},
+ want: 15 * time.Second,
+ wantExpire: false,
+ },
+ {
+ name: "remaining_shorter_than_base",
+ base: 10 * time.Second,
+ deadline: now.Add(3 * time.Second),
+ want: 3 * time.Second,
+ wantExpire: false,
+ },
+ {
+ name: "base_shorter_than_remaining",
+ base: 2 * time.Second,
+ deadline: now.Add(8 * time.Second),
+ want: 2 * time.Second,
+ wantExpire: false,
+ },
+ {
+ name: "base_zero_uses_remaining",
+ base: 0,
+ deadline: now.Add(5 * time.Second),
+ want: 5 * time.Second,
+ wantExpire: false,
+ },
+ {
+ name: "expired_deadline",
+ base: 10 * time.Second,
+ deadline: now.Add(-time.Millisecond),
+ want: 0,
+ wantExpire: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, expired := openAIWSIngressResolveDrainReadTimeout(tt.base, tt.deadline, now)
+ require.Equal(t, tt.want, got)
+ require.Equal(t, tt.wantExpire, expired)
+ })
+ }
+}
+
+func TestOpenAIWSIngressTurnAbortDispositionForReason(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ in openAIWSIngressTurnAbortReason
+ want openAIWSIngressTurnAbortDisposition
+ }{
+ {
+ name: "continue turn on previous response mismatch",
+ in: openAIWSIngressTurnAbortReasonPreviousResponse,
+ want: openAIWSIngressTurnAbortDispositionContinueTurn,
+ },
+ {
+ name: "continue turn on tool output mismatch",
+ in: openAIWSIngressTurnAbortReasonToolOutput,
+ want: openAIWSIngressTurnAbortDispositionContinueTurn,
+ },
+ {
+ name: "continue turn on upstream error event",
+ in: openAIWSIngressTurnAbortReasonUpstreamError,
+ want: openAIWSIngressTurnAbortDispositionContinueTurn,
+ },
+ {
+ name: "close gracefully on context canceled",
+ in: openAIWSIngressTurnAbortReasonContextCanceled,
+ want: openAIWSIngressTurnAbortDispositionCloseGracefully,
+ },
+ {
+ name: "close gracefully on client closed",
+ in: openAIWSIngressTurnAbortReasonClientClosed,
+ want: openAIWSIngressTurnAbortDispositionCloseGracefully,
+ },
+ {
+ name: "default fail request on unknown reason",
+ in: openAIWSIngressTurnAbortReasonUnknown,
+ want: openAIWSIngressTurnAbortDispositionFailRequest,
+ },
+ {
+ name: "default fail request on write upstream reason",
+ in: openAIWSIngressTurnAbortReasonWriteUpstream,
+ want: openAIWSIngressTurnAbortDispositionFailRequest,
+ },
+ {
+ name: "default fail request on read upstream reason",
+ in: openAIWSIngressTurnAbortReasonReadUpstream,
+ want: openAIWSIngressTurnAbortDispositionFailRequest,
+ },
+ {
+ name: "default fail request on write client reason",
+ in: openAIWSIngressTurnAbortReasonWriteClient,
+ want: openAIWSIngressTurnAbortDispositionFailRequest,
+ },
+ {
+ name: "continue turn on upstream restart",
+ in: openAIWSIngressTurnAbortReasonUpstreamRestart,
+ want: openAIWSIngressTurnAbortDispositionContinueTurn,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.want, openAIWSIngressTurnAbortDispositionForReason(tt.in))
+ })
+ }
+}
+
func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) {
t.Parallel()
@@ -254,12 +676,12 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
want: false,
},
{
- name: "skip_on_first_turn",
+ name: "infer_on_first_turn_when_expected_previous_exists",
storeDisabled: true,
turn: 1,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
- want: false,
+ want: true,
},
{
name: "skip_without_function_call_output",
@@ -312,6 +734,145 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
}
}
+func TestShouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ storeDisabled bool
+ hasFunctionCallOutput bool
+ previousResponseID string
+ hasToolOutputContext bool
+ want bool
+ }{
+ {
+ name: "reject_when_store_disabled_and_missing_prev_without_context",
+ storeDisabled: true,
+ hasFunctionCallOutput: true,
+ previousResponseID: "",
+ hasToolOutputContext: false,
+ want: true,
+ },
+ {
+ name: "skip_when_store_enabled",
+ storeDisabled: false,
+ hasFunctionCallOutput: true,
+ previousResponseID: "",
+ hasToolOutputContext: false,
+ want: false,
+ },
+ {
+ name: "skip_when_previous_response_id_exists",
+ storeDisabled: true,
+ hasFunctionCallOutput: true,
+ previousResponseID: "resp_1",
+ hasToolOutputContext: false,
+ want: false,
+ },
+ {
+ name: "skip_when_has_tool_output_context",
+ storeDisabled: true,
+ hasFunctionCallOutput: true,
+ previousResponseID: "",
+ hasToolOutputContext: true,
+ want: false,
+ },
+ {
+ name: "skip_when_no_function_call_output",
+ storeDisabled: true,
+ hasFunctionCallOutput: false,
+ previousResponseID: "",
+ hasToolOutputContext: false,
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(
+ tt.storeDisabled,
+ tt.hasFunctionCallOutput,
+ tt.previousResponseID,
+ tt.hasToolOutputContext,
+ )
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestOpenAIWSHasToolOutputContextInPayload(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ payload []byte
+ expectedCallIDs []string
+ wantHasToolCall bool
+ wantHasItemReference bool
+ }{
+ {
+ name: "empty_payload",
+ payload: nil,
+ expectedCallIDs: []string{"call_1"},
+ wantHasToolCall: false,
+ wantHasItemReference: false,
+ },
+ {
+ name: "has_tool_call_context",
+ payload: []byte(`{"input":[{"type":"tool_call","call_id":"call_1"},{"type":"function_call_output","call_id":"call_1"}]}`),
+ expectedCallIDs: []string{"call_1"},
+ wantHasToolCall: true,
+ wantHasItemReference: false,
+ },
+ {
+ name: "has_function_call_context",
+ payload: []byte(`{"input":[{"type":"function_call","call_id":"call_1"},{"type":"function_call_output","call_id":"call_1"}]}`),
+ expectedCallIDs: []string{"call_1"},
+ wantHasToolCall: true,
+ wantHasItemReference: false,
+ },
+ {
+ name: "tool_call_without_call_id_is_not_context",
+ payload: []byte(`{"input":[{"type":"tool_call"},{"type":"function_call_output","call_id":"call_1"}]}`),
+ expectedCallIDs: []string{"call_1"},
+ wantHasToolCall: false,
+ wantHasItemReference: false,
+ },
+ {
+ name: "has_item_reference_for_all_function_call_outputs",
+ payload: []byte(`{"input":[{"type":"item_reference","id":"call_1"},{"type":"item_reference","id":"call_2"},{"type":"function_call_output","call_id":"call_1"},{"type":"function_call_output","call_id":"call_2"}]}`),
+ expectedCallIDs: []string{"call_1", "call_2"},
+ wantHasToolCall: false,
+ wantHasItemReference: true,
+ },
+ {
+ name: "missing_item_reference_for_some_call_ids",
+ payload: []byte(`{"input":[{"type":"item_reference","id":"call_1"},{"type":"function_call_output","call_id":"call_1"},{"type":"function_call_output","call_id":"call_2"}]}`),
+ expectedCallIDs: []string{"call_1", "call_2"},
+ wantHasToolCall: false,
+ wantHasItemReference: false,
+ },
+ {
+ name: "ignores_non_array_input",
+ payload: []byte(`{"input":"bad"}`),
+ expectedCallIDs: []string{"call_1"},
+ wantHasToolCall: false,
+ wantHasItemReference: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.wantHasToolCall, openAIWSHasToolCallContextInPayload(tt.payload))
+ require.Equal(t, tt.wantHasItemReference, openAIWSHasItemReferenceForAllFunctionCallOutputsInPayload(tt.payload, tt.expectedCallIDs))
+ })
+ }
+}
+
func TestOpenAIWSInputIsPrefixExtended(t *testing.T) {
t.Parallel()
@@ -529,7 +1090,7 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
}`)
t.Run("strict_incremental_keep", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false, nil, nil)
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "strict_incremental_ok", reason)
@@ -537,28 +1098,28 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
t.Run("missing_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_previous_response_id", reason)
})
t.Run("missing_last_turn_response_id", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false, nil, nil)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_last_turn_response_id", reason)
})
t.Run("previous_response_id_mismatch", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false, nil, nil)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "previous_response_id_mismatch", reason)
})
t.Run("missing_previous_turn_payload", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false, nil, nil)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_previous_turn_payload", reason)
@@ -573,7 +1134,7 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
"previous_response_id":"resp_turn_1",
"input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]
}`)
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "non_input_changed", reason)
@@ -588,7 +1149,7 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
"previous_response_id":"resp_turn_1",
"input":[{"type":"input_text","text":"different"}]
}`)
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil)
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "strict_incremental_ok", reason)
@@ -602,21 +1163,63 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
"previous_response_id":"resp_external",
"input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
}`)
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true, nil, []string{"call_1"})
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "has_function_call_output", reason)
})
+ t.Run("function_call_output_pending_call_id_match_keeps_previous_response_id", func(t *testing.T) {
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "store":false,
+ "previous_response_id":"resp_turn_1",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(
+ previousPayload,
+ payload,
+ "resp_turn_1",
+ true,
+ []string{"call_1"},
+ []string{"call_1"},
+ )
+ require.NoError(t, err)
+ require.True(t, keep)
+ require.Equal(t, "function_call_output_call_id_match", reason)
+ })
+
+ t.Run("function_call_output_pending_call_id_mismatch_drops_previous_response_id", func(t *testing.T) {
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "store":false,
+ "previous_response_id":"resp_turn_1",
+ "input":[{"type":"function_call_output","call_id":"call_other","output":"ok"}]
+ }`)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(
+ previousPayload,
+ payload,
+ "resp_turn_1",
+ true,
+ []string{"call_1"},
+ []string{"call_other"},
+ )
+ require.NoError(t, err)
+ require.False(t, keep)
+ require.Equal(t, "function_call_output_call_id_mismatch", reason)
+ })
+
t.Run("non_input_compare_error", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false, nil, nil)
require.Error(t, err)
require.False(t, keep)
require.Equal(t, "non_input_compare_error", reason)
})
t.Run("current_payload_compare_error", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false, nil, nil)
require.Error(t, err)
require.False(t, keep)
require.Equal(t, "non_input_compare_error", reason)
@@ -670,6 +1273,171 @@ func TestBuildOpenAIWSReplayInputSequence(t *testing.T) {
require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
})
+
+ t.Run("replay_input_limited_by_bytes_keeps_newest_items", func(t *testing.T) {
+ makeItem := func(text string) json.RawMessage {
+ raw, err := json.Marshal(map[string]any{
+ "type": "input_text",
+ "text": text,
+ })
+ require.NoError(t, err)
+ return json.RawMessage(raw)
+ }
+ largeA := strings.Repeat("a", openAIWSIngressReplayInputMaxBytes/2)
+ largeB := strings.Repeat("b", openAIWSIngressReplayInputMaxBytes/2)
+ largeC := strings.Repeat("c", openAIWSIngressReplayInputMaxBytes/2)
+ previousLarge := []json.RawMessage{
+ makeItem(largeA),
+ makeItem(largeB),
+ }
+ currentPayload, err := json.Marshal(map[string]any{
+ "previous_response_id": "resp_1",
+ "input": []map[string]any{
+ {"type": "input_text", "text": largeC},
+ },
+ })
+ require.NoError(t, err)
+
+ items, exists, err := buildOpenAIWSReplayInputSequence(
+ previousLarge,
+ true,
+ currentPayload,
+ true,
+ )
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.GreaterOrEqual(t, len(items), 1)
+ require.Equal(t, largeC, gjson.GetBytes(items[len(items)-1], "text").String(), "latest item should always be preserved")
+ require.Less(t, len(items), 3, "oversized replay input should be truncated")
+ })
+
+ t.Run("replay_input_limited_by_bytes_still_keeps_single_oversized_latest_item", func(t *testing.T) {
+ tooLargeText := strings.Repeat("z", openAIWSIngressReplayInputMaxBytes+1024)
+ currentPayload, err := json.Marshal(map[string]any{
+ "input": []map[string]any{
+ {"type": "input_text", "text": tooLargeText},
+ },
+ })
+ require.NoError(t, err)
+
+ items, exists, err := buildOpenAIWSReplayInputSequence(
+ nil,
+ false,
+ currentPayload,
+ false,
+ )
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.Len(t, items, 1)
+ require.Equal(t, tooLargeText, gjson.GetBytes(items[0], "text").String())
+ })
+}
+
+func TestOpenAIWSInputAppearsEditedFromPreviousFullInput(t *testing.T) {
+ t.Parallel()
+
+ makeItems := func(values ...string) []json.RawMessage {
+ items := make([]json.RawMessage, 0, len(values))
+ for _, v := range values {
+ raw, err := json.Marshal(map[string]any{
+ "type": "input_text",
+ "text": v,
+ })
+ require.NoError(t, err)
+ items = append(items, json.RawMessage(raw))
+ }
+ return items
+ }
+
+ previous := makeItems("hello", "world")
+
+ t.Run("skip_when_no_previous_response_id", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`),
+ false,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("skip_when_previous_full_input_missing", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ nil,
+ false,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("error_when_current_payload_invalid", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[}`),
+ true,
+ )
+ require.Error(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("skip_when_current_input_missing", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"previous_response_id":"resp_1"}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("skip_when_previous_len_lt_2", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ makeItems("hello"),
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("skip_when_current_shorter_than_previous", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("skip_when_current_has_previous_prefix", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"},{"type":"input_text","text":"new"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("detect_when_current_is_full_snapshot_edit", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.True(t, edited)
+ })
}
func TestSetOpenAIWSPayloadInputSequence(t *testing.T) {
@@ -712,3 +1480,147 @@ func TestCloneOpenAIWSRawMessages(t *testing.T) {
require.Len(t, cloned, 0)
})
}
+
+// ---------------------------------------------------------------------------
+// TestInjectPreviousResponseIDForFunctionCallOutput
+// 端到端测试:当客户端发送 function_call_output 但未携带 previous_response_id 时,
+// Gateway 应主动注入 lastTurnResponseID,避免上游返回 tool_output_not_found 错误。
+// ---------------------------------------------------------------------------
+
+func TestInjectPreviousResponseIDForFunctionCallOutput(t *testing.T) {
+ t.Parallel()
+
+ // 辅助函数:模拟 forwarder 中的注入逻辑
+ // 返回 (注入后的 payload, 注入后的 previousResponseID, 是否执行了注入)
+ simulateInject := func(
+ storeDisabled bool,
+ turn int,
+ payload []byte,
+ expectedPrev string,
+ ) ([]byte, string, bool) {
+ currentPreviousResponseID := ""
+ prev := gjson.GetBytes(payload, "previous_response_id")
+ if prev.Exists() {
+ currentPreviousResponseID = strings.TrimSpace(prev.String())
+ }
+ hasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
+
+ if shouldInferIngressFunctionCallOutputPreviousResponseID(
+ storeDisabled, turn, hasFunctionCallOutput, currentPreviousResponseID, expectedPrev,
+ ) {
+ injected, err := setPreviousResponseIDToRawPayload(payload, expectedPrev)
+ if err != nil {
+ return payload, currentPreviousResponseID, false
+ }
+ return injected, expectedPrev, true
+ }
+ return payload, currentPreviousResponseID, false
+ }
+
+ t.Run("inject_when_function_call_output_without_prev_id", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"function_call_output","call_id":"call_abc123","output":"result"}]}`)
+ updated, prevID, injected := simulateInject(true, 2, payload, "resp_last_turn")
+
+ require.True(t, injected, "应该执行注入")
+ require.Equal(t, "resp_last_turn", prevID)
+ require.Equal(t, "resp_last_turn", gjson.GetBytes(updated, "previous_response_id").String())
+ // 验证原始 input 保持不变
+ require.Equal(t, "call_abc123", gjson.GetBytes(updated, `input.0.call_id`).String())
+ require.Equal(t, "function_call_output", gjson.GetBytes(updated, `input.0.type`).String())
+ require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String())
+ })
+
+ t.Run("skip_when_prev_id_already_present", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_client","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, prevID, injected := simulateInject(true, 2, payload, "resp_last_turn")
+
+ require.False(t, injected, "客户端已携带 previous_response_id,不应注入")
+ require.Equal(t, "resp_client", prevID)
+ })
+
+ t.Run("skip_when_store_enabled", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, _, injected := simulateInject(false, 2, payload, "resp_last_turn")
+
+ require.False(t, injected, "store 未禁用时不应注入")
+ })
+
+ t.Run("skip_when_no_function_call_output", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"input_text","text":"hello"}]}`)
+ _, _, injected := simulateInject(true, 2, payload, "resp_last_turn")
+
+ require.False(t, injected, "没有 function_call_output 时不应注入")
+ })
+
+ t.Run("skip_when_expected_prev_empty", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, _, injected := simulateInject(true, 2, payload, "")
+
+ require.False(t, injected, "没有 expectedPrev 时不应注入")
+ })
+
+ t.Run("inject_preserves_multiple_function_call_outputs", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"a"},{"type":"function_call_output","call_id":"call_2","output":"b"}]}`)
+ updated, prevID, injected := simulateInject(true, 5, payload, "resp_multi")
+
+ require.True(t, injected)
+ require.Equal(t, "resp_multi", prevID)
+ require.Equal(t, "resp_multi", gjson.GetBytes(updated, "previous_response_id").String())
+ outputs := gjson.GetBytes(updated, `input.#(type=="function_call_output")#.call_id`).Array()
+ require.Len(t, outputs, 2)
+ require.Equal(t, "call_1", outputs[0].String())
+ require.Equal(t, "call_2", outputs[1].String())
+ })
+
+ t.Run("inject_on_first_turn_with_expected_prev", func(t *testing.T) {
+ t.Parallel()
+ // turn=1 但有 expectedPrev(可能来自 session state store 恢复),应注入
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, _, injected := simulateInject(true, 1, payload, "resp_restored")
+
+ require.True(t, injected, "turn=1 且有 expectedPrev 时应注入")
+ })
+
+ t.Run("inject_updates_payload_bytes_correctly", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ updated, _, injected := simulateInject(true, 3, payload, "resp_check_size")
+
+ require.True(t, injected)
+ // 注入后 payload 长度应增加(包含了新的 previous_response_id 字段)
+ require.Greater(t, len(updated), len(payload))
+ // 验证 JSON 合法性
+ require.True(t, json.Valid(updated), "注入后的 payload 应为合法 JSON")
+ })
+
+ t.Run("skip_when_turn_zero", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, _, injected := simulateInject(true, 0, payload, "resp_1")
+
+ require.False(t, injected, "turn=0 时不应注入")
+ })
+
+ t.Run("inject_with_whitespace_expected_prev", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ // shouldInfer 内部会 trim,所以带空格的 expectedPrev 仍然有效
+ _, _, injected := simulateInject(true, 2, payload, " resp_trimmed ")
+
+ require.True(t, injected, "trim 后非空的 expectedPrev 应触发注入")
+ })
+
+ t.Run("skip_when_prev_id_is_whitespace_only", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, _, injected := simulateInject(true, 2, payload, " ")
+
+ require.False(t, injected, "纯空白的 expectedPrev 不应触发注入")
+ })
+}
diff --git a/backend/internal/service/openai_ws_forwarder_panic_test.go b/backend/internal/service/openai_ws_forwarder_panic_test.go
new file mode 100644
index 000000000..0ceacb187
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_panic_test.go
@@ -0,0 +1,107 @@
+package service
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type openAIWSPanicResolver struct{}
+
+func (openAIWSPanicResolver) Resolve(account *Account) OpenAIWSProtocolDecision {
+ panic("resolver panic")
+}
+
+type openAIWSPanicStateStore struct {
+ OpenAIWSStateStore
+}
+
+func (openAIWSPanicStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
+ panic("state_store panic")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PanicRecovered(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ svc := buildIngressPolicyTestService(cfg)
+ svc.openaiWSResolver = openAIWSPanicResolver{}
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
+ })
+
+ serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, serverErr, &closeErr)
+ require.Equal(t, coderws.StatusInternalError, closeErr.StatusCode())
+ require.Equal(t, "internal websocket proxy panic", closeErr.Reason())
+}
+
+func TestOpenAIGatewayService_ForwardOpenAIWSV2_PanicRecovered(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ openaiWSStateStore: openAIWSPanicStateStore{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ cache: &stubGatewayCache{},
+ }
+
+ account := &Account{
+ ID: 445,
+ Name: "openai-forwarder-panic",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ }
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ req.Header.Set("session_id", "sess-panic-check")
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ _, err := svc.forwardOpenAIWSV2(
+ context.Background(),
+ ginCtx,
+ account,
+ map[string]any{
+ "model": "gpt-5.1",
+ "input": []any{},
+ },
+ "sk-test",
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ true,
+ true,
+ "gpt-5.1",
+ "gpt-5.1",
+ time.Now(),
+ 1,
+ "",
+ )
+ require.Error(t, err)
+ require.ErrorContains(t, err, "panic recovered")
+}
diff --git a/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go b/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go
new file mode 100644
index 000000000..ed8539f28
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go
@@ -0,0 +1,818 @@
+package service
+
+import (
+ "encoding/json"
+ "errors"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+// ---------------------------------------------------------------------------
+// 改动 1 测试:预防性检测 — store_disabled + function_call_output + 无 previous_response_id
+// 在 sendAndRelay 中提前返回可恢复错误,避免上游先写数据再报错导致无法恢复
+// ---------------------------------------------------------------------------
+
+// TestProactiveToolOutputNotFound_ErrorShape 验证预防性检测生成的错误格式正确:
+// stage = tool_output_not_found、wroteDownstream = false、可被恢复函数识别。
+func TestProactiveToolOutputNotFound_ErrorShape(t *testing.T) {
+ t.Parallel()
+
+ err := wrapOpenAIWSIngressTurnErrorWithPartial(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"),
+ false,
+ nil,
+ )
+
+ // 1. 应被识别为 tool_output_not_found
+ require.True(t, isOpenAIWSIngressToolOutputNotFound(err),
+ "预防性检测错误应被 isOpenAIWSIngressToolOutputNotFound 识别")
+
+ // 2. wroteDownstream 应为 false(尚未发送任何数据)
+ require.False(t, openAIWSIngressTurnWroteDownstream(err),
+ "预防性检测在发送前触发,wroteDownstream 必须为 false")
+
+ // 3. 应被 classifyOpenAIWSIngressTurnAbortReason 归类为 ToolOutput
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.Equal(t, openAIWSIngressTurnAbortReasonToolOutput, reason)
+ require.True(t, expected, "tool_output_not_found 是预期错误")
+
+ // 4. 处置方式应为 ContinueTurn(允许恢复重试)
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition,
+ "tool_output_not_found 应为 ContinueTurn 处置,允许恢复重试")
+
+ // 5. 部分结果应为 nil(因为尚未产生任何上游响应)
+ // partialResult=nil 时 OpenAIWSIngressTurnPartialResult 返回 (nil, false)
+ partial, ok := OpenAIWSIngressTurnPartialResult(err)
+ require.False(t, ok, "预防性检测传入 partialResult=nil,ok 应为 false")
+ require.Nil(t, partial, "预防性检测不应有部分结果")
+}
+
+// TestProactiveToolOutputNotFound_NotPreviousResponseNotFound
+// 验证预防性检测生成的错误不会被误识别为 previous_response_not_found。
+func TestProactiveToolOutputNotFound_NotPreviousResponseNotFound(t *testing.T) {
+ t.Parallel()
+
+ err := wrapOpenAIWSIngressTurnErrorWithPartial(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("proactive tool_output_not_found"),
+ false,
+ nil,
+ )
+
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(err),
+ "tool_output_not_found 不应被误识别为 previous_response_not_found")
+}
+
+// TestProactiveDetection_ConditionMatrix 验证预防性检测的三个触发条件的所有组合。
+// 只有 store_disabled + function_call_output + 无 previous_response_id + 无可关联上下文 才触发。
+func TestProactiveDetection_ConditionMatrix(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ storeDisabled bool
+ hasFunctionCallOutput bool
+ previousResponseID string
+ hasToolOutputContext bool
+ shouldTrigger bool
+ }{
+ {
+ name: "all_conditions_met_should_trigger",
+ storeDisabled: true,
+ hasFunctionCallOutput: true,
+ previousResponseID: "",
+ hasToolOutputContext: false,
+ shouldTrigger: true,
+ },
+ {
+ name: "whitespace_only_previous_response_id_should_trigger",
+ storeDisabled: true,
+ hasFunctionCallOutput: true,
+ previousResponseID: " ",
+ hasToolOutputContext: false,
+ shouldTrigger: true,
+ },
+ {
+ name: "store_enabled_should_not_trigger",
+ storeDisabled: false,
+ hasFunctionCallOutput: true,
+ previousResponseID: "",
+ hasToolOutputContext: false,
+ shouldTrigger: false,
+ },
+ {
+ name: "no_function_call_output_should_not_trigger",
+ storeDisabled: true,
+ hasFunctionCallOutput: false,
+ previousResponseID: "",
+ hasToolOutputContext: false,
+ shouldTrigger: false,
+ },
+ {
+ name: "has_previous_response_id_should_not_trigger",
+ storeDisabled: true,
+ hasFunctionCallOutput: true,
+ previousResponseID: "resp_abc",
+ hasToolOutputContext: false,
+ shouldTrigger: false,
+ },
+ {
+ name: "all_false_should_not_trigger",
+ storeDisabled: false,
+ hasFunctionCallOutput: false,
+ previousResponseID: "resp_abc",
+ hasToolOutputContext: false,
+ shouldTrigger: false,
+ },
+ {
+ name: "store_disabled_no_fco_has_prev_should_not_trigger",
+ storeDisabled: true,
+ hasFunctionCallOutput: false,
+ previousResponseID: "resp_abc",
+ hasToolOutputContext: false,
+ shouldTrigger: false,
+ },
+ {
+ name: "store_enabled_has_fco_has_prev_should_not_trigger",
+ storeDisabled: false,
+ hasFunctionCallOutput: true,
+ previousResponseID: "resp_abc",
+ hasToolOutputContext: false,
+ shouldTrigger: false,
+ },
+ {
+ name: "has_tool_output_context_should_not_trigger",
+ storeDisabled: true,
+ hasFunctionCallOutput: true,
+ previousResponseID: "",
+ hasToolOutputContext: true,
+ shouldTrigger: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ // 模拟 sendAndRelay 中的检测逻辑
+ triggered := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(
+ tt.storeDisabled,
+ tt.hasFunctionCallOutput,
+ tt.previousResponseID,
+ tt.hasToolOutputContext,
+ )
+ require.Equal(t, tt.shouldTrigger, triggered)
+ })
+ }
+}
+
+// TestProactiveDetection_PayloadExtraction 验证从真实 payload 中提取的条件参数
+// 与预防性检测逻辑的配合。确保 payload 解析结果能正确触发或跳过检测。
+func TestProactiveDetection_PayloadExtraction(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ payload string
+ wantHasFCO bool
+ wantPrevID string
+ wantHasContext bool
+ shouldTrigger bool // 假设 storeDisabled=true
+ }{
+ {
+ name: "fco_without_previous_response_id",
+ payload: `{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`,
+ wantHasFCO: true,
+ wantPrevID: "",
+ wantHasContext: false,
+ shouldTrigger: true,
+ },
+ {
+ name: "fco_with_previous_response_id",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`,
+ wantHasFCO: true,
+ wantPrevID: "resp_1",
+ wantHasContext: false,
+ shouldTrigger: false,
+ },
+ {
+ name: "no_fco_without_previous_response_id",
+ payload: `{"type":"response.create","input":[{"type":"input_text","text":"hello"}]}`,
+ wantHasFCO: false,
+ wantPrevID: "",
+ wantHasContext: false,
+ shouldTrigger: false,
+ },
+ {
+ name: "multiple_fco_without_previous_response_id",
+ payload: `{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"r1"},{"type":"function_call_output","call_id":"call_2","output":"r2"}]}`,
+ wantHasFCO: true,
+ wantPrevID: "",
+ wantHasContext: false,
+ shouldTrigger: true,
+ },
+ {
+ name: "fco_with_tool_call_context",
+ payload: `{"type":"response.create","input":[{"type":"tool_call","call_id":"call_1"},{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`,
+ wantHasFCO: true,
+ wantPrevID: "",
+ wantHasContext: true,
+ shouldTrigger: false,
+ },
+ {
+ name: "fco_with_item_reference_context",
+ payload: `{"type":"response.create","input":[{"type":"item_reference","id":"call_1"},{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`,
+ wantHasFCO: true,
+ wantPrevID: "",
+ wantHasContext: true,
+ shouldTrigger: false,
+ },
+ {
+ name: "empty_input_without_previous_response_id",
+ payload: `{"type":"response.create","input":[]}`,
+ wantHasFCO: false,
+ wantPrevID: "",
+ wantHasContext: false,
+ shouldTrigger: false,
+ },
+ {
+ name: "no_input_field",
+ payload: `{"type":"response.create","model":"gpt-5.1"}`,
+ wantHasFCO: false,
+ wantPrevID: "",
+ wantHasContext: false,
+ shouldTrigger: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(tt.payload)
+
+ callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)
+ hasFCO := len(callIDs) > 0
+ prevID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ hasContext := openAIWSHasToolCallContextInPayload(payload) ||
+ openAIWSHasItemReferenceForAllFunctionCallOutputsInPayload(payload, callIDs)
+
+ require.Equal(t, tt.wantHasFCO, hasFCO, "hasFunctionCallOutput 不匹配")
+ require.Equal(t, tt.wantPrevID, prevID, "previousResponseID 不匹配")
+ require.Equal(t, tt.wantHasContext, hasContext, "tool output context 检测结果不匹配")
+
+ // 模拟 storeDisabled=true 时的检测
+ triggered := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(
+ true,
+ hasFCO,
+ prevID,
+ hasContext,
+ )
+ require.Equal(t, tt.shouldTrigger, triggered, "检测触发结果不匹配")
+ })
+ }
+}
+
+// TestProactiveDetection_RecoveryChainIntegration 验证预防性检测错误能被
+// recoverIngressPrevResponseNotFound 的恢复逻辑识别并处理。
+// 这是两个改动之间的集成测试。
+func TestProactiveDetection_RecoveryChainIntegration(t *testing.T) {
+ t.Parallel()
+
+ // 模拟预防性检测返回的错误
+ proactiveErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"),
+ false,
+ nil,
+ )
+
+ // 1. 恢复函数入口检测:isToolOutputMissing 应为 true
+ require.True(t, isOpenAIWSIngressToolOutputNotFound(proactiveErr),
+ "预防性检测错误应通过 isToolOutputMissing 检查")
+
+ // 2. isPrevNotFound 应为 false(确保不走 previous_response_not_found 分支)
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(proactiveErr),
+ "预防性检测错误不应走 previous_response_not_found 分支")
+
+ // 3. wroteDownstream=false 确保可以安全重试
+ require.False(t, openAIWSIngressTurnWroteDownstream(proactiveErr),
+ "预防性检测在发送前触发,wroteDownstream 必须为 false")
+}
+
+// ---------------------------------------------------------------------------
+// 改动 2 测试:恢复逻辑修复 — previous_response_id 已缺失时跳过 drop 步骤
+// ---------------------------------------------------------------------------
+
+// TestToolOutputRecovery_PreviousResponseIDAlreadyEmpty 验证核心修复:
+// 当 payload 中本就不存在 previous_response_id 时,跳过 drop 步骤,
+// 直接进入 setOpenAIWSPayloadInputSequence。
+func TestToolOutputRecovery_PreviousResponseIDAlreadyEmpty(t *testing.T) {
+ t.Parallel()
+
+ // 场景:预防性检测触发后,payload 中不存在 previous_response_id
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ require.Empty(t, currentPreviousResponseID, "payload 中不应存在 previous_response_id")
+
+ // 修复后的逻辑:currentPreviousResponseID 为空时跳过 drop
+ // 直接使用原始 payload 作为 updatedPayload
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ // 此分支不应执行
+ t.Fatal("不应进入 drop 分支")
+ }
+
+ // 验证跳过 drop 后,payload 仍然有效且可以继续处理
+ require.True(t, gjson.ValidBytes(updatedPayload), "payload 应为有效 JSON")
+ require.True(t, gjson.GetBytes(updatedPayload, `input.#(type=="function_call_output")`).Exists(),
+ "function_call_output 应保持不变")
+
+ // 模拟 setOpenAIWSPayloadInputSequence:使用 replay input 替换
+ replayInput := []json.RawMessage{
+ json.RawMessage(`{"type":"function_call_output","call_id":"call_1","output":"ok"}`),
+ }
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ replayInput,
+ true,
+ )
+ require.NoError(t, setInputErr, "setOpenAIWSPayloadInputSequence 应成功")
+ require.True(t, gjson.ValidBytes(updatedWithInput), "更新后的 payload 应为有效 JSON")
+ require.True(t, gjson.GetBytes(updatedWithInput, `input.#(type=="function_call_output")`).Exists(),
+ "replay input 中的 function_call_output 应存在")
+}
+
+// TestToolOutputRecovery_PreviousResponseIDExists 验证修复不影响原有逻辑:
+// 当 payload 中存在 previous_response_id 时,仍然执行 drop。
+func TestToolOutputRecovery_PreviousResponseIDExists(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ require.NotEmpty(t, currentPreviousResponseID, "payload 中应存在 previous_response_id")
+
+ // 修复后的逻辑:currentPreviousResponseID 不为空时执行 drop
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErr, "drop 操作不应出错")
+ require.True(t, removed, "previous_response_id 应被成功移除")
+ updatedPayload = dropped
+ }
+
+ // 验证 drop 后 previous_response_id 已移除
+ require.False(t, gjson.GetBytes(updatedPayload, "previous_response_id").Exists(),
+ "previous_response_id 应被移除")
+
+ // 验证 function_call_output 仍然存在
+ require.True(t, gjson.GetBytes(updatedPayload, `input.#(type=="function_call_output")`).Exists(),
+ "function_call_output 应保持不变")
+
+ // 模拟 setOpenAIWSPayloadInputSequence
+ replayInput := []json.RawMessage{
+ json.RawMessage(`{"type":"function_call_output","call_id":"call_1","output":"ok"}`),
+ }
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ replayInput,
+ true,
+ )
+ require.NoError(t, setInputErr, "setOpenAIWSPayloadInputSequence 应成功")
+ require.True(t, gjson.ValidBytes(updatedWithInput), "更新后的 payload 应为有效 JSON")
+}
+
+// TestToolOutputRecovery_DropError 验证当 drop 操作出错时返回 false。
+func TestToolOutputRecovery_DropError(t *testing.T) {
+ t.Parallel()
+
+ // 使用非法 JSON 模拟 drop 错误
+ currentPreviousResponseID := "resp_stale"
+
+ if currentPreviousResponseID != "" {
+ // 使用含有 previous_response_id 但格式异常的 payload 触发 drop 错误
+ badPayload := []byte(`not-valid-json`)
+ _, _, dropErr := dropPreviousResponseIDFromRawPayload(badPayload)
+ // 无论 drop 是否报错,验证错误分支的行为
+ if dropErr != nil {
+ // drop 出错时应返回 false(跳过恢复)
+ t.Log("drop 出错场景:恢复应返回 false")
+ }
+ }
+}
+
+// TestToolOutputRecovery_DropNotRemoved 验证当 drop 操作返回 removed=false 时的行为。
+// 修复后:如果 currentPreviousResponseID 不为空但 drop 未能移除(removed=false),
+// updatedPayload 仍使用原始 payload(不更新),继续执行后续流程。
+func TestToolOutputRecovery_DropNotRemoved(t *testing.T) {
+ t.Parallel()
+
+ // 构造一个含有 previous_response_id 的 payload
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ require.NotEmpty(t, currentPreviousResponseID)
+
+ // 使用自定义 delete 函数模拟 removed=false 的场景
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ // 正常 drop 应该成功
+ dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErr)
+ if removed {
+ updatedPayload = dropped
+ }
+ // 无论 removed 与否,后续逻辑都应继续执行(不再提前 return false)
+ }
+
+ // 验证可以继续执行 setOpenAIWSPayloadInputSequence
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ nil,
+ false, // fullInputExists=false 时直接返回原 payload
+ )
+ require.NoError(t, setInputErr)
+ require.Equal(t, string(updatedPayload), string(updatedWithInput))
+}
+
+// TestToolOutputRecovery_WithDeleteFn_Removed_False 使用 WithDeleteFn 接口模拟
+// drop 函数返回 removed=false 的边界情况。
+func TestToolOutputRecovery_WithDeleteFn_Removed_False(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ currentPreviousResponseID := "resp_1"
+
+ // 使用 noop delete 函数模拟 removed=false
+ noopDelete := func(data []byte, _ string) ([]byte, error) {
+ return data, nil // 不修改 payload,sjson 会返回原样
+ }
+
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ dropped, removed, dropErr := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, noopDelete)
+ require.NoError(t, dropErr)
+ // noop delete 不移除字段,但 previous_response_id 仍存在
+ // removed 取决于 payload 比较
+ if removed {
+ updatedPayload = dropped
+ }
+ }
+
+ // 无论 removed 与否,都应继续(不提前 return false)
+ require.True(t, gjson.ValidBytes(updatedPayload))
+}
+
+// ---------------------------------------------------------------------------
+// 端到端场景测试:预防性检测 → 恢复逻辑 完整链路
+// ---------------------------------------------------------------------------
+
+// TestEndToEnd_ProactiveDetection_RecoveryWithEmptyPreviousResponseID
+// 完整模拟:store_disabled 模式下,客户端发送 function_call_output 但
+// 未携带 previous_response_id → 预防性检测触发 → 恢复逻辑跳过 drop → 重放。
+func TestEndToEnd_ProactiveDetection_RecoveryWithEmptyPreviousResponseID(t *testing.T) {
+ t.Parallel()
+
+ // Step 1: 构造客户端 payload(无 previous_response_id)
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[
+ {"type":"function_call_output","call_id":"call_abc","output":"{\"result\":\"ok\"}"}
+ ]
+ }`)
+
+ // Step 2: 提取条件参数(模拟 sendAndRelay 中的变量赋值)
+ turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ turnFunctionCallOutputCallIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)
+ turnHasFunctionCallOutput := len(turnFunctionCallOutputCallIDs) > 0
+ turnStoreDisabled := true // 模拟 store_disabled 模式
+
+ require.Empty(t, turnPreviousResponseID)
+ require.True(t, turnHasFunctionCallOutput)
+ require.Equal(t, []string{"call_abc"}, turnFunctionCallOutputCallIDs)
+
+ // Step 3: 预防性检测触发
+ shouldTrigger := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(
+ turnStoreDisabled,
+ turnHasFunctionCallOutput,
+ turnPreviousResponseID,
+ false,
+ )
+ require.True(t, shouldTrigger, "预防性检测应触发")
+
+ // Step 4: 构造预防性检测错误
+ proactiveErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"),
+ false,
+ nil,
+ )
+
+ // Step 5: 验证恢复入口条件
+ isToolOutputMissing := isOpenAIWSIngressToolOutputNotFound(proactiveErr)
+ require.True(t, isToolOutputMissing, "恢复入口应识别 tool_output_not_found")
+
+ // Step 6: 模拟恢复逻辑(改动 2 的核心路径)
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ require.Empty(t, currentPreviousResponseID, "payload 中无 previous_response_id")
+
+ // 改动 2:跳过 drop 步骤
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ t.Fatal("不应进入 drop 分支")
+ }
+
+ // Step 7: 执行 setOpenAIWSPayloadInputSequence
+ replayInput := []json.RawMessage{
+ json.RawMessage(`{"type":"function_call_output","call_id":"call_abc","output":"{\"result\":\"ok\"}"}`),
+ }
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ replayInput,
+ true,
+ )
+ require.NoError(t, setInputErr, "setOpenAIWSPayloadInputSequence 应成功")
+ require.True(t, gjson.ValidBytes(updatedWithInput), "结果应为有效 JSON")
+
+ // Step 8: 验证最终 payload
+ require.False(t, gjson.GetBytes(updatedWithInput, "previous_response_id").Exists(),
+ "最终 payload 不应含 previous_response_id")
+ require.True(t, gjson.GetBytes(updatedWithInput, `input.#(type=="function_call_output")`).Exists(),
+ "最终 payload 应含 function_call_output")
+ require.Equal(t, "call_abc",
+ gjson.GetBytes(updatedWithInput, `input.#(type=="function_call_output").call_id`).String(),
+ "call_id 应保持不变")
+}
+
+// TestEndToEnd_ProactiveDetection_StoreEnabledBypasses 对照测试:
+// store 未禁用时,即使 function_call_output 无 previous_response_id,也不触发预防性检测。
+func TestEndToEnd_ProactiveDetection_StoreEnabledBypasses(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ turnHasFunctionCallOutput := len(openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)) > 0
+ turnStoreDisabled := false // store 未禁用
+
+ shouldTrigger := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(
+ turnStoreDisabled,
+ turnHasFunctionCallOutput,
+ turnPreviousResponseID,
+ false,
+ )
+ require.False(t, shouldTrigger, "store 未禁用时不应触发预防性检测")
+}
+
+// TestEndToEnd_ProactiveDetection_WithPreviousResponseIDBypasses 对照测试:
+// store_disabled 但 payload 含有 previous_response_id 时,不触发预防性检测。
+func TestEndToEnd_ProactiveDetection_WithPreviousResponseIDBypasses(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "previous_response_id":"resp_valid",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ turnHasFunctionCallOutput := len(openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)) > 0
+ turnStoreDisabled := true
+
+ shouldTrigger := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(
+ turnStoreDisabled,
+ turnHasFunctionCallOutput,
+ turnPreviousResponseID,
+ false,
+ )
+ require.False(t, shouldTrigger, "有 previous_response_id 时不应触发预防性检测")
+}
+
+// TestEndToEnd_NormalTextInput_NeverTriggersProactive 对照测试:
+// 普通文本输入(无 function_call_output)在任何模式下都不触发预防性检测。
+func TestEndToEnd_NormalTextInput_NeverTriggersProactive(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"input_text","text":"hello world"}]
+ }`)
+
+ turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ turnHasFunctionCallOutput := len(openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)) > 0
+ turnStoreDisabled := true
+
+ shouldTrigger := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(
+ turnStoreDisabled,
+ turnHasFunctionCallOutput,
+ turnPreviousResponseID,
+ false,
+ )
+ require.False(t, shouldTrigger, "普通文本输入不应触发预防性检测")
+}
+
+// ---------------------------------------------------------------------------
+// 改动 2 与原有 drop 路径的兼容性测试
+// ---------------------------------------------------------------------------
+
+// TestToolOutputRecovery_OriginalFlow_WithPreviousResponseID_StillWorks
+// 验证改动 2 不破坏原有的 tool_output_not_found 恢复流程:
+// 当 previous_response_id 存在时,仍然执行 drop + setInputSequence。
+func TestToolOutputRecovery_OriginalFlow_WithPreviousResponseID_StillWorks(t *testing.T) {
+ t.Parallel()
+
+ // 原有场景:用户按 ESC 取消 function_call 后重新发送消息,
+ // payload 含有过时的 previous_response_id
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale",
+ "input":[
+ {"type":"function_call_output","call_id":"call_canceled","output":"canceled"},
+ {"type":"input_text","text":"new message"}
+ ]
+ }`)
+
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ require.Equal(t, "resp_stale", currentPreviousResponseID)
+
+ // 改动后的逻辑
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErr)
+ require.True(t, removed)
+ updatedPayload = dropped
+ }
+
+ // 验证 drop 成功
+ require.False(t, gjson.GetBytes(updatedPayload, "previous_response_id").Exists())
+
+ // 验证 input 保持不变
+ inputCount := gjson.GetBytes(updatedPayload, "input.#").Int()
+ require.Equal(t, int64(2), inputCount, "input 数组应保持不变")
+
+ // setOpenAIWSPayloadInputSequence 使用 replay input
+ replayInput := []json.RawMessage{
+ json.RawMessage(`{"type":"input_text","text":"new message"}`),
+ }
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ replayInput,
+ true,
+ )
+ require.NoError(t, setInputErr)
+ require.True(t, gjson.ValidBytes(updatedWithInput))
+ require.Equal(t, int64(1), gjson.GetBytes(updatedWithInput, "input.#").Int(),
+ "replay input 应替换原始 input")
+}
+
+// TestToolOutputRecovery_SetInputSequence_NoReplayInput 验证当没有 replay input 时,
+// setOpenAIWSPayloadInputSequence 直接返回原 payload。
+func TestToolOutputRecovery_SetInputSequence_NoReplayInput(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+
+ // fullInputExists=false 时直接返回原 payload
+ result, err := setOpenAIWSPayloadInputSequence(payload, nil, false)
+ require.NoError(t, err)
+ require.Equal(t, string(payload), string(result))
+}
+
+// TestToolOutputRecovery_SetInputSequence_EmptyReplayInput 验证 replay input 为空数组时
+// 仍然能正确设置(替换为空数组)。
+func TestToolOutputRecovery_SetInputSequence_EmptyReplayInput(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+
+ result, err := setOpenAIWSPayloadInputSequence(payload, []json.RawMessage{}, true)
+ require.NoError(t, err)
+ require.True(t, gjson.ValidBytes(result))
+ require.Equal(t, int64(0), gjson.GetBytes(result, "input.#").Int(),
+ "空 replay input 应替换为空数组")
+}
+
+// ---------------------------------------------------------------------------
+// 边界条件与防御性测试
+// ---------------------------------------------------------------------------
+
+// TestProactiveDetection_EmptyPayload 验证空 payload 的行为。
+func TestProactiveDetection_EmptyPayload(t *testing.T) {
+ t.Parallel()
+
+ emptyPayloads := [][]byte{
+ nil,
+ {},
+ []byte(""),
+ }
+
+ for i, payload := range emptyPayloads {
+ callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)
+ require.Empty(t, callIDs, "空 payload[%d] 不应提取到 call_id", i)
+
+ prevID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ require.Empty(t, prevID, "空 payload[%d] 不应有 previous_response_id", i)
+ }
+}
+
+// TestProactiveDetection_MalformedJSON 验证非法 JSON 不会导致 panic。
+func TestProactiveDetection_MalformedJSON(t *testing.T) {
+ t.Parallel()
+
+ badPayloads := [][]byte{
+ []byte(`{invalid json`),
+ []byte(`{"type":"response.create","input":"not_array"}`),
+ []byte(`{"type":"response.create","input":123}`),
+ }
+
+ for i, payload := range badPayloads {
+ // 不应 panic
+ callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)
+ _ = callIDs
+ prevID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ _ = prevID
+
+ // 模拟检测逻辑不应 panic
+ hasFCO := len(callIDs) > 0
+ hasContext := openAIWSHasToolCallContextInPayload(payload) ||
+ openAIWSHasItemReferenceForAllFunctionCallOutputsInPayload(payload, callIDs)
+ triggered := shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(
+ true,
+ hasFCO,
+ prevID,
+ hasContext,
+ )
+ _ = triggered
+ t.Logf("badPayload[%d]: hasFCO=%v, triggered=%v", i, hasFCO, triggered)
+ }
+}
+
+// TestToolOutputRecovery_OldCodeWouldFail_Regression 回归测试:
+// 验证旧代码在 previous_response_id 为空时会因 removed=false 而失败。
+// 新代码跳过 drop 步骤后应成功。
+func TestToolOutputRecovery_OldCodeWouldFail_Regression(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ // 旧代码行为:直接调用 dropPreviousResponseIDFromRawPayload
+ _, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErr)
+ require.False(t, removed, "payload 中无 previous_response_id 时 drop 返回 removed=false")
+
+ // 旧代码:!removed → return false(恢复失败)
+ oldCodeResult := dropErr == nil && removed // 旧条件:dropErr != nil || !removed
+ require.False(t, oldCodeResult, "旧代码在此场景会失败(return false)")
+
+ // 新代码行为:先检查 currentPreviousResponseID,为空时跳过 drop
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ updatedPayload := payload
+ newCodeSkippedDrop := false
+ if currentPreviousResponseID != "" {
+ dropped, removedNew, dropErrNew := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErrNew)
+ if removedNew {
+ updatedPayload = dropped
+ }
+ } else {
+ newCodeSkippedDrop = true
+ }
+
+ require.True(t, newCodeSkippedDrop, "新代码应跳过 drop 步骤")
+ require.Equal(t, string(payload), string(updatedPayload), "payload 应保持不变")
+
+ // 新代码继续执行 setOpenAIWSPayloadInputSequence
+ replayInput := []json.RawMessage{
+ json.RawMessage(`{"type":"function_call_output","call_id":"call_1","output":"ok"}`),
+ }
+ result, setErr := setOpenAIWSPayloadInputSequence(updatedPayload, replayInput, true)
+ require.NoError(t, setErr, "新代码应能继续执行 setInputSequence")
+ require.True(t, gjson.ValidBytes(result))
+}
diff --git a/backend/internal/service/openai_ws_forwarder_recovery_test.go b/backend/internal/service/openai_ws_forwarder_recovery_test.go
new file mode 100644
index 000000000..86abd88f0
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_recovery_test.go
@@ -0,0 +1,692 @@
+package service
+
+import (
+ "encoding/json"
+ "errors"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+// ---------------------------------------------------------------------------
+// openAIWSIngressTurnWroteDownstream 辅助函数测试
+// ---------------------------------------------------------------------------
+
+func TestOpenAIWSIngressTurnWroteDownstream(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ want bool
+ }{
+ {
+ name: "nil_error_returns_false",
+ err: nil,
+ want: false,
+ },
+ {
+ name: "plain_error_returns_false",
+ err: errors.New("some random error"),
+ want: false,
+ },
+ {
+ name: "turn_error_wrote_downstream_false",
+ err: wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ false,
+ ),
+ want: false,
+ },
+ {
+ name: "turn_error_wrote_downstream_true",
+ err: wrapOpenAIWSIngressTurnError(
+ "upstream_error_event",
+ errors.New("upstream error"),
+ true,
+ ),
+ want: true,
+ },
+ {
+ name: "turn_error_with_partial_result_wrote_downstream_true",
+ err: wrapOpenAIWSIngressTurnErrorWithPartial(
+ "read_upstream",
+ errors.New("connection reset"),
+ true,
+ &OpenAIForwardResult{RequestID: "resp_partial"},
+ ),
+ want: true,
+ },
+ {
+ name: "turn_error_with_partial_result_wrote_downstream_false",
+ err: wrapOpenAIWSIngressTurnErrorWithPartial(
+ "read_upstream",
+ errors.New("connection reset"),
+ false,
+ &OpenAIForwardResult{RequestID: "resp_partial"},
+ ),
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.want, openAIWSIngressTurnWroteDownstream(tt.err))
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// previous_response_not_found 错误与 ContinueTurn 处置测试
+// ---------------------------------------------------------------------------
+
+func TestPreviousResponseNotFound_ClassifiesAsContinueTurn(t *testing.T) {
+ t.Parallel()
+
+ // previous_response_not_found(wroteDownstream=false)应被归类为 ContinueTurn
+ err := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ false,
+ )
+
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.Equal(t, openAIWSIngressTurnAbortReasonPreviousResponse, reason)
+ require.True(t, expected)
+
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+}
+
+func TestToolOutputNotFound_ClassifiesAsContinueTurn(t *testing.T) {
+ t.Parallel()
+
+ // tool_output_not_found(wroteDownstream=false)应被归类为 ContinueTurn
+ err := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("no tool call found for function call output"),
+ false,
+ )
+
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.Equal(t, openAIWSIngressTurnAbortReasonToolOutput, reason)
+ require.True(t, expected)
+
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+}
+
+// ---------------------------------------------------------------------------
+// function_call_output 与 previous_response_id 语义绑定测试
+// 验证核心修复:带 function_call_output 时不能 drop previous_response_id
+// ---------------------------------------------------------------------------
+
+func TestFunctionCallOutputPayload_HasFunctionCallOutput(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ payload string
+ want bool
+ }{
+ {
+ name: "payload_with_function_call_output",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_abc","output":"ok"}]}`,
+ want: true,
+ },
+ {
+ name: "payload_with_mixed_input_including_function_call_output",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"function_call_output","call_id":"call_abc","output":"ok"}]}`,
+ want: true,
+ },
+ {
+ name: "payload_without_function_call_output",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"}]}`,
+ want: false,
+ },
+ {
+ name: "payload_with_empty_input",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[]}`,
+ want: false,
+ },
+ {
+ name: "payload_without_input",
+ payload: `{"type":"response.create","model":"gpt-5.1"}`,
+ want: false,
+ },
+ {
+ name: "multiple_function_call_outputs",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"r1"},{"type":"function_call_output","call_id":"call_2","output":"r2"}]}`,
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := gjson.GetBytes([]byte(tt.payload), `input.#(type=="function_call_output")`).Exists()
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestDropPreviousResponseID_BreaksFunctionCallOutput(t *testing.T) {
+ t.Parallel()
+
+ // 核心回归测试:验证 drop previous_response_id 后 function_call_output 会变成孤立引用
+ //
+ // 场景:客户端发送 {previous_response_id: "resp_1", input: [{type: "function_call_output", call_id: "call_abc"}]}
+ // 如果 drop 了 previous_response_id,上游会创建全新上下文,找不到 call_abc 对应的 tool call
+ // 结果:上游报 "No tool call found for function call output with call_id call_abc"
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale_or_lost",
+ "input":[
+ {"type":"function_call_output","call_id":"call_abc","output":"{\"result\":\"ok\"}"},
+ {"type":"function_call_output","call_id":"call_def","output":"{\"result\":\"done\"}"}
+ ]
+ }`)
+
+ // 1. 验证原始 payload 有 previous_response_id 和 function_call_output
+ require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists())
+ require.True(t, gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists())
+
+ // 2. drop previous_response_id
+ dropped, removed, err := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, err)
+ require.True(t, removed)
+
+ // 3. 验证 drop 后的状态:previous_response_id 被移除但 function_call_output 仍然存在
+ require.False(t, gjson.GetBytes(dropped, "previous_response_id").Exists(),
+ "previous_response_id 应该被移除")
+ require.True(t, gjson.GetBytes(dropped, `input.#(type=="function_call_output")`).Exists(),
+ "function_call_output 仍然存在,但此时它引用的 call_id 没有了上下文 — 这就是 bug 的根因")
+
+ // 4. 验证 call_id 仍然在 payload 中(说明 drop 不会清理 function_call_output)
+ callIDs := make([]string, 0)
+ gjson.GetBytes(dropped, "input").ForEach(func(_, item gjson.Result) bool {
+ if item.Get("type").String() == "function_call_output" {
+ callIDs = append(callIDs, item.Get("call_id").String())
+ }
+ return true
+ })
+ require.ElementsMatch(t, []string{"call_abc", "call_def"}, callIDs,
+ "function_call_output 的 call_id 未被清除,但上游已无法匹配")
+}
+
+func TestRecoveryStrategy_FunctionCallOutput_ShouldNotDrop(t *testing.T) {
+ t.Parallel()
+
+ // 此测试验证修复的核心逻辑:
+ // 当 hasFunctionCallOutput=true 且 set/align 策略均失败时,
+ // 正确行为是放弃恢复(return false),而非 drop previous_response_id
+ //
+ // 因为:function_call_output 语义绑定 previous_response_id
+ // drop previous_response_id 但保留 function_call_output → 上游报错
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_lost",
+ "input":[{"type":"function_call_output","call_id":"call_JDKR","output":"ok"}]
+ }`)
+
+ hasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
+ require.True(t, hasFunctionCallOutput, "payload 必须包含 function_call_output")
+
+ // 模拟 set 策略失败(currentPreviousResponseID 不为空,不满足 set 条件)
+ currentPreviousResponseID := "resp_lost"
+ expectedPrev := "resp_expected"
+ require.NotEmpty(t, currentPreviousResponseID, "set 策略需要 currentPreviousResponseID 为空")
+
+ // 模拟 align 策略失败
+ _, aligned, alignErr := alignStoreDisabledPreviousResponseID(payload, expectedPrev)
+ if alignErr == nil && aligned {
+ // align 成功了,更新 payload 中的 previous_response_id
+ t.Log("align 策略成功,此场景不触发 abort 路径")
+ }
+ // 注意:align 通常会成功(替换 resp_lost → resp_expected)。
+ // 但在真实场景中,如果 align 后的 previous_response_id 仍然在上游不存在,
+ // 上游会再次返回 previous_response_not_found,此时二次进入恢复函数,
+ // 但 turnPrevRecoveryTried=true 会阻止二次恢复,直接走 abort。
+
+ // 验证关键断言:即使 drop 技术上可行,也不应该执行
+ // 因为这会导致 "No tool call found for function call output" 错误
+ droppedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErr)
+ require.True(t, removed, "drop 操作本身可以成功")
+
+ // 但 drop 后的 payload 仍有 function_call_output —— 这就是为什么不能 drop
+ hasFCOAfterDrop := gjson.GetBytes(droppedPayload, `input.#(type=="function_call_output")`).Exists()
+ require.True(t, hasFCOAfterDrop,
+ "drop previous_response_id 不会移除 function_call_output,"+
+ "导致上游报 'No tool call found for function call output'")
+}
+
+// ---------------------------------------------------------------------------
+// ContinueTurn abort 路径错误通知测试
+// ---------------------------------------------------------------------------
+
+func TestContinueTurnAbort_ErrorEventFormat(t *testing.T) {
+ t.Parallel()
+
+ // 验证 ContinueTurn abort 时生成的 error 事件格式正确
+ abortReason := openAIWSIngressTurnAbortReasonPreviousResponse
+ abortMessage := "turn failed: " + string(abortReason)
+
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` +
+ string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`)
+
+ // 验证 JSON 格式有效
+ var parsed map[string]any
+ err := json.Unmarshal(errorEvent, &parsed)
+ require.NoError(t, err, "error 事件应为有效 JSON")
+
+ // 验证事件结构
+ require.Equal(t, "error", parsed["type"])
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "server_error", errorObj["type"])
+ require.Equal(t, string(abortReason), errorObj["code"])
+ require.Contains(t, errorObj["message"], string(abortReason))
+}
+
+func TestContinueTurnAbort_ErrorEventWithSpecialChars(t *testing.T) {
+ t.Parallel()
+
+ // 验证包含特殊字符的错误消息不会破坏 JSON 格式
+ specialMessages := []string{
+ `No tool call found for function call output with call_id call_JDKR0SzNTARIsGb0L3hofFWd.`,
+ `error with "quotes" and \backslash`,
+ "error with\nnewline",
+ `error with & entities`,
+ "", // 空消息
+ }
+
+ for i, msg := range specialMessages {
+ msg := msg
+ t.Run("special_message_"+strconv.Itoa(i), func(t *testing.T) {
+ t.Parallel()
+ abortReason := openAIWSIngressTurnAbortReasonToolOutput
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` +
+ string(abortReason) + `","message":` + strconv.Quote(msg) + `}}`)
+
+ var parsed map[string]any
+ err := json.Unmarshal(errorEvent, &parsed)
+ require.NoError(t, err, "error event with special chars should be valid JSON: %q", msg)
+
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, msg, errorObj["message"])
+ })
+ }
+}
+
+func TestContinueTurnAbort_WroteDownstreamDeterminesNotification(t *testing.T) {
+ t.Parallel()
+
+ // 验证 wroteDownstream 标志如何影响错误通知策略
+ tests := []struct {
+ name string
+ wroteDownstream bool
+ shouldSendErrorToClient bool
+ }{
+ {
+ name: "not_wrote_downstream_should_send_error",
+ wroteDownstream: false,
+ shouldSendErrorToClient: true,
+ },
+ {
+ name: "wrote_downstream_should_not_send_error",
+ wroteDownstream: true,
+ shouldSendErrorToClient: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ err := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ tt.wroteDownstream,
+ )
+ wroteDownstream := openAIWSIngressTurnWroteDownstream(err)
+ require.Equal(t, tt.wroteDownstream, wroteDownstream)
+
+ // 只有当 wroteDownstream=false 时才需要补发 error 事件
+ shouldNotify := !wroteDownstream
+ require.Equal(t, tt.shouldSendErrorToClient, shouldNotify)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// previous_response_id 恢复策略:set / align / abort 完整流程测试
+// ---------------------------------------------------------------------------
+
+func TestRecoveryStrategy_SetPreviousResponseID(t *testing.T) {
+ t.Parallel()
+
+ // 场景:客户端未发送 previous_response_id,但 session 中有记录
+ // 此时应该通过 set 策略注入 previous_response_id
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ expectedPrev := "resp_expected"
+
+ // set 策略:当 currentPreviousResponseID 为空时,注入 expectedPrev
+ updated, err := setPreviousResponseIDToRawPayload(payload, expectedPrev)
+ require.NoError(t, err)
+ require.Equal(t, expectedPrev, gjson.GetBytes(updated, "previous_response_id").String())
+
+ // function_call_output 保持不变
+ require.True(t, gjson.GetBytes(updated, `input.#(type=="function_call_output")`).Exists())
+ require.Equal(t, "call_1", gjson.GetBytes(updated, `input.#(type=="function_call_output").call_id`).String())
+}
+
+func TestRecoveryStrategy_AlignPreviousResponseID(t *testing.T) {
+ t.Parallel()
+
+ // 场景:客户端发送了过时的 previous_response_id,需要 align 到最新
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ expectedPrev := "resp_latest"
+
+ updated, changed, err := alignStoreDisabledPreviousResponseID(payload, expectedPrev)
+ require.NoError(t, err)
+ require.True(t, changed)
+ require.Equal(t, expectedPrev, gjson.GetBytes(updated, "previous_response_id").String())
+
+ // function_call_output 保持不变
+ require.True(t, gjson.GetBytes(updated, `input.#(type=="function_call_output")`).Exists())
+}
+
+func TestRecoveryStrategy_AlignFailsWhenNoExpectedPrev(t *testing.T) {
+ t.Parallel()
+
+ // 场景:没有预期的 previous_response_id,align 无法执行
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "")
+ require.NoError(t, err)
+ require.False(t, changed, "align 应该在 expectedPrev 为空时不执行")
+ require.Equal(t, string(payload), string(updated))
+}
+
+// ---------------------------------------------------------------------------
+// isOpenAIWSIngressPreviousResponseNotFound 边界条件测试
+// ---------------------------------------------------------------------------
+
+func TestIsOpenAIWSIngressPreviousResponseNotFound_WroteDownstreamBlocks(t *testing.T) {
+ t.Parallel()
+
+ // wroteDownstream=true 时,即使 stage 是 previous_response_not_found,
+ // 也不应被识别为可恢复的 previous_response_not_found
+ // (因为已经向客户端写入了数据,无法安全重试)
+ err := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ true, // wroteDownstream = true
+ )
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(err),
+ "wroteDownstream=true 时不应识别为可恢复的 previous_response_not_found")
+}
+
+func TestIsOpenAIWSIngressPreviousResponseNotFound_DifferentStageReturns_False(t *testing.T) {
+ t.Parallel()
+
+ stages := []string{
+ "read_upstream",
+ "write_upstream",
+ "upstream_error_event",
+ openAIWSIngressStageToolOutputNotFound,
+ "unknown",
+ "",
+ }
+
+ for _, stage := range stages {
+ stage := stage
+ t.Run("stage_"+stage, func(t *testing.T) {
+ t.Parallel()
+ err := wrapOpenAIWSIngressTurnError(stage, errors.New("some error"), false)
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(err),
+ "stage=%q 不应被识别为 previous_response_not_found", stage)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 端到端场景测试:function_call_output 恢复链路
+// ---------------------------------------------------------------------------
+
+func TestEndToEnd_FunctionCallOutputRecoveryChain(t *testing.T) {
+ t.Parallel()
+
+ // 完整场景:
+ // 1. 客户端发送带 function_call_output 的请求
+ // 2. 上游返回 previous_response_not_found
+ // 3. 恢复策略尝试 set/align
+ // 4. 如果都失败,应该 abort(而非 drop previous_response_id)
+ // 5. 客户端收到 error 事件
+ // 6. 客户端重置并发送完整请求
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_lost",
+ "input":[
+ {"type":"function_call_output","call_id":"call_JDKR0SzNTARIsGb0L3hofFWd","output":"{\"ok\":true}"}
+ ]
+ }`)
+
+ // Step 1: 检测 function_call_output
+ hasFCO := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
+ require.True(t, hasFCO, "payload 包含 function_call_output")
+
+ // Step 2: 模拟 previous_response_not_found 错误
+ turnErr := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ false,
+ )
+ require.True(t, isOpenAIWSIngressPreviousResponseNotFound(turnErr))
+
+ // Step 3: 验证 ContinueTurn 处置
+ reason, _ := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+
+ // Step 4: set 策略 — 失败(currentPreviousResponseID 不为空)
+ currentPrevID := gjson.GetBytes(payload, "previous_response_id").String()
+ require.NotEmpty(t, currentPrevID, "set 策略前提条件不满足(需要 currentPreviousResponseID 为空)")
+
+ // Step 5: align 策略 — 假设 expectedPrev 为空(session 中无记录)
+ expectedPrev := ""
+ _, aligned, alignErr := alignStoreDisabledPreviousResponseID(payload, expectedPrev)
+ require.NoError(t, alignErr)
+ require.False(t, aligned, "expectedPrev 为空时 align 应失败")
+
+ // Step 6: 此时应该 abort(return false)而非 drop
+ // 验证:如果错误地执行 drop,会导致 function_call_output 成为孤立引用
+ dropped, removed, _ := dropPreviousResponseIDFromRawPayload(payload)
+ if removed {
+ hasFCOAfterDrop := gjson.GetBytes(dropped, `input.#(type=="function_call_output")`).Exists()
+ require.True(t, hasFCOAfterDrop,
+ "drop 后 function_call_output 仍存在,上游会报 'No tool call found'")
+ }
+
+ // Step 7: 正确行为——abort 后生成 error 事件通知客户端
+ wroteDownstream := openAIWSIngressTurnWroteDownstream(turnErr)
+ require.False(t, wroteDownstream, "abort 前未向客户端写入数据")
+
+ abortMessage := "turn failed: " + string(reason)
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` +
+ string(reason) + `","message":` + strconv.Quote(abortMessage) + `}}`)
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(errorEvent, &parsed))
+ require.Equal(t, "error", parsed["type"])
+}
+
+func TestEndToEnd_NonFunctionCallOutput_CanDrop(t *testing.T) {
+ t.Parallel()
+
+ // 对照场景:没有 function_call_output 的 payload 可以安全 drop previous_response_id
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_old",
+ "input":[{"type":"input_text","text":"hello"}]
+ }`)
+
+ hasFCO := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
+ require.False(t, hasFCO, "此 payload 不包含 function_call_output")
+
+ // drop 是安全的
+ dropped, removed, err := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, err)
+ require.True(t, removed)
+ require.False(t, gjson.GetBytes(dropped, "previous_response_id").Exists())
+
+ // input 仍然有效(input_text 不依赖 previous_response_id)
+ require.Equal(t, "hello", gjson.GetBytes(dropped, "input.0.text").String())
+}
+
+// ---------------------------------------------------------------------------
+// shouldKeepIngressPreviousResponseID 与 function_call_output 的交互测试
+// ---------------------------------------------------------------------------
+
+func TestShouldKeepIngressPreviousResponseID_FunctionCallOutputCallIDMatch(t *testing.T) {
+ t.Parallel()
+
+ // 当 function_call_output 的 call_id 与 pending call_id 匹配时,应保留 previous_response_id
+ previousPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
+ currentPayload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_1",
+ "input":[{"type":"function_call_output","call_id":"call_match","output":"ok"}]
+ }`)
+
+ keep, reason, err := shouldKeepIngressPreviousResponseID(
+ previousPayload,
+ currentPayload,
+ "resp_1",
+ true, // hasFunctionCallOutput
+ []string{"call_match"}, // pendingCallIDs
+ []string{"call_match"}, // requestCallIDs
+ )
+ require.NoError(t, err)
+ require.True(t, keep, "call_id 匹配时应保留 previous_response_id")
+ require.Equal(t, "function_call_output_call_id_match", reason)
+}
+
+func TestShouldKeepIngressPreviousResponseID_FunctionCallOutputCallIDMismatch(t *testing.T) {
+ t.Parallel()
+
+ // 当 function_call_output 的 call_id 与 pending call_id 不匹配时,
+ // 应放弃 previous_response_id
+ previousPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
+ currentPayload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_1",
+ "input":[{"type":"function_call_output","call_id":"call_wrong","output":"ok"}]
+ }`)
+
+ keep, reason, err := shouldKeepIngressPreviousResponseID(
+ previousPayload,
+ currentPayload,
+ "resp_1",
+ true, // hasFunctionCallOutput
+ []string{"call_real"}, // pendingCallIDs
+ []string{"call_wrong"}, // requestCallIDs
+ )
+ require.NoError(t, err)
+ require.False(t, keep, "call_id 不匹配时应放弃 previous_response_id")
+ require.Equal(t, "function_call_output_call_id_mismatch", reason)
+}
+
+// ---------------------------------------------------------------------------
+// isOpenAIWSIngressTurnRetryable 与 function_call_output 场景的交互
+// ---------------------------------------------------------------------------
+
+func TestIsOpenAIWSIngressTurnRetryable_PreviousResponseNotFound(t *testing.T) {
+ t.Parallel()
+
+ // previous_response_not_found 不应被标记为 retryable(因为有专门的恢复路径)
+ err := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ false,
+ )
+ require.False(t, isOpenAIWSIngressTurnRetryable(err),
+ "previous_response_not_found 有专门的恢复逻辑,不走通用重试")
+}
+
+func TestIsOpenAIWSIngressTurnRetryable_WroteDownstreamBlocksRetry(t *testing.T) {
+ t.Parallel()
+
+ // wroteDownstream=true 时,任何 stage 都不应 retryable
+ err := wrapOpenAIWSIngressTurnError(
+ "write_upstream",
+ errors.New("write failed"),
+ true, // wroteDownstream
+ )
+ require.False(t, isOpenAIWSIngressTurnRetryable(err),
+ "wroteDownstream=true 时不应重试")
+}
+
+// ---------------------------------------------------------------------------
+// normalizeOpenAIWSIngressPayloadBeforeSend 与恢复的集成测试
+// ---------------------------------------------------------------------------
+
+func TestNormalizePayload_FunctionCallOutputPassthrough(t *testing.T) {
+ t.Parallel()
+
+ // 透传模式:normalizer 不再注入 previous_response_id,原样传递
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 1,
+ turn: 2,
+ connID: "conn_test",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "",
+ expectedPreviousResponse: "resp_expected",
+ pendingExpectedCallIDs: []string{"call_1"},
+ })
+
+ // 透传模式:previous_response_id 保持客户端原值(空),由下游 recovery 处理
+ require.Empty(t, out.currentPreviousResponseID,
+ "透传模式不应注入 previous_response_id")
+ require.True(t, out.hasFunctionCallOutputCallID)
+ require.Equal(t, []string{"call_1"}, out.functionCallOutputCallIDs)
+}
diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go
deleted file mode 100644
index 592801f66..000000000
--- a/backend/internal/service/openai_ws_forwarder_success_test.go
+++ /dev/null
@@ -1,1306 +0,0 @@
-package service
-
-import (
- "context"
- "encoding/json"
- "errors"
- "io"
- "net/http"
- "net/http/httptest"
- "strconv"
- "strings"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- "github.com/stretchr/testify/require"
- "github.com/tidwall/gjson"
-)
-
-func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- type receivedPayload struct {
- Type string
- PreviousResponseID string
- StreamExists bool
- Stream bool
- }
- receivedCh := make(chan receivedPayload, 1)
-
- upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- t.Errorf("upgrade websocket failed: %v", err)
- return
- }
- defer func() {
- _ = conn.Close()
- }()
-
- var request map[string]any
- if err := conn.ReadJSON(&request); err != nil {
- t.Errorf("read ws request failed: %v", err)
- return
- }
- requestJSON := requestToJSONString(request)
- receivedCh <- receivedPayload{
- Type: strings.TrimSpace(gjson.Get(requestJSON, "type").String()),
- PreviousResponseID: strings.TrimSpace(gjson.Get(requestJSON, "previous_response_id").String()),
- StreamExists: gjson.Get(requestJSON, "stream").Exists(),
- Stream: gjson.Get(requestJSON, "stream").Bool(),
- }
-
- if err := conn.WriteJSON(map[string]any{
- "type": "response.created",
- "response": map[string]any{
- "id": "resp_new_1",
- "model": "gpt-5.1",
- },
- }); err != nil {
- t.Errorf("write response.created failed: %v", err)
- return
- }
- if err := conn.WriteJSON(map[string]any{
- "type": "response.completed",
- "response": map[string]any{
- "id": "resp_new_1",
- "model": "gpt-5.1",
- "usage": map[string]any{
- "input_tokens": 12,
- "output_tokens": 7,
- "input_tokens_details": map[string]any{
- "cached_tokens": 3,
- },
- },
- },
- }); err != nil {
- t.Errorf("write response.completed failed: %v", err)
- return
- }
- }))
- defer wsServer.Close()
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
- groupID := int64(1001)
- c.Set("api_key", &APIKey{GroupID: &groupID})
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
- cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
-
- upstream := &httpUpstreamRecorder{
- resp: &http.Response{
- StatusCode: http.StatusOK,
- Header: http.Header{"Content-Type": []string{"application/json"}},
- Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
- },
- }
-
- cache := &stubGatewayCache{}
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: upstream,
- cache: cache,
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 9,
- Name: "openai-ws",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 2,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": wsServer.URL,
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_1","input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, 12, result.Usage.InputTokens)
- require.Equal(t, 7, result.Usage.OutputTokens)
- require.Equal(t, 3, result.Usage.CacheReadInputTokens)
- require.Equal(t, "resp_new_1", result.RequestID)
- require.True(t, result.OpenAIWSMode)
- require.False(t, gjson.GetBytes(upstream.lastBody, "model").Exists(), "WSv2 成功时不应回落 HTTP 上游")
-
- received := <-receivedCh
- require.Equal(t, "response.create", received.Type)
- require.Equal(t, "resp_prev_1", received.PreviousResponseID)
- require.True(t, received.StreamExists, "WS 请求应携带 stream 字段")
- require.False(t, received.Stream, "应保持客户端 stream=false 的原始语义")
-
- store := svc.getOpenAIWSStateStore()
- mappedAccountID, getErr := store.GetResponseAccount(context.Background(), groupID, "resp_new_1")
- require.NoError(t, getErr)
- require.Equal(t, account.ID, mappedAccountID)
- connID, ok := store.GetResponseConn("resp_new_1")
- require.True(t, ok)
- require.NotEmpty(t, connID)
-
- responseBody := rec.Body.Bytes()
- require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String())
-}
-
-func requestToJSONString(payload map[string]any) string {
- if len(payload) == 0 {
- return "{}"
- }
- b, err := json.Marshal(payload)
- if err != nil {
- return "{}"
- }
- return string(b)
-}
-
-func TestLogOpenAIWSBindResponseAccountWarn(t *testing.T) {
- require.NotPanics(t, func() {
- logOpenAIWSBindResponseAccountWarn(1, 2, "resp_ok", nil)
- })
- require.NotPanics(t, func() {
- logOpenAIWSBindResponseAccountWarn(1, 2, "resp_err", errors.New("bind failed"))
- })
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_RewriteModelAndToolCallsOnCompletedEvent(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
- groupID := int64(3001)
- c.Set("api_key", &APIKey{GroupID: &groupID})
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_model_tool_1","model":"gpt-5.1","tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}],"usage":{"input_tokens":2,"output_tokens":1}},"tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}]}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 1301,
- Name: "openai-rewrite",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "model_mapping": map[string]any{
- "custom-original-model": "gpt-5.1",
- },
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"custom-original-model","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "resp_model_tool_1", result.RequestID)
- require.Equal(t, "custom-original-model", gjson.GetBytes(rec.Body.Bytes(), "model").String(), "响应模型应回写为原始请求模型")
- require.Equal(t, "edit", gjson.GetBytes(rec.Body.Bytes(), "tool_calls.0.function.name").String(), "工具名称应被修正为 OpenCode 规范")
-}
-
-func TestOpenAIWSPayloadString_OnlyAcceptsStringValues(t *testing.T) {
- payload := map[string]any{
- "type": nil,
- "model": 123,
- "prompt_cache_key": " cache-key ",
- "previous_response_id": []byte(" resp_1 "),
- }
-
- require.Equal(t, "", openAIWSPayloadString(payload, "type"))
- require.Equal(t, "", openAIWSPayloadString(payload, "model"))
- require.Equal(t, "cache-key", openAIWSPayloadString(payload, "prompt_cache_key"))
- require.Equal(t, "resp_1", openAIWSPayloadString(payload, "previous_response_id"))
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- var upgradeCount atomic.Int64
- var sequence atomic.Int64
- upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- upgradeCount.Add(1)
- conn, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- t.Errorf("upgrade websocket failed: %v", err)
- return
- }
- defer func() {
- _ = conn.Close()
- }()
-
- for {
- var request map[string]any
- if err := conn.ReadJSON(&request); err != nil {
- return
- }
- idx := sequence.Add(1)
- responseID := "resp_reuse_" + strconv.FormatInt(idx, 10)
- if err := conn.WriteJSON(map[string]any{
- "type": "response.created",
- "response": map[string]any{
- "id": responseID,
- "model": "gpt-5.1",
- },
- }); err != nil {
- return
- }
- if err := conn.WriteJSON(map[string]any{
- "type": "response.completed",
- "response": map[string]any{
- "id": responseID,
- "model": "gpt-5.1",
- "usage": map[string]any{
- "input_tokens": 2,
- "output_tokens": 1,
- },
- },
- }); err != nil {
- return
- }
- }
- }))
- defer wsServer.Close()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
- account := &Account{
- ID: 19,
- Name: "openai-ws",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 2,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": wsServer.URL,
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- for i := 0; i < 2; i++ {
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
- groupID := int64(2001)
- c.Set("api_key", &APIKey{GroupID: &groupID})
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_reuse","input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_"))
- }
-
- require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链")
- metrics := svc.SnapshotOpenAIWSPoolMetrics()
- require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1))
- require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
- c.Request.Header.Set("session_id", "sess-oauth-1")
- c.Request.Header.Set("conversation_id", "conv-oauth-1")
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.AllowStoreRecovery = false
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_oauth_1","model":"gpt-5.1","usage":{"input_tokens":3,"output_tokens":2}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
- account := &Account{
- ID: 29,
- Name: "openai-oauth",
- Platform: PlatformOpenAI,
- Type: AccountTypeOAuth,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "oauth-token-1",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"store":true,"input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "resp_oauth_1", result.RequestID)
-
- require.NotNil(t, captureConn.lastWrite)
- requestJSON := requestToJSONString(captureConn.lastWrite)
- require.True(t, gjson.Get(requestJSON, "store").Exists(), "OAuth WSv2 应显式写入 store 字段")
- require.False(t, gjson.Get(requestJSON, "store").Bool(), "默认策略应将 OAuth store 置为 false")
- require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段")
- require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true")
- require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta"))
- require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id"))
- require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id"))
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_prompt_cache_key","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
- account := &Account{
- ID: 31,
- Name: "openai-oauth",
- Platform: PlatformOpenAI,
- Type: AccountTypeOAuth,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "oauth-token-1",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":true,"prompt_cache_key":"pcache_123","input":[{"type":"input_text","text":"hi"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "resp_prompt_cache_key", result.RequestID)
-
- require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id"))
- require.Empty(t, captureDialer.lastHeaders.Get("conversation_id"))
- require.NotNil(t, captureConn.lastWrite)
- require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
-}
-
-func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
-
- upstream := &httpUpstreamRecorder{
- resp: &http.Response{
- StatusCode: http.StatusOK,
- Header: http.Header{"Content-Type": []string{"application/json"}},
- Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
- },
- }
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: upstream,
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 39,
- Name: "openai-ws-v1",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": "https://api.openai.com/v1/responses",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_v1","input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.Error(t, err)
- require.Nil(t, result)
- require.Contains(t, err.Error(), "ws v1")
- require.Equal(t, http.StatusBadRequest, rec.Code)
- require.Contains(t, rec.Body.String(), "WSv1")
- require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求")
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_TurnStateAndMetadataReplayOnReconnect(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- var connIndex atomic.Int64
- headersCh := make(chan http.Header, 4)
- upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- idx := connIndex.Add(1)
- headersCh <- cloneHeader(r.Header)
-
- respHeader := http.Header{}
- if idx == 1 {
- respHeader.Set("x-codex-turn-state", "turn_state_first")
- }
- conn, err := upgrader.Upgrade(w, r, respHeader)
- if err != nil {
- t.Errorf("upgrade websocket failed: %v", err)
- return
- }
- defer func() {
- _ = conn.Close()
- }()
-
- var request map[string]any
- if err := conn.ReadJSON(&request); err != nil {
- t.Errorf("read ws request failed: %v", err)
- return
- }
- responseID := "resp_turn_" + strconv.FormatInt(idx, 10)
- if err := conn.WriteJSON(map[string]any{
- "type": "response.completed",
- "response": map[string]any{
- "id": responseID,
- "model": "gpt-5.1",
- "usage": map[string]any{
- "input_tokens": 2,
- "output_tokens": 1,
- },
- },
- }); err != nil {
- t.Errorf("write response.completed failed: %v", err)
- return
- }
- }))
- defer wsServer.Close()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 49,
- Name: "openai-turn-state",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": wsServer.URL,
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- reqBody := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
- rec1 := httptest.NewRecorder()
- c1, _ := gin.CreateTestContext(rec1)
- c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c1.Request.Header.Set("session_id", "session_turn_state")
- c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_1")
- result1, err := svc.Forward(context.Background(), c1, account, reqBody)
- require.NoError(t, err)
- require.NotNil(t, result1)
-
- sessionHash := svc.GenerateSessionHash(c1, reqBody)
- store := svc.getOpenAIWSStateStore()
- turnState, ok := store.GetSessionTurnState(0, sessionHash)
- require.True(t, ok)
- require.Equal(t, "turn_state_first", turnState)
-
- // 主动淘汰连接,模拟下一次请求发生重连。
- connID, hasConn := store.GetResponseConn(result1.RequestID)
- require.True(t, hasConn)
- svc.getOpenAIWSConnPool().evictConn(account.ID, connID)
-
- rec2 := httptest.NewRecorder()
- c2, _ := gin.CreateTestContext(rec2)
- c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c2.Request.Header.Set("session_id", "session_turn_state")
- c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_2")
- result2, err := svc.Forward(context.Background(), c2, account, reqBody)
- require.NoError(t, err)
- require.NotNil(t, result2)
-
- firstHandshakeHeaders := <-headersCh
- secondHandshakeHeaders := <-headersCh
- require.Equal(t, "turn_meta_1", firstHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
- require.Equal(t, "turn_meta_2", secondHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
- require.Equal(t, "turn_state_first", secondHandshakeHeaders.Get("X-Codex-Turn-State"))
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_GeneratePrewarm(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("session_id", "session-prewarm")
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_prewarm_1","model":"gpt-5.1","usage":{"input_tokens":0,"output_tokens":0}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_main_1","model":"gpt-5.1","usage":{"input_tokens":4,"output_tokens":2}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 59,
- Name: "openai-prewarm",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "resp_main_1", result.RequestID)
-
- require.Len(t, captureConn.writes, 2, "开启 generate=false 预热后应发送两次 WS 请求")
- firstWrite := requestToJSONString(captureConn.writes[0])
- secondWrite := requestToJSONString(captureConn.writes[1])
- require.True(t, gjson.Get(firstWrite, "generate").Exists())
- require.False(t, gjson.Get(firstWrite, "generate").Bool())
- require.False(t, gjson.Get(secondWrite, "generate").Exists())
-}
-
-func TestOpenAIGatewayService_PrewarmReadHonorsParentContext(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- toolCorrector: NewCodexToolCorrector(),
- }
- account := &Account{
- ID: 601,
- Name: "openai-prewarm-timeout",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- }
- conn := newOpenAIWSConn("prewarm_ctx_conn", account.ID, &openAIWSBlockingConn{
- readDelay: 200 * time.Millisecond,
- }, nil)
- lease := &openAIWSConnLease{
- accountID: account.ID,
- conn: conn,
- }
- payload := map[string]any{
- "type": "response.create",
- "model": "gpt-5.1",
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond)
- defer cancel()
- start := time.Now()
- err := svc.performOpenAIWSGeneratePrewarm(
- ctx,
- lease,
- OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
- payload,
- "",
- map[string]any{"model": "gpt-5.1"},
- account,
- nil,
- 0,
- )
- elapsed := time.Since(start)
- require.Error(t, err)
- require.Contains(t, err.Error(), "prewarm_read_event")
- require.Less(t, elapsed, 180*time.Millisecond, "预热读取应受父 context 取消控制,不应阻塞到 read_timeout")
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_meta_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_meta_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 69,
- Name: "openai-turn-metadata",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
-
- rec1 := httptest.NewRecorder()
- c1, _ := gin.CreateTestContext(rec1)
- c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c1.Request.Header.Set("session_id", "session-metadata-reuse")
- c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_1")
- result1, err := svc.Forward(context.Background(), c1, account, body)
- require.NoError(t, err)
- require.NotNil(t, result1)
- require.Equal(t, "resp_meta_1", result1.RequestID)
-
- rec2 := httptest.NewRecorder()
- c2, _ := gin.CreateTestContext(rec2)
- c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c2.Request.Header.Set("session_id", "session-metadata-reuse")
- c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_2")
- result2, err := svc.Forward(context.Background(), c2, account, body)
- require.NoError(t, err)
- require.NotNil(t, result2)
- require.Equal(t, "resp_meta_2", result2.RequestID)
-
- require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接")
- require.Len(t, captureConn.writes, 2)
-
- firstWrite := requestToJSONString(captureConn.writes[0])
- secondWrite := requestToJSONString(captureConn.writes[1])
- require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String())
- require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String())
-}
-
-func TestOpenAIGatewayService_Forward_WSv2StoreFalseSessionConnIsolation(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- var upgradeCount atomic.Int64
- var sequence atomic.Int64
- upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- upgradeCount.Add(1)
- conn, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- t.Errorf("upgrade websocket failed: %v", err)
- return
- }
- defer func() {
- _ = conn.Close()
- }()
-
- for {
- var request map[string]any
- if err := conn.ReadJSON(&request); err != nil {
- return
- }
- responseID := "resp_store_false_" + strconv.FormatInt(sequence.Add(1), 10)
- if err := conn.WriteJSON(map[string]any{
- "type": "response.completed",
- "response": map[string]any{
- "id": responseID,
- "model": "gpt-5.1",
- "usage": map[string]any{
- "input_tokens": 1,
- "output_tokens": 1,
- },
- },
- }); err != nil {
- return
- }
- }
- }))
- defer wsServer.Close()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
- cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 79,
- Name: "openai-store-false",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 2,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": wsServer.URL,
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
-
- rec1 := httptest.NewRecorder()
- c1, _ := gin.CreateTestContext(rec1)
- c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c1.Request.Header.Set("session_id", "session_store_false_a")
- result1, err := svc.Forward(context.Background(), c1, account, body)
- require.NoError(t, err)
- require.NotNil(t, result1)
- require.Equal(t, int64(1), upgradeCount.Load())
-
- rec2 := httptest.NewRecorder()
- c2, _ := gin.CreateTestContext(rec2)
- c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c2.Request.Header.Set("session_id", "session_store_false_a")
- result2, err := svc.Forward(context.Background(), c2, account, body)
- require.NoError(t, err)
- require.NotNil(t, result2)
- require.Equal(t, int64(1), upgradeCount.Load(), "同一 session(store=false) 应复用同一 WS 连接")
-
- rec3 := httptest.NewRecorder()
- c3, _ := gin.CreateTestContext(rec3)
- c3.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c3.Request.Header.Set("session_id", "session_store_false_b")
- result3, err := svc.Forward(context.Background(), c3, account, body)
- require.NoError(t, err)
- require.NotNil(t, result3)
- require.Equal(t, int64(2), upgradeCount.Load(), "不同 session(store=false) 应隔离连接,避免续链状态互相覆盖")
-}
-
-func TestOpenAIGatewayService_Forward_WSv2StoreFalseDisableForceNewConnAllowsReuse(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- var upgradeCount atomic.Int64
- var sequence atomic.Int64
- upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- upgradeCount.Add(1)
- conn, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- t.Errorf("upgrade websocket failed: %v", err)
- return
- }
- defer func() {
- _ = conn.Close()
- }()
-
- for {
- var request map[string]any
- if err := conn.ReadJSON(&request); err != nil {
- return
- }
- responseID := "resp_store_false_reuse_" + strconv.FormatInt(sequence.Add(1), 10)
- if err := conn.WriteJSON(map[string]any{
- "type": "response.completed",
- "response": map[string]any{
- "id": responseID,
- "model": "gpt-5.1",
- "usage": map[string]any{
- "input_tokens": 1,
- "output_tokens": 1,
- },
- },
- }); err != nil {
- return
- }
- }
- }))
- defer wsServer.Close()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 80,
- Name: "openai-store-false-reuse",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 2,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": wsServer.URL,
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
-
- rec1 := httptest.NewRecorder()
- c1, _ := gin.CreateTestContext(rec1)
- c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c1.Request.Header.Set("session_id", "session_store_false_reuse_a")
- result1, err := svc.Forward(context.Background(), c1, account, body)
- require.NoError(t, err)
- require.NotNil(t, result1)
- require.Equal(t, int64(1), upgradeCount.Load())
-
- rec2 := httptest.NewRecorder()
- c2, _ := gin.CreateTestContext(rec2)
- c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c2.Request.Header.Set("session_id", "session_store_false_reuse_b")
- result2, err := svc.Forward(context.Background(), c2, account, body)
- require.NoError(t, err)
- require.NotNil(t, result2)
- require.Equal(t, int64(1), upgradeCount.Load(), "关闭强制新连后,不同 session(store=false) 可复用连接")
-}
-
-func TestOpenAIGatewayService_Forward_WSv2ReadTimeoutAppliesPerRead(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 1
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- readDelays: []time.Duration{
- 700 * time.Millisecond,
- 700 * time.Millisecond,
- },
- events: [][]byte{
- []byte(`{"type":"response.created","response":{"id":"resp_timeout_ok","model":"gpt-5.1"}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_timeout_ok","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- upstream := &httpUpstreamRecorder{
- resp: &http.Response{
- StatusCode: http.StatusOK,
- Header: http.Header{"Content-Type": []string{"application/json"}},
- Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)),
- },
- }
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: upstream,
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 81,
- Name: "openai-read-timeout",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "resp_timeout_ok", result.RequestID)
- require.Nil(t, upstream.lastReq, "每次 Read 都应独立应用超时;总时长超过 read_timeout 不应误回退 HTTP")
-}
-
-type openAIWSCaptureDialer struct {
- mu sync.Mutex
- conn *openAIWSCaptureConn
- lastHeaders http.Header
- handshake http.Header
- dialCount int
-}
-
-func (d *openAIWSCaptureDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = proxyURL
- d.mu.Lock()
- d.lastHeaders = cloneHeader(headers)
- d.dialCount++
- respHeaders := cloneHeader(d.handshake)
- d.mu.Unlock()
- return d.conn, 0, respHeaders, nil
-}
-
-func (d *openAIWSCaptureDialer) DialCount() int {
- d.mu.Lock()
- defer d.mu.Unlock()
- return d.dialCount
-}
-
-type openAIWSCaptureConn struct {
- mu sync.Mutex
- readDelays []time.Duration
- events [][]byte
- lastWrite map[string]any
- writes []map[string]any
- closed bool
-}
-
-func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error {
- _ = ctx
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.closed {
- return errOpenAIWSConnClosed
- }
- switch payload := value.(type) {
- case map[string]any:
- c.lastWrite = cloneMapStringAny(payload)
- c.writes = append(c.writes, cloneMapStringAny(payload))
- case json.RawMessage:
- var parsed map[string]any
- if err := json.Unmarshal(payload, &parsed); err == nil {
- c.lastWrite = cloneMapStringAny(parsed)
- c.writes = append(c.writes, cloneMapStringAny(parsed))
- }
- case []byte:
- var parsed map[string]any
- if err := json.Unmarshal(payload, &parsed); err == nil {
- c.lastWrite = cloneMapStringAny(parsed)
- c.writes = append(c.writes, cloneMapStringAny(parsed))
- }
- }
- return nil
-}
-
-func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
- if ctx == nil {
- ctx = context.Background()
- }
- c.mu.Lock()
- if c.closed {
- c.mu.Unlock()
- return nil, errOpenAIWSConnClosed
- }
- if len(c.events) == 0 {
- c.mu.Unlock()
- return nil, io.EOF
- }
- delay := time.Duration(0)
- if len(c.readDelays) > 0 {
- delay = c.readDelays[0]
- c.readDelays = c.readDelays[1:]
- }
- event := c.events[0]
- c.events = c.events[1:]
- c.mu.Unlock()
- if delay > 0 {
- timer := time.NewTimer(delay)
- defer timer.Stop()
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-timer.C:
- }
- }
- return event, nil
-}
-
-func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
- _ = ctx
- return nil
-}
-
-func (c *openAIWSCaptureConn) Close() error {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.closed = true
- return nil
-}
-
-func cloneMapStringAny(src map[string]any) map[string]any {
- if src == nil {
- return nil
- }
- dst := make(map[string]any, len(src))
- for k, v := range src {
- dst[k] = v
- }
- return dst
-}
diff --git a/backend/internal/service/openai_ws_forwarder_turn_error_test.go b/backend/internal/service/openai_ws_forwarder_turn_error_test.go
new file mode 100644
index 000000000..ea7ea8ccb
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_turn_error_test.go
@@ -0,0 +1,62 @@
+package service
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIWSIngressTurnPartialResult_NotTurnError(t *testing.T) {
+ result, ok := OpenAIWSIngressTurnPartialResult(errors.New("plain error"))
+ require.False(t, ok)
+ require.Nil(t, result)
+}
+
+func TestOpenAIWSIngressTurnPartialResult_DeepCopy(t *testing.T) {
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_partial",
+ Usage: OpenAIUsage{
+ InputTokens: 12,
+ OutputTokens: 34,
+ },
+ PendingFunctionCallIDs: []string{"call_1", "call_2"},
+ }
+ err := wrapOpenAIWSIngressTurnErrorWithPartial("read_upstream", errors.New("boom"), false, partial)
+
+ got, ok := OpenAIWSIngressTurnPartialResult(err)
+ require.True(t, ok)
+ require.NotNil(t, got)
+ require.Equal(t, partial.RequestID, got.RequestID)
+ require.Equal(t, partial.Usage, got.Usage)
+ require.Equal(t, partial.PendingFunctionCallIDs, got.PendingFunctionCallIDs)
+
+ // mutate returned copy should not affect stored partial result
+ got.PendingFunctionCallIDs[0] = "changed"
+ again, ok := OpenAIWSIngressTurnPartialResult(err)
+ require.True(t, ok)
+ require.Equal(t, "call_1", again.PendingFunctionCallIDs[0])
+}
+
+func TestOpenAIWSClientReadIdleTimeout_DefaultAndConfig(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ require.Equal(t, 30*time.Minute, svc.openAIWSClientReadIdleTimeout())
+
+ svc.cfg = &config.Config{}
+ svc.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 1800
+ require.Equal(t, 30*time.Minute, svc.openAIWSClientReadIdleTimeout())
+
+ svc.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 120
+ require.Equal(t, 120*time.Second, svc.openAIWSClientReadIdleTimeout())
+}
+
+func TestOpenAIWSPassthroughIdleTimeout_DefaultAndConfig(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ require.Equal(t, time.Hour, svc.openAIWSPassthroughIdleTimeout())
+
+ svc.cfg = &config.Config{}
+ svc.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 120
+ require.Equal(t, 120*time.Second, svc.openAIWSPassthroughIdleTimeout())
+}
diff --git a/backend/internal/service/openai_ws_hotpath_perf_test.go b/backend/internal/service/openai_ws_hotpath_perf_test.go
new file mode 100644
index 000000000..6339fc3d9
--- /dev/null
+++ b/backend/internal/service/openai_ws_hotpath_perf_test.go
@@ -0,0 +1,938 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+ "unsafe"
+
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// Mock conn for hotpath performance tests
+// ---------------------------------------------------------------------------
+
+type openAIWSNoopConn struct{}
+
+func (c *openAIWSNoopConn) WriteJSON(context.Context, any) error { return nil }
+func (c *openAIWSNoopConn) ReadMessage(context.Context) ([]byte, error) { return nil, nil }
+func (c *openAIWSNoopConn) Ping(context.Context) error { return nil }
+func (c *openAIWSNoopConn) Close() error { return nil }
+
+// openAIWSIdentityConn is a distinct conn instance used to verify pointer identity.
+type openAIWSIdentityConn struct{}
+
+func (c *openAIWSIdentityConn) WriteJSON(context.Context, any) error { return nil }
+func (c *openAIWSIdentityConn) ReadMessage(context.Context) ([]byte, error) { return nil, nil }
+func (c *openAIWSIdentityConn) Ping(context.Context) error { return nil }
+func (c *openAIWSIdentityConn) Close() error { return nil }
+
+func mustDefaultOpenAIWSStateStore(t *testing.T, raw OpenAIWSStateStore) *defaultOpenAIWSStateStore {
+ t.Helper()
+ store, ok := raw.(*defaultOpenAIWSStateStore)
+ require.True(t, ok)
+ return store
+}
+
+// ===================================================================
+// 1. maybeTouchLease throttle
+// ===================================================================
+
+func TestMaybeTouchLease_NilReceiverDoesNotPanic(t *testing.T) {
+ var c *openAIWSIngressContext
+ require.NotPanics(t, func() {
+ c.maybeTouchLease(time.Minute)
+ })
+}
+
+func TestMaybeTouchLease_FirstCallAlwaysTouches(t *testing.T) {
+ c := &openAIWSIngressContext{}
+ require.Zero(t, c.lastTouchUnixNano.Load(), "precondition: lastTouchUnixNano should be zero")
+
+ c.maybeTouchLease(5 * time.Minute)
+
+ require.NotZero(t, c.lastTouchUnixNano.Load(), "first maybeTouchLease must update lastTouchUnixNano")
+ require.False(t, c.expiresAt().IsZero(), "first maybeTouchLease must set expiresAt")
+}
+
+func TestMaybeTouchLease_SecondCallWithin1sIsSkipped(t *testing.T) {
+ c := &openAIWSIngressContext{}
+
+ // First touch
+ c.maybeTouchLease(5 * time.Minute)
+ firstExpiry := c.expiresAt()
+ firstTouch := c.lastTouchUnixNano.Load()
+ require.NotZero(t, firstTouch)
+
+ // Second touch immediately -- within 1s, should be skipped
+ c.maybeTouchLease(10 * time.Minute)
+ secondExpiry := c.expiresAt()
+ secondTouch := c.lastTouchUnixNano.Load()
+
+ require.Equal(t, firstTouch, secondTouch, "lastTouchUnixNano should NOT change within 1s")
+ require.Equal(t, firstExpiry, secondExpiry, "expiresAt should NOT change within 1s")
+}
+
+func TestMaybeTouchLease_CallAfter1sActuallyTouches(t *testing.T) {
+ c := &openAIWSIngressContext{}
+
+ c.maybeTouchLease(5 * time.Minute)
+ firstExpiry := c.expiresAt()
+
+ // Simulate 1s+ passing by backdating the lastTouchUnixNano
+ backdated := time.Now().Add(-2 * time.Second).UnixNano()
+ c.lastTouchUnixNano.Store(backdated)
+ // Also backdate expiresAt so we can observe the change
+ c.setExpiresAt(time.Now().Add(-time.Minute))
+ expiryAfterBackdate := c.expiresAt()
+ require.True(t, expiryAfterBackdate.Before(firstExpiry), "precondition: expiresAt should be backdated")
+
+ c.maybeTouchLease(5 * time.Minute)
+ touchAfter := c.lastTouchUnixNano.Load()
+ secondExpiry := c.expiresAt()
+
+ require.Greater(t, touchAfter, backdated, "lastTouchUnixNano should advance past the backdated value")
+ require.True(t, secondExpiry.After(expiryAfterBackdate), "expiresAt should advance after 1s+ gap")
+}
+
+func TestTouchLease_NilReceiverDoesNotPanic(t *testing.T) {
+ var c *openAIWSIngressContext
+ require.NotPanics(t, func() {
+ c.touchLease(time.Now(), 5*time.Minute)
+ })
+}
+
+func TestTouchLease_AlwaysUpdatesLastTouchUnixNano(t *testing.T) {
+ c := &openAIWSIngressContext{}
+
+ now := time.Now()
+ c.touchLease(now, 5*time.Minute)
+ first := c.lastTouchUnixNano.Load()
+ require.NotZero(t, first)
+
+ // touchLease (non-throttled) always updates, even if called again immediately.
+ time.Sleep(time.Millisecond) // ensure clock moves forward
+ now2 := time.Now()
+ c.touchLease(now2, 5*time.Minute)
+ second := c.lastTouchUnixNano.Load()
+ require.Greater(t, second, first, "touchLease must always update lastTouchUnixNano")
+}
+
+// ===================================================================
+// 2. activeConn cached connection
+// ===================================================================
+
+func TestActiveConn_NilLeaseReturnsError(t *testing.T) {
+ var lease *openAIWSIngressContextLease
+ conn, err := lease.activeConn()
+ require.Nil(t, conn)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestActiveConn_NilContextReturnsError(t *testing.T) {
+ lease := &openAIWSIngressContextLease{context: nil}
+ conn, err := lease.activeConn()
+ require.Nil(t, conn)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestActiveConn_ReleasedLeaseReturnsError(t *testing.T) {
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner",
+ upstream: &openAIWSNoopConn{},
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner",
+ }
+ lease.released.Store(true)
+
+ conn, err := lease.activeConn()
+ require.Nil(t, conn)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestActiveConn_FirstCallPopulatesCachedConn(t *testing.T) {
+ upstream := &openAIWSIdentityConn{}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_1",
+ upstream: upstream,
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner_1",
+ }
+
+ require.Nil(t, lease.cachedConn, "precondition: cachedConn should be nil")
+
+ conn, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream, conn, "should return the upstream conn")
+ require.Equal(t, upstream, lease.cachedConn, "should populate cachedConn")
+}
+
+func TestActiveConn_SecondCallReturnsCachedDirectly(t *testing.T) {
+ upstream1 := &openAIWSIdentityConn{}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_cache",
+ upstream: upstream1,
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner_cache",
+ }
+
+ // First call populates cache
+ conn1, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream1, conn1)
+
+ // Swap the upstream -- cached path should NOT see the swap
+ upstream2 := &openAIWSIdentityConn{}
+ ctx.mu.Lock()
+ ctx.upstream = upstream2
+ ctx.mu.Unlock()
+
+ conn2, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream1, conn2, "second call should return cachedConn, not the swapped upstream")
+}
+
+func TestActiveConn_OwnerMismatchReturnsError(t *testing.T) {
+ ctx := &openAIWSIngressContext{
+ ownerID: "other_owner",
+ upstream: &openAIWSNoopConn{},
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "my_owner",
+ }
+
+ conn, err := lease.activeConn()
+ require.Nil(t, conn)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestActiveConn_NilUpstreamReturnsError(t *testing.T) {
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner",
+ upstream: nil,
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner",
+ }
+
+ conn, err := lease.activeConn()
+ require.Nil(t, conn)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestActiveConn_MarkBrokenClearsCachedConn(t *testing.T) {
+ upstream := &openAIWSNoopConn{}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_mb",
+ upstream: upstream,
+ }
+ pool := &openAIWSIngressContextPool{
+ idleTTL: 10 * time.Minute,
+ }
+ lease := &openAIWSIngressContextLease{
+ pool: pool,
+ context: ctx,
+ ownerID: "owner_mb",
+ }
+
+ // Populate cache
+ conn, err := lease.activeConn()
+ require.NoError(t, err)
+ require.NotNil(t, conn)
+ require.NotNil(t, lease.cachedConn)
+
+ lease.MarkBroken()
+ require.Nil(t, lease.cachedConn, "MarkBroken must clear cachedConn")
+}
+
+func TestActiveConn_ReleaseClearsCachedConn(t *testing.T) {
+ upstream := &openAIWSNoopConn{}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_rel",
+ upstream: upstream,
+ }
+ pool := &openAIWSIngressContextPool{
+ idleTTL: 10 * time.Minute,
+ }
+ lease := &openAIWSIngressContextLease{
+ pool: pool,
+ context: ctx,
+ ownerID: "owner_rel",
+ }
+
+ // Populate cache
+ conn, err := lease.activeConn()
+ require.NoError(t, err)
+ require.NotNil(t, conn)
+ require.NotNil(t, lease.cachedConn)
+
+ lease.Release()
+ require.Nil(t, lease.cachedConn, "Release must clear cachedConn")
+}
+
+func TestActiveConn_AfterClearCachedConn_ReacquiresViaMutex(t *testing.T) {
+ upstream1 := &openAIWSIdentityConn{}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_reacq",
+ upstream: upstream1,
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner_reacq",
+ }
+
+ // Populate cache with upstream1
+ conn, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream1, conn)
+
+ // Simulate a cleared cache (e.g., after recovery)
+ lease.cachedConn = nil
+
+ // Swap upstream
+ upstream2 := &openAIWSIdentityConn{}
+ ctx.mu.Lock()
+ ctx.upstream = upstream2
+ ctx.mu.Unlock()
+
+ // Should now re-acquire via mutex and return upstream2
+ conn2, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream2, conn2, "after clearing cachedConn, next call must re-acquire via mutex")
+ require.Equal(t, upstream2, lease.cachedConn, "should re-populate cachedConn with new upstream")
+}
+
+// ===================================================================
+// 3. Event type TrimSpace-free functions
+// ===================================================================
+
+func TestIsOpenAIWSTerminalEvent(t *testing.T) {
+ tests := []struct {
+ eventType string
+ want bool
+ }{
+ {"response.completed", true},
+ {"response.done", true},
+ {"response.failed", true},
+ {"response.incomplete", true},
+ {"response.cancelled", true},
+ {"response.canceled", true},
+ {"response.created", false},
+ {"response.in_progress", false},
+ {"response.output_text.delta", false},
+ {"", false},
+ {"unknown_event", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.eventType, func(t *testing.T) {
+ require.Equal(t, tt.want, isOpenAIWSTerminalEvent(tt.eventType))
+ })
+ }
+}
+
+func TestShouldPersistOpenAIWSLastResponseID_HotpathPerf(t *testing.T) {
+ tests := []struct {
+ eventType string
+ want bool
+ }{
+ {"response.completed", true},
+ {"response.done", true},
+ {"response.failed", false},
+ {"response.incomplete", false},
+ {"response.cancelled", false},
+ {"", false},
+ {"unknown_event", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.eventType, func(t *testing.T) {
+ require.Equal(t, tt.want, shouldPersistOpenAIWSLastResponseID(tt.eventType))
+ })
+ }
+}
+
+func TestIsOpenAIWSTokenEvent(t *testing.T) {
+ tests := []struct {
+ eventType string
+ want bool
+ }{
+ // Known false: structural events
+ {"response.created", false},
+ {"response.in_progress", false},
+ {"response.output_item.added", false},
+ {"response.output_item.done", false},
+ // Delta events
+ {"response.output_text.delta", true},
+ {"response.content_part.delta", true},
+ {"response.audio.delta", true},
+ {"response.function_call_arguments.delta", true},
+ // output_text prefix
+ {"response.output_text.done", true},
+ {"response.output_text.annotation.added", true},
+ // output prefix (but not output_item)
+ {"response.output.done", true},
+ // Terminal events that are also token events
+ {"response.completed", true},
+ {"response.done", true},
+ // Empty and unknown
+ {"", false},
+ {"unknown_event", false},
+ {"session.created", false},
+ {"session.updated", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.eventType, func(t *testing.T) {
+ require.Equal(t, tt.want, isOpenAIWSTokenEvent(tt.eventType))
+ })
+ }
+}
+
+func TestOpenAIWSEventShouldParseUsage(t *testing.T) {
+ tests := []struct {
+ eventType string
+ want bool
+ }{
+ {"response.completed", true},
+ {"response.done", true},
+ {"response.failed", true},
+ {"", false},
+ {"unknown", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.eventType, func(t *testing.T) {
+ require.Equal(t, tt.want, openAIWSEventShouldParseUsage(tt.eventType))
+ })
+ }
+}
+
+func TestOpenAIWSEventMayContainToolCalls(t *testing.T) {
+ tests := []struct {
+ eventType string
+ want bool
+ }{
+ // Explicit function_call / tool_call in name
+ {"response.function_call_arguments.delta", true},
+ {"response.function_call_arguments.done", true},
+ {"response.tool_call.delta", true},
+ // Structural events that may contain tool output items
+ {"response.output_item.added", true},
+ {"response.output_item.done", true},
+ {"response.completed", true},
+ {"response.done", true},
+ // Non-tool events
+ {"response.output_text.delta", false},
+ {"response.created", false},
+ {"response.in_progress", false},
+ {"", false},
+ {"unknown", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.eventType, func(t *testing.T) {
+ require.Equal(t, tt.want, openAIWSEventMayContainToolCalls(tt.eventType))
+ })
+ }
+}
+
+// ===================================================================
+// 4. parseOpenAIWSEventType (lightweight version)
+// ===================================================================
+
+func TestParseOpenAIWSEventType_EmptyMessage(t *testing.T) {
+ eventType, responseID := parseOpenAIWSEventType(nil)
+ require.Empty(t, eventType)
+ require.Empty(t, responseID)
+
+ eventType, responseID = parseOpenAIWSEventType([]byte{})
+ require.Empty(t, eventType)
+ require.Empty(t, responseID)
+}
+
+func TestParseOpenAIWSEventType_ResponseIDExtracted(t *testing.T) {
+ msg := []byte(`{"type":"response.completed","response":{"id":"resp_abc123"}}`)
+ eventType, responseID := parseOpenAIWSEventType(msg)
+ require.Equal(t, "response.completed", eventType)
+ require.Equal(t, "resp_abc123", responseID)
+}
+
+func TestParseOpenAIWSEventType_FallbackToID(t *testing.T) {
+ msg := []byte(`{"type":"response.output_text.delta","id":"evt_fallback_id"}`)
+ eventType, responseID := parseOpenAIWSEventType(msg)
+ require.Equal(t, "response.output_text.delta", eventType)
+ require.Equal(t, "evt_fallback_id", responseID)
+}
+
+func TestParseOpenAIWSEventType_ConsistentWithEnvelope(t *testing.T) {
+ testMessages := [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`),
+ []byte(`{"type":"response.output_text.delta","id":"evt_2"}`),
+ []byte(`{"type":"response.created","response":{"id":"resp_3"}}`),
+ []byte(`{"type":"error","error":{"message":"bad request"}}`),
+ []byte(`{"type":"response.done","id":"resp_4","response":{"id":"resp_4_inner"}}`),
+ []byte(`{}`),
+ []byte(`{"type":"session.created"}`),
+ }
+ for i, msg := range testMessages {
+ t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) {
+ typeLight, idLight := parseOpenAIWSEventType(msg)
+ typeEnv, idEnv, _ := parseOpenAIWSEventEnvelope(msg)
+ require.Equal(t, typeEnv, typeLight, "eventType must match parseOpenAIWSEventEnvelope")
+ require.Equal(t, idEnv, idLight, "responseID must match parseOpenAIWSEventEnvelope")
+ })
+ }
+}
+
+// ===================================================================
+// 6. openAIWSResponseAccountCacheKey (xxhash, v2 prefix)
+// ===================================================================
+
+func TestOpenAIWSResponseAccountCacheKey_Deterministic(t *testing.T) {
+ key1 := openAIWSResponseAccountCacheKey("resp_deterministic_test")
+ key2 := openAIWSResponseAccountCacheKey("resp_deterministic_test")
+ require.Equal(t, key1, key2, "same responseID must produce the same key")
+}
+
+func TestOpenAIWSResponseAccountCacheKey_DifferentIDsDifferentKeys(t *testing.T) {
+ key1 := openAIWSResponseAccountCacheKey("resp_alpha")
+ key2 := openAIWSResponseAccountCacheKey("resp_beta")
+ require.NotEqual(t, key1, key2, "different responseIDs must produce different keys")
+}
+
+func TestOpenAIWSResponseAccountCacheKey_V2Prefix(t *testing.T) {
+ key := openAIWSResponseAccountCacheKey("resp_v2_check")
+ require.True(t, strings.Contains(key, "v2:"), "key must contain v2: prefix for version compatibility")
+}
+
+func TestOpenAIWSResponseAccountCacheKey_StartsWithCachePrefix(t *testing.T) {
+ key := openAIWSResponseAccountCacheKey("resp_prefix_check")
+ require.True(t, strings.HasPrefix(key, openAIWSResponseAccountCachePrefix),
+ "key must start with the standard cache prefix %q, got %q", openAIWSResponseAccountCachePrefix, key)
+}
+
+func TestOpenAIWSResponseAccountCacheKey_HexLength(t *testing.T) {
+ key := openAIWSResponseAccountCacheKey("resp_hex_length")
+ // Expected format: "openai:response:v2:<16 hex chars>"
+ prefix := openAIWSResponseAccountCachePrefix + "v2:"
+ require.True(t, strings.HasPrefix(key, prefix))
+ hexPart := strings.TrimPrefix(key, prefix)
+ require.Len(t, hexPart, 16, "xxhash hex digest should be zero-padded to 16 chars, got %q", hexPart)
+}
+
+func TestOpenAIWSResponseAccountCacheKey_ManyInputs_AllPaddedTo16(t *testing.T) {
+ // Verify that all inputs produce exactly 16-char hex, testing many variations.
+ prefix := openAIWSResponseAccountCachePrefix + "v2:"
+ for i := 0; i < 1000; i++ {
+ responseID := fmt.Sprintf("resp_%d", i)
+ key := openAIWSResponseAccountCacheKey(responseID)
+ hexPart := strings.TrimPrefix(key, prefix)
+ require.Len(t, hexPart, 16, "responseID=%q produced hex %q (len %d)", responseID, hexPart, len(hexPart))
+ }
+}
+
+// ===================================================================
+// 7. openAIWSSessionTurnStateKey uses strconv
+// ===================================================================
+
+func TestOpenAIWSSessionTurnStateKey_NormalCase(t *testing.T) {
+ key := openAIWSSessionTurnStateKey(123, "abc_hash")
+ require.Equal(t, "123:abc_hash", key)
+}
+
+func TestOpenAIWSSessionTurnStateKey_EmptySessionHash(t *testing.T) {
+ key := openAIWSSessionTurnStateKey(123, "")
+ require.Equal(t, "", key)
+}
+
+func TestOpenAIWSSessionTurnStateKey_WhitespaceOnlySessionHash(t *testing.T) {
+ key := openAIWSSessionTurnStateKey(123, " ")
+ require.Equal(t, "", key)
+}
+
+func TestOpenAIWSSessionTurnStateKey_NegativeGroupID(t *testing.T) {
+ key := openAIWSSessionTurnStateKey(-1, "hash")
+ require.Equal(t, "-1:hash", key)
+}
+
+func TestOpenAIWSSessionTurnStateKey_ZeroGroupID(t *testing.T) {
+ key := openAIWSSessionTurnStateKey(0, "hash")
+ require.Equal(t, "0:hash", key)
+}
+
+// ===================================================================
+// 8. openAIWSIngressContextSessionKey uses strconv
+// ===================================================================
+
+func TestOpenAIWSIngressContextSessionKey_NormalCase(t *testing.T) {
+ key := openAIWSIngressContextSessionKey(456, "session_xyz")
+ require.Equal(t, "456:session_xyz", key)
+}
+
+func TestOpenAIWSIngressContextSessionKey_EmptySessionHash(t *testing.T) {
+ key := openAIWSIngressContextSessionKey(456, "")
+ require.Equal(t, "", key)
+}
+
+func TestOpenAIWSIngressContextSessionKey_WhitespaceOnlySessionHash(t *testing.T) {
+ key := openAIWSIngressContextSessionKey(456, " \t ")
+ require.Equal(t, "", key)
+}
+
+func TestOpenAIWSIngressContextSessionKey_LargeGroupID(t *testing.T) {
+ key := openAIWSIngressContextSessionKey(9223372036854775807, "h")
+ require.Equal(t, "9223372036854775807:h", key)
+}
+
+// ===================================================================
+// 9. deriveOpenAISessionHash and deriveOpenAILegacySessionHash
+// ===================================================================
+
+func TestDeriveOpenAISessionHash_EmptyReturnsEmpty(t *testing.T) {
+ require.Equal(t, "", deriveOpenAISessionHash(""))
+ require.Equal(t, "", deriveOpenAISessionHash(" "))
+}
+
+func TestDeriveOpenAISessionHash_ProducesXXHash16Chars(t *testing.T) {
+ hash := deriveOpenAISessionHash("test_session_id")
+ require.Len(t, hash, 16, "xxhash hex should be exactly 16 chars, got %q", hash)
+ // Verify it's valid hex
+ for _, ch := range hash {
+ require.True(t, (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f'),
+ "hash should be lowercase hex, got char %c in %q", ch, hash)
+ }
+}
+
+func TestDeriveOpenAISessionHash_Deterministic(t *testing.T) {
+ h1 := deriveOpenAISessionHash("session_abc")
+ h2 := deriveOpenAISessionHash("session_abc")
+ require.Equal(t, h1, h2)
+}
+
+func TestDeriveOpenAILegacySessionHash_EmptyReturnsEmpty(t *testing.T) {
+ require.Equal(t, "", deriveOpenAILegacySessionHash(""))
+ require.Equal(t, "", deriveOpenAILegacySessionHash(" "))
+}
+
+func TestDeriveOpenAILegacySessionHash_ProducesSHA256_64Chars(t *testing.T) {
+ hash := deriveOpenAILegacySessionHash("test_session_id")
+ require.Len(t, hash, 64, "SHA-256 hex should be exactly 64 chars, got %q", hash)
+ for _, ch := range hash {
+ require.True(t, (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f'),
+ "hash should be lowercase hex, got char %c in %q", ch, hash)
+ }
+}
+
+func TestDeriveOpenAILegacySessionHash_Deterministic(t *testing.T) {
+ h1 := deriveOpenAILegacySessionHash("session_xyz")
+ h2 := deriveOpenAILegacySessionHash("session_xyz")
+ require.Equal(t, h1, h2)
+}
+
+func TestDeriveOpenAISessionHashes_MatchesIndividualFunctions(t *testing.T) {
+ sessionID := "test_combined_session"
+ currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
+
+ require.Equal(t, deriveOpenAISessionHash(sessionID), currentHash)
+ require.Equal(t, deriveOpenAILegacySessionHash(sessionID), legacyHash)
+}
+
+func TestDeriveOpenAISessionHashes_EmptyReturnsEmpty(t *testing.T) {
+ currentHash, legacyHash := deriveOpenAISessionHashes("")
+ require.Equal(t, "", currentHash)
+ require.Equal(t, "", legacyHash)
+}
+
+func TestDeriveOpenAISessionHashes_DifferentInputsDifferentOutputs(t *testing.T) {
+ h1Current, h1Legacy := deriveOpenAISessionHashes("session_A")
+ h2Current, h2Legacy := deriveOpenAISessionHashes("session_B")
+ require.NotEqual(t, h1Current, h2Current)
+ require.NotEqual(t, h1Legacy, h2Legacy)
+}
+
+func TestDeriveOpenAISessionHash_DifferentFromLegacy(t *testing.T) {
+ // xxhash and SHA-256 produce completely different outputs for the same input
+ currentHash := deriveOpenAISessionHash("same_input")
+ legacyHash := deriveOpenAILegacySessionHash("same_input")
+ require.NotEqual(t, currentHash, legacyHash, "xxhash and SHA-256 should produce different results")
+ require.Len(t, currentHash, 16)
+ require.Len(t, legacyHash, 64)
+}
+
+// ===================================================================
+// 10. State store sharded lock (responseToConn)
+// ===================================================================
+
+func TestConnShard_DistributesAcrossShards(t *testing.T) {
+ store := mustDefaultOpenAIWSStateStore(t, NewOpenAIWSStateStore(nil))
+
+ shardHits := make(map[int]int)
+ for i := 0; i < 256; i++ {
+ key := fmt.Sprintf("resp_%d", i)
+ shard := store.connShard(key)
+ // Find which shard index this is
+ for j := 0; j < openAIWSStateStoreConnShards; j++ {
+ if shard == &store.responseToConnShards[j] {
+ shardHits[j]++
+ break
+ }
+ }
+ }
+
+ // With 256 keys and 16 shards, each shard should get some keys.
+ // We don't require perfect uniformity, just that keys aren't all in one shard.
+ require.Greater(t, len(shardHits), 1, "keys must be distributed across multiple shards, got %d shards used", len(shardHits))
+ require.GreaterOrEqual(t, len(shardHits), openAIWSStateStoreConnShards/2,
+ "keys should hit at least half the shards for reasonable distribution")
+}
+
+func TestStateStore_ShardedBindGetDelete(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+
+ store.BindResponseConn("resp_shard_1", "conn_a", time.Minute)
+ store.BindResponseConn("resp_shard_2", "conn_b", time.Minute)
+
+ conn1, ok1 := store.GetResponseConn("resp_shard_1")
+ require.True(t, ok1)
+ require.Equal(t, "conn_a", conn1)
+
+ conn2, ok2 := store.GetResponseConn("resp_shard_2")
+ require.True(t, ok2)
+ require.Equal(t, "conn_b", conn2)
+
+ store.DeleteResponseConn("resp_shard_1")
+ _, ok1After := store.GetResponseConn("resp_shard_1")
+ require.False(t, ok1After)
+
+ // resp_shard_2 should still be accessible
+ conn2After, ok2After := store.GetResponseConn("resp_shard_2")
+ require.True(t, ok2After)
+ require.Equal(t, "conn_b", conn2After)
+}
+
+func TestStateStore_ShardedConcurrentAccessNoRace(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+ const goroutines = 32
+ const opsPerGoroutine = 200
+
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ for g := 0; g < goroutines; g++ {
+ g := g
+ go func() {
+ defer wg.Done()
+ for i := 0; i < opsPerGoroutine; i++ {
+ key := fmt.Sprintf("resp_conc_%d_%d", g, i)
+ connID := fmt.Sprintf("conn_%d_%d", g, i)
+
+ store.BindResponseConn(key, connID, time.Minute)
+ got, ok := store.GetResponseConn(key)
+ if ok {
+ _ = got
+ }
+ store.DeleteResponseConn(key)
+ }
+ }()
+ }
+
+ wg.Wait()
+}
+
+// ===================================================================
+// 11. State store: Get paths don't call maybeCleanup
+// ===================================================================
+
+func TestStateStore_GetPaths_DoNotTriggerCleanup(t *testing.T) {
+ raw := NewOpenAIWSStateStore(nil)
+ store := mustDefaultOpenAIWSStateStore(t, raw)
+
+ // Seed some data so Get paths have something to read
+ store.BindResponseConn("resp_get_noclean", "conn_1", time.Minute)
+ store.BindResponsePendingToolCalls(0, "resp_get_noclean", []string{"call_1"}, time.Minute)
+ store.BindSessionTurnState(1, "session_get_noclean", "state_1", time.Minute)
+ store.BindSessionConn(1, "session_get_noclean", "conn_1", time.Minute)
+
+ // Record the lastCleanupUnixNano after the Bind calls
+ cleanupBefore := store.lastCleanupUnixNano.Load()
+
+ // Set lastCleanup to the future to ensure no cleanup triggers from Binds
+ store.lastCleanupUnixNano.Store(time.Now().Add(time.Hour).UnixNano())
+ cleanupFrozen := store.lastCleanupUnixNano.Load()
+
+ // Perform many Get calls
+ for i := 0; i < 100; i++ {
+ store.GetResponseConn("resp_get_noclean")
+ store.GetResponsePendingToolCalls(0, "resp_get_noclean")
+ store.GetSessionTurnState(1, "session_get_noclean")
+ store.GetSessionConn(1, "session_get_noclean")
+ }
+
+ cleanupAfterGets := store.lastCleanupUnixNano.Load()
+ require.Equal(t, cleanupFrozen, cleanupAfterGets,
+ "Get paths must NOT change lastCleanupUnixNano (was %d before, %d after)", cleanupBefore, cleanupAfterGets)
+}
+
+func TestStateStore_MaybeCleanup_NilReceiverDoesNotPanic(t *testing.T) {
+ var store *defaultOpenAIWSStateStore
+ require.NotPanics(t, func() {
+ store.maybeCleanup()
+ })
+}
+
+func TestStateStore_BindPaths_MayTriggerCleanup(t *testing.T) {
+ raw := NewOpenAIWSStateStore(nil)
+ store := mustDefaultOpenAIWSStateStore(t, raw)
+
+ // Set lastCleanup to long ago to ensure cleanup triggers on next Bind
+ pastNano := time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano()
+ store.lastCleanupUnixNano.Store(pastNano)
+
+ store.BindResponseConn("resp_bind_trigger", "conn_trigger", time.Minute)
+
+ cleanupAfterBind := store.lastCleanupUnixNano.Load()
+ require.NotEqual(t, pastNano, cleanupAfterBind,
+ "Bind paths should trigger maybeCleanup when interval has elapsed")
+}
+
+// ===================================================================
+// 12. GetResponsePendingToolCalls returns internal slice directly
+// ===================================================================
+
+func TestGetResponsePendingToolCalls_ReturnsInternalSlice(t *testing.T) {
+ raw := NewOpenAIWSStateStore(nil)
+ store := mustDefaultOpenAIWSStateStore(t, raw)
+
+ store.BindResponsePendingToolCalls(0, "resp_slice_identity", []string{"call_x", "call_y"}, time.Minute)
+
+ callIDs, ok := store.GetResponsePendingToolCalls(0, "resp_slice_identity")
+ require.True(t, ok)
+ require.Equal(t, []string{"call_x", "call_y"}, callIDs)
+
+ // Verify it's the same underlying slice as stored in the binding (pointer equality).
+ // The binding stores callIDs as a copied slice at bind time, but Get returns it directly.
+ id := openAIWSResponsePendingToolCallsBindingKey(0, "resp_slice_identity")
+ store.responsePendingToolMu.RLock()
+ binding, exists := store.responsePendingTool[id]
+ store.responsePendingToolMu.RUnlock()
+ require.True(t, exists)
+
+ // Check pointer equality of the underlying array via unsafe
+ gotHeader := (*[3]uintptr)(unsafe.Pointer(&callIDs))
+ internalHeader := (*[3]uintptr)(unsafe.Pointer(&binding.callIDs))
+ require.Equal(t, gotHeader[0], internalHeader[0],
+ "returned slice should share the same underlying array pointer as the internal binding (zero-copy)")
+}
+
+// ===================================================================
+// Additional edge-case tests for completeness
+// ===================================================================
+
+func TestParseOpenAIWSEventType_MalformedJSON(t *testing.T) {
+ // Should not panic on malformed JSON
+ eventType, responseID := parseOpenAIWSEventType([]byte(`{not valid json`))
+ // gjson returns empty for invalid JSON
+ _ = eventType
+ _ = responseID
+}
+
+func TestOpenAIWSResponseAccountCacheKey_EmptyInput(t *testing.T) {
+ // Even empty string should produce a valid key
+ key := openAIWSResponseAccountCacheKey("")
+ require.True(t, strings.HasPrefix(key, openAIWSResponseAccountCachePrefix+"v2:"))
+ hexPart := strings.TrimPrefix(key, openAIWSResponseAccountCachePrefix+"v2:")
+ require.Len(t, hexPart, 16)
+}
+
+func TestConnShard_SameKeyAlwaysSameShard(t *testing.T) {
+ store := mustDefaultOpenAIWSStateStore(t, NewOpenAIWSStateStore(nil))
+ shard1 := store.connShard("resp_stable_key")
+ shard2 := store.connShard("resp_stable_key")
+ require.Equal(t, shard1, shard2, "same key must always map to the same shard")
+}
+
+func TestMaybeTouchLease_ConcurrentSafe(t *testing.T) {
+ c := &openAIWSIngressContext{}
+ var wg sync.WaitGroup
+ const goroutines = 16
+
+ wg.Add(goroutines)
+ for i := 0; i < goroutines; i++ {
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 100; j++ {
+ c.maybeTouchLease(5 * time.Minute)
+ }
+ }()
+ }
+ wg.Wait()
+
+ require.NotZero(t, c.lastTouchUnixNano.Load())
+ require.False(t, c.expiresAt().IsZero())
+}
+
+func TestActiveConn_SingleOwnerSequentialAccess(t *testing.T) {
+ // activeConn uses a non-synchronized cachedConn field by design.
+ // A lease is only used by a single goroutine (the forwarding loop).
+ // This test verifies sequential repeated calls from the same goroutine
+ // always return the same cached conn without error.
+ upstream := &openAIWSNoopConn{}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_seq",
+ upstream: upstream,
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner_seq",
+ }
+
+ for i := 0; i < 1000; i++ {
+ conn, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream, conn)
+ }
+}
+
+func TestOpenAIWSIngressContextSessionKey_ConsistentWithTurnStateKey(t *testing.T) {
+ // Both functions use the same pattern: strconv.FormatInt(groupID, 10) + ":" + hash
+ groupID := int64(42)
+ sessionHash := "test_hash"
+
+ sessionKey := openAIWSIngressContextSessionKey(groupID, sessionHash)
+ turnStateKey := openAIWSSessionTurnStateKey(groupID, sessionHash)
+
+ require.Equal(t, sessionKey, turnStateKey,
+ "openAIWSIngressContextSessionKey and openAIWSSessionTurnStateKey should produce identical keys for the same inputs")
+}
+
+func TestStateStore_ShardedBindOverwrite(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+
+ store.BindResponseConn("resp_overwrite", "conn_old", time.Minute)
+ store.BindResponseConn("resp_overwrite", "conn_new", time.Minute)
+
+ conn, ok := store.GetResponseConn("resp_overwrite")
+ require.True(t, ok)
+ require.Equal(t, "conn_new", conn, "later bind should overwrite earlier bind")
+}
+
+func TestStateStore_ShardedTTLExpiry(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+
+ store.BindResponseConn("resp_ttl_shard", "conn_ttl", 30*time.Millisecond)
+ conn, ok := store.GetResponseConn("resp_ttl_shard")
+ require.True(t, ok)
+ require.Equal(t, "conn_ttl", conn)
+
+ time.Sleep(60 * time.Millisecond)
+ _, ok = store.GetResponseConn("resp_ttl_shard")
+ require.False(t, ok, "entry should be expired after TTL")
+}
diff --git a/backend/internal/service/openai_ws_ingress_context_pool.go b/backend/internal/service/openai_ws_ingress_context_pool.go
new file mode 100644
index 000000000..daa0ebb02
--- /dev/null
+++ b/backend/internal/service/openai_ws_ingress_context_pool.go
@@ -0,0 +1,1594 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+var (
+ errOpenAIWSIngressContextBusy = errors.New("openai ws ingress context is busy")
+)
+
+const (
+ openAIWSIngressScheduleLayerExact = "l0_exact"
+ openAIWSIngressScheduleLayerNew = "l1_new_context"
+ openAIWSIngressScheduleLayerMigration = "l2_migration"
+ openAIWSIngressAcquireMaxWaitRetries = 4096
+ openAIWSIngressAcquireMaxQueueWait = 30 * time.Minute
+
+ // openAIWSUpstreamConnMaxAge 是上游 WebSocket 连接的默认最大存活时间。
+ // OpenAI 在 60 分钟后强制关闭连接,此处默认 55 分钟主动轮换以避免中途断连。
+ openAIWSUpstreamConnMaxAge = 55 * time.Minute
+
+ // openAIWSIngressDelayedPingAfterYield 是 yield 后延迟 Ping 探测的等待时间。
+ // 在会话暂时空闲后提前发现死连接,避免下次 Acquire 时才发现。
+ openAIWSIngressDelayedPingAfterYield = 5 * time.Second
+
+ // openAIWSIngressPingTimeout 是后台 Ping 探测的超时时间。
+ openAIWSIngressPingTimeout = 5 * time.Second
+)
+
+const (
+ openAIWSIngressStickinessWeak = "weak"
+ openAIWSIngressStickinessBalanced = "balanced"
+ openAIWSIngressStickinessStrong = "strong"
+)
+
+type openAIWSIngressContextAcquireRequest struct {
+ Account *Account
+ GroupID int64
+ SessionHash string
+ OwnerID string
+ WSURL string
+ Headers http.Header
+ ProxyURL string
+ Turn int
+
+ HasPreviousResponseID bool
+ StrictAffinity bool
+ StoreDisabled bool
+}
+
+type openAIWSIngressContextPool struct {
+ cfg *config.Config
+ dialer openAIWSClientDialer
+
+ idleTTL time.Duration
+ sweepInterval time.Duration
+ upstreamMaxAge time.Duration
+
+ seq atomic.Uint64
+
+ // schedulerStats provides load-aware signals (error rate, circuit breaker
+ // state) for migration scoring. When nil, scoring falls back to the
+ // existing idle-time + failure-streak heuristic.
+ schedulerStats *openAIAccountRuntimeStats
+
+ mu sync.Mutex
+ accounts map[int64]*openAIWSIngressAccountPool
+
+ stopCh chan struct{}
+ stopOnce sync.Once
+ workerWg sync.WaitGroup
+ closeOnce sync.Once
+}
+
+type openAIWSIngressAccountPool struct {
+ mu sync.Mutex
+
+ refs atomic.Int64
+
+ // dynamicCap 动态容量:初始 1,按需增长(L1 新建时 +1),空闲超时后缩减。
+ // 实际容量为 min(dynamicCap, effectiveContextCapacity)。
+ dynamicCap atomic.Int32
+
+ contexts map[string]*openAIWSIngressContext
+ bySession map[string]string
+}
+
+type openAIWSIngressContext struct {
+ id string
+ groupID int64
+ accountID int64
+ sessionHash string
+ sessionKey string
+
+ mu sync.Mutex
+ dialing bool
+ dialDone chan struct{}
+ releaseDone chan struct{} // ownerID 释放时发送单信号,唤醒一个等待者
+ ownerID string
+ lastUsedAtUnix atomic.Int64
+ expiresAtUnix atomic.Int64
+ lastTouchUnixNano atomic.Int64 // throttle: skip touchLease if < 1s since last
+ broken bool
+ failureStreak int
+ lastFailureAt time.Time
+ migrationCount int
+ lastMigrationAt time.Time
+ upstream openAIWSClientConn
+ upstreamConnID string
+ upstreamConnCreatedAt atomic.Int64 // UnixNano; 0 表示未设置
+ handshakeHeaders http.Header
+ prewarmed atomic.Bool
+ pendingPingTimer *time.Timer // 延迟 Ping 去重:同一 context 仅保留一个 pending ping
+}
+
+type openAIWSIngressContextLease struct {
+ pool *openAIWSIngressContextPool
+ context *openAIWSIngressContext
+ ownerID string
+ queueWait time.Duration
+ connPick time.Duration
+ reused bool
+ scheduleLayer string
+ stickiness string
+ migrationUsed bool
+ released atomic.Bool
+ cachedConnMu sync.RWMutex
+ cachedConn openAIWSClientConn // fast path: avoid mutex on every activeConn() call
+}
+
+func openAIWSTimeToUnixNano(ts time.Time) int64 {
+ if ts.IsZero() {
+ return 0
+ }
+ return ts.UnixNano()
+}
+
+func openAIWSUnixNanoToTime(ns int64) time.Time {
+ if ns <= 0 {
+ return time.Time{}
+ }
+ return time.Unix(0, ns)
+}
+
+func (c *openAIWSIngressContext) setLastUsedAt(ts time.Time) {
+ if c == nil {
+ return
+ }
+ c.lastUsedAtUnix.Store(openAIWSTimeToUnixNano(ts))
+}
+
+func (c *openAIWSIngressContext) lastUsedAt() time.Time {
+ if c == nil {
+ return time.Time{}
+ }
+ return openAIWSUnixNanoToTime(c.lastUsedAtUnix.Load())
+}
+
+func (c *openAIWSIngressContext) setExpiresAt(ts time.Time) {
+ if c == nil {
+ return
+ }
+ c.expiresAtUnix.Store(openAIWSTimeToUnixNano(ts))
+}
+
+func (c *openAIWSIngressContext) expiresAt() time.Time {
+ if c == nil {
+ return time.Time{}
+ }
+ return openAIWSUnixNanoToTime(c.expiresAtUnix.Load())
+}
+
+// upstreamConnAge 返回上游连接已存活的时长。
+// 若 createdAt 未设置(零值)或 now 早于 createdAt(时钟回拨),返回 0。
+func (c *openAIWSIngressContext) upstreamConnAge(now time.Time) time.Duration {
+ if c == nil {
+ return 0
+ }
+ ns := c.upstreamConnCreatedAt.Load()
+ if ns <= 0 {
+ return 0
+ }
+ age := now.Sub(time.Unix(0, ns))
+ if age < 0 {
+ return 0
+ }
+ return age
+}
+
+func (c *openAIWSIngressContext) touchLease(now time.Time, ttl time.Duration) {
+ if c == nil {
+ return
+ }
+ nowUnix := openAIWSTimeToUnixNano(now)
+ c.lastUsedAtUnix.Store(nowUnix)
+ c.expiresAtUnix.Store(openAIWSTimeToUnixNano(now.Add(ttl)))
+ c.lastTouchUnixNano.Store(nowUnix)
+}
+
+// maybeTouchLease is a throttled version of touchLease.
+// It skips the update if less than 1 second has passed since the last touch,
+// avoiding redundant time.Now() + atomic stores on every hot-path message.
+func (c *openAIWSIngressContext) maybeTouchLease(ttl time.Duration) {
+ if c == nil {
+ return
+ }
+ now := time.Now()
+ nowNano := now.UnixNano()
+ lastNano := c.lastTouchUnixNano.Load()
+ if lastNano > 0 && nowNano-lastNano < int64(time.Second) {
+ return
+ }
+ c.touchLease(now, ttl)
+}
+
+func newOpenAIWSIngressContextPool(cfg *config.Config) *openAIWSIngressContextPool {
+ pool := &openAIWSIngressContextPool{
+ cfg: cfg,
+ dialer: newDefaultOpenAIWSClientDialer(),
+ idleTTL: 10 * time.Minute,
+ sweepInterval: 30 * time.Second,
+ upstreamMaxAge: openAIWSUpstreamConnMaxAge,
+ accounts: make(map[int64]*openAIWSIngressAccountPool),
+ stopCh: make(chan struct{}),
+ }
+ if cfg != nil && cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
+ pool.idleTTL = time.Duration(cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
+ }
+ if cfg != nil && cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds >= 0 {
+ // 配置语义:0 表示禁用超龄轮换。
+ pool.upstreamMaxAge = time.Duration(cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds) * time.Second
+ }
+ pool.startWorker()
+ return pool
+}
+
+func (p *openAIWSIngressContextPool) setClientDialerForTest(dialer openAIWSClientDialer) {
+ if p == nil || dialer == nil {
+ return
+ }
+ p.dialer = dialer
+}
+
+func (p *openAIWSIngressContextPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
+ if p == nil {
+ return OpenAIWSTransportMetricsSnapshot{}
+ }
+ if dialer, ok := p.dialer.(openAIWSTransportMetricsDialer); ok {
+ return dialer.SnapshotTransportMetrics()
+ }
+ return OpenAIWSTransportMetricsSnapshot{}
+}
+
+func (p *openAIWSIngressContextPool) maxConnsHardCap() int {
+ if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 {
+ return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount
+ }
+ return 8
+}
+
+func (p *openAIWSIngressContextPool) effectiveContextCapacity(account *Account) int {
+ if account == nil || account.Concurrency <= 0 {
+ return 0
+ }
+ capacity := account.Concurrency
+ hardCap := p.maxConnsHardCap()
+ if hardCap > 0 && capacity > hardCap {
+ return hardCap
+ }
+ return capacity
+}
+
+func (p *openAIWSIngressContextPool) Close() {
+ if p == nil {
+ return
+ }
+ p.closeOnce.Do(func() {
+ p.stopOnce.Do(func() {
+ close(p.stopCh)
+ })
+ p.workerWg.Wait()
+
+ var toClose []openAIWSClientConn
+ p.mu.Lock()
+ accountPools := make([]*openAIWSIngressAccountPool, 0, len(p.accounts))
+ for _, ap := range p.accounts {
+ if ap != nil {
+ accountPools = append(accountPools, ap)
+ }
+ }
+ p.accounts = make(map[int64]*openAIWSIngressAccountPool)
+ p.mu.Unlock()
+
+ for _, ap := range accountPools {
+ ap.mu.Lock()
+ for _, ctx := range ap.contexts {
+ if ctx == nil {
+ continue
+ }
+ ctx.mu.Lock()
+ if ctx.upstream != nil {
+ toClose = append(toClose, ctx.upstream)
+ }
+ ctx.upstream = nil
+ ctx.upstreamConnCreatedAt.Store(0)
+ ctx.broken = true
+ ctx.ownerID = ""
+ ctx.handshakeHeaders = nil
+ ctx.mu.Unlock()
+ }
+ ap.contexts = make(map[string]*openAIWSIngressContext)
+ ap.bySession = make(map[string]string)
+ ap.mu.Unlock()
+ }
+
+ for _, conn := range toClose {
+ if conn != nil {
+ _ = conn.Close()
+ }
+ }
+ })
+}
+
+func (p *openAIWSIngressContextPool) startWorker() {
+ if p == nil {
+ return
+ }
+ interval := p.sweepInterval
+ if interval <= 0 {
+ interval = 30 * time.Second
+ }
+ p.workerWg.Add(1)
+ go func() {
+ defer p.workerWg.Done()
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-p.stopCh:
+ return
+ case <-ticker.C:
+ p.sweepExpiredIdleContexts()
+ }
+ }
+ }()
+}
+
+func (p *openAIWSIngressContextPool) Acquire(
+ ctx context.Context,
+ req openAIWSIngressContextAcquireRequest,
+) (*openAIWSIngressContextLease, error) {
+ if p == nil {
+ return nil, errors.New("openai ws ingress context pool is nil")
+ }
+ if req.Account == nil || req.Account.ID <= 0 {
+ return nil, errors.New("invalid account in ingress context acquire request")
+ }
+ ownerID := strings.TrimSpace(req.OwnerID)
+ if ownerID == "" {
+ return nil, errors.New("owner id is empty")
+ }
+ if strings.TrimSpace(req.WSURL) == "" {
+ return nil, errors.New("ws url is empty")
+ }
+ capacity := p.effectiveContextCapacity(req.Account)
+ if capacity <= 0 {
+ return nil, errOpenAIWSConnQueueFull
+ }
+
+ sessionHash := strings.TrimSpace(req.SessionHash)
+ if sessionHash == "" {
+ // 无会话信号时退化为连接级上下文,避免跨连接串会话。
+ sessionHash = "conn:" + ownerID
+ }
+ sessionKey := openAIWSIngressContextSessionKey(req.GroupID, sessionHash)
+ accountID := req.Account.ID
+
+ start := time.Now()
+ queueWait := time.Duration(0)
+ waitRetries := 0
+
+ p.mu.Lock()
+ ap := p.getOrCreateAccountPoolLocked(accountID)
+ ap.refs.Add(1)
+ p.mu.Unlock()
+ defer ap.refs.Add(-1)
+
+ calcConnPick := func() time.Duration {
+ connPick := time.Since(start) - queueWait
+ if connPick < 0 {
+ return 0
+ }
+ return connPick
+ }
+
+ for {
+ now := time.Now()
+ var (
+ selected *openAIWSIngressContext
+ reusedContext bool
+ newlyCreated bool
+ ownerAssigned bool
+ migrationUsed bool
+ scheduleLayer string
+ oldUpstream openAIWSClientConn
+ deferredClose []openAIWSClientConn
+ )
+
+ ap.mu.Lock()
+
+ stickiness := p.resolveStickinessLevelLocked(ap, sessionKey, req, now)
+ allowMigration, minMigrationScore := openAIWSIngressMigrationPolicyByStickiness(stickiness)
+
+ if existingID, ok := ap.bySession[sessionKey]; ok {
+ if existing := ap.contexts[existingID]; existing != nil {
+ existing.mu.Lock()
+ switch existing.ownerID {
+ case "":
+ if existing.releaseDone != nil {
+ select {
+ case <-existing.releaseDone:
+ default:
+ }
+ }
+ existing.ownerID = ownerID
+ ownerAssigned = true
+ existing.touchLease(now, p.idleTTL)
+ selected = existing
+ reusedContext = true
+ scheduleLayer = openAIWSIngressScheduleLayerExact
+ case ownerID:
+ existing.touchLease(now, p.idleTTL)
+ selected = existing
+ reusedContext = true
+ scheduleLayer = openAIWSIngressScheduleLayerExact
+ default:
+ // 当前 context 被其他 owner 占用,等待其释放后重试(循环重试替代递归)。
+ blockedByOwner := existing.ownerID
+ if existing.releaseDone == nil {
+ existing.releaseDone = make(chan struct{}, 1)
+ }
+ releaseDone := existing.releaseDone
+ existing.mu.Unlock()
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(deferredClose)
+
+ logOpenAIWSModeInfo(
+ "ctx_pool_owner_wait_begin account_id=%d ctx_id=%s owner_id=%s blocked_by=%s retry=%d",
+ accountID, existing.id, ownerID,
+ truncateOpenAIWSLogValue(blockedByOwner, openAIWSIDValueMaxLen),
+ waitRetries,
+ )
+ waitStart := time.Now()
+ select {
+ case <-releaseDone:
+ queueWait += time.Since(waitStart)
+ waitRetries++
+ if waitRetries >= openAIWSIngressAcquireMaxWaitRetries || queueWait >= openAIWSIngressAcquireMaxQueueWait {
+ logOpenAIWSModeInfo(
+ "ctx_pool_owner_wait_exhausted account_id=%d ctx_id=%s owner_id=%s wait_retries=%d queue_wait_ms=%d",
+ accountID, existing.id, ownerID, waitRetries, queueWait.Milliseconds(),
+ )
+ return nil, errOpenAIWSIngressContextBusy
+ }
+ continue
+ case <-ctx.Done():
+ queueWait += time.Since(waitStart)
+ logOpenAIWSModeInfo(
+ "ctx_pool_owner_wait_canceled account_id=%d ctx_id=%s owner_id=%s wait_retries=%d queue_wait_ms=%d",
+ accountID, existing.id, ownerID, waitRetries, queueWait.Milliseconds(),
+ )
+ return nil, errOpenAIWSIngressContextBusy
+ }
+ }
+ existing.mu.Unlock()
+ }
+ }
+
+ if selected == nil {
+ dynCap := p.effectiveDynamicCapacity(ap, capacity)
+ if len(ap.contexts) >= dynCap {
+ deferredClose = append(deferredClose, p.evictExpiredIdleLocked(ap, now)...)
+ }
+ if len(ap.contexts) >= dynCap {
+ if dynCap < capacity {
+ // 动态扩容:尚未达到硬上限,增长 1 后创建新 context
+ ap.dynamicCap.Add(1)
+ } else if !allowMigration {
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(deferredClose)
+ logOpenAIWSModeInfo(
+ "ctx_pool_full_no_migration account_id=%d capacity=%d stickiness=%s",
+ accountID, capacity, stickiness,
+ )
+ return nil, errOpenAIWSConnQueueFull
+ } else {
+ recycle := p.pickMigrationCandidateLocked(ap, minMigrationScore, now)
+ if recycle == nil {
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(deferredClose)
+ logOpenAIWSModeInfo(
+ "ctx_pool_no_migration_candidate account_id=%d capacity=%d min_score=%.1f",
+ accountID, capacity, minMigrationScore,
+ )
+ return nil, errOpenAIWSConnQueueFull
+ }
+ recycle.mu.Lock()
+ oldSessionKey := recycle.sessionKey
+ oldUpstream = recycle.upstream
+ recycle.sessionHash = sessionHash
+ recycle.sessionKey = sessionKey
+ recycle.groupID = req.GroupID
+ if recycle.releaseDone != nil {
+ select {
+ case <-recycle.releaseDone:
+ default:
+ }
+ }
+ recycle.ownerID = ownerID
+ recycle.touchLease(now, p.idleTTL)
+ // 会话被回收复用时关闭旧上游,避免跨会话污染。
+ recycle.upstream = nil
+ recycle.upstreamConnID = ""
+ recycle.upstreamConnCreatedAt.Store(0)
+ recycle.handshakeHeaders = nil
+ recycle.broken = false
+ recycle.migrationCount++
+ recycle.lastMigrationAt = now
+ recycle.mu.Unlock()
+
+ if oldSessionKey != "" {
+ if mapped, ok := ap.bySession[oldSessionKey]; ok && mapped == recycle.id {
+ delete(ap.bySession, oldSessionKey)
+ }
+ }
+ ap.bySession[sessionKey] = recycle.id
+ selected = recycle
+ reusedContext = true
+ migrationUsed = true
+ scheduleLayer = openAIWSIngressScheduleLayerMigration
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(deferredClose)
+ if oldUpstream != nil {
+ _ = oldUpstream.Close()
+ }
+ reusedConn, ensureErr := p.ensureContextUpstream(ctx, selected, req)
+ if ensureErr != nil {
+ p.releaseContext(selected, ownerID)
+ return nil, ensureErr
+ }
+ logOpenAIWSModeInfo(
+ "ctx_pool_migration account_id=%d ctx_id=%s old_session=%s new_session=%s group_id=%d session_hash=%s migration_count=%d",
+ accountID, selected.id, truncateOpenAIWSLogValue(oldSessionKey, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(sessionKey, openAIWSIDValueMaxLen), selected.groupID,
+ truncateOpenAIWSLogValue(selected.sessionHash, openAIWSIDValueMaxLen), selected.migrationCount,
+ )
+ return &openAIWSIngressContextLease{
+ pool: p,
+ context: selected,
+ ownerID: ownerID,
+ queueWait: queueWait,
+ connPick: calcConnPick(),
+ reused: reusedContext && reusedConn,
+ scheduleLayer: scheduleLayer,
+ stickiness: stickiness,
+ migrationUsed: migrationUsed,
+ }, nil
+ }
+ }
+
+ ctxID := fmt.Sprintf("ctx_%d_%d", accountID, p.seq.Add(1))
+ created := &openAIWSIngressContext{
+ id: ctxID,
+ groupID: req.GroupID,
+ accountID: accountID,
+ sessionHash: sessionHash,
+ sessionKey: sessionKey,
+ ownerID: ownerID,
+ }
+ created.touchLease(now, p.idleTTL)
+ ap.contexts[ctxID] = created
+ ap.bySession[sessionKey] = ctxID
+ selected = created
+ newlyCreated = true
+ ownerAssigned = true
+ scheduleLayer = openAIWSIngressScheduleLayerNew
+ }
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(deferredClose)
+
+ reusedConn, ensureErr := p.ensureContextUpstream(ctx, selected, req)
+ if ensureErr != nil {
+ if newlyCreated {
+ ap.mu.Lock()
+ delete(ap.contexts, selected.id)
+ if mapped, ok := ap.bySession[sessionKey]; ok && mapped == selected.id {
+ delete(ap.bySession, sessionKey)
+ }
+ ap.mu.Unlock()
+ } else if ownerAssigned {
+ p.releaseContext(selected, ownerID)
+ }
+ return nil, ensureErr
+ }
+
+ return &openAIWSIngressContextLease{
+ pool: p,
+ context: selected,
+ ownerID: ownerID,
+ queueWait: queueWait,
+ connPick: calcConnPick(),
+ reused: reusedContext && reusedConn,
+ scheduleLayer: scheduleLayer,
+ stickiness: stickiness,
+ migrationUsed: migrationUsed,
+ }, nil
+ }
+}
+
+func (p *openAIWSIngressContextPool) resolveStickinessLevelLocked(
+ ap *openAIWSIngressAccountPool,
+ sessionKey string,
+ req openAIWSIngressContextAcquireRequest,
+ now time.Time,
+) string {
+ if req.StrictAffinity {
+ return openAIWSIngressStickinessStrong
+ }
+
+ level := openAIWSIngressStickinessWeak
+ switch {
+ case req.HasPreviousResponseID:
+ level = openAIWSIngressStickinessStrong
+ case req.StoreDisabled || req.Turn > 1:
+ level = openAIWSIngressStickinessBalanced
+ }
+
+ if ap == nil {
+ return level
+ }
+ ctxID, ok := ap.bySession[sessionKey]
+ if !ok {
+ return level
+ }
+ existing := ap.contexts[ctxID]
+ if existing == nil {
+ return level
+ }
+
+ existing.mu.Lock()
+ broken := existing.broken
+ failureStreak := existing.failureStreak
+ lastFailureAt := existing.lastFailureAt
+ lastUsedAt := existing.lastUsedAt()
+ existing.mu.Unlock()
+
+ recentFailure := failureStreak > 0 && !lastFailureAt.IsZero() && now.Sub(lastFailureAt) <= 2*time.Minute
+ if broken || recentFailure {
+ return openAIWSIngressStickinessDowngrade(level)
+ }
+ if failureStreak == 0 && !lastUsedAt.IsZero() && now.Sub(lastUsedAt) <= 20*time.Second {
+ return openAIWSIngressStickinessUpgrade(level)
+ }
+ return level
+}
+
+func openAIWSIngressMigrationPolicyByStickiness(stickiness string) (bool, float64) {
+ switch stickiness {
+ case openAIWSIngressStickinessStrong:
+ return false, 80 // was 85; lowered to allow migration away from degraded accounts
+ case openAIWSIngressStickinessBalanced:
+ return true, 65 // was 68; lowered to allow more aggressive migration to healthier accounts
+ default:
+ return true, 40 // was 45; lowered for weak stickiness
+ }
+}
+
+func openAIWSIngressStickinessDowngrade(level string) string {
+ switch level {
+ case openAIWSIngressStickinessStrong:
+ return openAIWSIngressStickinessBalanced
+ case openAIWSIngressStickinessBalanced:
+ return openAIWSIngressStickinessWeak
+ default:
+ return openAIWSIngressStickinessWeak
+ }
+}
+
+func openAIWSIngressStickinessUpgrade(level string) string {
+ switch level {
+ case openAIWSIngressStickinessWeak:
+ return openAIWSIngressStickinessBalanced
+ case openAIWSIngressStickinessBalanced:
+ return openAIWSIngressStickinessStrong
+ default:
+ return openAIWSIngressStickinessStrong
+ }
+}
+
+func (p *openAIWSIngressContextPool) pickMigrationCandidateLocked(
+ ap *openAIWSIngressAccountPool,
+ minScore float64,
+ now time.Time,
+) *openAIWSIngressContext {
+ if ap == nil {
+ return nil
+ }
+ var (
+ selected *openAIWSIngressContext
+ selectedScore float64
+ selectedAt time.Time
+ hasSelected bool
+ )
+
+ for _, ctx := range ap.contexts {
+ if ctx == nil {
+ continue
+ }
+ score, lastUsed, ok := scoreOpenAIWSIngressMigrationCandidate(ctx, now, p.schedulerStats)
+ if !ok || score < minScore {
+ continue
+ }
+ if !hasSelected || score > selectedScore || (score == selectedScore && lastUsed.Before(selectedAt)) {
+ selected = ctx
+ selectedScore = score
+ selectedAt = lastUsed
+ hasSelected = true
+ }
+ }
+ return selected
+}
+
+func scoreOpenAIWSIngressMigrationCandidate(c *openAIWSIngressContext, now time.Time, stats *openAIAccountRuntimeStats) (float64, time.Time, bool) {
+ if c == nil {
+ return 0, time.Time{}, false
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if strings.TrimSpace(c.ownerID) != "" {
+ return 0, time.Time{}, false
+ }
+
+ score := 100.0
+ if c.broken {
+ score -= 30
+ }
+ if c.failureStreak > 0 {
+ score -= float64(minInt(c.failureStreak*12, 40))
+ }
+ if !c.lastFailureAt.IsZero() && now.Sub(c.lastFailureAt) <= 2*time.Minute {
+ score -= 18
+ }
+ if !c.lastMigrationAt.IsZero() && now.Sub(c.lastMigrationAt) <= time.Minute {
+ score -= 10
+ }
+ if c.migrationCount > 0 {
+ score -= float64(minInt(c.migrationCount*4, 20))
+ }
+
+ lastUsedAt := c.lastUsedAt()
+ idleDuration := now.Sub(lastUsedAt)
+ switch {
+ case idleDuration <= 15*time.Second:
+ score -= 15
+ case idleDuration >= 3*time.Minute:
+ score += 16
+ default:
+ score += idleDuration.Seconds() / 12.0
+ }
+
+ // Load-aware factors: penalize contexts bound to accounts that the
+ // scheduler has flagged as degraded or circuit-open. When stats is nil
+ // (e.g. during tests or before scheduler init), these adjustments are
+ // silently skipped so existing behaviour is preserved.
+ if stats != nil && c.accountID > 0 {
+ errorRate, _, _ := stats.snapshot(c.accountID)
+ // errorRate is in [0,1]; a fully-erroring account subtracts up to 30
+ // points, making it significantly harder for a migration to land on
+ // an unhealthy account.
+ score -= errorRate * 30
+
+ // Circuit-open accounts receive a harsh penalty (-50) that in
+ // practice drops the score below any reasonable minimum threshold,
+ // effectively blocking migration to that account.
+ if stats.isCircuitOpen(c.accountID) {
+ score -= 50
+ }
+ }
+
+ return score, lastUsedAt, true
+}
+
+func minInt(a, b int) int {
+ if a <= b {
+ return a
+ }
+ return b
+}
+
+func closeOpenAIWSClientConns(conns []openAIWSClientConn) {
+ for _, conn := range conns {
+ if conn != nil {
+ _ = conn.Close()
+ }
+ }
+}
+
+func (p *openAIWSIngressContextPool) ensureContextUpstream(
+ ctx context.Context,
+ c *openAIWSIngressContext,
+ req openAIWSIngressContextAcquireRequest,
+) (bool, error) {
+ if p == nil || c == nil {
+ return false, errOpenAIWSConnClosed
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ for {
+ c.mu.Lock()
+ if c.upstream != nil && !c.broken {
+ now := time.Now()
+ connAge := c.upstreamConnAge(now)
+ if p.upstreamMaxAge > 0 && connAge > 0 && connAge >= p.upstreamMaxAge {
+ // 主动轮换:关闭旧连接,不设 broken、不增 failureStreak
+ oldUpstream, oldConnID := c.upstream, c.upstreamConnID
+ c.upstream = nil
+ c.upstreamConnID = ""
+ c.upstreamConnCreatedAt.Store(0)
+ c.handshakeHeaders = nil
+ c.prewarmed.Store(false)
+ c.mu.Unlock()
+ _ = oldUpstream.Close()
+ logOpenAIWSModeInfo(
+ "ctx_pool_upstream_max_age_rotate account_id=%d ctx_id=%s conn_id=%s conn_age_min=%.1f max_age_min=%.1f",
+ c.accountID, c.id, oldConnID,
+ connAge.Minutes(), p.upstreamMaxAge.Minutes(),
+ )
+ continue // 回到 for 循环走 dialing 路径
+ }
+ c.touchLease(now, p.idleTTL)
+ c.mu.Unlock()
+ return true, nil
+ }
+ if c.dialing {
+ dialDone := c.dialDone
+ c.mu.Unlock()
+ if dialDone == nil {
+ if err := ctx.Err(); err != nil {
+ return false, err
+ }
+ continue
+ }
+ select {
+ case <-dialDone:
+ continue
+ case <-ctx.Done():
+ return false, ctx.Err()
+ }
+ }
+ oldUpstream := c.upstream
+ c.upstream = nil
+ c.upstreamConnCreatedAt.Store(0)
+ c.handshakeHeaders = nil
+ c.upstreamConnID = ""
+ c.prewarmed.Store(false)
+ c.broken = false
+ c.dialing = true
+ dialDone := make(chan struct{})
+ c.dialDone = dialDone
+ c.mu.Unlock()
+
+ if oldUpstream != nil {
+ _ = oldUpstream.Close()
+ }
+
+ dialer := p.dialer
+ if dialer == nil {
+ c.mu.Lock()
+ c.broken = true
+ c.failureStreak++
+ c.lastFailureAt = time.Now()
+ c.dialing = false
+ if c.dialDone == dialDone {
+ c.dialDone = nil
+ }
+ close(dialDone)
+ c.mu.Unlock()
+ return false, errors.New("openai ws ingress context dialer is nil")
+ }
+ conn, statusCode, handshakeHeaders, err := dialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL)
+ if err != nil {
+ wrappedErr := err
+ var dialErr *openAIWSDialError
+ if !errors.As(err, &dialErr) {
+ wrappedErr = &openAIWSDialError{
+ StatusCode: statusCode,
+ ResponseHeaders: cloneHeader(handshakeHeaders),
+ Err: err,
+ }
+ }
+ c.mu.Lock()
+ c.broken = true
+ c.failureStreak++
+ c.lastFailureAt = time.Now()
+ c.dialing = false
+ if c.dialDone == dialDone {
+ c.dialDone = nil
+ }
+ close(dialDone)
+ failureStreak := c.failureStreak
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_dial_fail account_id=%d ctx_id=%s status_code=%d failure_streak=%d cause=%s",
+ c.accountID, c.id, statusCode, failureStreak, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ )
+ return false, wrappedErr
+ }
+
+ c.mu.Lock()
+ now := time.Now()
+ c.upstream = conn
+ c.upstreamConnID = fmt.Sprintf("ctxws_%d_%d", c.accountID, p.seq.Add(1))
+ c.upstreamConnCreatedAt.Store(now.UnixNano())
+ c.handshakeHeaders = cloneHeader(handshakeHeaders)
+ c.prewarmed.Store(false)
+ c.touchLease(now, p.idleTTL)
+ c.broken = false
+ c.failureStreak = 0
+ c.lastFailureAt = time.Time{}
+ c.dialing = false
+ if c.dialDone == dialDone {
+ c.dialDone = nil
+ }
+ close(dialDone)
+ connID := c.upstreamConnID
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_dial_ok account_id=%d ctx_id=%s conn_id=%s",
+ c.accountID, c.id, connID,
+ )
+ return false, nil
+ }
+}
+
+func (p *openAIWSIngressContextPool) yieldContext(c *openAIWSIngressContext, ownerID string) {
+ p.releaseContextWithPolicy(c, ownerID, false)
+ // yield 后延迟 Ping,提前发现死连接
+ p.scheduleDelayedPing(c, openAIWSIngressDelayedPingAfterYield)
+}
+
+func (p *openAIWSIngressContextPool) releaseContext(c *openAIWSIngressContext, ownerID string) {
+ p.releaseContextWithPolicy(c, ownerID, true)
+}
+
+func (p *openAIWSIngressContextPool) releaseContextWithPolicy(
+ c *openAIWSIngressContext,
+ ownerID string,
+ closeUpstream bool,
+) {
+ if p == nil || c == nil {
+ return
+ }
+ var upstream openAIWSClientConn
+ c.mu.Lock()
+ if c.ownerID == ownerID {
+ if closeUpstream {
+ // 会话结束或链路损坏时销毁上游连接,避免污染后续请求。
+ upstream = c.upstream
+ c.upstream = nil
+ c.upstreamConnCreatedAt.Store(0)
+ c.handshakeHeaders = nil
+ c.upstreamConnID = ""
+ c.prewarmed.Store(false)
+ }
+ c.ownerID = ""
+ // 通知一个等待中的 Acquire 请求,避免 close 广播导致惊群。
+ if c.releaseDone != nil {
+ select {
+ case c.releaseDone <- struct{}{}:
+ default:
+ }
+ }
+ now := time.Now()
+ c.setLastUsedAt(now)
+ c.setExpiresAt(now.Add(p.idleTTL))
+ c.broken = false
+ }
+ c.mu.Unlock()
+ if upstream != nil {
+ _ = upstream.Close()
+ }
+}
+
+func (p *openAIWSIngressContextPool) markContextBroken(c *openAIWSIngressContext) {
+ if c == nil {
+ return
+ }
+ c.mu.Lock()
+ upstream := c.upstream
+ c.upstream = nil
+ c.upstreamConnCreatedAt.Store(0)
+ c.handshakeHeaders = nil
+ c.upstreamConnID = ""
+ c.prewarmed.Store(false)
+ c.broken = true
+ c.failureStreak++
+ c.lastFailureAt = time.Now()
+ // 注意:此处不发送 releaseDone 信号。ownerID 仍被占用,等待者被唤醒后
+ // 会发现 owner 未释放而重新阻塞,造成信号浪费。实际释放由 Release() 完成。
+ failureStreak := c.failureStreak
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_mark_broken account_id=%d ctx_id=%s failure_streak=%d",
+ c.accountID, c.id, failureStreak,
+ )
+ if upstream != nil {
+ _ = upstream.Close()
+ }
+}
+
+// markContextBrokenIfConnMatch 仅在连接代次(connID)匹配时标记 broken。
+// 后台 Ping 在解锁期间执行,期间连接可能已被重建为新连接;
+// 若 connID 已变则说明旧连接已被替换,放弃标记以避免误杀新连接。
+func (p *openAIWSIngressContextPool) markContextBrokenIfConnMatch(c *openAIWSIngressContext, expectedConnID string) {
+ if c == nil {
+ return
+ }
+ c.mu.Lock()
+ actualConnID := c.upstreamConnID
+ if actualConnID != expectedConnID {
+ // 连接已被重建(connID 变了),放弃标记
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_bg_ping_skip_stale account_id=%d ctx_id=%s expected_conn_id=%s actual_conn_id=%s",
+ c.accountID, c.id, expectedConnID, actualConnID,
+ )
+ return
+ }
+ ownerID := c.ownerID
+ dialing := c.dialing
+ if ownerID != "" || dialing {
+ // Ping 期间 context 可能被重新占用或进入建连流程,不应由后台探测路径误杀活跃连接。
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_bg_ping_skip_busy account_id=%d ctx_id=%s conn_id=%s owner_id=%s dialing=%v",
+ c.accountID,
+ c.id,
+ actualConnID,
+ truncateOpenAIWSLogValue(ownerID, openAIWSIDValueMaxLen),
+ dialing,
+ )
+ return
+ }
+ upstream := c.upstream
+ c.upstream = nil
+ c.upstreamConnCreatedAt.Store(0)
+ c.handshakeHeaders = nil
+ c.upstreamConnID = ""
+ c.prewarmed.Store(false)
+ c.broken = true
+ c.failureStreak++
+ c.lastFailureAt = time.Now()
+ failureStreak := c.failureStreak
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_mark_broken account_id=%d ctx_id=%s failure_streak=%d",
+ c.accountID, c.id, failureStreak,
+ )
+ if upstream != nil {
+ _ = upstream.Close()
+ }
+}
+
+func (p *openAIWSIngressContextPool) getOrCreateAccountPoolLocked(accountID int64) *openAIWSIngressAccountPool {
+ if ap, ok := p.accounts[accountID]; ok && ap != nil {
+ return ap
+ }
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ ap.dynamicCap.Store(1)
+ p.accounts[accountID] = ap
+ return ap
+}
+
+// effectiveDynamicCapacity 返回 min(dynamicCap, hardCap)。
+// dynamicCap 从 1 开始,按需增长,空闲时缩减;hardCap 由账户并发度和全局上限决定。
+func (p *openAIWSIngressContextPool) effectiveDynamicCapacity(ap *openAIWSIngressAccountPool, hardCap int) int {
+ if ap == nil || hardCap <= 0 {
+ return hardCap
+ }
+ dynCap := int(ap.dynamicCap.Load())
+ if dynCap < 1 {
+ dynCap = 1
+ ap.dynamicCap.Store(1)
+ }
+ if dynCap > hardCap {
+ return hardCap
+ }
+ return dynCap
+}
+
+func (p *openAIWSIngressContextPool) evictExpiredIdleLocked(
+ ap *openAIWSIngressAccountPool,
+ now time.Time,
+) []openAIWSClientConn {
+ if ap == nil {
+ return nil
+ }
+ var toClose []openAIWSClientConn
+ for id, ctx := range ap.contexts {
+ if ctx == nil {
+ delete(ap.contexts, id)
+ continue
+ }
+ ctx.mu.Lock()
+ expiresAt := ctx.expiresAt()
+ expired := ctx.ownerID == "" && !expiresAt.IsZero() && now.After(expiresAt)
+ upstream := ctx.upstream
+ if expired {
+ ctx.upstream = nil
+ ctx.upstreamConnCreatedAt.Store(0)
+ ctx.handshakeHeaders = nil
+ ctx.upstreamConnID = ""
+ }
+ ctx.mu.Unlock()
+ if !expired {
+ continue
+ }
+ delete(ap.contexts, id)
+ if mappedID, ok := ap.bySession[ctx.sessionKey]; ok && mappedID == id {
+ delete(ap.bySession, ctx.sessionKey)
+ }
+ if upstream != nil {
+ toClose = append(toClose, upstream)
+ }
+ }
+ return toClose
+}
+
+func (p *openAIWSIngressContextPool) pickOldestIdleContextLocked(ap *openAIWSIngressAccountPool) *openAIWSIngressContext {
+ if ap == nil {
+ return nil
+ }
+ var (
+ selected *openAIWSIngressContext
+ selectedAt time.Time
+ )
+ for _, ctx := range ap.contexts {
+ if ctx == nil {
+ continue
+ }
+ ctx.mu.Lock()
+ idle := strings.TrimSpace(ctx.ownerID) == ""
+ lastUsed := ctx.lastUsedAt()
+ ctx.mu.Unlock()
+ if !idle {
+ continue
+ }
+ if selected == nil || lastUsed.Before(selectedAt) {
+ selected = ctx
+ selectedAt = lastUsed
+ }
+ }
+ return selected
+}
+
+// closeAgedIdleUpstreamsLocked 关闭空闲且超龄的上游连接。
+// 只清理 upstream,保留 context 槽位(不删 context、不清 bySession)。
+// 不设 broken、不增 failureStreak。
+// 调用方必须持有 ap.mu。
+func (p *openAIWSIngressContextPool) closeAgedIdleUpstreamsLocked(
+ ap *openAIWSIngressAccountPool,
+ now time.Time,
+) []openAIWSClientConn {
+ if ap == nil || p.upstreamMaxAge <= 0 {
+ return nil
+ }
+ var toClose []openAIWSClientConn
+ for _, ctx := range ap.contexts {
+ if ctx == nil {
+ continue
+ }
+ ctx.mu.Lock()
+ idle := ctx.ownerID == ""
+ hasUpstream := ctx.upstream != nil
+ connAge := ctx.upstreamConnAge(now)
+ aged := connAge > 0 && connAge >= p.upstreamMaxAge
+ if idle && hasUpstream && aged {
+ toClose = append(toClose, ctx.upstream)
+ ctx.upstream = nil
+ ctx.upstreamConnCreatedAt.Store(0)
+ ctx.upstreamConnID = ""
+ ctx.handshakeHeaders = nil
+ ctx.prewarmed.Store(false)
+ }
+ ctx.mu.Unlock()
+ }
+ return toClose
+}
+
+// pingContextUpstream 对空闲 context 的上游连接发送 Ping 探测。
+// 若 Ping 失败则标记 context 为 broken,让后续 Acquire 走重建路径。
+// 调用方不需要持有任何锁。
+//
+// 使用 connID 代次守卫:Ping 期间连接可能被重建,仅当 connID 未变时才标记 broken,
+// 避免旧连接 Ping 失败误杀新连接。
+func (p *openAIWSIngressContextPool) pingContextUpstream(c *openAIWSIngressContext) {
+ if p == nil || c == nil {
+ return
+ }
+ c.mu.Lock()
+ idle := c.ownerID == ""
+ hasUpstream := c.upstream != nil
+ broken := c.broken
+ dialing := c.dialing
+ upstream := c.upstream
+ connID := c.upstreamConnID // 快照连接代次
+ c.mu.Unlock()
+ if !idle || !hasUpstream || broken || dialing || upstream == nil {
+ return
+ }
+
+ pingCtx, cancel := context.WithTimeout(context.Background(), openAIWSIngressPingTimeout)
+ defer cancel()
+ if err := upstream.Ping(pingCtx); err != nil {
+ p.markContextBrokenIfConnMatch(c, connID)
+ logOpenAIWSModeInfo(
+ "ctx_pool_bg_ping_fail account_id=%d ctx_id=%s conn_id=%s cause=%s",
+ c.accountID, c.id, connID, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+}
+
+// pingIdleUpstreams 对账户池内所有空闲且有上游连接的 context 发起 Ping 探测。
+// 先在锁内收集候选列表,再在锁外逐个 Ping,避免阻塞其他操作。
+func (p *openAIWSIngressContextPool) pingIdleUpstreams(ap *openAIWSIngressAccountPool) {
+ if p == nil || ap == nil {
+ return
+ }
+ ap.mu.Lock()
+ targets := make([]*openAIWSIngressContext, 0, len(ap.contexts))
+ for _, ctx := range ap.contexts {
+ if ctx == nil {
+ continue
+ }
+ ctx.mu.Lock()
+ eligible := ctx.ownerID == "" && ctx.upstream != nil && !ctx.broken && !ctx.dialing
+ ctx.mu.Unlock()
+ if eligible {
+ targets = append(targets, ctx)
+ }
+ }
+ ap.mu.Unlock()
+
+ for _, ctx := range targets {
+ p.pingContextUpstream(ctx)
+ }
+}
+
+// scheduleDelayedPing 在 yield 后延迟一段时间对 context 发送 Ping 探测。
+// 通过 pendingPingTimer 去重:同一 context 同时只保留一个延迟 Ping,
+// 连续 yield 只 Reset timer 而不创建新 goroutine,避免高并发下 goroutine 堆积。
+func (p *openAIWSIngressContextPool) scheduleDelayedPing(c *openAIWSIngressContext, delay time.Duration) {
+ if p == nil || c == nil || delay <= 0 {
+ return
+ }
+ c.mu.Lock()
+ if c.pendingPingTimer != nil {
+ // 已有 pending ping,只需 Reset timer 延迟窗口
+ c.pendingPingTimer.Reset(delay)
+ c.mu.Unlock()
+ return
+ }
+ timer := time.NewTimer(delay)
+ c.pendingPingTimer = timer
+ c.mu.Unlock()
+
+ go func() {
+ select {
+ case <-p.stopCh:
+ timer.Stop()
+ case <-timer.C:
+ p.pingContextUpstream(c)
+ }
+ c.mu.Lock()
+ if c.pendingPingTimer == timer {
+ c.pendingPingTimer = nil
+ }
+ c.mu.Unlock()
+ }()
+}
+
+func (p *openAIWSIngressContextPool) sweepExpiredIdleContexts() {
+ if p == nil {
+ return
+ }
+ now := time.Now()
+
+ type accountSnapshot struct {
+ accountID int64
+ ap *openAIWSIngressAccountPool
+ }
+
+ snapshots := make([]accountSnapshot, 0, len(p.accounts))
+ p.mu.Lock()
+ for accountID, ap := range p.accounts {
+ if ap == nil {
+ delete(p.accounts, accountID)
+ continue
+ }
+ snapshots = append(snapshots, accountSnapshot{accountID: accountID, ap: ap})
+ }
+ p.mu.Unlock()
+
+ removable := make([]accountSnapshot, 0)
+ for _, item := range snapshots {
+ ap := item.ap
+ ap.mu.Lock()
+ toClose := p.evictExpiredIdleLocked(ap, now)
+ agedClose := p.closeAgedIdleUpstreamsLocked(ap, now)
+ empty := len(ap.contexts) == 0
+ // 动态缩容:将 dynamicCap 收缩到 max(1, 当前 context 数量)
+ shrinkTarget := int32(len(ap.contexts))
+ if shrinkTarget < 1 {
+ shrinkTarget = 1
+ }
+ if ap.dynamicCap.Load() > shrinkTarget {
+ ap.dynamicCap.Store(shrinkTarget)
+ }
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(toClose)
+ closeOpenAIWSClientConns(agedClose)
+ // 后台 Ping 探测:对剩余空闲连接发送 Ping,及时剔除死连接
+ if !empty {
+ p.pingIdleUpstreams(ap)
+ }
+ if empty && ap.refs.Load() == 0 {
+ removable = append(removable, item)
+ }
+ }
+
+ if len(removable) == 0 {
+ return
+ }
+
+ p.mu.Lock()
+ for _, item := range removable {
+ existing := p.accounts[item.accountID]
+ if existing != item.ap || existing == nil {
+ continue
+ }
+ if existing.refs.Load() != 0 {
+ continue
+ }
+ delete(p.accounts, item.accountID)
+ }
+ p.mu.Unlock()
+}
+
+func openAIWSIngressContextSessionKey(groupID int64, sessionHash string) string {
+ hash := strings.TrimSpace(sessionHash)
+ if hash == "" {
+ return ""
+ }
+ return strconv.FormatInt(groupID, 10) + ":" + hash
+}
+
+func (l *openAIWSIngressContextLease) ConnID() string {
+ if l == nil || l.context == nil {
+ return ""
+ }
+ l.context.mu.Lock()
+ defer l.context.mu.Unlock()
+ return strings.TrimSpace(l.context.upstreamConnID)
+}
+
+func (l *openAIWSIngressContextLease) QueueWaitDuration() time.Duration {
+ if l == nil {
+ return 0
+ }
+ return l.queueWait
+}
+
+func (l *openAIWSIngressContextLease) ConnPickDuration() time.Duration {
+ if l == nil {
+ return 0
+ }
+ return l.connPick
+}
+
+func (l *openAIWSIngressContextLease) Reused() bool {
+ if l == nil {
+ return false
+ }
+ return l.reused
+}
+
+func (l *openAIWSIngressContextLease) ScheduleLayer() string {
+ if l == nil {
+ return ""
+ }
+ return strings.TrimSpace(l.scheduleLayer)
+}
+
+func (l *openAIWSIngressContextLease) StickinessLevel() string {
+ if l == nil {
+ return ""
+ }
+ return strings.TrimSpace(l.stickiness)
+}
+
+func (l *openAIWSIngressContextLease) MigrationUsed() bool {
+ if l == nil {
+ return false
+ }
+ return l.migrationUsed
+}
+
+func (l *openAIWSIngressContextLease) HandshakeHeader(name string) string {
+ if l == nil || l.context == nil {
+ return ""
+ }
+ l.context.mu.Lock()
+ defer l.context.mu.Unlock()
+ if l.context.handshakeHeaders == nil {
+ return ""
+ }
+ return strings.TrimSpace(l.context.handshakeHeaders.Get(strings.TrimSpace(name)))
+}
+
+func (l *openAIWSIngressContextLease) IsPrewarmed() bool {
+ if l == nil || l.context == nil {
+ return false
+ }
+ return l.context.prewarmed.Load()
+}
+
+func (l *openAIWSIngressContextLease) MarkPrewarmed() {
+ if l == nil || l.context == nil {
+ return
+ }
+ l.context.prewarmed.Store(true)
+}
+
+func (l *openAIWSIngressContextLease) activeConn() (openAIWSClientConn, error) {
+ if l == nil || l.context == nil || l.released.Load() {
+ return nil, errOpenAIWSConnClosed
+ }
+ // Fast path: return cached conn without mutex if lease is still valid.
+ l.cachedConnMu.RLock()
+ cc := l.cachedConn
+ l.cachedConnMu.RUnlock()
+ if cc != nil {
+ return cc, nil
+ }
+ // Slow path: acquire mutex, validate ownership, cache result.
+ l.context.mu.Lock()
+ defer l.context.mu.Unlock()
+ if l.context.ownerID != l.ownerID {
+ return nil, errOpenAIWSConnClosed
+ }
+ if l.context.upstream == nil {
+ return nil, errOpenAIWSConnClosed
+ }
+ l.cachedConnMu.Lock()
+ l.cachedConn = l.context.upstream
+ l.cachedConnMu.Unlock()
+ return l.context.upstream, nil
+}
+
+func (l *openAIWSIngressContextLease) invalidateCachedConnOnIOError(err error) {
+ if l == nil || err == nil {
+ return
+ }
+ l.cachedConnMu.Lock()
+ l.cachedConn = nil
+ l.cachedConnMu.Unlock()
+ if l.pool != nil && l.context != nil && isOpenAIWSClientDisconnectError(err) {
+ l.pool.markContextBroken(l.context)
+ }
+}
+
+func (l *openAIWSIngressContextLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error {
+ conn, err := l.activeConn()
+ if err != nil {
+ return err
+ }
+ writeCtx := ctx
+ if writeCtx == nil {
+ writeCtx = context.Background()
+ }
+ if timeout > 0 {
+ var cancel context.CancelFunc
+ writeCtx, cancel = context.WithTimeout(writeCtx, timeout)
+ defer cancel()
+ }
+ if err := conn.WriteJSON(writeCtx, value); err != nil {
+ l.invalidateCachedConnOnIOError(err)
+ return err
+ }
+ l.context.maybeTouchLease(l.pool.idleTTL)
+ return nil
+}
+
+func (l *openAIWSIngressContextLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) {
+ conn, err := l.activeConn()
+ if err != nil {
+ return nil, err
+ }
+ readCtx := ctx
+ if readCtx == nil {
+ readCtx = context.Background()
+ }
+ if timeout > 0 {
+ var cancel context.CancelFunc
+ readCtx, cancel = context.WithTimeout(readCtx, timeout)
+ defer cancel()
+ }
+ payload, err := conn.ReadMessage(readCtx)
+ if err != nil {
+ l.invalidateCachedConnOnIOError(err)
+ return nil, err
+ }
+ l.context.maybeTouchLease(l.pool.idleTTL)
+ return payload, nil
+}
+
+func (l *openAIWSIngressContextLease) PingWithTimeout(timeout time.Duration) error {
+ conn, err := l.activeConn()
+ if err != nil {
+ return err
+ }
+ pingTimeout := timeout
+ if pingTimeout <= 0 {
+ pingTimeout = openAIWSConnHealthCheckTO
+ }
+ pingCtx, cancel := context.WithTimeout(context.Background(), pingTimeout)
+ defer cancel()
+ if err := conn.Ping(pingCtx); err != nil {
+ l.invalidateCachedConnOnIOError(err)
+ return err
+ }
+ l.context.maybeTouchLease(l.pool.idleTTL)
+ return nil
+}
+
+func (l *openAIWSIngressContextLease) MarkBroken() {
+ if l == nil || l.pool == nil || l.context == nil || l.released.Load() {
+ return
+ }
+ l.cachedConnMu.Lock()
+ l.cachedConn = nil
+ l.cachedConnMu.Unlock()
+ l.pool.markContextBroken(l.context)
+}
+
+func (l *openAIWSIngressContextLease) Release() {
+ if l == nil || l.context == nil || l.pool == nil {
+ return
+ }
+ if !l.released.CompareAndSwap(false, true) {
+ return
+ }
+ l.cachedConnMu.Lock()
+ l.cachedConn = nil
+ l.cachedConnMu.Unlock()
+ l.pool.releaseContext(l.context, l.ownerID)
+}
+
+func (l *openAIWSIngressContextLease) Yield() {
+ if l == nil || l.context == nil || l.pool == nil {
+ return
+ }
+ if !l.released.CompareAndSwap(false, true) {
+ return
+ }
+ l.cachedConnMu.Lock()
+ l.cachedConn = nil
+ l.cachedConnMu.Unlock()
+ l.pool.yieldContext(l.context, l.ownerID)
+}
diff --git a/backend/internal/service/openai_ws_ingress_context_pool_test.go b/backend/internal/service/openai_ws_ingress_context_pool_test.go
new file mode 100644
index 000000000..f5a6802c5
--- /dev/null
+++ b/backend/internal/service/openai_ws_ingress_context_pool_test.go
@@ -0,0 +1,2496 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "net"
+ "net/http"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIWSIngressContextPool_Acquire_HardCapacityEqualsConcurrency(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 801, Concurrency: 1}
+
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 2,
+ SessionHash: "session_a",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+
+ _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 2,
+ SessionHash: "session_b",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "并发=1 时第二个并发 owner 不应获取到 context")
+
+ lease1.Release()
+
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 2,
+ SessionHash: "session_b",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, openAIWSIngressScheduleLayerMigration, lease2.ScheduleLayer())
+ require.Equal(t, openAIWSIngressStickinessWeak, lease2.StickinessLevel())
+ require.True(t, lease2.MigrationUsed())
+ lease2.Release()
+
+ require.Equal(t, 2, dialer.DialCount(), "会话回收复用 context 后应重建上游连接,避免跨会话污染")
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_RespectsGlobalHardCap(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 802, Concurrency: 10}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 3,
+ SessionHash: "session_a",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ HasPreviousResponseID: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 3,
+ SessionHash: "session_b",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ HasPreviousResponseID: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+
+ _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 3,
+ SessionHash: "session_c",
+ OwnerID: "owner_c",
+ WSURL: "ws://test-upstream",
+ HasPreviousResponseID: true,
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "账号并发高于全局硬上限时,context pool 仍应被硬上限约束")
+
+ lease1.Release()
+ lease2.Release()
+ require.Equal(t, 2, dialer.DialCount())
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_DoesNotCrossAccount(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ accountA := &Account{ID: 901, Concurrency: 1}
+ accountB := &Account{ID: 902, Concurrency: 1}
+
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ leaseA, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: accountA,
+ GroupID: 5,
+ SessionHash: "same_session_hash",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, leaseA)
+ leaseA.Release()
+
+ leaseB, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: accountB,
+ GroupID: 5,
+ SessionHash: "same_session_hash",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, leaseB)
+ leaseB.Release()
+
+ require.Equal(t, 2, dialer.DialCount(), "相同 session_hash 在不同账号下必须使用不同 context,不允许跨账号复用")
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_StrongStickinessDisablesMigration(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 1001, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 9,
+ SessionHash: "session_keep_strong_a",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Release()
+
+ _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 9,
+ SessionHash: "session_keep_strong_b",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ HasPreviousResponseID: true,
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "strong 粘连不应迁移其它 session 的 context")
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_AdaptiveStickinessDowngradesAfterFailure(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 1002, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 11,
+ SessionHash: "session_adaptive",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.MarkBroken()
+ lease1.Release()
+
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 11,
+ SessionHash: "session_adaptive",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ HasPreviousResponseID: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, openAIWSIngressScheduleLayerExact, lease2.ScheduleLayer())
+ require.Equal(t, openAIWSIngressStickinessBalanced, lease2.StickinessLevel(), "失败后应从 strong 自适应降级到 balanced")
+ lease2.Release()
+ require.Equal(t, 2, dialer.DialCount(), "故障后重连同一 context 应重新建立上游连接")
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_EnsureFailureReleasesOwner(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ initialDialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(initialDialer)
+
+ account := &Account{ID: 1101, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 12,
+ SessionHash: "session_owner_release",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Release()
+
+ failDialer := &openAIWSAlwaysFailDialer{}
+ pool.setClientDialerForTest(failDialer)
+ _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 12,
+ SessionHash: "session_owner_release",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.Error(t, err)
+ require.NotErrorIs(t, err, errOpenAIWSIngressContextBusy, "ensure 上游失败后不应遗留 owner 导致 context 长时间 busy")
+
+ recoverDialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(recoverDialer)
+
+ lease3, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 12,
+ SessionHash: "session_owner_release",
+ OwnerID: "owner_c",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err, "owner 回滚后应允许后续会话重新获取同一 context")
+ require.NotNil(t, lease3)
+ lease3.Release()
+ require.Equal(t, 1, failDialer.DialCount())
+ require.Equal(t, 1, recoverDialer.DialCount())
+}
+
+func TestOpenAIWSIngressContextPool_Release_ClosesUpstreamAndForcesRedial(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstreamConn1 := &openAIWSCaptureConn{}
+ upstreamConn2 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ upstreamConn1,
+ upstreamConn2,
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 1102, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 13,
+ SessionHash: "session_same",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ connID1 := lease1.ConnID()
+ require.NotEmpty(t, connID1)
+ lease1.Release()
+
+ upstreamConn1.mu.Lock()
+ closed1 := upstreamConn1.closed
+ upstreamConn1.mu.Unlock()
+ require.True(t, closed1, "客户端会话结束后应关闭对应上游连接,防止复用污染")
+
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 13,
+ SessionHash: "session_same",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ connID2 := lease2.ConnID()
+ require.NotEmpty(t, connID2)
+ require.NotEqual(t, connID1, connID2, "下一次会话必须重新建立上游连接")
+ lease2.Release()
+
+ upstreamConn2.mu.Lock()
+ closed2 := upstreamConn2.closed
+ upstreamConn2.mu.Unlock()
+ require.True(t, closed2)
+ require.Equal(t, 2, dialer.DialCount())
+}
+
+func TestOpenAIWSIngressContextPool_Yield_ReleasesOwnerKeepsUpstream(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstreamConn := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{upstreamConn},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 1103, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 14,
+ SessionHash: "session_yield",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ connID1 := lease1.ConnID()
+ require.NotEmpty(t, connID1)
+
+ lease1.Yield()
+ upstreamConn.mu.Lock()
+ closedAfterYield := upstreamConn.closed
+ upstreamConn.mu.Unlock()
+ require.False(t, closedAfterYield, "yield 只应释放 owner,不应关闭上游连接")
+
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 14,
+ SessionHash: "session_yield",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, connID1, lease2.ConnID(), "yield 后应复用同一上游连接")
+ require.Equal(t, 1, dialer.DialCount(), "yield 后重新获取不应触发重拨号")
+
+ lease2.Release()
+ upstreamConn.mu.Lock()
+ closedAfterRelease := upstreamConn.closed
+ upstreamConn.mu.Unlock()
+ require.True(t, closedAfterRelease, "release 仍需关闭上游连接")
+}
+
+func TestOpenAIWSIngressContextPool_EvictExpiredIdleLocked_ClosesUpstream(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstreamConn := &openAIWSCaptureConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ expiredCtx := &openAIWSIngressContext{
+ id: "ctx_expired_1",
+ groupID: 21,
+ accountID: 1201,
+ sessionHash: "session_expired",
+ sessionKey: openAIWSIngressContextSessionKey(21, "session_expired"),
+ upstream: upstreamConn,
+ upstreamConnID: "ctxws_1201_1",
+ handshakeHeaders: map[string][]string{"x-test": {"ok"}},
+ }
+ expiredCtx.setExpiresAt(time.Now().Add(-2 * time.Second))
+ ap.contexts[expiredCtx.id] = expiredCtx
+ ap.bySession[expiredCtx.sessionKey] = expiredCtx.id
+
+ ap.mu.Lock()
+ toClose := pool.evictExpiredIdleLocked(ap, time.Now())
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(toClose)
+
+ require.Empty(t, ap.contexts, "过期 idle context 应被清理")
+ require.Empty(t, ap.bySession, "过期 context 的 session 索引应同步清理")
+ upstreamConn.mu.Lock()
+ closed := upstreamConn.closed
+ upstreamConn.mu.Unlock()
+ require.True(t, closed, "清理过期 context 时应关闭残留上游连接,避免泄漏")
+}
+
+func TestOpenAIWSIngressContextPool_ScoreAndStickinessHelpers(t *testing.T) {
+ now := time.Now()
+
+ require.Equal(t, 1, minInt(1, 2))
+ require.Equal(t, 2, minInt(3, 2))
+
+ require.Equal(t, openAIWSIngressStickinessBalanced, openAIWSIngressStickinessDowngrade(openAIWSIngressStickinessStrong))
+ require.Equal(t, openAIWSIngressStickinessWeak, openAIWSIngressStickinessDowngrade(openAIWSIngressStickinessBalanced))
+ require.Equal(t, openAIWSIngressStickinessWeak, openAIWSIngressStickinessDowngrade("unknown"))
+
+ require.Equal(t, openAIWSIngressStickinessBalanced, openAIWSIngressStickinessUpgrade(openAIWSIngressStickinessWeak))
+ require.Equal(t, openAIWSIngressStickinessStrong, openAIWSIngressStickinessUpgrade(openAIWSIngressStickinessBalanced))
+ require.Equal(t, openAIWSIngressStickinessStrong, openAIWSIngressStickinessUpgrade("unknown"))
+
+ allowStrong, scoreStrong := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessStrong)
+ require.False(t, allowStrong)
+ require.Equal(t, 80.0, scoreStrong)
+ allowBalanced, scoreBalanced := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessBalanced)
+ require.True(t, allowBalanced)
+ require.Equal(t, 65.0, scoreBalanced)
+ allowWeak, scoreWeak := openAIWSIngressMigrationPolicyByStickiness("weak_or_unknown")
+ require.True(t, allowWeak)
+ require.Equal(t, 40.0, scoreWeak)
+
+ busyCtx := &openAIWSIngressContext{ownerID: "owner_busy"}
+ _, _, ok := scoreOpenAIWSIngressMigrationCandidate(busyCtx, now, nil)
+ require.False(t, ok, "owner 占用中的 context 不应作为迁移候选")
+
+ oldIdle := &openAIWSIngressContext{}
+ oldIdle.setLastUsedAt(now.Add(-5 * time.Minute))
+ recentIdle := &openAIWSIngressContext{}
+ recentIdle.setLastUsedAt(now.Add(-10 * time.Second))
+ scoreOld, _, okOld := scoreOpenAIWSIngressMigrationCandidate(oldIdle, now, nil)
+ scoreRecent, _, okRecent := scoreOpenAIWSIngressMigrationCandidate(recentIdle, now, nil)
+ require.True(t, okOld)
+ require.True(t, okRecent)
+ require.Greater(t, scoreOld, scoreRecent, "更久未使用的空闲 context 应该更易被迁移")
+
+ penalized := &openAIWSIngressContext{
+ broken: true,
+ failureStreak: 2,
+ lastFailureAt: now.Add(-30 * time.Second),
+ migrationCount: 2,
+ lastMigrationAt: now.Add(-10 * time.Second),
+ }
+ penalized.setLastUsedAt(now.Add(-5 * time.Minute))
+ scorePenalized, _, okPenalized := scoreOpenAIWSIngressMigrationCandidate(penalized, now, nil)
+ require.True(t, okPenalized)
+ require.Less(t, scorePenalized, scoreOld, "近期失败和频繁迁移应降低迁移分数")
+}
+
+func TestOpenAIWSIngressContextPool_EvictPickAndSweep(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ now := time.Now()
+ expiredConn := &openAIWSCaptureConn{}
+ expiredCtx := &openAIWSIngressContext{
+ id: "ctx_expired",
+ sessionKey: "1:expired",
+ upstream: expiredConn,
+ upstreamConnID: "ctxws_expired",
+ }
+ expiredCtx.setLastUsedAt(now.Add(-20 * time.Minute))
+ expiredCtx.setExpiresAt(now.Add(-time.Minute))
+
+ idleNewCtx := &openAIWSIngressContext{
+ id: "ctx_idle_new",
+ sessionKey: "1:idle_new",
+ }
+ idleNewCtx.setLastUsedAt(now.Add(-30 * time.Second))
+ idleNewCtx.setExpiresAt(now.Add(time.Minute))
+
+ busyCtx := &openAIWSIngressContext{
+ id: "ctx_busy",
+ sessionKey: "1:busy",
+ ownerID: "active_owner",
+ }
+ busyCtx.setLastUsedAt(now.Add(-40 * time.Minute))
+ busyCtx.setExpiresAt(now.Add(-time.Minute))
+
+ ap := &openAIWSIngressAccountPool{
+ contexts: map[string]*openAIWSIngressContext{
+ "ctx_expired": expiredCtx,
+ "ctx_idle_new": idleNewCtx,
+ "ctx_busy": busyCtx,
+ },
+ bySession: map[string]string{
+ "1:expired": "ctx_expired",
+ "1:idle_new": "ctx_idle_new",
+ "1:busy": "ctx_busy",
+ },
+ }
+
+ ap.mu.Lock()
+ oldestIdle := pool.pickOldestIdleContextLocked(ap)
+ ap.mu.Unlock()
+ require.NotNil(t, oldestIdle)
+ require.Equal(t, "ctx_expired", oldestIdle.id, "应选择最旧的空闲 context")
+
+ ap.mu.Lock()
+ toClose := pool.evictExpiredIdleLocked(ap, now)
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(toClose)
+ require.NotContains(t, ap.contexts, "ctx_expired")
+ require.NotContains(t, ap.bySession, "1:expired")
+ require.Contains(t, ap.contexts, "ctx_idle_new", "未过期空闲 context 应保留")
+ require.Contains(t, ap.contexts, "ctx_busy", "有 owner 的 context 不应被 idle 过期清理")
+ expiredConn.mu.Lock()
+ expiredClosed := expiredConn.closed
+ expiredConn.mu.Unlock()
+ require.True(t, expiredClosed, "清理过期 idle context 时应关闭上游连接")
+
+ expiredInPoolConn := &openAIWSCaptureConn{}
+ pool.mu.Lock()
+ pool.accounts[5001] = ap
+ poolExpiredCtx := &openAIWSIngressContext{
+ id: "ctx_pool_expired",
+ sessionKey: "2:expired",
+ upstream: expiredInPoolConn,
+ }
+ poolExpiredCtx.setExpiresAt(now.Add(-time.Minute))
+ pool.accounts[5002] = &openAIWSIngressAccountPool{
+ contexts: map[string]*openAIWSIngressContext{
+ "ctx_pool_expired": poolExpiredCtx,
+ },
+ bySession: map[string]string{
+ "2:expired": "ctx_pool_expired",
+ },
+ }
+ pool.mu.Unlock()
+
+ pool.sweepExpiredIdleContexts()
+
+ pool.mu.Lock()
+ _, account2Exists := pool.accounts[5002]
+ account1 := pool.accounts[5001]
+ pool.mu.Unlock()
+ require.False(t, account2Exists, "sweep 后空账号应被移除")
+ require.NotNil(t, account1, "非空账号应保留")
+ expiredInPoolConn.mu.Lock()
+ sweptClosed := expiredInPoolConn.closed
+ expiredInPoolConn.mu.Unlock()
+ require.True(t, sweptClosed)
+}
+
+func TestOpenAIWSIngressContextLease_AccessorsAndPingGuards(t *testing.T) {
+ var nilLease *openAIWSIngressContextLease
+ require.Equal(t, "", nilLease.ConnID())
+ require.Zero(t, nilLease.QueueWaitDuration())
+ require.Zero(t, nilLease.ConnPickDuration())
+ require.False(t, nilLease.Reused())
+ require.Equal(t, "", nilLease.ScheduleLayer())
+ require.Equal(t, "", nilLease.StickinessLevel())
+ require.False(t, nilLease.MigrationUsed())
+ require.Equal(t, "", nilLease.HandshakeHeader("x-test"))
+ require.ErrorIs(t, nilLease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed)
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ ctxItem := &openAIWSIngressContext{
+ id: "ctx_lease",
+ ownerID: "owner_ok",
+ upstream: &openAIWSFakeConn{},
+ handshakeHeaders: http.Header{"X-Test": []string{"ok"}},
+ }
+ lease := &openAIWSIngressContextLease{
+ pool: pool,
+ context: ctxItem,
+ ownerID: "owner_ok",
+ queueWait: 5 * time.Millisecond,
+ connPick: 8 * time.Millisecond,
+ reused: true,
+ scheduleLayer: openAIWSIngressScheduleLayerExact,
+ stickiness: openAIWSIngressStickinessBalanced,
+ migrationUsed: true,
+ }
+
+ require.Equal(t, "ok", lease.HandshakeHeader("x-test"))
+ require.Equal(t, 5*time.Millisecond, lease.QueueWaitDuration())
+ require.Equal(t, 8*time.Millisecond, lease.ConnPickDuration())
+ require.True(t, lease.Reused())
+ require.Equal(t, openAIWSIngressScheduleLayerExact, lease.ScheduleLayer())
+ require.Equal(t, openAIWSIngressStickinessBalanced, lease.StickinessLevel())
+ require.True(t, lease.MigrationUsed())
+ require.NoError(t, lease.PingWithTimeout(0), "timeout=0 应回退默认 ping 超时")
+
+ lease.released.Store(true)
+ require.ErrorIs(t, lease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed)
+ lease.released.Store(false)
+
+ ctxItem.mu.Lock()
+ ctxItem.ownerID = "other_owner"
+ ctxItem.mu.Unlock()
+ lease.cachedConn = nil // clear cache to force re-validation (simulates migration)
+ require.ErrorIs(t, lease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed, "owner 不匹配时应拒绝访问")
+
+ ctxItem.mu.Lock()
+ ctxItem.ownerID = "owner_ok"
+ ctxItem.upstream = &openAIWSPingFailConn{}
+ ctxItem.mu.Unlock()
+ lease.cachedConn = nil // clear cache to pick up new upstream
+ require.Error(t, lease.PingWithTimeout(time.Millisecond), "上游 ping 失败应透传错误")
+
+ lease.Release()
+ lease.Release()
+ require.Equal(t, "", lease.context.ownerID, "重复 Release 应幂等且不会 panic")
+}
+
+func TestOpenAIWSIngressContextPool_EnsureContextUpstreamBranches(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ ctxItem := &openAIWSIngressContext{
+ id: "ctx_ensure",
+ accountID: 1,
+ ownerID: "owner",
+ upstream: &openAIWSFakeConn{},
+ }
+
+ reused, err := pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{
+ WSURL: "ws://test",
+ })
+ require.NoError(t, err)
+ require.True(t, reused, "已有可用 upstream 时应直接复用")
+
+ pool.dialer = nil
+ ctxItem.mu.Lock()
+ ctxItem.broken = true
+ ctxItem.mu.Unlock()
+ _, err = pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{
+ WSURL: "ws://test",
+ })
+ require.ErrorContains(t, err, "dialer is nil")
+
+ failDialer := &openAIWSAlwaysFailDialer{}
+ pool.setClientDialerForTest(failDialer)
+ _, err = pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{
+ WSURL: "ws://test",
+ })
+ require.Error(t, err)
+ var dialErr *openAIWSDialError
+ require.ErrorAs(t, err, &dialErr, "dial 失败应包装为 openAIWSDialError")
+ require.Equal(t, 503, dialErr.StatusCode)
+ ctxItem.mu.Lock()
+ broken := ctxItem.broken
+ failureStreak := ctxItem.failureStreak
+ ctxItem.mu.Unlock()
+ require.True(t, broken)
+ require.GreaterOrEqual(t, failureStreak, 1, "dial 失败后应累计 failure_streak")
+}
+
+func TestOpenAIWSIngressContextPool_MarkBrokenDoesNotSignalWaiterBeforeRelease(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstream := &openAIWSCaptureConn{}
+ ctxItem := &openAIWSIngressContext{
+ id: "ctx_mark_broken",
+ ownerID: "owner_broken",
+ releaseDone: make(chan struct{}, 1),
+ upstream: upstream,
+ }
+
+ pool.markContextBroken(ctxItem)
+
+ select {
+ case <-ctxItem.releaseDone:
+ t.Fatal("markContextBroken should not wake waiters before owner is released")
+ default:
+ }
+
+ ctxItem.mu.Lock()
+ require.True(t, ctxItem.broken)
+ require.Equal(t, "owner_broken", ctxItem.ownerID)
+ require.Nil(t, ctxItem.upstream)
+ ctxItem.mu.Unlock()
+
+ upstream.mu.Lock()
+ require.True(t, upstream.closed, "mark broken should close current upstream connection")
+ upstream.mu.Unlock()
+
+ pool.releaseContext(ctxItem, "owner_broken")
+
+ select {
+ case <-ctxItem.releaseDone:
+ case <-time.After(200 * time.Millisecond):
+ t.Fatal("releaseContext should signal one waiting acquire after owner is released")
+ }
+
+ ctxItem.mu.Lock()
+ require.Equal(t, "", ctxItem.ownerID)
+ require.False(t, ctxItem.broken)
+ ctxItem.mu.Unlock()
+}
+
+type openAIWSWriteDisconnectConn struct{}
+
+func (c *openAIWSWriteDisconnectConn) WriteJSON(context.Context, any) error {
+ return net.ErrClosed
+}
+
+func (c *openAIWSWriteDisconnectConn) ReadMessage(context.Context) ([]byte, error) {
+ return nil, net.ErrClosed
+}
+
+func (c *openAIWSWriteDisconnectConn) Ping(context.Context) error {
+ return net.ErrClosed
+}
+
+func (c *openAIWSWriteDisconnectConn) Close() error {
+ return nil
+}
+
+type openAIWSWriteGenericFailConn struct{}
+
+func (c *openAIWSWriteGenericFailConn) WriteJSON(context.Context, any) error {
+ return errors.New("writer failed")
+}
+
+func (c *openAIWSWriteGenericFailConn) ReadMessage(context.Context) ([]byte, error) {
+ return nil, errors.New("reader failed")
+}
+
+func (c *openAIWSWriteGenericFailConn) Ping(context.Context) error {
+ return errors.New("ping failed")
+}
+
+func (c *openAIWSWriteGenericFailConn) Close() error {
+ return nil
+}
+
+func TestOpenAIWSIngressContextLease_IOErrorInvalidatesCachedConn(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstream := &openAIWSWriteDisconnectConn{}
+ ctxItem := &openAIWSIngressContext{
+ id: "ctx_io_err",
+ accountID: 7,
+ ownerID: "owner_io",
+ upstream: upstream,
+ }
+ lease := &openAIWSIngressContextLease{
+ pool: pool,
+ context: ctxItem,
+ ownerID: "owner_io",
+ }
+ lease.cachedConn = upstream
+
+ err := lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
+ require.Error(t, err)
+ require.ErrorIs(t, err, net.ErrClosed)
+
+ lease.cachedConnMu.RLock()
+ cached := lease.cachedConn
+ lease.cachedConnMu.RUnlock()
+ require.Nil(t, cached, "write failure should invalidate cachedConn")
+
+ ctxItem.mu.Lock()
+ require.True(t, ctxItem.broken, "disconnect-style IO failure should mark context broken")
+ require.Nil(t, ctxItem.upstream, "broken context should drop upstream reference")
+ ctxItem.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextLease_GenericIOErrorKeepsContextButInvalidatesCache(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstream := &openAIWSWriteGenericFailConn{}
+ ctxItem := &openAIWSIngressContext{
+ id: "ctx_generic_err",
+ accountID: 8,
+ ownerID: "owner_generic",
+ upstream: upstream,
+ }
+ lease := &openAIWSIngressContextLease{
+ pool: pool,
+ context: ctxItem,
+ ownerID: "owner_generic",
+ }
+ lease.cachedConn = upstream
+
+ err := lease.PingWithTimeout(time.Second)
+ require.Error(t, err)
+
+ lease.cachedConnMu.RLock()
+ cached := lease.cachedConn
+ lease.cachedConnMu.RUnlock()
+ require.Nil(t, cached, "generic IO failure should still invalidate cachedConn")
+
+ ctxItem.mu.Lock()
+ require.False(t, ctxItem.broken, "non-disconnect IO failure should not force-broken context")
+ require.Equal(t, upstream, ctxItem.upstream, "upstream should remain for non-disconnect errors")
+ ctxItem.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextPool_EnsureContextUpstream_SerializesConcurrentDial(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ releaseDial := make(chan struct{})
+ blockingDialer := &openAIWSBlockingDialer{
+ release: releaseDial,
+ dialStarted: make(chan struct{}, 4),
+ }
+ pool.setClientDialerForTest(blockingDialer)
+
+ account := &Account{ID: 1301, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ type acquireResult struct {
+ lease *openAIWSIngressContextLease
+ err error
+ }
+ resultCh := make(chan acquireResult, 2)
+ acquireOnce := func() {
+ lease, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 23,
+ SessionHash: "session_same_owner",
+ OwnerID: "owner_same",
+ WSURL: "ws://test-upstream",
+ })
+ resultCh <- acquireResult{lease: lease, err: err}
+ }
+
+ go acquireOnce()
+ select {
+ case <-blockingDialer.dialStarted:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("首个 dial 未按预期启动")
+ }
+ go acquireOnce()
+
+ select {
+ case <-blockingDialer.dialStarted:
+ t.Fatal("同一 context 并发 acquire 不应触发第二次 dial")
+ case <-time.After(120 * time.Millisecond):
+ }
+
+ close(releaseDial)
+
+ results := make([]acquireResult, 0, 2)
+ for i := 0; i < 2; i++ {
+ select {
+ case result := <-resultCh:
+ require.NoError(t, result.err)
+ require.NotNil(t, result.lease)
+ results = append(results, result)
+ case <-time.After(2 * time.Second):
+ t.Fatal("等待并发 acquire 结果超时")
+ }
+ }
+
+ for _, result := range results {
+ result.lease.Release()
+ }
+ require.Equal(t, 1, blockingDialer.DialCount(), "同一 context 并发获取应只发生一次上游拨号")
+}
+
+func TestOpenAIWSIngressContextPool_EnsureContextUpstream_WaiterTimeoutDoesNotReleaseOwner(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ releaseDial := make(chan struct{})
+ blockingDialer := &openAIWSBlockingDialer{
+ release: releaseDial,
+ dialStarted: make(chan struct{}, 4),
+ }
+ pool.setClientDialerForTest(blockingDialer)
+
+ account := &Account{ID: 1302, Concurrency: 1}
+ baseReq := openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 24,
+ SessionHash: "session_waiter_timeout",
+ OwnerID: "owner_same",
+ WSURL: "ws://test-upstream",
+ }
+
+ longCtx, longCancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer longCancel()
+ type acquireResult struct {
+ lease *openAIWSIngressContextLease
+ err error
+ }
+ firstResultCh := make(chan acquireResult, 1)
+ go func() {
+ lease, err := pool.Acquire(longCtx, baseReq)
+ firstResultCh <- acquireResult{lease: lease, err: err}
+ }()
+
+ select {
+ case <-blockingDialer.dialStarted:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("首个 dial 未按预期启动")
+ }
+
+ shortCtx, shortCancel := context.WithTimeout(context.Background(), 60*time.Millisecond)
+ defer shortCancel()
+ _, waiterErr := pool.Acquire(shortCtx, baseReq)
+ require.ErrorIs(t, waiterErr, context.DeadlineExceeded, "等待中的 acquire 超时应返回 context deadline exceeded")
+
+ close(releaseDial)
+
+ select {
+ case first := <-firstResultCh:
+ require.NoError(t, first.err)
+ require.NotNil(t, first.lease)
+ require.NoError(t, first.lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "ping"}, time.Second), "等待方超时不应释放已建连 owner")
+ first.lease.Release()
+ case <-time.After(2 * time.Second):
+ t.Fatal("等待首个 acquire 结果超时")
+ }
+
+ require.Equal(t, 1, blockingDialer.DialCount())
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_QueueWaitDurationRecorded(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 1303, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 25,
+ SessionHash: "session_queue_wait",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+
+ type acquireResult struct {
+ lease *openAIWSIngressContextLease
+ err error
+ }
+ waiterCh := make(chan acquireResult, 1)
+ go func() {
+ lease, acquireErr := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 25,
+ SessionHash: "session_queue_wait",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ waiterCh <- acquireResult{lease: lease, err: acquireErr}
+ }()
+
+ time.Sleep(120 * time.Millisecond)
+ lease1.Release()
+
+ select {
+ case waiter := <-waiterCh:
+ require.NoError(t, waiter.err)
+ require.NotNil(t, waiter.lease)
+ require.GreaterOrEqual(t, waiter.lease.QueueWaitDuration(), 100*time.Millisecond)
+ waiter.lease.Release()
+ case <-time.After(2 * time.Second):
+ t.Fatal("等待第二个 acquire 结果超时")
+ }
+}
+
+type openAIWSBlockingDialer struct {
+ mu sync.Mutex
+ release <-chan struct{}
+ dialStarted chan struct{}
+ dialCount int
+}
+
+func (d *openAIWSBlockingDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = wsURL
+ _ = headers
+ _ = proxyURL
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ d.mu.Lock()
+ d.dialCount++
+ d.mu.Unlock()
+ select {
+ case d.dialStarted <- struct{}{}:
+ default:
+ }
+ if d.release != nil {
+ select {
+ case <-d.release:
+ case <-ctx.Done():
+ return nil, 0, nil, ctx.Err()
+ }
+ }
+ return &openAIWSCaptureConn{}, 0, nil, nil
+}
+
+func (d *openAIWSBlockingDialer) DialCount() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dialCount
+}
+
+// ---------------------------------------------------------------------------
+// Load-aware migration scoring tests
+// ---------------------------------------------------------------------------
+
+func TestScoreOpenAIWSIngressMigrationCandidate_HighErrorRatePenalty(t *testing.T) {
+ now := time.Now()
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(9001)
+
+ // Report a pattern of failures interspersed with occasional successes.
+ // This pushes the error rate high without tripping the circuit breaker
+ // (consecutive failures stay below the default threshold of 5).
+ for i := 0; i < 6; i++ {
+ stats.report(accountID, false, nil, "", 0)
+ stats.report(accountID, false, nil, "", 0)
+ stats.report(accountID, false, nil, "", 0)
+ stats.report(accountID, true, nil, "", 0) // reset consecutive fail counter
+ }
+ require.False(t, stats.isCircuitOpen(accountID), "circuit breaker should remain closed for this test")
+
+ ctx := &openAIWSIngressContext{accountID: accountID}
+ ctx.setLastUsedAt(now.Add(-5 * time.Minute))
+
+ scoreWithStats, _, okStats := scoreOpenAIWSIngressMigrationCandidate(ctx, now, stats)
+ require.True(t, okStats)
+
+ // Score the same context without stats (nil) for comparison.
+ scoreWithout, _, okWithout := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil)
+ require.True(t, okWithout)
+
+ require.Less(t, scoreWithStats, scoreWithout,
+ "a context on a high-error-rate account should receive a lower migration score")
+
+ // The error rate penalty should be approximately errorRate * 30.
+ // Since the circuit breaker is not open, the only load-aware penalty is
+ // errorRate * 30.
+ errorRate, _, _ := stats.snapshot(accountID)
+ expectedPenalty := errorRate * 30
+ require.InDelta(t, expectedPenalty, scoreWithout-scoreWithStats, 1.0,
+ "penalty should be approximately errorRate * 30")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_CircuitOpenPenalty(t *testing.T) {
+ now := time.Now()
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(9002)
+
+ // Trip the circuit breaker by reporting consecutive failures beyond the
+ // default threshold (5).
+ for i := 0; i < defaultCircuitBreakerFailThreshold+1; i++ {
+ stats.report(accountID, false, nil, "", 0)
+ }
+ require.True(t, stats.isCircuitOpen(accountID), "circuit breaker should be open after many failures")
+
+ ctx := &openAIWSIngressContext{accountID: accountID}
+ ctx.setLastUsedAt(now.Add(-5 * time.Minute))
+
+ scoreCircuitOpen, _, ok := scoreOpenAIWSIngressMigrationCandidate(ctx, now, stats)
+ require.True(t, ok)
+
+ // Score without stats for comparison.
+ scoreBaseline, _, okBase := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil)
+ require.True(t, okBase)
+
+ // The circuit-open penalty is -50, plus errorRate*30, so the score should
+ // be substantially lower.
+ require.Less(t, scoreCircuitOpen, scoreBaseline-45,
+ "a context on a circuit-open account should have a very low migration score")
+
+ // In practice, the combined penalty should bring the score below any
+ // reasonable minimum migration threshold (weakest = 40).
+ _, weakMin := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessWeak)
+ require.Less(t, scoreCircuitOpen, weakMin,
+ "circuit-open accounts should score below even the weakest migration threshold")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_NilStatsFallback(t *testing.T) {
+ now := time.Now()
+
+ ctx := &openAIWSIngressContext{accountID: 9003}
+ ctx.setLastUsedAt(now.Add(-5 * time.Minute))
+
+ scoreNil, _, okNil := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil)
+ require.True(t, okNil)
+
+ // Create stats but report nothing for this account -- snapshot returns 0.
+ emptyStats := newOpenAIAccountRuntimeStats()
+ scoreEmpty, _, okEmpty := scoreOpenAIWSIngressMigrationCandidate(ctx, now, emptyStats)
+ require.True(t, okEmpty)
+
+ // With no data for the account, the load-aware path should add zero
+ // penalty, yielding the same score as nil stats.
+ require.InDelta(t, scoreNil, scoreEmpty, 0.001,
+ "when scheduler stats have no data for the account, score should match nil-stats baseline")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_NilContext(t *testing.T) {
+ now := time.Now()
+ score, _, ok := scoreOpenAIWSIngressMigrationCandidate(nil, now, nil)
+ require.False(t, ok)
+ require.Equal(t, 0.0, score)
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_IdleDurationBranches(t *testing.T) {
+ now := time.Now()
+
+ // Very recently used (≤15s): penalty of -15
+ recentCtx := &openAIWSIngressContext{}
+ recentCtx.setLastUsedAt(now.Add(-5 * time.Second))
+ scoreRecent, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, 100.0-15.0, scoreRecent, 0.5, "very recently used should get -15 penalty")
+
+ // Medium idle (between 15s and 3min): bonus = idleDuration.Seconds() / 12
+ mediumCtx := &openAIWSIngressContext{}
+ mediumCtx.setLastUsedAt(now.Add(-90 * time.Second)) // 90s idle
+ scoreMedium, _, ok := scoreOpenAIWSIngressMigrationCandidate(mediumCtx, now, nil)
+ require.True(t, ok)
+ expectedBonus := 90.0 / 12.0 // 7.5
+ require.InDelta(t, 100.0+expectedBonus, scoreMedium, 0.5, "medium idle should get proportional bonus")
+
+ // Long idle (≥3min): bonus of +16
+ longCtx := &openAIWSIngressContext{}
+ longCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreLong, _, ok := scoreOpenAIWSIngressMigrationCandidate(longCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, 100.0+16.0, scoreLong, 0.5, "long idle should get +16 bonus")
+
+ // Verify ordering: long > medium > recent
+ require.Greater(t, scoreLong, scoreMedium, "longer idle should score higher than medium")
+ require.Greater(t, scoreMedium, scoreRecent, "medium idle should score higher than very recent")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_BrokenAndFailures(t *testing.T) {
+ now := time.Now()
+
+ // Broken context: -30
+ brokenCtx := &openAIWSIngressContext{broken: true}
+ brokenCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreBroken, _, ok := scoreOpenAIWSIngressMigrationCandidate(brokenCtx, now, nil)
+ require.True(t, ok)
+
+ cleanCtx := &openAIWSIngressContext{}
+ cleanCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreClean, _, ok := scoreOpenAIWSIngressMigrationCandidate(cleanCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, 30.0, scoreClean-scoreBroken, 0.5, "broken should subtract 30")
+
+ // High failure streak (capped at 40)
+ highFailCtx := &openAIWSIngressContext{failureStreak: 5}
+ highFailCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreHighFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(highFailCtx, now, nil)
+ require.True(t, ok)
+ // 5*12=60 but capped at 40
+ require.InDelta(t, 40.0, scoreClean-scoreHighFail, 0.5, "failure streak penalty should cap at 40")
+
+ // Recent failure (within 2 min): -18
+ recentFailCtx := &openAIWSIngressContext{lastFailureAt: now.Add(-30 * time.Second)}
+ recentFailCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreRecentFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentFailCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, 18.0, scoreClean-scoreRecentFail, 0.5, "recent failure should subtract 18")
+
+ // Old failure (>2 min): no penalty
+ oldFailCtx := &openAIWSIngressContext{lastFailureAt: now.Add(-5 * time.Minute)}
+ oldFailCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreOldFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(oldFailCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, scoreClean, scoreOldFail, 0.5, "old failure should have no penalty")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_MigrationPenalties(t *testing.T) {
+ now := time.Now()
+
+ cleanCtx := &openAIWSIngressContext{}
+ cleanCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreClean, _, _ := scoreOpenAIWSIngressMigrationCandidate(cleanCtx, now, nil)
+
+ // Recent migration (within 1 min): -10
+ recentMigCtx := &openAIWSIngressContext{lastMigrationAt: now.Add(-30 * time.Second)}
+ recentMigCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreRecentMig, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentMigCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, 10.0, scoreClean-scoreRecentMig, 0.5, "recent migration should subtract 10")
+
+ // High migration count (capped at 20)
+ highMigCtx := &openAIWSIngressContext{migrationCount: 6}
+ highMigCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreHighMig, _, ok := scoreOpenAIWSIngressMigrationCandidate(highMigCtx, now, nil)
+ require.True(t, ok)
+ // 6*4=24 but capped at 20
+ require.InDelta(t, 20.0, scoreClean-scoreHighMig, 0.5, "migration count penalty should cap at 20")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_CombinedPenalties(t *testing.T) {
+ now := time.Now()
+ // All penalties combined: broken(-30) + failStreak 1*12(-12) + recentFail(-18) + recentMig(-10) + migCount 1*4(-4) + recentIdle(-15)
+ worstCtx := &openAIWSIngressContext{
+ broken: true,
+ failureStreak: 1,
+ lastFailureAt: now.Add(-30 * time.Second),
+ migrationCount: 1,
+ lastMigrationAt: now.Add(-30 * time.Second),
+ }
+ worstCtx.setLastUsedAt(now.Add(-5 * time.Second))
+ score, _, ok := scoreOpenAIWSIngressMigrationCandidate(worstCtx, now, nil)
+ require.True(t, ok)
+ expected := 100.0 - 30.0 - 12.0 - 18.0 - 10.0 - 4.0 - 15.0 // = 11.0
+ require.InDelta(t, expected, score, 0.5, "all penalties should stack correctly")
+}
+
+func TestOpenAIWSIngressContextPool_MigrationBlockedByCircuitBreaker(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ stats := newOpenAIAccountRuntimeStats()
+ pool.schedulerStats = stats
+
+ accountID := int64(9004)
+
+ // Trip circuit breaker for this account.
+ for i := 0; i < defaultCircuitBreakerFailThreshold+1; i++ {
+ stats.report(accountID, false, nil, "", 0)
+ }
+ require.True(t, stats.isCircuitOpen(accountID))
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: accountID, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ // Acquire the only slot.
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 30,
+ SessionHash: "session_cb_a",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Release()
+
+ // Now try a different session -- migration should fail because the only
+ // candidate context is on a circuit-open account, whose score will be
+ // below the minimum threshold.
+ _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 30,
+ SessionHash: "session_cb_b",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnQueueFull,
+ "migration to a circuit-open account should be blocked")
+}
+
+// ---------- 连接生命周期管理(超龄轮换)测试 ----------
+
+func TestOpenAIWSIngressContext_UpstreamConnAge_ZeroValue(t *testing.T) {
+ ctx := &openAIWSIngressContext{}
+ // 未设置 createdAt 时,connAge 应返回 0
+ require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now()))
+}
+
+func TestOpenAIWSIngressContext_UpstreamConnAge_Normal(t *testing.T) {
+ ctx := &openAIWSIngressContext{}
+ past := time.Now().Add(-10 * time.Minute)
+ ctx.upstreamConnCreatedAt.Store(past.UnixNano())
+ age := ctx.upstreamConnAge(time.Now())
+ require.True(t, age >= 10*time.Minute-time.Second, "connAge 应约为 10 分钟,实际=%v", age)
+ require.True(t, age < 11*time.Minute, "connAge 不应过大,实际=%v", age)
+}
+
+func TestOpenAIWSIngressContext_UpstreamConnAge_NilSafe(t *testing.T) {
+ var ctx *openAIWSIngressContext
+ require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now()))
+}
+
+func TestOpenAIWSIngressContext_UpstreamConnAge_ClockSkew(t *testing.T) {
+ ctx := &openAIWSIngressContext{}
+ future := time.Now().Add(10 * time.Minute)
+ ctx.upstreamConnCreatedAt.Store(future.UnixNano())
+ // now 早于 createdAt(时钟回拨),应返回 0
+ require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now()))
+}
+
+func TestNewOpenAIWSIngressContextPool_UpstreamMaxAge_ZeroDisablesRotation(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds = 0
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ require.Equal(t, time.Duration(0), pool.upstreamMaxAge)
+}
+
+func TestOpenAIWSIngressContextPool_EnsureUpstream_MaxAgeRotate(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 1 * time.Second // 设置极短的 maxAge 以便测试
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ conn2 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{conn1, conn2},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 901, Concurrency: 2}
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ // 第一次 Acquire:建连
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_age",
+ OwnerID: "owner_age_1",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ require.Equal(t, 1, dialer.DialCount(), "首次 Acquire 应 dial 一次")
+
+ // Yield 保留连接
+ lease1.Yield()
+
+ // 等待超过 maxAge
+ time.Sleep(1200 * time.Millisecond)
+
+ // 第二次 Acquire:应触发超龄轮换
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_age",
+ OwnerID: "owner_age_2",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, 2, dialer.DialCount(), "超龄轮换应触发重新 dial")
+ require.True(t, conn1.closed, "旧连接应被关闭")
+ lease2.Release()
+}
+
+func TestOpenAIWSIngressContextPool_EnsureUpstream_YoungConnNotRotated(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 10 * time.Minute // 远大于测试时间
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{conn1},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 902, Concurrency: 2}
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_young",
+ OwnerID: "owner_young_1",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ lease1.Yield()
+
+ // 立即重新 Acquire:连接年轻,不应轮换
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_young",
+ OwnerID: "owner_young_2",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, 1, dialer.DialCount(), "年轻连接不应触发重新 dial")
+ require.True(t, lease2.Reused(), "年轻连接应复用")
+ require.False(t, conn1.closed, "年轻连接不应被关闭")
+ lease2.Release()
+}
+
+func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_ClosesAgedIdle(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 1 * time.Second
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+
+ ctx := &openAIWSIngressContext{
+ id: "ctx_aged_1",
+ accountID: 903,
+ upstream: conn1,
+ }
+ // 设 createdAt 为 2 秒前
+ ctx.upstreamConnCreatedAt.Store(time.Now().Add(-2 * time.Second).UnixNano())
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_aged_1"] = ctx
+
+ now := time.Now()
+ ap.mu.Lock()
+ toClose := pool.closeAgedIdleUpstreamsLocked(ap, now)
+ ap.mu.Unlock()
+
+ require.Len(t, toClose, 1, "应关闭超龄空闲连接")
+ closeOpenAIWSClientConns(toClose)
+ require.True(t, conn1.closed)
+
+ // upstream 应已清空
+ ctx.mu.Lock()
+ require.Nil(t, ctx.upstream)
+ require.Equal(t, int64(0), ctx.upstreamConnCreatedAt.Load())
+ ctx.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_SkipsOwnedContext(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 1 * time.Second
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+
+ ctx := &openAIWSIngressContext{
+ id: "ctx_owned_1",
+ accountID: 904,
+ ownerID: "active_owner",
+ upstream: conn1,
+ }
+ ctx.upstreamConnCreatedAt.Store(time.Now().Add(-2 * time.Second).UnixNano())
+ ap.contexts["ctx_owned_1"] = ctx
+
+ now := time.Now()
+ ap.mu.Lock()
+ toClose := pool.closeAgedIdleUpstreamsLocked(ap, now)
+ ap.mu.Unlock()
+
+ require.Len(t, toClose, 0, "有 owner 的超龄连接不应被关闭")
+ require.False(t, conn1.closed)
+}
+
+func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_SkipsYoungConn(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 10 * time.Minute
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+
+ ctx := &openAIWSIngressContext{
+ id: "ctx_young_1",
+ accountID: 905,
+ upstream: conn1,
+ }
+ ctx.upstreamConnCreatedAt.Store(time.Now().UnixNano())
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_young_1"] = ctx
+
+ now := time.Now()
+ ap.mu.Lock()
+ toClose := pool.closeAgedIdleUpstreamsLocked(ap, now)
+ ap.mu.Unlock()
+
+ require.Len(t, toClose, 0, "年轻连接不应被关闭")
+ require.False(t, conn1.closed)
+}
+
+func TestOpenAIWSIngressContextPool_E2E_AcquireYieldAgedReconnect(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 55 * time.Minute // 使用实际默认值
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ conn2 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{conn1, conn2},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 906, Concurrency: 2}
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ // Acquire → Yield
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_e2e",
+ OwnerID: "owner_e2e_1",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Yield()
+
+ // 手动设置 createdAt 为 56 分钟前以模拟超龄
+ pool.mu.Lock()
+ ap := pool.accounts[account.ID]
+ pool.mu.Unlock()
+ require.NotNil(t, ap)
+
+ ap.mu.Lock()
+ for _, c := range ap.contexts {
+ c.upstreamConnCreatedAt.Store(time.Now().Add(-56 * time.Minute).UnixNano())
+ }
+ ap.mu.Unlock()
+
+ // 重新 Acquire:应检测到超龄并重连
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_e2e",
+ OwnerID: "owner_e2e_2",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, 2, dialer.DialCount(), "超龄 56 分钟的连接应触发重连")
+ require.True(t, conn1.closed, "旧的超龄连接应被关闭")
+ lease2.Release()
+}
+
+// 回归测试:容量满 + 存在过期 context 时,Acquire 仍能通过 evict 腾出空间后正常分配。
+func TestOpenAIWSIngressContextPool_Acquire_EvictsExpiredWhenFull(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ conn2 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{conn1, conn2},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 901, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ // 获取一个 lease 占满容量
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 3,
+ SessionHash: "session_old",
+ OwnerID: "owner_old",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Release()
+
+ // 手动令该 context 过期
+ pool.mu.Lock()
+ ap := pool.accounts[account.ID]
+ pool.mu.Unlock()
+ ap.mu.Lock()
+ for _, c := range ap.contexts {
+ c.setExpiresAt(time.Now().Add(-2 * time.Second))
+ }
+ ap.mu.Unlock()
+
+ // 容量满(1个过期 context),新 session 的 Acquire 应通过 evict 成功
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 3,
+ SessionHash: "session_new",
+ OwnerID: "owner_new",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err, "容量满但有过期 context 时 Acquire 应成功")
+ require.NotNil(t, lease2)
+ lease2.Release()
+}
+
+// 回归测试:Acquire 找到已过期但仍在 bySession 映射中的 context 时,能正确取得所有权并刷新租约。
+func TestOpenAIWSIngressContextPool_Acquire_ReusesExpiredContextBySession(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{conn1},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 902, Concurrency: 2}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ // 第一次获取 context
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 4,
+ SessionHash: "session_reuse",
+ OwnerID: "owner_1",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Release()
+
+ // 令 context 过期(但不清理,模拟热路径不再清理的行为)
+ pool.mu.Lock()
+ ap := pool.accounts[account.ID]
+ pool.mu.Unlock()
+ ap.mu.Lock()
+ for _, c := range ap.contexts {
+ c.setExpiresAt(time.Now().Add(-1 * time.Second))
+ }
+ ap.mu.Unlock()
+
+ // 同 session 再次 Acquire,应能复用过期但未清理的 context
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 4,
+ SessionHash: "session_reuse",
+ OwnerID: "owner_2",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err, "过期但未清理的 context 应被同 session 的 Acquire 复用")
+ require.NotNil(t, lease2)
+
+ // 验证租约已刷新(expiresAt 应在未来)
+ pool.mu.Lock()
+ ap2 := pool.accounts[account.ID]
+ pool.mu.Unlock()
+ ap2.mu.Lock()
+ for _, c := range ap2.contexts {
+ c.mu.Lock()
+ ea := c.expiresAt()
+ c.mu.Unlock()
+ require.True(t, ea.After(time.Now()), "复用后租约应被刷新到未来")
+ }
+ ap2.mu.Unlock()
+
+ lease2.Release()
+}
+
+// === P3: 后台主动 Ping 检测测试 ===
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_MarksBrokenOnPingFailure(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ ctx := &openAIWSIngressContext{
+ id: "ctx_ping_fail_1",
+ accountID: 2001,
+ upstream: failConn,
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_ping_fail_1"] = ctx
+
+ pool.pingIdleUpstreams(ap)
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ streak := ctx.failureStreak
+ ctx.mu.Unlock()
+ require.True(t, broken, "Ping 失败应标记 context 为 broken")
+ require.Equal(t, 1, streak, "Ping 失败应增加 failureStreak")
+}
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsOwnedContext(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ ctx := &openAIWSIngressContext{
+ id: "ctx_ping_owned_1",
+ accountID: 2002,
+ ownerID: "active_owner",
+ upstream: failConn,
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_ping_owned_1"] = ctx
+
+ pool.pingIdleUpstreams(ap)
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ ctx.mu.Unlock()
+ require.False(t, broken, "有 owner 的 context 不应被 Ping 探测")
+}
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsBrokenContext(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ ctx := &openAIWSIngressContext{
+ id: "ctx_ping_broken_1",
+ accountID: 2003,
+ upstream: failConn,
+ broken: true,
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_ping_broken_1"] = ctx
+
+ pool.pingIdleUpstreams(ap)
+
+ ctx.mu.Lock()
+ streak := ctx.failureStreak
+ ctx.mu.Unlock()
+ require.Equal(t, 0, streak, "已 broken 的 context 不应被再次 Ping")
+}
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_HealthyConnStaysHealthy(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ healthyConn := &openAIWSCaptureConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ ctx := &openAIWSIngressContext{
+ id: "ctx_ping_healthy_1",
+ accountID: 2004,
+ upstream: healthyConn,
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_ping_healthy_1"] = ctx
+
+ pool.pingIdleUpstreams(ap)
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ hasUpstream := ctx.upstream != nil
+ ctx.mu.Unlock()
+ require.False(t, broken, "Ping 成功的 context 不应被标记 broken")
+ require.True(t, hasUpstream, "Ping 成功的 upstream 应保持")
+}
+
+func TestOpenAIWSIngressContextPool_SweepTriggersPingOnIdleContexts(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ healthyConn := &openAIWSCaptureConn{}
+
+ pool.mu.Lock()
+ ap := pool.getOrCreateAccountPoolLocked(3001)
+ pool.mu.Unlock()
+
+ ap.mu.Lock()
+ ctxFail := &openAIWSIngressContext{
+ id: "ctx_sweep_ping_fail",
+ accountID: 3001,
+ upstream: failConn,
+ }
+ ctxFail.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_sweep_ping_fail"] = ctxFail
+
+ ctxOk := &openAIWSIngressContext{
+ id: "ctx_sweep_ping_ok",
+ accountID: 3001,
+ upstream: healthyConn,
+ }
+ ctxOk.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_sweep_ping_ok"] = ctxOk
+ ap.dynamicCap.Store(2)
+ ap.mu.Unlock()
+
+ // 手动触发 sweep
+ pool.sweepExpiredIdleContexts()
+
+ ctxFail.mu.Lock()
+ failBroken := ctxFail.broken
+ ctxFail.mu.Unlock()
+ require.True(t, failBroken, "sweep 后 Ping 失败的空闲 context 应被标记 broken")
+
+ ctxOk.mu.Lock()
+ okBroken := ctxOk.broken
+ ctxOk.mu.Unlock()
+ require.False(t, okBroken, "sweep 后 Ping 成功的空闲 context 应保持健康")
+}
+
+func TestOpenAIWSIngressContextPool_YieldSchedulesDelayedPing(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{failConn},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 4001, Concurrency: 2}
+ bCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ lease, err := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_yield_ping",
+ OwnerID: "owner_yield_ping",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+
+ ingressCtx := lease.context
+ // Yield 触发延迟 Ping
+ lease.Yield()
+
+ // 等待延迟 Ping 执行完毕(默认 5s + 余量)
+ time.Sleep(6 * time.Second)
+
+ ingressCtx.mu.Lock()
+ broken := ingressCtx.broken
+ ingressCtx.mu.Unlock()
+ require.True(t, broken, "Yield 后延迟 Ping 应检测到 PingFailConn 并标记 broken")
+}
+
+func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_CancelledOnPoolClose(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_delayed_cancel_1",
+ accountID: 5001,
+ upstream: failConn,
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ // 安排 10s 延迟(远大于测试等待时间)
+ pool.scheduleDelayedPing(ctx, 10*time.Second)
+
+ // 立刻关闭 pool,应取消延迟 Ping
+ pool.Close()
+ time.Sleep(200 * time.Millisecond)
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ ctx.mu.Unlock()
+ require.False(t, broken, "pool 关闭后延迟 Ping 不应执行")
+}
+
+// === effectiveDynamicCapacity 边界测试 ===
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_NilAccountPool(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ require.Equal(t, 4, pool.effectiveDynamicCapacity(nil, 4), "ap==nil 时应返回 hardCap")
+ require.Equal(t, 0, pool.effectiveDynamicCapacity(nil, 0), "ap==nil && hardCap==0")
+}
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_ZeroHardCap(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ ap := &openAIWSIngressAccountPool{}
+ ap.dynamicCap.Store(3)
+ require.Equal(t, 0, pool.effectiveDynamicCapacity(ap, 0), "hardCap<=0 应返回 hardCap")
+ require.Equal(t, -1, pool.effectiveDynamicCapacity(ap, -1), "hardCap<0 应返回 hardCap")
+}
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapBelowOne(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ ap := &openAIWSIngressAccountPool{}
+ ap.dynamicCap.Store(0) // 异常值
+ result := pool.effectiveDynamicCapacity(ap, 4)
+ require.Equal(t, 1, result, "dynCap<1 应自动修复为 1")
+ require.Equal(t, int32(1), ap.dynamicCap.Load(), "dynCap 应被修复为 1")
+}
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapExceedsHardCap(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ ap := &openAIWSIngressAccountPool{}
+ ap.dynamicCap.Store(10)
+ require.Equal(t, 4, pool.effectiveDynamicCapacity(ap, 4), "dynCap>hardCap 应 clamp 到 hardCap")
+}
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapEqualsHardCap(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ ap := &openAIWSIngressAccountPool{}
+ ap.dynamicCap.Store(4)
+ require.Equal(t, 4, pool.effectiveDynamicCapacity(ap, 4), "dynCap==hardCap 应返回 hardCap")
+}
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_NormalPath(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ ap := &openAIWSIngressAccountPool{}
+ ap.dynamicCap.Store(2)
+ require.Equal(t, 2, pool.effectiveDynamicCapacity(ap, 8), "正常 dynCap= 2, "第二次 Acquire 应触发 dynamicCap 增长")
+
+ lease1.Release()
+ lease2.Release()
+}
+
+func TestOpenAIWSIngressContextPool_Sweeper_ShrinksDynamicCap(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1 // 1 秒 TTL 让 context 快速过期
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{}, &openAIWSCaptureConn{}, &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 6002, Concurrency: 4}
+ bCtx := context.Background()
+
+ // 创建两个 context
+ lease1, _ := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{
+ Account: account, GroupID: 1, SessionHash: "s1", OwnerID: "o1", WSURL: "ws://t",
+ })
+ lease2, _ := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{
+ Account: account, GroupID: 1, SessionHash: "s2", OwnerID: "o2", WSURL: "ws://t",
+ })
+ lease1.Release()
+ lease2.Release()
+
+ pool.mu.Lock()
+ ap := pool.accounts[account.ID]
+ pool.mu.Unlock()
+ require.True(t, ap.dynamicCap.Load() >= 2)
+
+ // 等待 context 过期
+ time.Sleep(2 * time.Second)
+
+ // 手动 sweep
+ pool.sweepExpiredIdleContexts()
+
+ // sweep 后 dynamicCap 应缩减
+ ap.mu.Lock()
+ ctxCount := len(ap.contexts)
+ ap.mu.Unlock()
+ dynCap := ap.dynamicCap.Load()
+ // 如果所有 context 都被 evict,dynamicCap 应缩到 1(min)
+ if ctxCount == 0 {
+ require.Equal(t, int32(1), dynCap, "context 全部 evict 后 dynamicCap 应缩至 1")
+ } else {
+ require.LessOrEqual(t, dynCap, int32(ctxCount), "dynamicCap 应缩至当前 context 数")
+ }
+}
+
+// === Ping 额外边界测试 ===
+
+func TestOpenAIWSIngressContextPool_PingContextUpstream_NilPoolOrContext(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})}
+ // nil context 不应 panic
+ pool.pingContextUpstream(nil)
+ // nil pool 不应 panic
+ var nilPool *openAIWSIngressContextPool
+ nilPool.pingContextUpstream(&openAIWSIngressContext{})
+}
+
+func TestOpenAIWSIngressContextPool_PingContextUpstream_SkipsDialingContext(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_dialing_1", accountID: 7001,
+ upstream: failConn, dialing: true,
+ }
+ pool.pingContextUpstream(ctx)
+ ctx.mu.Lock()
+ require.False(t, ctx.broken, "dialing 中的 context 不应被 Ping")
+ ctx.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextPool_PingContextUpstream_SkipsNoUpstream(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ ctx := &openAIWSIngressContext{
+ id: "ctx_no_upstream", accountID: 7002, upstream: nil,
+ }
+ pool.pingContextUpstream(ctx)
+ ctx.mu.Lock()
+ require.False(t, ctx.broken, "无 upstream 的 context 不应被 Ping")
+ ctx.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_NilPoolOrAP(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})}
+ pool.pingIdleUpstreams(nil)
+ var nilPool *openAIWSIngressContextPool
+ nilPool.pingIdleUpstreams(&openAIWSIngressAccountPool{})
+}
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsNilContext(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ ap := &openAIWSIngressAccountPool{
+ contexts: map[string]*openAIWSIngressContext{"nil_ctx": nil},
+ bySession: make(map[string]string),
+ }
+ // 不应 panic
+ pool.pingIdleUpstreams(ap)
+}
+
+func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_ZeroDelay(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_zero_delay", accountID: 8001, upstream: failConn,
+ }
+ // delay <= 0 应为 no-op
+ pool.scheduleDelayedPing(ctx, 0)
+ pool.scheduleDelayedPing(ctx, -1*time.Second)
+ time.Sleep(100 * time.Millisecond)
+ ctx.mu.Lock()
+ require.False(t, ctx.broken, "delay<=0 不应触发 Ping")
+ ctx.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_NilParams(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})}
+ // nil context
+ pool.scheduleDelayedPing(nil, 5*time.Second)
+ // nil pool
+ var nilPool *openAIWSIngressContextPool
+ nilPool.scheduleDelayedPing(&openAIWSIngressContext{}, 5*time.Second)
+}
+
+// === P1 并发回归:旧连接 Ping 失败不应误杀新连接 ===
+
+func TestOpenAIWSIngressContextPool_PingFailDoesNotKillNewConn(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ // 旧连接:Ping 带 200ms 延迟后失败
+ oldConn := newOpenAIWSDelayedPingFailConn(200 * time.Millisecond)
+ // 新连接:正常的 Ping
+ newConn := &openAIWSCaptureConn{}
+
+ ctx := &openAIWSIngressContext{
+ id: "ctx_race_test",
+ accountID: 9001,
+ upstream: oldConn,
+ upstreamConnID: "old_conn_1",
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ // 在后台对旧连接发起 Ping 探测
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ pool.pingContextUpstream(ctx)
+ }()
+
+ // 等待 Ping 开始执行
+ <-oldConn.pingDone
+
+ // 模拟连接重建:在 Ping 执行期间将 upstream 替换为新连接
+ ctx.mu.Lock()
+ ctx.upstream = newConn
+ ctx.upstreamConnID = "new_conn_2"
+ ctx.broken = false
+ ctx.failureStreak = 0
+ ctx.mu.Unlock()
+
+ // 等待 Ping goroutine 完成
+ wg.Wait()
+
+ // 验证:新连接不应被标记为 broken
+ ctx.mu.Lock()
+ broken := ctx.broken
+ streak := ctx.failureStreak
+ upstream := ctx.upstream
+ connID := ctx.upstreamConnID
+ ctx.mu.Unlock()
+
+ require.False(t, broken, "新连接不应被旧 Ping 失败标记为 broken")
+ require.Equal(t, 0, streak, "failureStreak 不应增加")
+ require.Equal(t, newConn, upstream, "upstream 应仍是新连接")
+ require.Equal(t, "new_conn_2", connID, "upstreamConnID 应仍是新连接的 ID")
+ require.False(t, newConn.Closed(), "新连接不应被关闭")
+}
+
+func TestOpenAIWSIngressContextPool_PingFailKillsSameConn(t *testing.T) {
+ // 对照测试:connID 未变时应正常标记 broken
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_same_conn",
+ accountID: 9002,
+ upstream: failConn,
+ upstreamConnID: "conn_1",
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ pool.pingContextUpstream(ctx)
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ streak := ctx.failureStreak
+ connID := ctx.upstreamConnID
+ ctx.mu.Unlock()
+
+ require.True(t, broken, "同一连接 Ping 失败应标记 broken")
+ require.Equal(t, 1, streak, "failureStreak 应为 1")
+ require.Empty(t, connID, "upstreamConnID 应被清空")
+}
+
+func TestOpenAIWSIngressContextPool_PingFailDoesNotKillBusyConn(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := newOpenAIWSDelayedPingFailConn(200 * time.Millisecond)
+ ctx := &openAIWSIngressContext{
+ id: "ctx_busy_conn",
+ accountID: 9005,
+ upstream: failConn,
+ upstreamConnID: "conn_busy_1",
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ pool.pingContextUpstream(ctx)
+ }()
+
+ <-failConn.pingDone
+
+ ctx.mu.Lock()
+ ctx.ownerID = "active_owner"
+ ctx.mu.Unlock()
+
+ wg.Wait()
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ streak := ctx.failureStreak
+ upstream := ctx.upstream
+ connID := ctx.upstreamConnID
+ ownerID := ctx.ownerID
+ ctx.mu.Unlock()
+
+ require.False(t, broken, "busy context 不应被后台 Ping 标记 broken")
+ require.Equal(t, 0, streak, "failureStreak 不应增加")
+ require.Equal(t, failConn, upstream, "busy context 的 upstream 不应被替换")
+ require.Equal(t, "conn_busy_1", connID, "busy context 的 connID 不应变化")
+ require.Equal(t, "active_owner", ownerID, "owner 应保持不变")
+ require.False(t, failConn.Closed(), "busy context 的连接不应被后台 Ping 关闭")
+}
+
+// === P2a 去重:连续多次 Yield 仅触发一次延迟 Ping ===
+
+func TestOpenAIWSIngressContextPool_ConsecutiveYieldsOnlyOnePing(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ // 使用 Ping 失败连接以便观察是否被标记 broken
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_yield_dedup",
+ accountID: 9003,
+ upstream: failConn,
+ upstreamConnID: "conn_dedup_1",
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ // 连续调用 5 次 scheduleDelayedPing
+ for i := 0; i < 5; i++ {
+ pool.scheduleDelayedPing(ctx, 100*time.Millisecond)
+ }
+
+ // 验证:只有一个 pendingPingTimer(不应堆积多个 goroutine)
+ ctx.mu.Lock()
+ hasPending := ctx.pendingPingTimer != nil
+ ctx.mu.Unlock()
+ require.True(t, hasPending, "应有一个 pendingPingTimer")
+
+ // 等待 timer 到期并执行 Ping
+ time.Sleep(300 * time.Millisecond)
+
+ // Ping 失败应标记 broken(证明延迟 Ping 确实执行了)
+ ctx.mu.Lock()
+ broken := ctx.broken
+ pendingTimer := ctx.pendingPingTimer
+ ctx.mu.Unlock()
+
+ require.True(t, broken, "延迟 Ping 应已执行并标记 broken")
+ require.Nil(t, pendingTimer, "Ping 执行后 pendingPingTimer 应被清理")
+}
+
+func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_ResetExtendsDelay(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_reset_delay",
+ accountID: 9004,
+ upstream: failConn,
+ upstreamConnID: "conn_reset_1",
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ // 第一次调度 200ms 延迟
+ pool.scheduleDelayedPing(ctx, 200*time.Millisecond)
+
+ // 100ms 后再次调度 200ms(应 Reset timer,从此刻起再等 200ms)
+ time.Sleep(100 * time.Millisecond)
+ pool.scheduleDelayedPing(ctx, 200*time.Millisecond)
+
+ // 150ms 后(距第一次 250ms,距 Reset 150ms)应未执行
+ time.Sleep(150 * time.Millisecond)
+ ctx.mu.Lock()
+ broken := ctx.broken
+ ctx.mu.Unlock()
+ require.False(t, broken, "Reset 后 150ms 不应触发 Ping")
+
+ // 再等 100ms(距 Reset 250ms)应已执行
+ time.Sleep(100 * time.Millisecond)
+ ctx.mu.Lock()
+ broken = ctx.broken
+ ctx.mu.Unlock()
+ require.True(t, broken, "Reset 后 250ms 应已触发 Ping")
+}
diff --git a/backend/internal/service/openai_ws_ingress_normalizer.go b/backend/internal/service/openai_ws_ingress_normalizer.go
new file mode 100644
index 000000000..74cd67374
--- /dev/null
+++ b/backend/internal/service/openai_ws_ingress_normalizer.go
@@ -0,0 +1,42 @@
+package service
+
+type openAIWSIngressPreSendNormalizeInput struct {
+ accountID int64
+ turn int
+ connID string
+
+ currentPayload []byte
+ currentPayloadBytes int
+ currentPreviousResponseID string
+ expectedPreviousResponse string
+ pendingExpectedCallIDs []string
+}
+
+type openAIWSIngressPreSendNormalizeOutput struct {
+ currentPayload []byte
+ currentPayloadBytes int
+ currentPreviousResponseID string
+ expectedPreviousResponseID string
+ pendingExpectedCallIDs []string
+ functionCallOutputCallIDs []string
+ hasFunctionCallOutputCallID bool
+}
+
+// normalizeOpenAIWSIngressPayloadBeforeSend 纯透传 + callID 提取。
+// proxy 只负责转发、认证替换、计费,所有边缘场景由 recoverIngressPrevResponseNotFound 兜底。
+func normalizeOpenAIWSIngressPayloadBeforeSend(input openAIWSIngressPreSendNormalizeInput) openAIWSIngressPreSendNormalizeOutput {
+ _ = input.accountID
+ _ = input.turn
+ _ = input.connID
+ callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(input.currentPayload)
+
+ return openAIWSIngressPreSendNormalizeOutput{
+ currentPayload: input.currentPayload,
+ currentPayloadBytes: input.currentPayloadBytes,
+ currentPreviousResponseID: input.currentPreviousResponseID,
+ expectedPreviousResponseID: input.expectedPreviousResponse,
+ pendingExpectedCallIDs: input.pendingExpectedCallIDs,
+ functionCallOutputCallIDs: callIDs,
+ hasFunctionCallOutputCallID: len(callIDs) > 0,
+ }
+}
diff --git a/backend/internal/service/openai_ws_ingress_normalizer_test.go b/backend/internal/service/openai_ws_ingress_normalizer_test.go
new file mode 100644
index 000000000..ff68f4b5b
--- /dev/null
+++ b/backend/internal/service/openai_ws_ingress_normalizer_test.go
@@ -0,0 +1,193 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// === 纯透传 normalizer 行为测试 ===
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_BasicPassthrough(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_prev",
+ "input":[{"type":"input_text","text":"hello"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 1,
+ turn: 2,
+ connID: "conn_1",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "resp_prev",
+ expectedPreviousResponse: "resp_expected",
+ pendingExpectedCallIDs: []string{"call_1"},
+ })
+
+ require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传")
+ require.Equal(t, len(payload), out.currentPayloadBytes)
+ require.Equal(t, "resp_prev", out.currentPreviousResponseID, "currentPreviousResponseID 应原样透传")
+ require.Equal(t, "resp_expected", out.expectedPreviousResponseID, "expectedPreviousResponseID 应原样透传")
+ require.Equal(t, []string{"call_1"}, out.pendingExpectedCallIDs, "pendingExpectedCallIDs 应原样透传")
+ require.False(t, out.hasFunctionCallOutputCallID, "无 FCO 时应为 false")
+ require.Empty(t, out.functionCallOutputCallIDs)
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_ExtractsCallID(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_prev",
+ "input":[{"type":"function_call_output","call_id":"call_abc","output":"{}"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 2,
+ turn: 3,
+ connID: "conn_2",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "resp_prev",
+ expectedPreviousResponse: "resp_prev",
+ })
+
+ require.True(t, out.hasFunctionCallOutputCallID, "有 FCO 时应为 true")
+ require.Equal(t, []string{"call_abc"}, out.functionCallOutputCallIDs, "应正确提取 call_id")
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_NoFCO(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"input_text","text":"hello"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 3,
+ turn: 1,
+ connID: "conn_3",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "",
+ expectedPreviousResponse: "",
+ })
+
+ require.False(t, out.hasFunctionCallOutputCallID)
+ require.Empty(t, out.functionCallOutputCallIDs)
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_MultipleFCO(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_ok",
+ "input":[
+ {"type":"function_call_output","call_id":"call_a","output":"{}"},
+ {"type":"function_call_output","call_id":"call_b","output":"{}"},
+ {"type":"function_call_output","call_id":"call_c","output":"{}"}
+ ]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 4,
+ turn: 2,
+ connID: "conn_4",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "resp_ok",
+ expectedPreviousResponse: "resp_ok",
+ })
+
+ require.True(t, out.hasFunctionCallOutputCallID)
+ require.ElementsMatch(t, []string{"call_a", "call_b", "call_c"}, out.functionCallOutputCallIDs, "应提取所有 call_id")
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_EmptyInput(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 5,
+ turn: 1,
+ connID: "conn_5",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "",
+ expectedPreviousResponse: "",
+ })
+
+ require.JSONEq(t, string(payload), string(out.currentPayload), "空 input 不应 panic")
+ require.False(t, out.hasFunctionCallOutputCallID)
+ require.Empty(t, out.functionCallOutputCallIDs)
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_ESCInterruptPassthrough(t *testing.T) {
+ t.Parallel()
+
+ // 场景:ESC 中断后客户端有意不传 previous_response_id,有 pendingCallIDs。
+ // 透传不补 prev、不注入 aborted output。
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"input_text","text":"new task after ESC"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 6,
+ turn: 5,
+ connID: "conn_esc",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "",
+ expectedPreviousResponse: "resp_prev_turn4",
+ pendingExpectedCallIDs: []string{"call_pending_1", "call_pending_2"},
+ })
+
+ require.Empty(t, out.currentPreviousResponseID, "透传不应补 previous_response_id")
+ require.Equal(t, "resp_prev_turn4", out.expectedPreviousResponseID)
+ require.False(t, out.hasFunctionCallOutputCallID, "透传不应注入 function_call_output")
+ require.Empty(t, out.functionCallOutputCallIDs)
+ require.Equal(t, []string{"call_pending_1", "call_pending_2"}, out.pendingExpectedCallIDs, "pendingExpectedCallIDs 应原样传递")
+ require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传")
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_StalePrevPassthrough(t *testing.T) {
+ t.Parallel()
+
+ // 场景:客户端传了过期 previous_response_id,透传不对齐。
+ // 由下游 recoverIngressPrevResponseNotFound 处理。
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"{}"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 7,
+ turn: 4,
+ connID: "conn_stale",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "resp_stale",
+ expectedPreviousResponse: "resp_latest",
+ })
+
+ require.Equal(t, "resp_stale", out.currentPreviousResponseID, "透传不应对齐 previous_response_id")
+ require.Equal(t, "resp_latest", out.expectedPreviousResponseID)
+ require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传")
+ require.True(t, out.hasFunctionCallOutputCallID)
+ require.Equal(t, []string{"call_1"}, out.functionCallOutputCallIDs)
+}
diff --git a/backend/internal/service/openai_ws_passthrough_layout_test.go b/backend/internal/service/openai_ws_passthrough_layout_test.go
new file mode 100644
index 000000000..0cf611faa
--- /dev/null
+++ b/backend/internal/service/openai_ws_passthrough_layout_test.go
@@ -0,0 +1,39 @@
+package service
+
+import (
+ "os"
+ "path/filepath"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIWSPassthroughDataPlaneLayout(t *testing.T) {
+ t.Parallel()
+
+ serviceDir, err := os.Getwd()
+ require.NoError(t, err)
+ v2Dir := filepath.Join(serviceDir, "openai_ws_v2")
+ forwarderFile := filepath.Join(serviceDir, "openai_ws_forwarder.go")
+
+ requiredFiles := []string{
+ filepath.Join(v2Dir, "entry.go"),
+ filepath.Join(v2Dir, "caddy_adapter.go"),
+ filepath.Join(v2Dir, "passthrough_relay.go"),
+ }
+ for _, file := range requiredFiles {
+ info, err := os.Stat(file)
+ require.NoError(t, err)
+ require.False(t, info.IsDir())
+ }
+
+ content, err := os.ReadFile(forwarderFile)
+ require.NoError(t, err)
+ forwarder := string(content)
+
+ // openai_ws_forwarder 允许分流入口,不承载 passthrough 数据面函数实现。
+ require.Contains(t, forwarder, "proxyResponsesWebSocketV2Passthrough(")
+ require.NotContains(t, forwarder, "func runUpstreamToClient(")
+ require.NotContains(t, forwarder, "func observeUpstreamMessage(")
+ require.NotContains(t, forwarder, "func parseUsageAndAccumulate(")
+}
diff --git a/backend/internal/service/openai_ws_pool.go b/backend/internal/service/openai_ws_pool.go
deleted file mode 100644
index db6a96a7d..000000000
--- a/backend/internal/service/openai_ws_pool.go
+++ /dev/null
@@ -1,1706 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "fmt"
- "math"
- "net/http"
- "sort"
- "strconv"
- "strings"
- "sync"
- "sync/atomic"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "golang.org/x/sync/errgroup"
-)
-
-const (
- openAIWSConnMaxAge = 60 * time.Minute
- openAIWSConnHealthCheckIdle = 90 * time.Second
- openAIWSConnHealthCheckTO = 2 * time.Second
- openAIWSConnPrewarmExtraDelay = 2 * time.Second
- openAIWSAcquireCleanupInterval = 3 * time.Second
- openAIWSBackgroundPingInterval = 30 * time.Second
- openAIWSBackgroundSweepTicker = 30 * time.Second
-
- openAIWSPrewarmFailureWindow = 30 * time.Second
- openAIWSPrewarmFailureSuppress = 2
-)
-
-var (
- errOpenAIWSConnClosed = errors.New("openai ws connection closed")
- errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full")
- errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable")
-)
-
-type openAIWSDialError struct {
- StatusCode int
- ResponseHeaders http.Header
- Err error
-}
-
-func (e *openAIWSDialError) Error() string {
- if e == nil {
- return ""
- }
- if e.StatusCode > 0 {
- return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err)
- }
- return fmt.Sprintf("openai ws dial failed: %v", e.Err)
-}
-
-func (e *openAIWSDialError) Unwrap() error {
- if e == nil {
- return nil
- }
- return e.Err
-}
-
-type openAIWSAcquireRequest struct {
- Account *Account
- WSURL string
- Headers http.Header
- ProxyURL string
- PreferredConnID string
- // ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。
- ForceNewConn bool
- // ForcePreferredConn: 强制本次只使用 PreferredConnID,禁止漂移到其它连接。
- ForcePreferredConn bool
-}
-
-type openAIWSConnLease struct {
- pool *openAIWSConnPool
- accountID int64
- conn *openAIWSConn
- queueWait time.Duration
- connPick time.Duration
- reused bool
- released atomic.Bool
-}
-
-func (l *openAIWSConnLease) activeConn() (*openAIWSConn, error) {
- if l == nil || l.conn == nil {
- return nil, errOpenAIWSConnClosed
- }
- if l.released.Load() {
- return nil, errOpenAIWSConnClosed
- }
- return l.conn, nil
-}
-
-func (l *openAIWSConnLease) ConnID() string {
- if l == nil || l.conn == nil {
- return ""
- }
- return l.conn.id
-}
-
-func (l *openAIWSConnLease) QueueWaitDuration() time.Duration {
- if l == nil {
- return 0
- }
- return l.queueWait
-}
-
-func (l *openAIWSConnLease) ConnPickDuration() time.Duration {
- if l == nil {
- return 0
- }
- return l.connPick
-}
-
-func (l *openAIWSConnLease) Reused() bool {
- if l == nil {
- return false
- }
- return l.reused
-}
-
-func (l *openAIWSConnLease) HandshakeHeader(name string) string {
- if l == nil || l.conn == nil {
- return ""
- }
- return l.conn.handshakeHeader(name)
-}
-
-func (l *openAIWSConnLease) IsPrewarmed() bool {
- if l == nil || l.conn == nil {
- return false
- }
- return l.conn.isPrewarmed()
-}
-
-func (l *openAIWSConnLease) MarkPrewarmed() {
- if l == nil || l.conn == nil {
- return
- }
- l.conn.markPrewarmed()
-}
-
-func (l *openAIWSConnLease) WriteJSON(value any, timeout time.Duration) error {
- conn, err := l.activeConn()
- if err != nil {
- return err
- }
- return conn.writeJSONWithTimeout(context.Background(), value, timeout)
-}
-
-func (l *openAIWSConnLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error {
- conn, err := l.activeConn()
- if err != nil {
- return err
- }
- return conn.writeJSONWithTimeout(ctx, value, timeout)
-}
-
-func (l *openAIWSConnLease) WriteJSONContext(ctx context.Context, value any) error {
- conn, err := l.activeConn()
- if err != nil {
- return err
- }
- return conn.writeJSON(value, ctx)
-}
-
-func (l *openAIWSConnLease) ReadMessage(timeout time.Duration) ([]byte, error) {
- conn, err := l.activeConn()
- if err != nil {
- return nil, err
- }
- return conn.readMessageWithTimeout(timeout)
-}
-
-func (l *openAIWSConnLease) ReadMessageContext(ctx context.Context) ([]byte, error) {
- conn, err := l.activeConn()
- if err != nil {
- return nil, err
- }
- return conn.readMessage(ctx)
-}
-
-func (l *openAIWSConnLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) {
- conn, err := l.activeConn()
- if err != nil {
- return nil, err
- }
- return conn.readMessageWithContextTimeout(ctx, timeout)
-}
-
-func (l *openAIWSConnLease) PingWithTimeout(timeout time.Duration) error {
- conn, err := l.activeConn()
- if err != nil {
- return err
- }
- return conn.pingWithTimeout(timeout)
-}
-
-func (l *openAIWSConnLease) MarkBroken() {
- if l == nil || l.pool == nil || l.conn == nil || l.released.Load() {
- return
- }
- l.pool.evictConn(l.accountID, l.conn.id)
-}
-
-func (l *openAIWSConnLease) Release() {
- if l == nil || l.conn == nil {
- return
- }
- if !l.released.CompareAndSwap(false, true) {
- return
- }
- l.conn.release()
-}
-
-type openAIWSConn struct {
- id string
- ws openAIWSClientConn
-
- handshakeHeaders http.Header
-
- leaseCh chan struct{}
- closedCh chan struct{}
- closeOnce sync.Once
-
- readMu sync.Mutex
- writeMu sync.Mutex
-
- waiters atomic.Int32
- createdAtNano atomic.Int64
- lastUsedNano atomic.Int64
- prewarmed atomic.Bool
-}
-
-func newOpenAIWSConn(id string, _ int64, ws openAIWSClientConn, handshakeHeaders http.Header) *openAIWSConn {
- now := time.Now()
- conn := &openAIWSConn{
- id: id,
- ws: ws,
- handshakeHeaders: cloneHeader(handshakeHeaders),
- leaseCh: make(chan struct{}, 1),
- closedCh: make(chan struct{}),
- }
- conn.leaseCh <- struct{}{}
- conn.createdAtNano.Store(now.UnixNano())
- conn.lastUsedNano.Store(now.UnixNano())
- return conn
-}
-
-func (c *openAIWSConn) tryAcquire() bool {
- if c == nil {
- return false
- }
- select {
- case <-c.closedCh:
- return false
- default:
- }
- select {
- case <-c.leaseCh:
- select {
- case <-c.closedCh:
- c.release()
- return false
- default:
- }
- return true
- default:
- return false
- }
-}
-
-func (c *openAIWSConn) acquire(ctx context.Context) error {
- if c == nil {
- return errOpenAIWSConnClosed
- }
- for {
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-c.closedCh:
- return errOpenAIWSConnClosed
- case <-c.leaseCh:
- select {
- case <-c.closedCh:
- c.release()
- return errOpenAIWSConnClosed
- default:
- }
- return nil
- }
- }
-}
-
-func (c *openAIWSConn) release() {
- if c == nil {
- return
- }
- select {
- case c.leaseCh <- struct{}{}:
- default:
- }
- c.touch()
-}
-
-func (c *openAIWSConn) close() {
- if c == nil {
- return
- }
- c.closeOnce.Do(func() {
- close(c.closedCh)
- if c.ws != nil {
- _ = c.ws.Close()
- }
- select {
- case c.leaseCh <- struct{}{}:
- default:
- }
- })
-}
-
-func (c *openAIWSConn) writeJSONWithTimeout(parent context.Context, value any, timeout time.Duration) error {
- if c == nil {
- return errOpenAIWSConnClosed
- }
- select {
- case <-c.closedCh:
- return errOpenAIWSConnClosed
- default:
- }
-
- writeCtx := parent
- if writeCtx == nil {
- writeCtx = context.Background()
- }
- if timeout <= 0 {
- return c.writeJSON(value, writeCtx)
- }
- var cancel context.CancelFunc
- writeCtx, cancel = context.WithTimeout(writeCtx, timeout)
- defer cancel()
- return c.writeJSON(value, writeCtx)
-}
-
-func (c *openAIWSConn) writeJSON(value any, writeCtx context.Context) error {
- c.writeMu.Lock()
- defer c.writeMu.Unlock()
- if c.ws == nil {
- return errOpenAIWSConnClosed
- }
- if writeCtx == nil {
- writeCtx = context.Background()
- }
- if err := c.ws.WriteJSON(writeCtx, value); err != nil {
- return err
- }
- c.touch()
- return nil
-}
-
-func (c *openAIWSConn) readMessageWithTimeout(timeout time.Duration) ([]byte, error) {
- return c.readMessageWithContextTimeout(context.Background(), timeout)
-}
-
-func (c *openAIWSConn) readMessageWithContextTimeout(parent context.Context, timeout time.Duration) ([]byte, error) {
- if c == nil {
- return nil, errOpenAIWSConnClosed
- }
- select {
- case <-c.closedCh:
- return nil, errOpenAIWSConnClosed
- default:
- }
-
- if parent == nil {
- parent = context.Background()
- }
- if timeout <= 0 {
- return c.readMessage(parent)
- }
- readCtx, cancel := context.WithTimeout(parent, timeout)
- defer cancel()
- return c.readMessage(readCtx)
-}
-
-func (c *openAIWSConn) readMessage(readCtx context.Context) ([]byte, error) {
- c.readMu.Lock()
- defer c.readMu.Unlock()
- if c.ws == nil {
- return nil, errOpenAIWSConnClosed
- }
- if readCtx == nil {
- readCtx = context.Background()
- }
- payload, err := c.ws.ReadMessage(readCtx)
- if err != nil {
- return nil, err
- }
- c.touch()
- return payload, nil
-}
-
-func (c *openAIWSConn) pingWithTimeout(timeout time.Duration) error {
- if c == nil {
- return errOpenAIWSConnClosed
- }
- select {
- case <-c.closedCh:
- return errOpenAIWSConnClosed
- default:
- }
-
- c.writeMu.Lock()
- defer c.writeMu.Unlock()
- if c.ws == nil {
- return errOpenAIWSConnClosed
- }
- if timeout <= 0 {
- timeout = openAIWSConnHealthCheckTO
- }
- pingCtx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- if err := c.ws.Ping(pingCtx); err != nil {
- return err
- }
- return nil
-}
-
-func (c *openAIWSConn) touch() {
- if c == nil {
- return
- }
- c.lastUsedNano.Store(time.Now().UnixNano())
-}
-
-func (c *openAIWSConn) createdAt() time.Time {
- if c == nil {
- return time.Time{}
- }
- nano := c.createdAtNano.Load()
- if nano <= 0 {
- return time.Time{}
- }
- return time.Unix(0, nano)
-}
-
-func (c *openAIWSConn) lastUsedAt() time.Time {
- if c == nil {
- return time.Time{}
- }
- nano := c.lastUsedNano.Load()
- if nano <= 0 {
- return time.Time{}
- }
- return time.Unix(0, nano)
-}
-
-func (c *openAIWSConn) idleDuration(now time.Time) time.Duration {
- if c == nil {
- return 0
- }
- last := c.lastUsedAt()
- if last.IsZero() {
- return 0
- }
- return now.Sub(last)
-}
-
-func (c *openAIWSConn) age(now time.Time) time.Duration {
- if c == nil {
- return 0
- }
- created := c.createdAt()
- if created.IsZero() {
- return 0
- }
- return now.Sub(created)
-}
-
-func (c *openAIWSConn) isLeased() bool {
- if c == nil {
- return false
- }
- return len(c.leaseCh) == 0
-}
-
-func (c *openAIWSConn) handshakeHeader(name string) string {
- if c == nil || c.handshakeHeaders == nil {
- return ""
- }
- return strings.TrimSpace(c.handshakeHeaders.Get(strings.TrimSpace(name)))
-}
-
-func (c *openAIWSConn) isPrewarmed() bool {
- if c == nil {
- return false
- }
- return c.prewarmed.Load()
-}
-
-func (c *openAIWSConn) markPrewarmed() {
- if c == nil {
- return
- }
- c.prewarmed.Store(true)
-}
-
-type openAIWSAccountPool struct {
- mu sync.Mutex
- conns map[string]*openAIWSConn
- pinnedConns map[string]int
- creating int
- lastCleanupAt time.Time
- lastAcquire *openAIWSAcquireRequest
- prewarmActive bool
- prewarmUntil time.Time
- prewarmFails int
- prewarmFailAt time.Time
-}
-
-type OpenAIWSPoolMetricsSnapshot struct {
- AcquireTotal int64
- AcquireReuseTotal int64
- AcquireCreateTotal int64
- AcquireQueueWaitTotal int64
- AcquireQueueWaitMsTotal int64
- ConnPickTotal int64
- ConnPickMsTotal int64
- ScaleUpTotal int64
- ScaleDownTotal int64
-}
-
-type openAIWSPoolMetrics struct {
- acquireTotal atomic.Int64
- acquireReuseTotal atomic.Int64
- acquireCreateTotal atomic.Int64
- acquireQueueWaitTotal atomic.Int64
- acquireQueueWaitMs atomic.Int64
- connPickTotal atomic.Int64
- connPickMs atomic.Int64
- scaleUpTotal atomic.Int64
- scaleDownTotal atomic.Int64
-}
-
-type openAIWSConnPool struct {
- cfg *config.Config
- // 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。
- clientDialer openAIWSClientDialer
-
- accounts sync.Map // key: int64(accountID), value: *openAIWSAccountPool
- seq atomic.Uint64
-
- metrics openAIWSPoolMetrics
-
- workerStopCh chan struct{}
- workerWg sync.WaitGroup
- closeOnce sync.Once
-}
-
-func newOpenAIWSConnPool(cfg *config.Config) *openAIWSConnPool {
- pool := &openAIWSConnPool{
- cfg: cfg,
- clientDialer: newDefaultOpenAIWSClientDialer(),
- workerStopCh: make(chan struct{}),
- }
- pool.startBackgroundWorkers()
- return pool
-}
-
-func (p *openAIWSConnPool) SnapshotMetrics() OpenAIWSPoolMetricsSnapshot {
- if p == nil {
- return OpenAIWSPoolMetricsSnapshot{}
- }
- return OpenAIWSPoolMetricsSnapshot{
- AcquireTotal: p.metrics.acquireTotal.Load(),
- AcquireReuseTotal: p.metrics.acquireReuseTotal.Load(),
- AcquireCreateTotal: p.metrics.acquireCreateTotal.Load(),
- AcquireQueueWaitTotal: p.metrics.acquireQueueWaitTotal.Load(),
- AcquireQueueWaitMsTotal: p.metrics.acquireQueueWaitMs.Load(),
- ConnPickTotal: p.metrics.connPickTotal.Load(),
- ConnPickMsTotal: p.metrics.connPickMs.Load(),
- ScaleUpTotal: p.metrics.scaleUpTotal.Load(),
- ScaleDownTotal: p.metrics.scaleDownTotal.Load(),
- }
-}
-
-func (p *openAIWSConnPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
- if p == nil {
- return OpenAIWSTransportMetricsSnapshot{}
- }
- if dialer, ok := p.clientDialer.(openAIWSTransportMetricsDialer); ok {
- return dialer.SnapshotTransportMetrics()
- }
- return OpenAIWSTransportMetricsSnapshot{}
-}
-
-func (p *openAIWSConnPool) setClientDialerForTest(dialer openAIWSClientDialer) {
- if p == nil || dialer == nil {
- return
- }
- p.clientDialer = dialer
-}
-
-// Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。
-func (p *openAIWSConnPool) Close() {
- if p == nil {
- return
- }
- p.closeOnce.Do(func() {
- if p.workerStopCh != nil {
- close(p.workerStopCh)
- }
- p.workerWg.Wait()
- // 遍历所有账户池,关闭全部空闲连接。
- p.accounts.Range(func(key, value any) bool {
- ap, ok := value.(*openAIWSAccountPool)
- if !ok || ap == nil {
- return true
- }
- ap.mu.Lock()
- for _, conn := range ap.conns {
- if conn != nil && !conn.isLeased() {
- conn.close()
- }
- }
- ap.mu.Unlock()
- return true
- })
- })
-}
-
-func (p *openAIWSConnPool) startBackgroundWorkers() {
- if p == nil || p.workerStopCh == nil {
- return
- }
- p.workerWg.Add(2)
- go func() {
- defer p.workerWg.Done()
- p.runBackgroundPingWorker()
- }()
- go func() {
- defer p.workerWg.Done()
- p.runBackgroundCleanupWorker()
- }()
-}
-
-type openAIWSIdlePingCandidate struct {
- accountID int64
- conn *openAIWSConn
-}
-
-func (p *openAIWSConnPool) runBackgroundPingWorker() {
- if p == nil {
- return
- }
- ticker := time.NewTicker(openAIWSBackgroundPingInterval)
- defer ticker.Stop()
- for {
- select {
- case <-ticker.C:
- p.runBackgroundPingSweep()
- case <-p.workerStopCh:
- return
- }
- }
-}
-
-func (p *openAIWSConnPool) runBackgroundPingSweep() {
- if p == nil {
- return
- }
- candidates := p.snapshotIdleConnsForPing()
- var g errgroup.Group
- g.SetLimit(10)
- for _, item := range candidates {
- item := item
- if item.conn == nil || item.conn.isLeased() || item.conn.waiters.Load() > 0 {
- continue
- }
- g.Go(func() error {
- if err := item.conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- p.evictConn(item.accountID, item.conn.id)
- }
- return nil
- })
- }
- _ = g.Wait()
-}
-
-func (p *openAIWSConnPool) snapshotIdleConnsForPing() []openAIWSIdlePingCandidate {
- if p == nil {
- return nil
- }
- candidates := make([]openAIWSIdlePingCandidate, 0)
- p.accounts.Range(func(key, value any) bool {
- accountID, ok := key.(int64)
- if !ok || accountID <= 0 {
- return true
- }
- ap, ok := value.(*openAIWSAccountPool)
- if !ok || ap == nil {
- return true
- }
- ap.mu.Lock()
- for _, conn := range ap.conns {
- if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 {
- continue
- }
- candidates = append(candidates, openAIWSIdlePingCandidate{
- accountID: accountID,
- conn: conn,
- })
- }
- ap.mu.Unlock()
- return true
- })
- return candidates
-}
-
-func (p *openAIWSConnPool) runBackgroundCleanupWorker() {
- if p == nil {
- return
- }
- ticker := time.NewTicker(openAIWSBackgroundSweepTicker)
- defer ticker.Stop()
- for {
- select {
- case <-ticker.C:
- p.runBackgroundCleanupSweep(time.Now())
- case <-p.workerStopCh:
- return
- }
- }
-}
-
-func (p *openAIWSConnPool) runBackgroundCleanupSweep(now time.Time) {
- if p == nil {
- return
- }
- type cleanupResult struct {
- evicted []*openAIWSConn
- }
- results := make([]cleanupResult, 0)
- p.accounts.Range(func(_ any, value any) bool {
- ap, ok := value.(*openAIWSAccountPool)
- if !ok || ap == nil {
- return true
- }
- maxConns := p.maxConnsHardCap()
- ap.mu.Lock()
- if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
- maxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
- }
- evicted := p.cleanupAccountLocked(ap, now, maxConns)
- ap.lastCleanupAt = now
- ap.mu.Unlock()
- if len(evicted) > 0 {
- results = append(results, cleanupResult{evicted: evicted})
- }
- return true
- })
- for _, result := range results {
- closeOpenAIWSConns(result.evicted)
- }
-}
-
-func (p *openAIWSConnPool) Acquire(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConnLease, error) {
- if p != nil {
- p.metrics.acquireTotal.Add(1)
- }
- return p.acquire(ctx, cloneOpenAIWSAcquireRequest(req), 0)
-}
-
-func (p *openAIWSConnPool) acquire(ctx context.Context, req openAIWSAcquireRequest, retry int) (*openAIWSConnLease, error) {
- if p == nil || req.Account == nil || req.Account.ID <= 0 {
- return nil, errors.New("invalid ws acquire request")
- }
- if stringsTrim(req.WSURL) == "" {
- return nil, errors.New("ws url is empty")
- }
-
- accountID := req.Account.ID
- effectiveMaxConns := p.effectiveMaxConnsByAccount(req.Account)
- if effectiveMaxConns <= 0 {
- return nil, errOpenAIWSConnQueueFull
- }
- var evicted []*openAIWSConn
- ap := p.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.lastAcquire = cloneOpenAIWSAcquireRequestPtr(&req)
- now := time.Now()
- if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval {
- evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns)
- ap.lastCleanupAt = now
- }
- pickStartedAt := time.Now()
- allowReuse := !req.ForceNewConn
- preferredConnID := stringsTrim(req.PreferredConnID)
- forcePreferredConn := allowReuse && req.ForcePreferredConn
-
- if allowReuse {
- if forcePreferredConn {
- if preferredConnID == "" {
- p.recordConnPickDuration(time.Since(pickStartedAt))
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSPreferredConnUnavailable
- }
- preferredConn, ok := ap.conns[preferredConnID]
- if !ok || preferredConn == nil {
- p.recordConnPickDuration(time.Since(pickStartedAt))
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSPreferredConnUnavailable
- }
- if preferredConn.tryAcquire() {
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- if p.shouldHealthCheckConn(preferredConn) {
- if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- preferredConn.close()
- p.evictConn(accountID, preferredConn.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
- lease := &openAIWSConnLease{
- pool: p,
- accountID: accountID,
- conn: preferredConn,
- connPick: connPick,
- reused: true,
- }
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
-
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- if int(preferredConn.waiters.Load()) >= p.queueLimitPerConn() {
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSConnQueueFull
- }
- preferredConn.waiters.Add(1)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- defer preferredConn.waiters.Add(-1)
- waitStart := time.Now()
- p.metrics.acquireQueueWaitTotal.Add(1)
-
- if err := preferredConn.acquire(ctx); err != nil {
- if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- if p.shouldHealthCheckConn(preferredConn) {
- if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- preferredConn.release()
- preferredConn.close()
- p.evictConn(accountID, preferredConn.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
-
- queueWait := time.Since(waitStart)
- p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
- lease := &openAIWSConnLease{
- pool: p,
- accountID: accountID,
- conn: preferredConn,
- queueWait: queueWait,
- connPick: connPick,
- reused: true,
- }
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
-
- if preferredConnID != "" {
- if conn, ok := ap.conns[preferredConnID]; ok && conn.tryAcquire() {
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- if p.shouldHealthCheckConn(conn) {
- if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- conn.close()
- p.evictConn(accountID, conn.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
- lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
- }
-
- best := p.pickLeastBusyConnLocked(ap, "")
- if best != nil && best.tryAcquire() {
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- if p.shouldHealthCheckConn(best) {
- if err := best.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- best.close()
- p.evictConn(accountID, best.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
- lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: best, connPick: connPick, reused: true}
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
- for _, conn := range ap.conns {
- if conn == nil || conn == best {
- continue
- }
- if conn.tryAcquire() {
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- if p.shouldHealthCheckConn(conn) {
- if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- conn.close()
- p.evictConn(accountID, conn.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
- lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
- }
- }
-
- if req.ForceNewConn && len(ap.conns)+ap.creating >= effectiveMaxConns {
- if idle := p.pickOldestIdleConnLocked(ap); idle != nil {
- delete(ap.conns, idle.id)
- evicted = append(evicted, idle)
- p.metrics.scaleDownTotal.Add(1)
- }
- }
-
- if len(ap.conns)+ap.creating < effectiveMaxConns {
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- ap.creating++
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
-
- conn, dialErr := p.dialConn(ctx, req)
-
- ap = p.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.creating--
- if dialErr != nil {
- ap.prewarmFails++
- ap.prewarmFailAt = time.Now()
- ap.mu.Unlock()
- return nil, dialErr
- }
- ap.conns[conn.id] = conn
- ap.prewarmFails = 0
- ap.prewarmFailAt = time.Time{}
- ap.mu.Unlock()
- p.metrics.acquireCreateTotal.Add(1)
-
- if !conn.tryAcquire() {
- if err := conn.acquire(ctx); err != nil {
- conn.close()
- p.evictConn(accountID, conn.id)
- return nil, err
- }
- }
- lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick}
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
-
- if req.ForceNewConn {
- p.recordConnPickDuration(time.Since(pickStartedAt))
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSConnQueueFull
- }
-
- target := p.pickLeastBusyConnLocked(ap, req.PreferredConnID)
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- if target == nil {
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSConnClosed
- }
- if int(target.waiters.Load()) >= p.queueLimitPerConn() {
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSConnQueueFull
- }
- target.waiters.Add(1)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- defer target.waiters.Add(-1)
- waitStart := time.Now()
- p.metrics.acquireQueueWaitTotal.Add(1)
-
- if err := target.acquire(ctx); err != nil {
- if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- if p.shouldHealthCheckConn(target) {
- if err := target.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- target.release()
- target.close()
- p.evictConn(accountID, target.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
-
- queueWait := time.Since(waitStart)
- p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
- lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: target, queueWait: queueWait, connPick: connPick, reused: true}
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
-}
-
-func (p *openAIWSConnPool) recordConnPickDuration(duration time.Duration) {
- if p == nil {
- return
- }
- if duration < 0 {
- duration = 0
- }
- p.metrics.connPickTotal.Add(1)
- p.metrics.connPickMs.Add(duration.Milliseconds())
-}
-
-func (p *openAIWSConnPool) pickOldestIdleConnLocked(ap *openAIWSAccountPool) *openAIWSConn {
- if ap == nil || len(ap.conns) == 0 {
- return nil
- }
- var oldest *openAIWSConn
- for _, conn := range ap.conns {
- if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
- continue
- }
- if oldest == nil || conn.lastUsedAt().Before(oldest.lastUsedAt()) {
- oldest = conn
- }
- }
- return oldest
-}
-
-func (p *openAIWSConnPool) getOrCreateAccountPool(accountID int64) *openAIWSAccountPool {
- if p == nil || accountID <= 0 {
- return nil
- }
- if existing, ok := p.accounts.Load(accountID); ok {
- if ap, typed := existing.(*openAIWSAccountPool); typed && ap != nil {
- return ap
- }
- }
- ap := &openAIWSAccountPool{
- conns: make(map[string]*openAIWSConn),
- pinnedConns: make(map[string]int),
- }
- actual, _ := p.accounts.LoadOrStore(accountID, ap)
- if typed, ok := actual.(*openAIWSAccountPool); ok && typed != nil {
- return typed
- }
- return ap
-}
-
-// ensureAccountPoolLocked 兼容旧调用。
-func (p *openAIWSConnPool) ensureAccountPoolLocked(accountID int64) *openAIWSAccountPool {
- return p.getOrCreateAccountPool(accountID)
-}
-
-func (p *openAIWSConnPool) getAccountPool(accountID int64) (*openAIWSAccountPool, bool) {
- if p == nil || accountID <= 0 {
- return nil, false
- }
- value, ok := p.accounts.Load(accountID)
- if !ok || value == nil {
- return nil, false
- }
- ap, typed := value.(*openAIWSAccountPool)
- return ap, typed && ap != nil
-}
-
-func (p *openAIWSConnPool) isConnPinnedLocked(ap *openAIWSAccountPool, connID string) bool {
- if ap == nil || connID == "" || len(ap.pinnedConns) == 0 {
- return false
- }
- return ap.pinnedConns[connID] > 0
-}
-
-func (p *openAIWSConnPool) cleanupAccountLocked(ap *openAIWSAccountPool, now time.Time, maxConns int) []*openAIWSConn {
- if ap == nil {
- return nil
- }
- maxAge := p.maxConnAge()
-
- evicted := make([]*openAIWSConn, 0)
- for id, conn := range ap.conns {
- if conn == nil {
- delete(ap.conns, id)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, id)
- }
- continue
- }
- select {
- case <-conn.closedCh:
- delete(ap.conns, id)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, id)
- }
- evicted = append(evicted, conn)
- continue
- default:
- }
- if p.isConnPinnedLocked(ap, id) {
- continue
- }
- if maxAge > 0 && !conn.isLeased() && conn.age(now) > maxAge {
- delete(ap.conns, id)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, id)
- }
- evicted = append(evicted, conn)
- }
- }
-
- if maxConns <= 0 {
- maxConns = p.maxConnsHardCap()
- }
- maxIdle := p.maxIdlePerAccount()
- if maxIdle < 0 || maxIdle > maxConns {
- maxIdle = maxConns
- }
- if maxIdle >= 0 && len(ap.conns) > maxIdle {
- idleConns := make([]*openAIWSConn, 0, len(ap.conns))
- for id, conn := range ap.conns {
- if conn == nil {
- delete(ap.conns, id)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, id)
- }
- continue
- }
- // 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。
- if conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
- continue
- }
- idleConns = append(idleConns, conn)
- }
- sort.SliceStable(idleConns, func(i, j int) bool {
- return idleConns[i].lastUsedAt().Before(idleConns[j].lastUsedAt())
- })
- redundant := len(ap.conns) - maxIdle
- if redundant > len(idleConns) {
- redundant = len(idleConns)
- }
- for i := 0; i < redundant; i++ {
- conn := idleConns[i]
- delete(ap.conns, conn.id)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, conn.id)
- }
- evicted = append(evicted, conn)
- }
- if redundant > 0 {
- p.metrics.scaleDownTotal.Add(int64(redundant))
- }
- }
-
- return evicted
-}
-
-func (p *openAIWSConnPool) pickLeastBusyConnLocked(ap *openAIWSAccountPool, preferredConnID string) *openAIWSConn {
- if ap == nil || len(ap.conns) == 0 {
- return nil
- }
- preferredConnID = stringsTrim(preferredConnID)
- if preferredConnID != "" {
- if conn, ok := ap.conns[preferredConnID]; ok {
- return conn
- }
- }
- var best *openAIWSConn
- var bestWaiters int32
- var bestLastUsed time.Time
- for _, conn := range ap.conns {
- if conn == nil {
- continue
- }
- waiters := conn.waiters.Load()
- lastUsed := conn.lastUsedAt()
- if best == nil ||
- waiters < bestWaiters ||
- (waiters == bestWaiters && lastUsed.Before(bestLastUsed)) {
- best = conn
- bestWaiters = waiters
- bestLastUsed = lastUsed
- }
- }
- return best
-}
-
-func accountPoolLoadLocked(ap *openAIWSAccountPool) (inflight int, waiters int) {
- if ap == nil {
- return 0, 0
- }
- for _, conn := range ap.conns {
- if conn == nil {
- continue
- }
- if conn.isLeased() {
- inflight++
- }
- waiters += int(conn.waiters.Load())
- }
- return inflight, waiters
-}
-
-// AccountPoolLoad 返回指定账号连接池的并发与排队快照。
-func (p *openAIWSConnPool) AccountPoolLoad(accountID int64) (inflight int, waiters int, conns int) {
- if p == nil || accountID <= 0 {
- return 0, 0, 0
- }
- ap, ok := p.getAccountPool(accountID)
- if !ok || ap == nil {
- return 0, 0, 0
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- inflight, waiters = accountPoolLoadLocked(ap)
- return inflight, waiters, len(ap.conns)
-}
-
-func (p *openAIWSConnPool) ensureTargetIdleAsync(accountID int64) {
- if p == nil || accountID <= 0 {
- return
- }
-
- var req openAIWSAcquireRequest
- need := 0
- ap, ok := p.getAccountPool(accountID)
- if !ok || ap == nil {
- return
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- if ap.lastAcquire == nil {
- return
- }
- if ap.prewarmActive {
- return
- }
- now := time.Now()
- if !ap.prewarmUntil.IsZero() && now.Before(ap.prewarmUntil) {
- return
- }
- if p.shouldSuppressPrewarmLocked(ap, now) {
- return
- }
- effectiveMaxConns := p.maxConnsHardCap()
- if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
- effectiveMaxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
- }
- target := p.targetConnCountLocked(ap, effectiveMaxConns)
- current := len(ap.conns) + ap.creating
- if current >= target {
- return
- }
- need = target - current
- if need <= 0 {
- return
- }
- req = cloneOpenAIWSAcquireRequest(*ap.lastAcquire)
- ap.prewarmActive = true
- if cooldown := p.prewarmCooldown(); cooldown > 0 {
- ap.prewarmUntil = now.Add(cooldown)
- }
- ap.creating += need
- p.metrics.scaleUpTotal.Add(int64(need))
-
- go p.prewarmConns(accountID, req, need)
-}
-
-func (p *openAIWSConnPool) targetConnCountLocked(ap *openAIWSAccountPool, maxConns int) int {
- if ap == nil {
- return 0
- }
-
- if maxConns <= 0 {
- return 0
- }
-
- minIdle := p.minIdlePerAccount()
- if minIdle < 0 {
- minIdle = 0
- }
- if minIdle > maxConns {
- minIdle = maxConns
- }
-
- inflight, waiters := accountPoolLoadLocked(ap)
- utilization := p.targetUtilization()
- demand := inflight + waiters
- if demand <= 0 {
- return minIdle
- }
-
- target := 1
- if demand > 1 {
- target = int(math.Ceil(float64(demand) / utilization))
- }
- if waiters > 0 && target < len(ap.conns)+1 {
- target = len(ap.conns) + 1
- }
- if target < minIdle {
- target = minIdle
- }
- if target > maxConns {
- target = maxConns
- }
- return target
-}
-
-func (p *openAIWSConnPool) prewarmConns(accountID int64, req openAIWSAcquireRequest, total int) {
- defer func() {
- if ap, ok := p.getAccountPool(accountID); ok && ap != nil {
- ap.mu.Lock()
- ap.prewarmActive = false
- ap.mu.Unlock()
- }
- }()
-
- for i := 0; i < total; i++ {
- ctx, cancel := context.WithTimeout(context.Background(), p.dialTimeout()+openAIWSConnPrewarmExtraDelay)
- conn, err := p.dialConn(ctx, req)
- cancel()
-
- ap, ok := p.getAccountPool(accountID)
- if !ok || ap == nil {
- if conn != nil {
- conn.close()
- }
- return
- }
- ap.mu.Lock()
- if ap.creating > 0 {
- ap.creating--
- }
- if err != nil {
- ap.prewarmFails++
- ap.prewarmFailAt = time.Now()
- ap.mu.Unlock()
- continue
- }
- if len(ap.conns) >= p.effectiveMaxConnsByAccount(req.Account) {
- ap.mu.Unlock()
- conn.close()
- continue
- }
- ap.conns[conn.id] = conn
- ap.prewarmFails = 0
- ap.prewarmFailAt = time.Time{}
- ap.mu.Unlock()
- }
-}
-
-func (p *openAIWSConnPool) evictConn(accountID int64, connID string) {
- if p == nil || accountID <= 0 || stringsTrim(connID) == "" {
- return
- }
- var conn *openAIWSConn
- ap, ok := p.getAccountPool(accountID)
- if ok && ap != nil {
- ap.mu.Lock()
- if c, exists := ap.conns[connID]; exists {
- conn = c
- delete(ap.conns, connID)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, connID)
- }
- }
- ap.mu.Unlock()
- }
- if conn != nil {
- conn.close()
- }
-}
-
-func (p *openAIWSConnPool) PinConn(accountID int64, connID string) bool {
- if p == nil || accountID <= 0 {
- return false
- }
- connID = stringsTrim(connID)
- if connID == "" {
- return false
- }
- ap, ok := p.getAccountPool(accountID)
- if !ok || ap == nil {
- return false
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- if _, exists := ap.conns[connID]; !exists {
- return false
- }
- if ap.pinnedConns == nil {
- ap.pinnedConns = make(map[string]int)
- }
- ap.pinnedConns[connID]++
- return true
-}
-
-func (p *openAIWSConnPool) UnpinConn(accountID int64, connID string) {
- if p == nil || accountID <= 0 {
- return
- }
- connID = stringsTrim(connID)
- if connID == "" {
- return
- }
- ap, ok := p.getAccountPool(accountID)
- if !ok || ap == nil {
- return
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- if len(ap.pinnedConns) == 0 {
- return
- }
- count := ap.pinnedConns[connID]
- if count <= 1 {
- delete(ap.pinnedConns, connID)
- return
- }
- ap.pinnedConns[connID] = count - 1
-}
-
-func (p *openAIWSConnPool) dialConn(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConn, error) {
- if p == nil || p.clientDialer == nil {
- return nil, errors.New("openai ws client dialer is nil")
- }
- conn, status, handshakeHeaders, err := p.clientDialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL)
- if err != nil {
- return nil, &openAIWSDialError{
- StatusCode: status,
- ResponseHeaders: cloneHeader(handshakeHeaders),
- Err: err,
- }
- }
- if conn == nil {
- return nil, &openAIWSDialError{
- StatusCode: status,
- ResponseHeaders: cloneHeader(handshakeHeaders),
- Err: errors.New("openai ws dialer returned nil connection"),
- }
- }
- id := p.nextConnID(req.Account.ID)
- return newOpenAIWSConn(id, req.Account.ID, conn, handshakeHeaders), nil
-}
-
-func (p *openAIWSConnPool) nextConnID(accountID int64) string {
- seq := p.seq.Add(1)
- buf := make([]byte, 0, 32)
- buf = append(buf, "oa_ws_"...)
- buf = strconv.AppendInt(buf, accountID, 10)
- buf = append(buf, '_')
- buf = strconv.AppendUint(buf, seq, 10)
- return string(buf)
-}
-
-func (p *openAIWSConnPool) shouldHealthCheckConn(conn *openAIWSConn) bool {
- if conn == nil {
- return false
- }
- return conn.idleDuration(time.Now()) >= openAIWSConnHealthCheckIdle
-}
-
-func (p *openAIWSConnPool) maxConnsHardCap() int {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 {
- return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount
- }
- return 8
-}
-
-func (p *openAIWSConnPool) dynamicMaxConnsEnabled() bool {
- if p != nil && p.cfg != nil {
- return p.cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled
- }
- return false
-}
-
-func (p *openAIWSConnPool) modeRouterV2Enabled() bool {
- if p != nil && p.cfg != nil {
- return p.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
- }
- return false
-}
-
-func (p *openAIWSConnPool) maxConnsFactorByAccount(account *Account) float64 {
- if p == nil || p.cfg == nil || account == nil {
- return 1.0
- }
- switch account.Type {
- case AccountTypeOAuth:
- if p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor > 0 {
- return p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor
- }
- case AccountTypeAPIKey:
- if p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor > 0 {
- return p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor
- }
- }
- return 1.0
-}
-
-func (p *openAIWSConnPool) effectiveMaxConnsByAccount(account *Account) int {
- hardCap := p.maxConnsHardCap()
- if hardCap <= 0 {
- return 0
- }
- if p.modeRouterV2Enabled() {
- if account == nil {
- return hardCap
- }
- if account.Concurrency <= 0 {
- return 0
- }
- return account.Concurrency
- }
- if account == nil || !p.dynamicMaxConnsEnabled() {
- return hardCap
- }
- if account.Concurrency <= 0 {
- // 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。
- return hardCap
- }
- factor := p.maxConnsFactorByAccount(account)
- if factor <= 0 {
- factor = 1.0
- }
- effective := int(math.Ceil(float64(account.Concurrency) * factor))
- if effective < 1 {
- effective = 1
- }
- if effective > hardCap {
- effective = hardCap
- }
- return effective
-}
-
-func (p *openAIWSConnPool) minIdlePerAccount() int {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MinIdlePerAccount >= 0 {
- return p.cfg.Gateway.OpenAIWS.MinIdlePerAccount
- }
- return 0
-}
-
-func (p *openAIWSConnPool) maxIdlePerAccount() int {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount >= 0 {
- return p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount
- }
- return 4
-}
-
-func (p *openAIWSConnPool) maxConnAge() time.Duration {
- return openAIWSConnMaxAge
-}
-
-func (p *openAIWSConnPool) queueLimitPerConn() int {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.QueueLimitPerConn > 0 {
- return p.cfg.Gateway.OpenAIWS.QueueLimitPerConn
- }
- return 256
-}
-
-func (p *openAIWSConnPool) targetUtilization() float64 {
- if p != nil && p.cfg != nil {
- ratio := p.cfg.Gateway.OpenAIWS.PoolTargetUtilization
- if ratio > 0 && ratio <= 1 {
- return ratio
- }
- }
- return 0.7
-}
-
-func (p *openAIWSConnPool) prewarmCooldown() time.Duration {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS > 0 {
- return time.Duration(p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS) * time.Millisecond
- }
- return 0
-}
-
-func (p *openAIWSConnPool) shouldSuppressPrewarmLocked(ap *openAIWSAccountPool, now time.Time) bool {
- if ap == nil {
- return true
- }
- if ap.prewarmFails <= 0 {
- return false
- }
- if ap.prewarmFailAt.IsZero() {
- ap.prewarmFails = 0
- return false
- }
- if now.Sub(ap.prewarmFailAt) > openAIWSPrewarmFailureWindow {
- ap.prewarmFails = 0
- ap.prewarmFailAt = time.Time{}
- return false
- }
- return ap.prewarmFails >= openAIWSPrewarmFailureSuppress
-}
-
-func (p *openAIWSConnPool) dialTimeout() time.Duration {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 {
- return time.Duration(p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second
- }
- return 10 * time.Second
-}
-
-func cloneOpenAIWSAcquireRequest(req openAIWSAcquireRequest) openAIWSAcquireRequest {
- copied := req
- copied.Headers = cloneHeader(req.Headers)
- copied.WSURL = stringsTrim(req.WSURL)
- copied.ProxyURL = stringsTrim(req.ProxyURL)
- copied.PreferredConnID = stringsTrim(req.PreferredConnID)
- return copied
-}
-
-func cloneOpenAIWSAcquireRequestPtr(req *openAIWSAcquireRequest) *openAIWSAcquireRequest {
- if req == nil {
- return nil
- }
- copied := cloneOpenAIWSAcquireRequest(*req)
- return &copied
-}
-
-func cloneHeader(src http.Header) http.Header {
- if src == nil {
- return nil
- }
- dst := make(http.Header, len(src))
- for k, vals := range src {
- if len(vals) == 0 {
- dst[k] = nil
- continue
- }
- copied := make([]string, len(vals))
- copy(copied, vals)
- dst[k] = copied
- }
- return dst
-}
-
-func closeOpenAIWSConns(conns []*openAIWSConn) {
- if len(conns) == 0 {
- return
- }
- for _, conn := range conns {
- if conn == nil {
- continue
- }
- conn.close()
- }
-}
-
-func stringsTrim(value string) string {
- return strings.TrimSpace(value)
-}
diff --git a/backend/internal/service/openai_ws_pool_benchmark_test.go b/backend/internal/service/openai_ws_pool_benchmark_test.go
deleted file mode 100644
index bff74b626..000000000
--- a/backend/internal/service/openai_ws_pool_benchmark_test.go
+++ /dev/null
@@ -1,58 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "testing"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
-)
-
-func BenchmarkOpenAIWSPoolAcquire(b *testing.B) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
-
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(&openAIWSCountingDialer{})
-
- account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- req := openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- }
- ctx := context.Background()
-
- lease, err := pool.Acquire(ctx, req)
- if err != nil {
- b.Fatalf("warm acquire failed: %v", err)
- }
- lease.Release()
-
- b.ReportAllocs()
- b.ResetTimer()
- b.RunParallel(func(pb *testing.PB) {
- for pb.Next() {
- var (
- got *openAIWSConnLease
- acquireErr error
- )
- for retry := 0; retry < 3; retry++ {
- got, acquireErr = pool.Acquire(ctx, req)
- if acquireErr == nil {
- break
- }
- if !errors.Is(acquireErr, errOpenAIWSConnClosed) {
- break
- }
- }
- if acquireErr != nil {
- b.Fatalf("acquire failed: %v", acquireErr)
- }
- got.Release()
- }
- })
-}
diff --git a/backend/internal/service/openai_ws_pool_test.go b/backend/internal/service/openai_ws_pool_test.go
deleted file mode 100644
index b2683ee04..000000000
--- a/backend/internal/service/openai_ws_pool_test.go
+++ /dev/null
@@ -1,1709 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "net/http"
- "strings"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/stretchr/testify/require"
-)
-
-func TestOpenAIWSConnPool_CleanupStaleAndTrimIdle(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- pool := newOpenAIWSConnPool(cfg)
-
- accountID := int64(10)
- ap := pool.getOrCreateAccountPool(accountID)
-
- stale := newOpenAIWSConn("stale", accountID, nil, nil)
- stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
- stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
-
- idleOld := newOpenAIWSConn("idle_old", accountID, nil, nil)
- idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
-
- idleNew := newOpenAIWSConn("idle_new", accountID, nil, nil)
- idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
-
- ap.conns[stale.id] = stale
- ap.conns[idleOld.id] = idleOld
- ap.conns[idleNew.id] = idleNew
-
- evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
- closeOpenAIWSConns(evicted)
-
- require.Nil(t, ap.conns["stale"], "stale connection should be rotated")
- require.Nil(t, ap.conns["idle_old"], "old idle should be trimmed by max_idle")
- require.NotNil(t, ap.conns["idle_new"], "newer idle should be kept")
-}
-
-func TestOpenAIWSConnPool_NextConnIDFormat(t *testing.T) {
- pool := newOpenAIWSConnPool(&config.Config{})
- id1 := pool.nextConnID(42)
- id2 := pool.nextConnID(42)
-
- require.True(t, strings.HasPrefix(id1, "oa_ws_42_"))
- require.True(t, strings.HasPrefix(id2, "oa_ws_42_"))
- require.NotEqual(t, id1, id2)
- require.Equal(t, "oa_ws_42_1", id1)
- require.Equal(t, "oa_ws_42_2", id2)
-}
-
-func TestOpenAIWSConnPool_AcquireCleanupInterval(t *testing.T) {
- require.Equal(t, 3*time.Second, openAIWSAcquireCleanupInterval)
- require.Less(t, openAIWSAcquireCleanupInterval, openAIWSBackgroundSweepTicker)
-}
-
-func TestOpenAIWSConnLease_WriteJSONAndGuards(t *testing.T) {
- conn := newOpenAIWSConn("lease_write", 1, &openAIWSFakeConn{}, nil)
- lease := &openAIWSConnLease{conn: conn}
- require.NoError(t, lease.WriteJSON(map[string]any{"type": "response.create"}, 0))
-
- var nilLease *openAIWSConnLease
- err := nilLease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-
- err = (&openAIWSConnLease{}).WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-}
-
-func TestOpenAIWSConn_WriteJSONWithTimeout_NilParentContextUsesBackground(t *testing.T) {
- probe := &openAIWSContextProbeConn{}
- conn := newOpenAIWSConn("ctx_probe", 1, probe, nil)
- require.NoError(t, conn.writeJSONWithTimeout(context.Background(), map[string]any{"type": "response.create"}, 0))
- require.NotNil(t, probe.lastWriteCtx)
-}
-
-func TestOpenAIWSConnPool_TargetConnCountAdaptive(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 6
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.5
-
- pool := newOpenAIWSConnPool(cfg)
- ap := pool.getOrCreateAccountPool(88)
-
- conn1 := newOpenAIWSConn("c1", 88, nil, nil)
- conn2 := newOpenAIWSConn("c2", 88, nil, nil)
- require.True(t, conn1.tryAcquire())
- require.True(t, conn2.tryAcquire())
- conn1.waiters.Store(1)
- conn2.waiters.Store(1)
-
- ap.conns[conn1.id] = conn1
- ap.conns[conn2.id] = conn2
-
- target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
- require.Equal(t, 6, target, "应按 inflight+waiters 与 target_utilization 自适应扩容到上限")
-
- conn1.release()
- conn2.release()
- conn1.waiters.Store(0)
- conn2.waiters.Store(0)
- target = pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
- require.Equal(t, 1, target, "低负载时应缩回到最小空闲连接")
-}
-
-func TestOpenAIWSConnPool_TargetConnCountMinIdleZero(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
-
- pool := newOpenAIWSConnPool(cfg)
- ap := pool.getOrCreateAccountPool(66)
-
- target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
- require.Equal(t, 0, target, "min_idle=0 且无负载时应允许缩容到 0")
-}
-
-func TestOpenAIWSConnPool_EnsureTargetIdleAsync(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
-
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(&openAIWSFakeDialer{})
-
- accountID := int64(77)
- account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.lastAcquire = &openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- }
- ap.mu.Unlock()
-
- pool.ensureTargetIdleAsync(accountID)
-
- require.Eventually(t, func() bool {
- ap, ok := pool.getAccountPool(accountID)
- if !ok || ap == nil {
- return false
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- return len(ap.conns) >= 2
- }, 2*time.Second, 20*time.Millisecond)
-
- metrics := pool.SnapshotMetrics()
- require.GreaterOrEqual(t, metrics.ScaleUpTotal, int64(2))
-}
-
-func TestOpenAIWSConnPool_EnsureTargetIdleAsyncCooldown(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
- cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 500
-
- pool := newOpenAIWSConnPool(cfg)
- dialer := &openAIWSCountingDialer{}
- pool.setClientDialerForTest(dialer)
-
- accountID := int64(178)
- account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.lastAcquire = &openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- }
- ap.mu.Unlock()
-
- pool.ensureTargetIdleAsync(accountID)
- require.Eventually(t, func() bool {
- ap, ok := pool.getAccountPool(accountID)
- if !ok || ap == nil {
- return false
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- return len(ap.conns) >= 2 && !ap.prewarmActive
- }, 2*time.Second, 20*time.Millisecond)
- firstDialCount := dialer.DialCount()
- require.GreaterOrEqual(t, firstDialCount, 2)
-
- // 人工制造缺口触发新一轮预热需求。
- ap, ok := pool.getAccountPool(accountID)
- require.True(t, ok)
- require.NotNil(t, ap)
- ap.mu.Lock()
- for id := range ap.conns {
- delete(ap.conns, id)
- break
- }
- ap.mu.Unlock()
-
- pool.ensureTargetIdleAsync(accountID)
- time.Sleep(120 * time.Millisecond)
- require.Equal(t, firstDialCount, dialer.DialCount(), "cooldown 窗口内不应再次触发预热")
-
- time.Sleep(450 * time.Millisecond)
- pool.ensureTargetIdleAsync(accountID)
- require.Eventually(t, func() bool {
- return dialer.DialCount() > firstDialCount
- }, 2*time.Second, 20*time.Millisecond)
-}
-
-func TestOpenAIWSConnPool_EnsureTargetIdleAsyncFailureSuppress(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
- cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 0
-
- pool := newOpenAIWSConnPool(cfg)
- dialer := &openAIWSAlwaysFailDialer{}
- pool.setClientDialerForTest(dialer)
-
- accountID := int64(279)
- account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.lastAcquire = &openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- }
- ap.mu.Unlock()
-
- pool.ensureTargetIdleAsync(accountID)
- require.Eventually(t, func() bool {
- ap, ok := pool.getAccountPool(accountID)
- if !ok || ap == nil {
- return false
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- return !ap.prewarmActive
- }, 2*time.Second, 20*time.Millisecond)
-
- pool.ensureTargetIdleAsync(accountID)
- require.Eventually(t, func() bool {
- ap, ok := pool.getAccountPool(accountID)
- if !ok || ap == nil {
- return false
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- return !ap.prewarmActive
- }, 2*time.Second, 20*time.Millisecond)
- require.Equal(t, 2, dialer.DialCount())
-
- // 连续失败达到阈值后,新的预热触发应被抑制,不再继续拨号。
- pool.ensureTargetIdleAsync(accountID)
- time.Sleep(120 * time.Millisecond)
- require.Equal(t, 2, dialer.DialCount())
-}
-
-func TestOpenAIWSConnPool_AcquireQueueWaitMetrics(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
-
- pool := newOpenAIWSConnPool(cfg)
- accountID := int64(99)
- account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- conn := newOpenAIWSConn("busy", accountID, &openAIWSFakeConn{}, nil)
- require.True(t, conn.tryAcquire()) // 占用连接,触发后续排队
-
- ap := pool.ensureAccountPoolLocked(accountID)
- ap.mu.Lock()
- ap.conns[conn.id] = conn
- ap.lastAcquire = &openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- }
- ap.mu.Unlock()
-
- go func() {
- time.Sleep(60 * time.Millisecond)
- conn.release()
- }()
-
- lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- })
- require.NoError(t, err)
- require.NotNil(t, lease)
- require.True(t, lease.Reused())
- require.GreaterOrEqual(t, lease.QueueWaitDuration(), 50*time.Millisecond)
- lease.Release()
-
- metrics := pool.SnapshotMetrics()
- require.GreaterOrEqual(t, metrics.AcquireQueueWaitTotal, int64(1))
- require.Greater(t, metrics.AcquireQueueWaitMsTotal, int64(0))
- require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
-}
-
-func TestOpenAIWSConnPool_ForceNewConnSkipsReuse(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
-
- pool := newOpenAIWSConnPool(cfg)
- dialer := &openAIWSCountingDialer{}
- pool.setClientDialerForTest(dialer)
-
- account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
-
- lease1, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- })
- require.NoError(t, err)
- require.NotNil(t, lease1)
- lease1.Release()
-
- lease2, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- ForceNewConn: true,
- })
- require.NoError(t, err)
- require.NotNil(t, lease2)
- lease2.Release()
-
- require.Equal(t, 2, dialer.DialCount(), "ForceNewConn=true 时应跳过空闲连接复用并新建连接")
-}
-
-func TestOpenAIWSConnPool_AcquireForcePreferredConnUnavailable(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
-
- pool := newOpenAIWSConnPool(cfg)
- account := &Account{ID: 124, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(account.ID)
- otherConn := newOpenAIWSConn("other_conn", account.ID, &openAIWSFakeConn{}, nil)
- ap.mu.Lock()
- ap.conns[otherConn.id] = otherConn
- ap.mu.Unlock()
-
- _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- ForcePreferredConn: true,
- })
- require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
-
- _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- PreferredConnID: "missing_conn",
- ForcePreferredConn: true,
- })
- require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
-}
-
-func TestOpenAIWSConnPool_AcquireForcePreferredConnQueuesOnPreferredOnly(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
-
- pool := newOpenAIWSConnPool(cfg)
- account := &Account{ID: 125, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(account.ID)
- preferredConn := newOpenAIWSConn("preferred_conn", account.ID, &openAIWSFakeConn{}, nil)
- otherConn := newOpenAIWSConn("other_conn_idle", account.ID, &openAIWSFakeConn{}, nil)
- require.True(t, preferredConn.tryAcquire(), "先占用 preferred 连接,触发排队获取")
- ap.mu.Lock()
- ap.conns[preferredConn.id] = preferredConn
- ap.conns[otherConn.id] = otherConn
- ap.lastCleanupAt = time.Now()
- ap.mu.Unlock()
-
- go func() {
- time.Sleep(60 * time.Millisecond)
- preferredConn.release()
- }()
-
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- lease, err := pool.Acquire(ctx, openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- PreferredConnID: preferredConn.id,
- ForcePreferredConn: true,
- })
- require.NoError(t, err)
- require.NotNil(t, lease)
- require.Equal(t, preferredConn.id, lease.ConnID(), "严格模式应只等待并复用 preferred 连接,不可漂移")
- require.GreaterOrEqual(t, lease.QueueWaitDuration(), 40*time.Millisecond)
- lease.Release()
- require.True(t, otherConn.tryAcquire(), "other 连接不应被严格模式抢占")
- otherConn.release()
-}
-
-func TestOpenAIWSConnPool_AcquireForcePreferredConnDirectAndQueueFull(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
-
- pool := newOpenAIWSConnPool(cfg)
- account := &Account{ID: 127, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(account.ID)
- preferredConn := newOpenAIWSConn("preferred_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
- otherConn := newOpenAIWSConn("other_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
- ap.mu.Lock()
- ap.conns[preferredConn.id] = preferredConn
- ap.conns[otherConn.id] = otherConn
- ap.lastCleanupAt = time.Now()
- ap.mu.Unlock()
-
- lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- PreferredConnID: preferredConn.id,
- ForcePreferredConn: true,
- })
- require.NoError(t, err)
- require.Equal(t, preferredConn.id, lease.ConnID(), "preferred 空闲时应直接命中")
- lease.Release()
-
- require.True(t, preferredConn.tryAcquire())
- preferredConn.waiters.Store(1)
- _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- PreferredConnID: preferredConn.id,
- ForcePreferredConn: true,
- })
- require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "严格模式下队列满应直接失败,不得漂移")
- preferredConn.waiters.Store(0)
- preferredConn.release()
-}
-
-func TestOpenAIWSConnPool_CleanupSkipsPinnedConn(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
-
- pool := newOpenAIWSConnPool(cfg)
- accountID := int64(126)
- ap := pool.getOrCreateAccountPool(accountID)
- pinnedConn := newOpenAIWSConn("pinned_conn", accountID, &openAIWSFakeConn{}, nil)
- idleConn := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
- ap.mu.Lock()
- ap.conns[pinnedConn.id] = pinnedConn
- ap.conns[idleConn.id] = idleConn
- ap.mu.Unlock()
-
- require.True(t, pool.PinConn(accountID, pinnedConn.id))
- evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
- closeOpenAIWSConns(evicted)
-
- ap.mu.Lock()
- _, pinnedExists := ap.conns[pinnedConn.id]
- _, idleExists := ap.conns[idleConn.id]
- ap.mu.Unlock()
- require.True(t, pinnedExists, "被 active ingress 绑定的连接不应被 cleanup 回收")
- require.False(t, idleExists, "非绑定的空闲连接应被回收")
-
- pool.UnpinConn(accountID, pinnedConn.id)
- evicted = pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
- closeOpenAIWSConns(evicted)
- ap.mu.Lock()
- _, pinnedExists = ap.conns[pinnedConn.id]
- ap.mu.Unlock()
- require.False(t, pinnedExists, "解绑后连接应可被正常回收")
-}
-
-func TestOpenAIWSConnPool_PinUnpinConnBranches(t *testing.T) {
- var nilPool *openAIWSConnPool
- require.False(t, nilPool.PinConn(1, "x"))
- nilPool.UnpinConn(1, "x")
-
- cfg := &config.Config{}
- pool := newOpenAIWSConnPool(cfg)
- accountID := int64(128)
- ap := &openAIWSAccountPool{
- conns: map[string]*openAIWSConn{},
- }
- pool.accounts.Store(accountID, ap)
-
- require.False(t, pool.PinConn(0, "x"))
- require.False(t, pool.PinConn(999, "x"))
- require.False(t, pool.PinConn(accountID, ""))
- require.False(t, pool.PinConn(accountID, "missing"))
-
- conn := newOpenAIWSConn("pin_refcount", accountID, &openAIWSFakeConn{}, nil)
- ap.mu.Lock()
- ap.conns[conn.id] = conn
- ap.mu.Unlock()
- require.True(t, pool.PinConn(accountID, conn.id))
- require.True(t, pool.PinConn(accountID, conn.id))
-
- ap.mu.Lock()
- require.Equal(t, 2, ap.pinnedConns[conn.id])
- ap.mu.Unlock()
-
- pool.UnpinConn(accountID, conn.id)
- ap.mu.Lock()
- require.Equal(t, 1, ap.pinnedConns[conn.id])
- ap.mu.Unlock()
-
- pool.UnpinConn(accountID, conn.id)
- ap.mu.Lock()
- _, exists := ap.pinnedConns[conn.id]
- ap.mu.Unlock()
- require.False(t, exists)
-
- pool.UnpinConn(accountID, conn.id)
- pool.UnpinConn(accountID, "")
- pool.UnpinConn(0, conn.id)
- pool.UnpinConn(999, conn.id)
-}
-
-func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
- cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
- cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
- cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
-
- pool := newOpenAIWSConnPool(cfg)
-
- oauthHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 10}
- require.Equal(t, 8, pool.effectiveMaxConnsByAccount(oauthHigh), "应受全局硬上限约束")
-
- oauthLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 3}
- require.Equal(t, 3, pool.effectiveMaxConnsByAccount(oauthLow))
-
- apiKeyHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 10}
- require.Equal(t, 6, pool.effectiveMaxConnsByAccount(apiKeyHigh), "API Key 应按系数缩放")
-
- apiKeyLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1}
- require.Equal(t, 1, pool.effectiveMaxConnsByAccount(apiKeyLow), "最小值应保持为 1")
-
- unlimited := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
- require.Equal(t, 8, pool.effectiveMaxConnsByAccount(unlimited), "无限并发应回退到全局硬上限")
-
- require.Equal(t, 8, pool.effectiveMaxConnsByAccount(nil), "缺少账号上下文应回退到全局硬上限")
-}
-
-func TestOpenAIWSConnPool_EffectiveMaxConnsDisabledFallbackHardCap(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
- cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false
- cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
- cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 1.0
-
- pool := newOpenAIWSConnPool(cfg)
- account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 2}
- require.Equal(t, 8, pool.effectiveMaxConnsByAccount(account), "关闭动态模式后应保持旧行为")
-}
-
-func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount_ModeRouterV2UsesAccountConcurrency(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
- cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
- cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0.3
- cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
-
- pool := newOpenAIWSConnPool(cfg)
-
- high := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 20}
- require.Equal(t, 20, pool.effectiveMaxConnsByAccount(high), "v2 路径应直接使用账号并发数作为池上限")
-
- nonPositive := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 0}
- require.Equal(t, 0, pool.effectiveMaxConnsByAccount(nonPositive), "并发数<=0 时应不可调度")
-}
-
-func TestOpenAIWSConnPool_AcquireRejectsWhenEffectiveMaxConnsIsZero(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
- pool := newOpenAIWSConnPool(cfg)
-
- account := &Account{ID: 901, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
- _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- })
- require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
-}
-
-func TestOpenAIWSConnLease_ReadMessageWithContextTimeout_PerRead(t *testing.T) {
- conn := newOpenAIWSConn("timeout", 1, &openAIWSBlockingConn{readDelay: 80 * time.Millisecond}, nil)
- lease := &openAIWSConnLease{conn: conn}
-
- _, err := lease.ReadMessageWithContextTimeout(context.Background(), 20*time.Millisecond)
- require.Error(t, err)
- require.ErrorIs(t, err, context.DeadlineExceeded)
-
- payload, err := lease.ReadMessageWithContextTimeout(context.Background(), 150*time.Millisecond)
- require.NoError(t, err)
- require.Contains(t, string(payload), "response.completed")
-
- parentCtx, cancel := context.WithCancel(context.Background())
- cancel()
- _, err = lease.ReadMessageWithContextTimeout(parentCtx, 150*time.Millisecond)
- require.Error(t, err)
- require.ErrorIs(t, err, context.Canceled)
-}
-
-func TestOpenAIWSConnLease_WriteJSONWithContextTimeout_RespectsParentContext(t *testing.T) {
- conn := newOpenAIWSConn("write_timeout_ctx", 1, &openAIWSWriteBlockingConn{}, nil)
- lease := &openAIWSConnLease{conn: conn}
-
- parentCtx, cancel := context.WithCancel(context.Background())
- go func() {
- time.Sleep(20 * time.Millisecond)
- cancel()
- }()
-
- start := time.Now()
- err := lease.WriteJSONWithContextTimeout(parentCtx, map[string]any{"type": "response.create"}, 2*time.Minute)
- elapsed := time.Since(start)
-
- require.Error(t, err)
- require.ErrorIs(t, err, context.Canceled)
- require.Less(t, elapsed, 200*time.Millisecond)
-}
-
-func TestOpenAIWSConnLease_PingWithTimeout(t *testing.T) {
- conn := newOpenAIWSConn("ping_ok", 1, &openAIWSFakeConn{}, nil)
- lease := &openAIWSConnLease{conn: conn}
- require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
-
- var nilLease *openAIWSConnLease
- err := nilLease.PingWithTimeout(50 * time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-}
-
-func TestOpenAIWSConn_ReadAndWriteCanProceedConcurrently(t *testing.T) {
- conn := newOpenAIWSConn("full_duplex", 1, &openAIWSBlockingConn{readDelay: 120 * time.Millisecond}, nil)
-
- readDone := make(chan error, 1)
- go func() {
- _, err := conn.readMessageWithContextTimeout(context.Background(), 200*time.Millisecond)
- readDone <- err
- }()
-
- // 让读取先占用 readMu。
- time.Sleep(20 * time.Millisecond)
-
- start := time.Now()
- err := conn.pingWithTimeout(50 * time.Millisecond)
- elapsed := time.Since(start)
-
- require.NoError(t, err)
- require.Less(t, elapsed, 80*time.Millisecond, "写路径不应被读锁长期阻塞")
- require.NoError(t, <-readDone)
-}
-
-func TestOpenAIWSConnPool_BackgroundPingSweep_EvictsDeadIdleConn(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- pool := newOpenAIWSConnPool(cfg)
-
- accountID := int64(301)
- ap := pool.getOrCreateAccountPool(accountID)
- conn := newOpenAIWSConn("dead_idle", accountID, &openAIWSPingFailConn{}, nil)
- ap.mu.Lock()
- ap.conns[conn.id] = conn
- ap.mu.Unlock()
-
- pool.runBackgroundPingSweep()
-
- ap.mu.Lock()
- _, exists := ap.conns[conn.id]
- ap.mu.Unlock()
- require.False(t, exists, "后台 ping 失败的空闲连接应被回收")
-}
-
-func TestOpenAIWSConnPool_BackgroundCleanupSweep_WithoutAcquire(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
- pool := newOpenAIWSConnPool(cfg)
-
- accountID := int64(302)
- ap := pool.getOrCreateAccountPool(accountID)
- stale := newOpenAIWSConn("stale_bg", accountID, &openAIWSFakeConn{}, nil)
- stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
- stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
- ap.mu.Lock()
- ap.conns[stale.id] = stale
- ap.mu.Unlock()
-
- pool.runBackgroundCleanupSweep(time.Now())
-
- ap.mu.Lock()
- _, exists := ap.conns[stale.id]
- ap.mu.Unlock()
- require.False(t, exists, "后台清理应在无新 acquire 时也回收过期连接")
-}
-
-func TestOpenAIWSConnPool_BackgroundWorkerGuardBranches(t *testing.T) {
- var nilPool *openAIWSConnPool
- require.NotPanics(t, func() {
- nilPool.startBackgroundWorkers()
- nilPool.runBackgroundPingWorker()
- nilPool.runBackgroundPingSweep()
- _ = nilPool.snapshotIdleConnsForPing()
- nilPool.runBackgroundCleanupWorker()
- nilPool.runBackgroundCleanupSweep(time.Now())
- })
-
- poolNoStop := &openAIWSConnPool{}
- require.NotPanics(t, func() {
- poolNoStop.startBackgroundWorkers()
- })
-
- poolStopPing := &openAIWSConnPool{workerStopCh: make(chan struct{})}
- pingDone := make(chan struct{})
- go func() {
- poolStopPing.runBackgroundPingWorker()
- close(pingDone)
- }()
- close(poolStopPing.workerStopCh)
- select {
- case <-pingDone:
- case <-time.After(500 * time.Millisecond):
- t.Fatal("runBackgroundPingWorker 未在 stop 信号后退出")
- }
-
- poolStopCleanup := &openAIWSConnPool{workerStopCh: make(chan struct{})}
- cleanupDone := make(chan struct{})
- go func() {
- poolStopCleanup.runBackgroundCleanupWorker()
- close(cleanupDone)
- }()
- close(poolStopCleanup.workerStopCh)
- select {
- case <-cleanupDone:
- case <-time.After(500 * time.Millisecond):
- t.Fatal("runBackgroundCleanupWorker 未在 stop 信号后退出")
- }
-}
-
-func TestOpenAIWSConnPool_SnapshotIdleConnsForPing_SkipsInvalidEntries(t *testing.T) {
- pool := &openAIWSConnPool{}
- pool.accounts.Store("invalid-key", &openAIWSAccountPool{})
- pool.accounts.Store(int64(123), "invalid-value")
-
- accountID := int64(123)
- ap := &openAIWSAccountPool{
- conns: make(map[string]*openAIWSConn),
- }
- ap.conns["nil_conn"] = nil
-
- leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
- require.True(t, leased.tryAcquire())
- ap.conns[leased.id] = leased
-
- waiting := newOpenAIWSConn("waiting", accountID, &openAIWSFakeConn{}, nil)
- waiting.waiters.Store(1)
- ap.conns[waiting.id] = waiting
-
- idle := newOpenAIWSConn("idle", accountID, &openAIWSFakeConn{}, nil)
- ap.conns[idle.id] = idle
-
- pool.accounts.Store(accountID, ap)
- candidates := pool.snapshotIdleConnsForPing()
- require.Len(t, candidates, 1)
- require.Equal(t, idle.id, candidates[0].conn.id)
-}
-
-func TestOpenAIWSConnPool_RunBackgroundCleanupSweep_SkipsInvalidAndUsesAccountCap(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
- cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
-
- pool := &openAIWSConnPool{cfg: cfg}
- pool.accounts.Store("bad-key", "bad-value")
-
- accountID := int64(2026)
- ap := &openAIWSAccountPool{
- conns: make(map[string]*openAIWSConn),
- }
- ap.conns["nil_conn"] = nil
- stale := newOpenAIWSConn("stale_bg_cleanup", accountID, &openAIWSFakeConn{}, nil)
- stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
- stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
- ap.conns[stale.id] = stale
- ap.lastAcquire = &openAIWSAcquireRequest{
- Account: &Account{
- ID: accountID,
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Concurrency: 1,
- },
- }
- pool.accounts.Store(accountID, ap)
-
- now := time.Now()
- require.NotPanics(t, func() {
- pool.runBackgroundCleanupSweep(now)
- })
-
- ap.mu.Lock()
- _, nilConnExists := ap.conns["nil_conn"]
- _, exists := ap.conns[stale.id]
- lastCleanupAt := ap.lastCleanupAt
- ap.mu.Unlock()
-
- require.False(t, nilConnExists, "后台清理应移除无效 nil 连接条目")
- require.False(t, exists, "后台清理应清理过期连接")
- require.Equal(t, now, lastCleanupAt)
-}
-
-func TestOpenAIWSConnPool_QueueLimitPerConn_DefaultAndConfigured(t *testing.T) {
- var nilPool *openAIWSConnPool
- require.Equal(t, 256, nilPool.queueLimitPerConn())
-
- pool := &openAIWSConnPool{cfg: &config.Config{}}
- require.Equal(t, 256, pool.queueLimitPerConn())
-
- pool.cfg.Gateway.OpenAIWS.QueueLimitPerConn = 9
- require.Equal(t, 9, pool.queueLimitPerConn())
-}
-
-func TestOpenAIWSConnPool_Close(t *testing.T) {
- cfg := &config.Config{}
- pool := newOpenAIWSConnPool(cfg)
-
- // Close 应该可以安全调用
- pool.Close()
-
- // workerStopCh 应已关闭
- select {
- case <-pool.workerStopCh:
- // 预期:channel 已关闭
- default:
- t.Fatal("Close 后 workerStopCh 应已关闭")
- }
-
- // 多次调用 Close 不应 panic
- pool.Close()
-
- // nil pool 调用 Close 不应 panic
- var nilPool *openAIWSConnPool
- nilPool.Close()
-}
-
-func TestOpenAIWSDialError_ErrorAndUnwrap(t *testing.T) {
- baseErr := errors.New("boom")
- dialErr := &openAIWSDialError{StatusCode: 502, Err: baseErr}
- require.Contains(t, dialErr.Error(), "status=502")
- require.ErrorIs(t, dialErr.Unwrap(), baseErr)
-
- noStatus := &openAIWSDialError{Err: baseErr}
- require.Contains(t, noStatus.Error(), "boom")
-
- var nilDialErr *openAIWSDialError
- require.Equal(t, "", nilDialErr.Error())
- require.NoError(t, nilDialErr.Unwrap())
-}
-
-func TestOpenAIWSConnLease_ReadWriteHelpersAndConnStats(t *testing.T) {
- conn := newOpenAIWSConn("helper_conn", 1, &openAIWSFakeConn{}, http.Header{
- "X-Test": []string{" value "},
- })
- lease := &openAIWSConnLease{conn: conn}
-
- require.NoError(t, lease.WriteJSONContext(context.Background(), map[string]any{"type": "response.create"}))
- payload, err := lease.ReadMessage(100 * time.Millisecond)
- require.NoError(t, err)
- require.Contains(t, string(payload), "response.completed")
-
- payload, err = lease.ReadMessageContext(context.Background())
- require.NoError(t, err)
- require.Contains(t, string(payload), "response.completed")
-
- payload, err = conn.readMessageWithTimeout(100 * time.Millisecond)
- require.NoError(t, err)
- require.Contains(t, string(payload), "response.completed")
-
- require.Equal(t, "value", conn.handshakeHeader(" X-Test "))
- require.NotZero(t, conn.createdAt())
- require.NotZero(t, conn.lastUsedAt())
- require.GreaterOrEqual(t, conn.age(time.Now()), time.Duration(0))
- require.GreaterOrEqual(t, conn.idleDuration(time.Now()), time.Duration(0))
- require.False(t, conn.isLeased())
-
- // 覆盖空上下文路径
- _, err = conn.readMessage(context.Background())
- require.NoError(t, err)
-
- // 覆盖 nil 保护分支
- var nilConn *openAIWSConn
- require.ErrorIs(t, nilConn.writeJSONWithTimeout(context.Background(), map[string]any{}, time.Second), errOpenAIWSConnClosed)
- _, err = nilConn.readMessageWithTimeout(10 * time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- _, err = nilConn.readMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-}
-
-func TestOpenAIWSConnPool_PickOldestIdleAndAccountPoolLoad(t *testing.T) {
- pool := &openAIWSConnPool{}
- accountID := int64(404)
- ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
-
- idleOld := newOpenAIWSConn("idle_old", accountID, &openAIWSFakeConn{}, nil)
- idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
- idleNew := newOpenAIWSConn("idle_new", accountID, &openAIWSFakeConn{}, nil)
- idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
- leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
- require.True(t, leased.tryAcquire())
- leased.waiters.Store(2)
-
- ap.conns[idleOld.id] = idleOld
- ap.conns[idleNew.id] = idleNew
- ap.conns[leased.id] = leased
-
- oldest := pool.pickOldestIdleConnLocked(ap)
- require.NotNil(t, oldest)
- require.Equal(t, idleOld.id, oldest.id)
-
- inflight, waiters := accountPoolLoadLocked(ap)
- require.Equal(t, 1, inflight)
- require.Equal(t, 2, waiters)
-
- pool.accounts.Store(accountID, ap)
- loadInflight, loadWaiters, conns := pool.AccountPoolLoad(accountID)
- require.Equal(t, 1, loadInflight)
- require.Equal(t, 2, loadWaiters)
- require.Equal(t, 3, conns)
-
- zeroInflight, zeroWaiters, zeroConns := pool.AccountPoolLoad(0)
- require.Equal(t, 0, zeroInflight)
- require.Equal(t, 0, zeroWaiters)
- require.Equal(t, 0, zeroConns)
-}
-
-func TestOpenAIWSConnPool_Close_WaitsWorkerGroupAndNilStopChannel(t *testing.T) {
- pool := &openAIWSConnPool{}
- release := make(chan struct{})
- pool.workerWg.Add(1)
- go func() {
- defer pool.workerWg.Done()
- <-release
- }()
-
- closed := make(chan struct{})
- go func() {
- pool.Close()
- close(closed)
- }()
-
- select {
- case <-closed:
- t.Fatal("Close 不应在 WaitGroup 未完成时提前返回")
- case <-time.After(30 * time.Millisecond):
- }
-
- close(release)
- select {
- case <-closed:
- case <-time.After(time.Second):
- t.Fatal("Close 未等待 workerWg 完成")
- }
-}
-
-func TestOpenAIWSConnPool_Close_ClosesOnlyIdleConnections(t *testing.T) {
- pool := &openAIWSConnPool{
- workerStopCh: make(chan struct{}),
- }
-
- accountID := int64(606)
- ap := &openAIWSAccountPool{
- conns: map[string]*openAIWSConn{},
- }
- idle := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
- leased := newOpenAIWSConn("leased_conn", accountID, &openAIWSFakeConn{}, nil)
- require.True(t, leased.tryAcquire())
-
- ap.conns[idle.id] = idle
- ap.conns[leased.id] = leased
- pool.accounts.Store(accountID, ap)
- pool.accounts.Store("invalid-key", "invalid-value")
-
- pool.Close()
-
- select {
- case <-idle.closedCh:
- // idle should be closed
- default:
- t.Fatal("空闲连接应在 Close 时被关闭")
- }
-
- select {
- case <-leased.closedCh:
- t.Fatal("已租赁连接不应在 Close 时被关闭")
- default:
- }
-
- leased.release()
- pool.Close()
-}
-
-func TestOpenAIWSConnPool_RunBackgroundPingSweep_ConcurrencyLimit(t *testing.T) {
- cfg := &config.Config{}
- pool := newOpenAIWSConnPool(cfg)
- accountID := int64(505)
- ap := pool.getOrCreateAccountPool(accountID)
-
- var current atomic.Int32
- var maxConcurrent atomic.Int32
- release := make(chan struct{})
- for i := 0; i < 25; i++ {
- conn := newOpenAIWSConn(pool.nextConnID(accountID), accountID, &openAIWSPingBlockingConn{
- current: ¤t,
- maxConcurrent: &maxConcurrent,
- release: release,
- }, nil)
- ap.mu.Lock()
- ap.conns[conn.id] = conn
- ap.mu.Unlock()
- }
-
- done := make(chan struct{})
- go func() {
- pool.runBackgroundPingSweep()
- close(done)
- }()
-
- require.Eventually(t, func() bool {
- return maxConcurrent.Load() >= 10
- }, time.Second, 10*time.Millisecond)
-
- close(release)
- select {
- case <-done:
- case <-time.After(2 * time.Second):
- t.Fatal("runBackgroundPingSweep 未在释放后完成")
- }
-
- require.LessOrEqual(t, maxConcurrent.Load(), int32(10))
-}
-
-func TestOpenAIWSConnLease_BasicGetterBranches(t *testing.T) {
- var nilLease *openAIWSConnLease
- require.Equal(t, "", nilLease.ConnID())
- require.Equal(t, time.Duration(0), nilLease.QueueWaitDuration())
- require.Equal(t, time.Duration(0), nilLease.ConnPickDuration())
- require.False(t, nilLease.Reused())
- require.Equal(t, "", nilLease.HandshakeHeader("x-test"))
- require.False(t, nilLease.IsPrewarmed())
- nilLease.MarkPrewarmed()
- nilLease.Release()
-
- conn := newOpenAIWSConn("getter_conn", 1, &openAIWSFakeConn{}, http.Header{"X-Test": []string{"ok"}})
- lease := &openAIWSConnLease{
- conn: conn,
- queueWait: 3 * time.Millisecond,
- connPick: 4 * time.Millisecond,
- reused: true,
- }
- require.Equal(t, "getter_conn", lease.ConnID())
- require.Equal(t, 3*time.Millisecond, lease.QueueWaitDuration())
- require.Equal(t, 4*time.Millisecond, lease.ConnPickDuration())
- require.True(t, lease.Reused())
- require.Equal(t, "ok", lease.HandshakeHeader("x-test"))
- require.False(t, lease.IsPrewarmed())
- lease.MarkPrewarmed()
- require.True(t, lease.IsPrewarmed())
- lease.Release()
-}
-
-func TestOpenAIWSConnPool_UtilityBranches(t *testing.T) {
- var nilPool *openAIWSConnPool
- require.Equal(t, OpenAIWSPoolMetricsSnapshot{}, nilPool.SnapshotMetrics())
- require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, nilPool.SnapshotTransportMetrics())
-
- pool := &openAIWSConnPool{cfg: &config.Config{}}
- pool.metrics.acquireTotal.Store(7)
- pool.metrics.acquireReuseTotal.Store(3)
- metrics := pool.SnapshotMetrics()
- require.Equal(t, int64(7), metrics.AcquireTotal)
- require.Equal(t, int64(3), metrics.AcquireReuseTotal)
-
- // 非 transport metrics dialer 路径
- pool.clientDialer = &openAIWSFakeDialer{}
- require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, pool.SnapshotTransportMetrics())
- pool.setClientDialerForTest(nil)
- require.NotNil(t, pool.clientDialer)
-
- require.Equal(t, 8, nilPool.maxConnsHardCap())
- require.False(t, nilPool.dynamicMaxConnsEnabled())
- require.Equal(t, 1.0, nilPool.maxConnsFactorByAccount(nil))
- require.Equal(t, 0, nilPool.minIdlePerAccount())
- require.Equal(t, 4, nilPool.maxIdlePerAccount())
- require.Equal(t, 256, nilPool.queueLimitPerConn())
- require.Equal(t, 0.7, nilPool.targetUtilization())
- require.Equal(t, time.Duration(0), nilPool.prewarmCooldown())
- require.Equal(t, 10*time.Second, nilPool.dialTimeout())
-
- // shouldSuppressPrewarmLocked 覆盖 3 条分支
- now := time.Now()
- apNilFail := &openAIWSAccountPool{prewarmFails: 1}
- require.False(t, pool.shouldSuppressPrewarmLocked(apNilFail, now))
- apZeroTime := &openAIWSAccountPool{prewarmFails: 2}
- require.False(t, pool.shouldSuppressPrewarmLocked(apZeroTime, now))
- require.Equal(t, 0, apZeroTime.prewarmFails)
- apOldFail := &openAIWSAccountPool{prewarmFails: 2, prewarmFailAt: now.Add(-openAIWSPrewarmFailureWindow - time.Second)}
- require.False(t, pool.shouldSuppressPrewarmLocked(apOldFail, now))
- apRecentFail := &openAIWSAccountPool{prewarmFails: openAIWSPrewarmFailureSuppress, prewarmFailAt: now}
- require.True(t, pool.shouldSuppressPrewarmLocked(apRecentFail, now))
-
- // recordConnPickDuration 的保护分支
- nilPool.recordConnPickDuration(10 * time.Millisecond)
- pool.recordConnPickDuration(-10 * time.Millisecond)
- require.Equal(t, int64(1), pool.metrics.connPickTotal.Load())
-
- // account pool 读写分支
- require.Nil(t, nilPool.getOrCreateAccountPool(1))
- require.Nil(t, pool.getOrCreateAccountPool(0))
- pool.accounts.Store(int64(7), "invalid")
- ap := pool.getOrCreateAccountPool(7)
- require.NotNil(t, ap)
- _, ok := pool.getAccountPool(0)
- require.False(t, ok)
- _, ok = pool.getAccountPool(12345)
- require.False(t, ok)
- pool.accounts.Store(int64(8), "bad-type")
- _, ok = pool.getAccountPool(8)
- require.False(t, ok)
-
- // health check 条件
- require.False(t, pool.shouldHealthCheckConn(nil))
- conn := newOpenAIWSConn("health", 1, &openAIWSFakeConn{}, nil)
- conn.lastUsedNano.Store(time.Now().Add(-openAIWSConnHealthCheckIdle - time.Second).UnixNano())
- require.True(t, pool.shouldHealthCheckConn(conn))
-}
-
-func TestOpenAIWSConn_LeaseAndTimeHelpers_NilAndClosedBranches(t *testing.T) {
- var nilConn *openAIWSConn
- nilConn.touch()
- require.Equal(t, time.Time{}, nilConn.createdAt())
- require.Equal(t, time.Time{}, nilConn.lastUsedAt())
- require.Equal(t, time.Duration(0), nilConn.idleDuration(time.Now()))
- require.Equal(t, time.Duration(0), nilConn.age(time.Now()))
- require.False(t, nilConn.isLeased())
- require.False(t, nilConn.isPrewarmed())
- nilConn.markPrewarmed()
-
- conn := newOpenAIWSConn("lease_state", 1, &openAIWSFakeConn{}, nil)
- require.True(t, conn.tryAcquire())
- require.True(t, conn.isLeased())
- conn.release()
- require.False(t, conn.isLeased())
- conn.close()
- require.False(t, conn.tryAcquire())
-
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- err := conn.acquire(ctx)
- require.Error(t, err)
-}
-
-func TestOpenAIWSConnLease_ReadWriteNilConnBranches(t *testing.T) {
- lease := &openAIWSConnLease{}
- require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
- require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
- _, err := lease.ReadMessage(10 * time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- _, err = lease.ReadMessageContext(context.Background())
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-}
-
-func TestOpenAIWSConnLease_ReleasedLeaseGuards(t *testing.T) {
- conn := newOpenAIWSConn("released_guard", 1, &openAIWSFakeConn{}, nil)
- lease := &openAIWSConnLease{conn: conn}
-
- require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
-
- lease.Release()
- lease.Release() // idempotent
-
- require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
- require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
- require.ErrorIs(t, lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
-
- _, err := lease.ReadMessage(10 * time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- _, err = lease.ReadMessageContext(context.Background())
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-
- require.ErrorIs(t, lease.PingWithTimeout(50*time.Millisecond), errOpenAIWSConnClosed)
-}
-
-func TestOpenAIWSConnLease_MarkBrokenAfterRelease_NoEviction(t *testing.T) {
- conn := newOpenAIWSConn("released_markbroken", 7, &openAIWSFakeConn{}, nil)
- ap := &openAIWSAccountPool{
- conns: map[string]*openAIWSConn{
- conn.id: conn,
- },
- }
- pool := &openAIWSConnPool{}
- pool.accounts.Store(int64(7), ap)
-
- lease := &openAIWSConnLease{
- pool: pool,
- accountID: 7,
- conn: conn,
- }
-
- lease.Release()
- lease.MarkBroken()
-
- ap.mu.Lock()
- _, exists := ap.conns[conn.id]
- ap.mu.Unlock()
- require.True(t, exists, "released lease should not evict active pool connection")
-}
-
-func TestOpenAIWSConn_AdditionalGuardBranches(t *testing.T) {
- var nilConn *openAIWSConn
- require.False(t, nilConn.tryAcquire())
- require.ErrorIs(t, nilConn.acquire(context.Background()), errOpenAIWSConnClosed)
- nilConn.release()
- nilConn.close()
- require.Equal(t, "", nilConn.handshakeHeader("x-test"))
-
- connBusy := newOpenAIWSConn("busy_ctx", 1, &openAIWSFakeConn{}, nil)
- require.True(t, connBusy.tryAcquire())
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- require.ErrorIs(t, connBusy.acquire(ctx), context.Canceled)
- connBusy.release()
-
- connClosed := newOpenAIWSConn("closed_guard", 1, &openAIWSFakeConn{}, nil)
- connClosed.close()
- require.ErrorIs(
- t,
- connClosed.writeJSONWithTimeout(context.Background(), map[string]any{"k": "v"}, time.Second),
- errOpenAIWSConnClosed,
- )
- _, err := connClosed.readMessageWithContextTimeout(context.Background(), time.Second)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- require.ErrorIs(t, connClosed.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
-
- connNoWS := newOpenAIWSConn("no_ws", 1, nil, nil)
- require.ErrorIs(t, connNoWS.writeJSON(map[string]any{"k": "v"}, context.Background()), errOpenAIWSConnClosed)
- _, err = connNoWS.readMessage(context.Background())
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- require.ErrorIs(t, connNoWS.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
- require.Equal(t, "", connNoWS.handshakeHeader("x-test"))
-
- connOK := newOpenAIWSConn("ok", 1, &openAIWSFakeConn{}, nil)
- require.NoError(t, connOK.writeJSON(map[string]any{"k": "v"}, nil))
- _, err = connOK.readMessageWithContextTimeout(context.Background(), 0)
- require.NoError(t, err)
- require.NoError(t, connOK.pingWithTimeout(0))
-
- connZero := newOpenAIWSConn("zero_ts", 1, &openAIWSFakeConn{}, nil)
- connZero.createdAtNano.Store(0)
- connZero.lastUsedNano.Store(0)
- require.True(t, connZero.createdAt().IsZero())
- require.True(t, connZero.lastUsedAt().IsZero())
- require.Equal(t, time.Duration(0), connZero.idleDuration(time.Now()))
- require.Equal(t, time.Duration(0), connZero.age(time.Now()))
-
- require.Nil(t, cloneOpenAIWSAcquireRequestPtr(nil))
- copied := cloneHeader(http.Header{
- "X-Empty": []string{},
- "X-Test": []string{"v1"},
- })
- require.Contains(t, copied, "X-Empty")
- require.Nil(t, copied["X-Empty"])
- require.Equal(t, "v1", copied.Get("X-Test"))
-
- closeOpenAIWSConns([]*openAIWSConn{nil, connOK})
-}
-
-func TestOpenAIWSConnLease_MarkBrokenEvictsConn(t *testing.T) {
- pool := newOpenAIWSConnPool(&config.Config{})
- accountID := int64(5001)
- conn := newOpenAIWSConn("broken_me", accountID, &openAIWSFakeConn{}, nil)
- ap := pool.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.conns[conn.id] = conn
- ap.mu.Unlock()
-
- lease := &openAIWSConnLease{
- pool: pool,
- accountID: accountID,
- conn: conn,
- }
- lease.MarkBroken()
-
- ap.mu.Lock()
- _, exists := ap.conns[conn.id]
- ap.mu.Unlock()
- require.False(t, exists)
- require.False(t, conn.tryAcquire(), "被标记为 broken 的连接应被关闭")
-}
-
-func TestOpenAIWSConnPool_TargetConnCountAndPrewarmBranches(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- pool := newOpenAIWSConnPool(cfg)
-
- require.Equal(t, 0, pool.targetConnCountLocked(nil, 1))
- ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
- require.Equal(t, 0, pool.targetConnCountLocked(ap, 0))
-
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 3
- require.Equal(t, 1, pool.targetConnCountLocked(ap, 1), "minIdle 应被 maxConns 截断")
-
- // 覆盖 waiters>0 且 target 需要至少 len(conns)+1 的分支
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.9
- busy := newOpenAIWSConn("busy_target", 2, &openAIWSFakeConn{}, nil)
- require.True(t, busy.tryAcquire())
- busy.waiters.Store(1)
- ap.conns[busy.id] = busy
- target := pool.targetConnCountLocked(ap, 4)
- require.GreaterOrEqual(t, target, len(ap.conns)+1)
-
- // prewarm: account pool 缺失时,拨号后的连接应被关闭并提前返回
- req := openAIWSAcquireRequest{
- Account: &Account{ID: 999, Platform: PlatformOpenAI, Type: AccountTypeAPIKey},
- WSURL: "wss://example.com/v1/responses",
- }
- pool.prewarmConns(999, req, 1)
-
- // prewarm: 拨号失败分支(prewarmFails 累加)
- accountID := int64(1000)
- failPool := newOpenAIWSConnPool(cfg)
- failPool.setClientDialerForTest(&openAIWSAlwaysFailDialer{})
- apFail := failPool.getOrCreateAccountPool(accountID)
- apFail.mu.Lock()
- apFail.creating = 1
- apFail.mu.Unlock()
- req.Account.ID = accountID
- failPool.prewarmConns(accountID, req, 1)
- apFail.mu.Lock()
- require.GreaterOrEqual(t, apFail.prewarmFails, 1)
- apFail.mu.Unlock()
-}
-
-func TestOpenAIWSConnPool_Acquire_ErrorBranches(t *testing.T) {
- var nilPool *openAIWSConnPool
- _, err := nilPool.Acquire(context.Background(), openAIWSAcquireRequest{})
- require.Error(t, err)
-
- pool := newOpenAIWSConnPool(&config.Config{})
- _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: &Account{ID: 1},
- WSURL: " ",
- })
- require.Error(t, err)
- require.Contains(t, err.Error(), "ws url is empty")
-
- // target=nil 分支:池满且仅有 nil 连接
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
- fullPool := newOpenAIWSConnPool(cfg)
- account := &Account{ID: 2001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := fullPool.getOrCreateAccountPool(account.ID)
- ap.mu.Lock()
- ap.conns["nil"] = nil
- ap.lastCleanupAt = time.Now()
- ap.mu.Unlock()
- _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- })
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-
- // queue full 分支:waiters 达上限
- account2 := &Account{ID: 2002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap2 := fullPool.getOrCreateAccountPool(account2.ID)
- conn := newOpenAIWSConn("queue_full", account2.ID, &openAIWSFakeConn{}, nil)
- require.True(t, conn.tryAcquire())
- conn.waiters.Store(1)
- ap2.mu.Lock()
- ap2.conns[conn.id] = conn
- ap2.lastCleanupAt = time.Now()
- ap2.mu.Unlock()
- _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account2,
- WSURL: "wss://example.com/v1/responses",
- })
- require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
-}
-
-type openAIWSFakeDialer struct{}
-
-func (d *openAIWSFakeDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = headers
- _ = proxyURL
- return &openAIWSFakeConn{}, 0, nil, nil
-}
-
-type openAIWSCountingDialer struct {
- mu sync.Mutex
- dialCount int
-}
-
-type openAIWSAlwaysFailDialer struct {
- mu sync.Mutex
- dialCount int
-}
-
-type openAIWSPingBlockingConn struct {
- current *atomic.Int32
- maxConcurrent *atomic.Int32
- release <-chan struct{}
-}
-
-func (c *openAIWSPingBlockingConn) WriteJSON(context.Context, any) error {
- return nil
-}
-
-func (c *openAIWSPingBlockingConn) ReadMessage(context.Context) ([]byte, error) {
- return []byte(`{"type":"response.completed","response":{"id":"resp_blocking_ping"}}`), nil
-}
-
-func (c *openAIWSPingBlockingConn) Ping(ctx context.Context) error {
- if c.current == nil || c.maxConcurrent == nil {
- return nil
- }
-
- now := c.current.Add(1)
- for {
- prev := c.maxConcurrent.Load()
- if now <= prev || c.maxConcurrent.CompareAndSwap(prev, now) {
- break
- }
- }
- defer c.current.Add(-1)
-
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-c.release:
- return nil
- }
-}
-
-func (c *openAIWSPingBlockingConn) Close() error {
- return nil
-}
-
-func (d *openAIWSCountingDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = headers
- _ = proxyURL
- d.mu.Lock()
- d.dialCount++
- d.mu.Unlock()
- return &openAIWSFakeConn{}, 0, nil, nil
-}
-
-func (d *openAIWSCountingDialer) DialCount() int {
- d.mu.Lock()
- defer d.mu.Unlock()
- return d.dialCount
-}
-
-func (d *openAIWSAlwaysFailDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = headers
- _ = proxyURL
- d.mu.Lock()
- d.dialCount++
- d.mu.Unlock()
- return nil, 503, nil, errors.New("dial failed")
-}
-
-func (d *openAIWSAlwaysFailDialer) DialCount() int {
- d.mu.Lock()
- defer d.mu.Unlock()
- return d.dialCount
-}
-
-type openAIWSFakeConn struct {
- mu sync.Mutex
- closed bool
- payload [][]byte
-}
-
-func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error {
- _ = ctx
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.closed {
- return errors.New("closed")
- }
- c.payload = append(c.payload, []byte("ok"))
- _ = value
- return nil
-}
-
-func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) {
- _ = ctx
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.closed {
- return nil, errors.New("closed")
- }
- return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil
-}
-
-func (c *openAIWSFakeConn) Ping(ctx context.Context) error {
- _ = ctx
- return nil
-}
-
-func (c *openAIWSFakeConn) Close() error {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.closed = true
- return nil
-}
-
-type openAIWSBlockingConn struct {
- readDelay time.Duration
-}
-
-func (c *openAIWSBlockingConn) WriteJSON(ctx context.Context, value any) error {
- _ = ctx
- _ = value
- return nil
-}
-
-func (c *openAIWSBlockingConn) ReadMessage(ctx context.Context) ([]byte, error) {
- delay := c.readDelay
- if delay <= 0 {
- delay = 10 * time.Millisecond
- }
- timer := time.NewTimer(delay)
- defer timer.Stop()
-
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-timer.C:
- return []byte(`{"type":"response.completed","response":{"id":"resp_blocking"}}`), nil
- }
-}
-
-func (c *openAIWSBlockingConn) Ping(ctx context.Context) error {
- _ = ctx
- return nil
-}
-
-func (c *openAIWSBlockingConn) Close() error {
- return nil
-}
-
-type openAIWSWriteBlockingConn struct{}
-
-func (c *openAIWSWriteBlockingConn) WriteJSON(ctx context.Context, _ any) error {
- <-ctx.Done()
- return ctx.Err()
-}
-
-func (c *openAIWSWriteBlockingConn) ReadMessage(context.Context) ([]byte, error) {
- return []byte(`{"type":"response.completed","response":{"id":"resp_write_block"}}`), nil
-}
-
-func (c *openAIWSWriteBlockingConn) Ping(context.Context) error {
- return nil
-}
-
-func (c *openAIWSWriteBlockingConn) Close() error {
- return nil
-}
-
-type openAIWSPingFailConn struct{}
-
-func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error {
- return nil
-}
-
-func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) {
- return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil
-}
-
-func (c *openAIWSPingFailConn) Ping(context.Context) error {
- return errors.New("ping failed")
-}
-
-func (c *openAIWSPingFailConn) Close() error {
- return nil
-}
-
-type openAIWSContextProbeConn struct {
- lastWriteCtx context.Context
-}
-
-func (c *openAIWSContextProbeConn) WriteJSON(ctx context.Context, _ any) error {
- c.lastWriteCtx = ctx
- return nil
-}
-
-func (c *openAIWSContextProbeConn) ReadMessage(context.Context) ([]byte, error) {
- return []byte(`{"type":"response.completed","response":{"id":"resp_ctx_probe"}}`), nil
-}
-
-func (c *openAIWSContextProbeConn) Ping(context.Context) error {
- return nil
-}
-
-func (c *openAIWSContextProbeConn) Close() error {
- return nil
-}
-
-type openAIWSNilConnDialer struct{}
-
-func (d *openAIWSNilConnDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = headers
- _ = proxyURL
- return nil, 200, nil, nil
-}
-
-func TestOpenAIWSConnPool_DialConnNilConnection(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
-
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(&openAIWSNilConnDialer{})
- account := &Account{ID: 91, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
-
- _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- })
- require.Error(t, err)
- require.Contains(t, err.Error(), "nil connection")
-}
-
-func TestOpenAIWSConnPool_SnapshotTransportMetrics(t *testing.T) {
- cfg := &config.Config{}
- pool := newOpenAIWSConnPool(cfg)
-
- dialer, ok := pool.clientDialer.(*coderOpenAIWSClientDialer)
- require.True(t, ok)
-
- _, err := dialer.proxyHTTPClient("http://127.0.0.1:28080")
- require.NoError(t, err)
- _, err = dialer.proxyHTTPClient("http://127.0.0.1:28080")
- require.NoError(t, err)
- _, err = dialer.proxyHTTPClient("http://127.0.0.1:28081")
- require.NoError(t, err)
-
- snapshot := pool.SnapshotTransportMetrics()
- require.Equal(t, int64(1), snapshot.ProxyClientCacheHits)
- require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses)
- require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001)
-}
diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go
index df4d4871c..69aaeaec0 100644
--- a/backend/internal/service/openai_ws_protocol_forward_test.go
+++ b/backend/internal/service/openai_ws_protocol_forward_test.go
@@ -30,6 +30,7 @@ func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -201,7 +202,9 @@ func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *
func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
+ var wsAttempts atomic.Int32
ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ wsAttempts.Add(1)
w.WriteHeader(http.StatusUpgradeRequired)
_, _ = w.Write([]byte(`upgrade required`))
}))
@@ -211,6 +214,7 @@ func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) {
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -260,6 +264,7 @@ func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) {
require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP")
require.Equal(t, http.StatusUpgradeRequired, rec.Code)
require.Contains(t, rec.Body.String(), "426")
+ require.Equal(t, int32(1), wsAttempts.Load(), "426 upgrade_required 应快速失败,不应进行 WS 重试")
}
func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) {
@@ -273,6 +278,7 @@ func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) {
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -332,6 +338,7 @@ func TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled(t *testing.
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -419,6 +426,7 @@ func TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenRetu
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
c.String(http.StatusAccepted, "already-written")
upstream := &httpUpstreamRecorder{
@@ -508,6 +516,7 @@ func TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP(t *testin
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -590,6 +599,7 @@ func TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP(t *test
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -672,6 +682,7 @@ func TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP(t *tes
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -759,6 +770,7 @@ func TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbac
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -866,6 +878,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDrop
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -966,6 +979,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryF
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -1064,6 +1078,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryW
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -1161,6 +1176,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOn
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
diff --git a/backend/internal/service/openai_ws_protocol_resolver.go b/backend/internal/service/openai_ws_protocol_resolver.go
index 368643bea..d4fcb472c 100644
--- a/backend/internal/service/openai_ws_protocol_resolver.go
+++ b/backend/internal/service/openai_ws_protocol_resolver.go
@@ -69,7 +69,7 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
switch mode {
case OpenAIWSIngressModeOff:
return openAIWSHTTPDecision("account_mode_off")
- case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
+ case OpenAIWSIngressModeCtxPool, OpenAIWSIngressModePassthrough:
// continue
default:
return openAIWSHTTPDecision("account_mode_off")
diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go
index 5be76e28f..da23f0eb1 100644
--- a/backend/internal/service/openai_ws_protocol_resolver_test.go
+++ b/backend/internal/service/openai_ws_protocol_resolver_test.go
@@ -143,21 +143,49 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
- cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
+ cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeOff
- account := &Account{
- Platform: PlatformOpenAI,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Extra: map[string]any{
- "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
- },
- }
-
- t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
+ t.Run("dedicated mode maps to ctx_pool and routes to ws v2", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
+ },
+ }
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
+ // dedicated is now mapped to ctx_pool for backward compatibility
+ require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
+ require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
+ })
+
+ t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) {
+ ctxPoolAccount := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
+ },
+ }
+ decision := NewOpenAIWSProtocolResolver(cfg).Resolve(ctxPoolAccount)
+ require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
+ require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
+ })
+
+ t.Run("passthrough mode routes to ws v2", func(t *testing.T) {
+ passthroughAccount := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ },
+ }
+ decision := NewOpenAIWSProtocolResolver(cfg).Resolve(passthroughAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
- require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
+ require.Equal(t, "ws_v2_mode_passthrough", decision.Reason)
})
t.Run("off mode routes to http", func(t *testing.T) {
@@ -174,7 +202,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
require.Equal(t, "account_mode_off", decision.Reason)
})
- t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
+ t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) {
legacyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
@@ -185,7 +213,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
- require.Equal(t, "ws_v2_mode_shared", decision.Reason)
+ require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
})
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
@@ -193,7 +221,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
- "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
diff --git a/backend/internal/service/openai_ws_recovery.go b/backend/internal/service/openai_ws_recovery.go
new file mode 100644
index 000000000..c8528de0f
--- /dev/null
+++ b/backend/internal/service/openai_ws_recovery.go
@@ -0,0 +1,760 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "strings"
+ "time"
+
+ coderws "github.com/coder/websocket"
+)
+
+// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。
+type openAIWSFallbackError struct {
+ Reason string
+ Err error
+}
+
+func (e *openAIWSFallbackError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.Err == nil {
+ return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason))
+ }
+ return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err)
+}
+
+func (e *openAIWSFallbackError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.Err
+}
+
+func wrapOpenAIWSFallback(reason string, err error) error {
+ return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err}
+}
+
+// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。
+type OpenAIWSClientCloseError struct {
+ statusCode coderws.StatusCode
+ reason string
+ err error
+}
+
+type openAIWSIngressTurnError struct {
+ stage string
+ cause error
+ wroteDownstream bool
+ partialResult *OpenAIForwardResult
+}
+
+type openAIWSIngressUpstreamLease interface {
+ ConnID() string
+ QueueWaitDuration() time.Duration
+ ConnPickDuration() time.Duration
+ Reused() bool
+ ScheduleLayer() string
+ StickinessLevel() string
+ MigrationUsed() bool
+ HandshakeHeader(name string) string
+ IsPrewarmed() bool
+ MarkPrewarmed()
+ WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error
+ ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error)
+ PingWithTimeout(timeout time.Duration) error
+ MarkBroken()
+ Yield()
+ Release()
+}
+
+func (e *openAIWSIngressTurnError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.cause == nil {
+ return strings.TrimSpace(e.stage)
+ }
+ return e.cause.Error()
+}
+
+func (e *openAIWSIngressTurnError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.cause
+}
+
+func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error {
+ return wrapOpenAIWSIngressTurnErrorWithPartial(stage, cause, wroteDownstream, nil)
+}
+
+func cloneOpenAIForwardResult(result *OpenAIForwardResult) *OpenAIForwardResult {
+ if result == nil {
+ return nil
+ }
+ cloned := *result
+ if result.PendingFunctionCallIDs != nil {
+ cloned.PendingFunctionCallIDs = make([]string, len(result.PendingFunctionCallIDs))
+ copy(cloned.PendingFunctionCallIDs, result.PendingFunctionCallIDs)
+ }
+ return &cloned
+}
+
+func wrapOpenAIWSIngressTurnErrorWithPartial(stage string, cause error, wroteDownstream bool, partialResult *OpenAIForwardResult) error {
+ if cause == nil {
+ return nil
+ }
+ return &openAIWSIngressTurnError{
+ stage: strings.TrimSpace(stage),
+ cause: cause,
+ wroteDownstream: wroteDownstream,
+ partialResult: cloneOpenAIForwardResult(partialResult),
+ }
+}
+
+// OpenAIWSIngressTurnPartialResult returns usage-bearing partial turn result
+// when WS ingress turn aborts after receiving upstream events.
+func OpenAIWSIngressTurnPartialResult(err error) (*OpenAIForwardResult, bool) {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil || turnErr.partialResult == nil {
+ return nil, false
+ }
+ return cloneOpenAIForwardResult(turnErr.partialResult), true
+}
+
+func isOpenAIWSIngressTurnRetryable(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) {
+ return false
+ }
+ if turnErr.wroteDownstream {
+ return false
+ }
+ switch turnErr.stage {
+ case "write_upstream", "read_upstream":
+ return true
+ default:
+ return false
+ }
+}
+
+func openAIWSIngressTurnRetryReason(err error) string {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return "unknown"
+ }
+ if turnErr.stage == "" {
+ return "unknown"
+ }
+ return turnErr.stage
+}
+
+func isOpenAIWSIngressPreviousResponseNotFound(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound {
+ return false
+ }
+ return !turnErr.wroteDownstream
+}
+
+func isOpenAIWSIngressToolOutputNotFound(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ if strings.TrimSpace(turnErr.stage) != openAIWSIngressStageToolOutputNotFound {
+ return false
+ }
+ return !turnErr.wroteDownstream
+}
+
+// openAIWSIngressTurnWroteDownstream 返回本次 turn 是否已向客户端写入过数据。
+// 用于 ContinueTurn abort 时判断是否需要补发 error 事件。
+func openAIWSIngressTurnWroteDownstream(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ return turnErr.wroteDownstream
+}
+
+func isOpenAIWSIngressUpstreamErrorEvent(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ return strings.TrimSpace(turnErr.stage) == "upstream_error_event"
+}
+
+func isOpenAIWSContinuationUnavailableCloseError(err error) bool {
+ var closeErr *OpenAIWSClientCloseError
+ if !errors.As(err, &closeErr) || closeErr == nil {
+ return false
+ }
+ if closeErr.StatusCode() != coderws.StatusPolicyViolation {
+ return false
+ }
+ return strings.Contains(closeErr.Reason(), openAIWSContinuationUnavailableReason)
+}
+
+// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。
+func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error {
+ return &OpenAIWSClientCloseError{
+ statusCode: statusCode,
+ reason: strings.TrimSpace(reason),
+ err: err,
+ }
+}
+
+func (e *OpenAIWSClientCloseError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.err == nil {
+ return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason))
+ }
+ return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err)
+}
+
+func (e *OpenAIWSClientCloseError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.err
+}
+
+func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode {
+ if e == nil {
+ return coderws.StatusInternalError
+ }
+ return e.statusCode
+}
+
+func (e *OpenAIWSClientCloseError) Reason() string {
+ if e == nil {
+ return ""
+ }
+ return strings.TrimSpace(e.reason)
+}
+
+func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) {
+ if err == nil {
+ return "-", "-"
+ }
+ statusCode := coderws.CloseStatus(err)
+ if statusCode == -1 {
+ return "-", "-"
+ }
+ closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
+ closeReason := "-"
+ var closeErr coderws.CloseError
+ if errors.As(err, &closeErr) {
+ reasonText := strings.TrimSpace(closeErr.Reason)
+ if reasonText != "" {
+ closeReason = normalizeOpenAIWSLogValue(reasonText)
+ }
+ }
+ return normalizeOpenAIWSLogValue(closeStatus), closeReason
+}
+
+func unwrapOpenAIWSDialBaseError(err error) error {
+ if err == nil {
+ return nil
+ }
+ var dialErr *openAIWSDialError
+ if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil {
+ return dialErr.Err
+ }
+ return err
+}
+
+func openAIWSDialRespHeaderForLog(err error, key string) string {
+ var dialErr *openAIWSDialError
+ if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil {
+ return "-"
+ }
+ return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen)
+}
+
+func classifyOpenAIWSDialError(err error) string {
+ if err == nil {
+ return "-"
+ }
+ baseErr := unwrapOpenAIWSDialBaseError(err)
+ if baseErr == nil {
+ return "-"
+ }
+ if errors.Is(baseErr, context.DeadlineExceeded) {
+ return "ctx_deadline_exceeded"
+ }
+ if errors.Is(baseErr, context.Canceled) {
+ return "ctx_canceled"
+ }
+ var netErr net.Error
+ if errors.As(baseErr, &netErr) && netErr.Timeout() {
+ return "net_timeout"
+ }
+ if status := coderws.CloseStatus(baseErr); status != -1 {
+ return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status)))
+ }
+ message := strings.ToLower(strings.TrimSpace(baseErr.Error()))
+ switch {
+ case strings.Contains(message, "handshake not finished"):
+ return "handshake_not_finished"
+ case strings.Contains(message, "bad handshake"):
+ return "bad_handshake"
+ case strings.Contains(message, "connection refused"):
+ return "connection_refused"
+ case strings.Contains(message, "no such host"):
+ return "dns_not_found"
+ case strings.Contains(message, "tls"):
+ return "tls_error"
+ case strings.Contains(message, "i/o timeout"):
+ return "io_timeout"
+ case strings.Contains(message, "context deadline exceeded"):
+ return "ctx_deadline_exceeded"
+ default:
+ return "dial_error"
+ }
+}
+
+func summarizeOpenAIWSDialError(err error) (
+ statusCode int,
+ dialClass string,
+ closeStatus string,
+ closeReason string,
+ respServer string,
+ respVia string,
+ respCFRay string,
+ respRequestID string,
+) {
+ dialClass = "-"
+ closeStatus = "-"
+ closeReason = "-"
+ respServer = "-"
+ respVia = "-"
+ respCFRay = "-"
+ respRequestID = "-"
+ if err == nil {
+ return
+ }
+ var dialErr *openAIWSDialError
+ if errors.As(err, &dialErr) && dialErr != nil {
+ statusCode = dialErr.StatusCode
+ respServer = openAIWSDialRespHeaderForLog(err, "server")
+ respVia = openAIWSDialRespHeaderForLog(err, "via")
+ respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray")
+ respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id")
+ }
+ dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err))
+ closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err))
+ return
+}
+
+func isOpenAIWSClientDisconnectError(err error) bool {
+ if err == nil {
+ return false
+ }
+ if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
+ return true
+ }
+ switch coderws.CloseStatus(err) {
+ case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
+ return true
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ if message == "" {
+ return false
+ }
+ return strings.Contains(message, "failed to read frame header: eof") ||
+ strings.Contains(message, "unexpected eof") ||
+ strings.Contains(message, "use of closed network connection") ||
+ strings.Contains(message, "connection reset by peer") ||
+ strings.Contains(message, "broken pipe")
+}
+
+func classifyOpenAIWSIngressReadErrorClass(err error) string {
+ if err == nil {
+ return "unknown"
+ }
+ if errors.Is(err, context.Canceled) {
+ return "context_canceled"
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ return "deadline_exceeded"
+ }
+ switch coderws.CloseStatus(err) {
+ case coderws.StatusServiceRestart:
+ return "service_restart"
+ case coderws.StatusTryAgainLater:
+ return "try_again_later"
+ }
+ if isOpenAIWSClientDisconnectError(err) {
+ return "upstream_closed"
+ }
+ if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
+ return "eof"
+ }
+ return "unknown"
+}
+
+func isOpenAIWSStreamWriteDisconnectError(err error, reqCtx context.Context) bool {
+ if err == nil {
+ return false
+ }
+ if reqCtx != nil && reqCtx.Err() != nil {
+ return true
+ }
+ return isOpenAIWSClientDisconnectError(err)
+}
+
+func openAIWSIngressResolveDrainReadTimeout(
+ baseTimeout time.Duration,
+ disconnectDeadline time.Time,
+ now time.Time,
+) (time.Duration, bool) {
+ if disconnectDeadline.IsZero() {
+ return baseTimeout, false
+ }
+ remaining := disconnectDeadline.Sub(now)
+ if remaining <= 0 {
+ return 0, true
+ }
+ if baseTimeout <= 0 || remaining < baseTimeout {
+ return remaining, false
+ }
+ return baseTimeout, false
+}
+
+func openAIWSIngressClientDisconnectedDrainTimeoutError(timeout time.Duration) error {
+ if timeout <= 0 {
+ timeout = openAIWSIngressClientDisconnectDrainTimeout
+ }
+ return fmt.Errorf("client disconnected before upstream terminal event (drain timeout=%s): %w", timeout, context.Canceled)
+}
+
+func openAIWSIngressPumpClosedTurnError(
+ clientDisconnected bool,
+ wroteDownstream bool,
+ partialResult *OpenAIForwardResult,
+) error {
+ if clientDisconnected {
+ return wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_disconnected_drain_timeout",
+ openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout),
+ wroteDownstream,
+ partialResult,
+ )
+ }
+ return wrapOpenAIWSIngressTurnErrorWithPartial(
+ "read_upstream",
+ errors.New("upstream event pump closed unexpectedly"),
+ wroteDownstream,
+ partialResult,
+ )
+}
+
+func shouldFlushOpenAIWSBufferedEventsOnError(reqStream bool, wroteDownstream bool, clientDisconnected bool) bool {
+ return reqStream && wroteDownstream && !clientDisconnected
+}
+
+// errOpenAIWSClientPreempted 表示客户端在当前 turn 尚未完成时发送了新的 response.create 请求。
+var errOpenAIWSClientPreempted = errors.New("client preempted current turn with new request")
+
+var errOpenAIWSAdvanceClientReadUnavailable = errors.New("client reader channels unavailable")
+
+func openAIWSAdvanceConsumePendingClientReadErr(pendingErr *error) error {
+ if pendingErr == nil || *pendingErr == nil {
+ return nil
+ }
+ readErr := *pendingErr
+ *pendingErr = nil
+ return fmt.Errorf("read client websocket request: %w", readErr)
+}
+
+func openAIWSAdvanceClientReadUnavailable(clientMsgCh <-chan []byte, clientReadErrCh <-chan error) bool {
+ return clientMsgCh == nil && clientReadErrCh == nil
+}
+
+// isOpenAIWSUpstreamRestartCloseError 检测上游是否因服务重启/维护关闭了连接。
+// 1012=ServiceRestart, 1013=TryAgainLater,都是临时性上游维护,proxy 应视为可恢复错误。
+func isOpenAIWSUpstreamRestartCloseError(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ if turnErr.stage != "read_upstream" {
+ return false
+ }
+ status := coderws.CloseStatus(turnErr.cause)
+ return status == 1012 || status == 1013 // ServiceRestart, TryAgainLater
+}
+
+func classifyOpenAIWSIngressTurnAbortReason(err error) (openAIWSIngressTurnAbortReason, bool) {
+ if err == nil {
+ return openAIWSIngressTurnAbortReasonUnknown, false
+ }
+ if isOpenAIWSIngressPreviousResponseNotFound(err) {
+ return openAIWSIngressTurnAbortReasonPreviousResponse, true
+ }
+ if isOpenAIWSIngressToolOutputNotFound(err) {
+ return openAIWSIngressTurnAbortReasonToolOutput, true
+ }
+ if isOpenAIWSIngressUpstreamErrorEvent(err) {
+ return openAIWSIngressTurnAbortReasonUpstreamError, true
+ }
+ if isOpenAIWSContinuationUnavailableCloseError(err) {
+ return openAIWSIngressTurnAbortReasonContinuationUnavailable, true
+ }
+ if errors.Is(err, errOpenAIWSClientPreempted) {
+ return openAIWSIngressTurnAbortReasonClientPreempted, true
+ }
+ if errors.Is(err, context.Canceled) {
+ return openAIWSIngressTurnAbortReasonContextCanceled, true
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ return openAIWSIngressTurnAbortReasonContextDeadline, false
+ }
+ if isOpenAIWSClientDisconnectError(err) {
+ return openAIWSIngressTurnAbortReasonClientClosed, true
+ }
+ // 上游 ServiceRestart/TryAgainLater:必须在 stage-based 分类之前检测,
+ // 否则会被 "read_upstream" 分支兜底为 FailRequest。
+ if isOpenAIWSUpstreamRestartCloseError(err) {
+ return openAIWSIngressTurnAbortReasonUpstreamRestart, true
+ }
+
+ var turnErr *openAIWSIngressTurnError
+ if errors.As(err, &turnErr) && turnErr != nil {
+ switch strings.TrimSpace(turnErr.stage) {
+ case "idle_timeout":
+ return openAIWSIngressTurnAbortReasonContextDeadline, false
+ case "write_upstream":
+ return openAIWSIngressTurnAbortReasonWriteUpstream, false
+ case "read_upstream":
+ return openAIWSIngressTurnAbortReasonReadUpstream, false
+ case "write_client":
+ return openAIWSIngressTurnAbortReasonWriteClient, false
+ }
+ }
+ return openAIWSIngressTurnAbortReasonUnknown, false
+}
+
+func openAIWSIngressTurnAbortDispositionForReason(reason openAIWSIngressTurnAbortReason) openAIWSIngressTurnAbortDisposition {
+ switch reason {
+ case openAIWSIngressTurnAbortReasonPreviousResponse,
+ openAIWSIngressTurnAbortReasonToolOutput,
+ openAIWSIngressTurnAbortReasonUpstreamError,
+ openAIWSIngressTurnAbortReasonClientPreempted,
+ openAIWSIngressTurnAbortReasonUpstreamRestart:
+ return openAIWSIngressTurnAbortDispositionContinueTurn
+ case openAIWSIngressTurnAbortReasonContextCanceled,
+ openAIWSIngressTurnAbortReasonClientClosed:
+ return openAIWSIngressTurnAbortDispositionCloseGracefully
+ default:
+ return openAIWSIngressTurnAbortDispositionFailRequest
+ }
+}
+
+func classifyOpenAIWSReadFallbackReason(err error) string {
+ if err == nil {
+ return "read_event"
+ }
+ switch coderws.CloseStatus(err) {
+ case coderws.StatusServiceRestart:
+ return "service_restart"
+ case coderws.StatusTryAgainLater:
+ return "try_again_later"
+ case coderws.StatusPolicyViolation:
+ return "policy_violation"
+ case coderws.StatusMessageTooBig:
+ return "message_too_big"
+ default:
+ return "read_event"
+ }
+}
+
+func classifyOpenAIWSAcquireError(err error) string {
+ if err == nil {
+ return "acquire_conn"
+ }
+ var dialErr *openAIWSDialError
+ if errors.As(err, &dialErr) {
+ switch dialErr.StatusCode {
+ case 426:
+ return "upgrade_required"
+ case 401, 403:
+ return "auth_failed"
+ case 429:
+ return "upstream_rate_limited"
+ }
+ if dialErr.StatusCode >= 500 {
+ return "upstream_5xx"
+ }
+ return "dial_failed"
+ }
+ if errors.Is(err, errOpenAIWSConnQueueFull) {
+ return "conn_queue_full"
+ }
+ if errors.Is(err, errOpenAIWSPreferredConnUnavailable) {
+ return "preferred_conn_unavailable"
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ return "acquire_timeout"
+ }
+ return "acquire_conn"
+}
+
+func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
+ code := strings.ToLower(strings.TrimSpace(codeRaw))
+ errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
+ msg := strings.ToLower(strings.TrimSpace(msgRaw))
+
+ switch code {
+ case "upgrade_required":
+ return "upgrade_required", true
+ case "websocket_not_supported", "websocket_unsupported":
+ return "ws_unsupported", true
+ case "websocket_connection_limit_reached":
+ return "ws_connection_limit_reached", true
+ case "previous_response_not_found":
+ return "previous_response_not_found", true
+ }
+ if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
+ return "upgrade_required", true
+ }
+ if strings.Contains(errType, "upgrade") {
+ return "upgrade_required", true
+ }
+ if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") {
+ return "ws_unsupported", true
+ }
+ if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") {
+ return "ws_connection_limit_reached", true
+ }
+ if strings.Contains(msg, "previous_response_not_found") ||
+ (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) {
+ return "previous_response_not_found", true
+ }
+ // "No tool output found for function call " / "No tool call found for function call output..."
+ // 表示 previous_response_id 指向的 response 包含未完成的 function_call(例如用户在 Codex CLI
+ // 按 ESC 取消 function_call 后重新发送消息)。此时 previous_response_id 本身就是问题,需要移除后重放。
+ if strings.Contains(msg, "no tool output found") ||
+ strings.Contains(msg, "no tool call found for function call output") ||
+ (strings.Contains(msg, "no tool call found") && strings.Contains(msg, "function call output")) {
+ return openAIWSIngressStageToolOutputNotFound, true
+ }
+ if strings.Contains(msg, "without its required following item") ||
+ strings.Contains(msg, "without its required preceding item") {
+ return openAIWSIngressStageToolOutputNotFound, true
+ }
+ if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") {
+ return "upstream_error_event", true
+ }
+ return "event_error", false
+}
+
+func classifyOpenAIWSErrorEvent(message []byte) (string, bool) {
+ if len(message) == 0 {
+ return "event_error", false
+ }
+ return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message))
+}
+
+func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
+ code := strings.ToLower(strings.TrimSpace(codeRaw))
+ errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
+ switch {
+ case strings.Contains(errType, "invalid_request"),
+ strings.Contains(code, "invalid_request"),
+ strings.Contains(code, "bad_request"),
+ code == "previous_response_not_found":
+ return http.StatusBadRequest
+ case strings.Contains(errType, "authentication"),
+ strings.Contains(code, "invalid_api_key"),
+ strings.Contains(code, "unauthorized"):
+ return http.StatusUnauthorized
+ case strings.Contains(errType, "permission"),
+ strings.Contains(code, "forbidden"):
+ return http.StatusForbidden
+ case strings.Contains(errType, "rate_limit"),
+ strings.Contains(code, "rate_limit"),
+ strings.Contains(code, "insufficient_quota"):
+ return http.StatusTooManyRequests
+ default:
+ return http.StatusBadGateway
+ }
+}
+
+func openAIWSErrorHTTPStatus(message []byte) int {
+ if len(message) == 0 {
+ return http.StatusBadGateway
+ }
+ codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message)
+ return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
+}
+
+func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration {
+ if s == nil || s.cfg == nil {
+ return 30 * time.Second
+ }
+ seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds
+ if seconds <= 0 {
+ return 0
+ }
+ return time.Duration(seconds) * time.Second
+}
+
+func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool {
+ if s == nil || accountID <= 0 {
+ return false
+ }
+ cooldown := s.openAIWSFallbackCooldown()
+ if cooldown <= 0 {
+ return false
+ }
+ rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID)
+ if !ok || rawUntil == nil {
+ return false
+ }
+ until, ok := rawUntil.(time.Time)
+ if !ok || until.IsZero() {
+ s.openaiWSFallbackUntil.Delete(accountID)
+ return false
+ }
+ if time.Now().Before(until) {
+ return true
+ }
+ s.openaiWSFallbackUntil.Delete(accountID)
+ return false
+}
+
+func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) {
+ if s == nil || accountID <= 0 {
+ return
+ }
+ cooldown := s.openAIWSFallbackCooldown()
+ if cooldown <= 0 {
+ return
+ }
+ s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown))
+}
+
+func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) {
+ if s == nil || accountID <= 0 {
+ return
+ }
+ s.openaiWSFallbackUntil.Delete(accountID)
+}
diff --git a/backend/internal/service/openai_ws_state_store.go b/backend/internal/service/openai_ws_state_store.go
index b606baa1a..cff5ab6cd 100644
--- a/backend/internal/service/openai_ws_state_store.go
+++ b/backend/internal/service/openai_ws_state_store.go
@@ -4,11 +4,13 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
- "fmt"
+ "strconv"
"strings"
"sync"
"sync/atomic"
"time"
+
+ "github.com/cespare/xxhash/v2"
)
const (
@@ -17,6 +19,7 @@ const (
openAIWSStateStoreCleanupMaxPerMap = 512
openAIWSStateStoreMaxEntriesPerMap = 65536
openAIWSStateStoreRedisTimeout = 3 * time.Second
+ openAIWSStateStoreHotCacheTTL = time.Minute
)
type openAIWSAccountBinding struct {
@@ -29,6 +32,11 @@ type openAIWSConnBinding struct {
expiresAt time.Time
}
+type openAIWSResponsePendingToolCallsBinding struct {
+ callIDs []string
+ expiresAt time.Time
+}
+
type openAIWSTurnStateBinding struct {
turnState string
expiresAt time.Time
@@ -39,6 +47,23 @@ type openAIWSSessionConnBinding struct {
expiresAt time.Time
}
+type openAIWSSessionLastResponseBinding struct {
+ responseID string
+ expiresAt time.Time
+}
+
+type openAIWSStateStoreSessionLastResponseCache interface {
+ SetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash, responseID string, ttl time.Duration) error
+ GetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) (string, error)
+ DeleteOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) error
+}
+
+type openAIWSStateStoreResponsePendingToolCallsCache interface {
+ SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error
+ GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error)
+ DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error
+}
+
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
// - response_id -> account_id 用于续链路由
// - response_id -> conn_id 用于连接内上下文复用
@@ -53,44 +78,111 @@ type OpenAIWSStateStore interface {
BindResponseConn(responseID, connID string, ttl time.Duration)
GetResponseConn(responseID string) (string, bool)
DeleteResponseConn(responseID string)
+ BindResponsePendingToolCalls(groupID int64, responseID string, callIDs []string, ttl time.Duration)
+ GetResponsePendingToolCalls(groupID int64, responseID string) ([]string, bool)
+ DeleteResponsePendingToolCalls(groupID int64, responseID string)
BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
DeleteSessionTurnState(groupID int64, sessionHash string)
+ BindSessionLastResponseID(groupID int64, sessionHash, responseID string, ttl time.Duration)
+ GetSessionLastResponseID(groupID int64, sessionHash string) (string, bool)
+ DeleteSessionLastResponseID(groupID int64, sessionHash string)
+
BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
GetSessionConn(groupID int64, sessionHash string) (string, bool)
DeleteSessionConn(groupID int64, sessionHash string)
}
+const openAIWSStateStoreConnShards = 16
+
+type openAIWSConnBindingShard struct {
+ mu sync.RWMutex
+ m map[string]openAIWSConnBinding
+}
+
type defaultOpenAIWSStateStore struct {
cache GatewayCache
- responseToAccountMu sync.RWMutex
- responseToAccount map[string]openAIWSAccountBinding
- responseToConnMu sync.RWMutex
- responseToConn map[string]openAIWSConnBinding
- sessionToTurnStateMu sync.RWMutex
- sessionToTurnState map[string]openAIWSTurnStateBinding
- sessionToConnMu sync.RWMutex
- sessionToConn map[string]openAIWSSessionConnBinding
+ responseToAccountMu sync.RWMutex
+ responseToAccount map[string]openAIWSAccountBinding
+ responseToConnShards [openAIWSStateStoreConnShards]openAIWSConnBindingShard
+ responsePendingToolMu sync.RWMutex
+ responsePendingTool map[string]openAIWSResponsePendingToolCallsBinding
+ sessionToTurnStateMu sync.RWMutex
+ sessionToTurnState map[string]openAIWSTurnStateBinding
+ sessionToLastRespMu sync.RWMutex
+ sessionToLastResp map[string]openAIWSSessionLastResponseBinding
+ sessionToConnMu sync.RWMutex
+ sessionToConn map[string]openAIWSSessionConnBinding
lastCleanupUnixNano atomic.Int64
+ stopCh chan struct{}
+ stopOnce sync.Once
+ workerWg sync.WaitGroup
+}
+
+func (s *defaultOpenAIWSStateStore) connShard(key string) *openAIWSConnBindingShard {
+ h := xxhash.Sum64String(key)
+ return &s.responseToConnShards[h%openAIWSStateStoreConnShards]
}
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
+ return newOpenAIWSStateStore(cache, openAIWSStateStoreCleanupInterval)
+}
+
+func newOpenAIWSStateStore(cache GatewayCache, cleanupInterval time.Duration) *defaultOpenAIWSStateStore {
store := &defaultOpenAIWSStateStore{
- cache: cache,
- responseToAccount: make(map[string]openAIWSAccountBinding, 256),
- responseToConn: make(map[string]openAIWSConnBinding, 256),
- sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
- sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
+ cache: cache,
+ responseToAccount: make(map[string]openAIWSAccountBinding, 256),
+ responsePendingTool: make(map[string]openAIWSResponsePendingToolCallsBinding, 256),
+ sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
+ sessionToLastResp: make(map[string]openAIWSSessionLastResponseBinding, 256),
+ sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
+ stopCh: make(chan struct{}),
+ }
+ for i := range store.responseToConnShards {
+ store.responseToConnShards[i].m = make(map[string]openAIWSConnBinding, 16)
}
store.lastCleanupUnixNano.Store(time.Now().UnixNano())
+ store.startCleanupWorker(cleanupInterval)
return store
}
+func (s *defaultOpenAIWSStateStore) startCleanupWorker(interval time.Duration) {
+ if s == nil || interval <= 0 {
+ return
+ }
+ s.workerWg.Add(1)
+ go func() {
+ defer s.workerWg.Done()
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-s.stopCh:
+ return
+ case <-ticker.C:
+ s.maybeCleanup()
+ }
+ }
+ }()
+}
+
+func (s *defaultOpenAIWSStateStore) Close() {
+ if s == nil {
+ return
+ }
+ s.stopOnce.Do(func() {
+ if s.stopCh != nil {
+ close(s.stopCh)
+ }
+ })
+ s.workerWg.Wait()
+}
+
func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" || accountID <= 0 {
@@ -99,19 +191,31 @@ func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, gro
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
- expiresAt := time.Now().Add(ttl)
+ var redisErr error
+ if s.cache != nil {
+ cacheKey := openAIWSResponseAccountCacheKey(id)
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
+ redisErr = s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
+ cancel()
+ if redisErr != nil {
+ logOpenAIWSModeInfo(
+ "state_store_bind_response_account_redis_fail group_id=%d response_id=%s account_id=%d cause=%s",
+ groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), accountID, truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+ }
+
+ // 无论 Redis 是否写成功,都写入本地缓存作为降级保障。
+ localTTL := openAIWSStateStoreLocalHotTTL(ttl)
s.responseToAccountMu.Lock()
ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
- s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
+ s.responseToAccount[id] = openAIWSAccountBinding{
+ accountID: accountID,
+ expiresAt: time.Now().Add(localTTL),
+ }
s.responseToAccountMu.Unlock()
- if s.cache == nil {
- return nil
- }
- cacheKey := openAIWSResponseAccountCacheKey(id)
- cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
- defer cancel()
- return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
+ return redisErr
}
func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
@@ -119,7 +223,6 @@ func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, grou
if id == "" {
return 0, nil
}
- s.maybeCleanup()
now := time.Now()
s.responseToAccountMu.RLock()
@@ -138,13 +241,33 @@ func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, grou
cacheKey := openAIWSResponseAccountCacheKey(id)
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
- defer cancel()
accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
- if err != nil || accountID <= 0 {
+ cancel()
+ if err == nil && accountID > 0 {
+ return accountID, nil
+ }
+
+ // Compatibility fallback for pre-v2 cache keys.
+ legacyCacheKey := openAIWSResponseAccountLegacyCacheKey(id)
+ legacyCtx, legacyCancel := withOpenAIWSStateStoreRedisTimeout(ctx)
+ legacyAccountID, legacyErr := s.cache.GetSessionAccountID(legacyCtx, groupID, legacyCacheKey)
+ legacyCancel()
+ if legacyErr != nil || legacyAccountID <= 0 {
// 缓存读取失败不阻断主流程,按未命中降级。
return 0, nil
}
- return accountID, nil
+
+ logOpenAIWSModeInfo(
+ "state_store_get_response_account_legacy_fallback group_id=%d response_id=%s account_id=%d",
+ groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), legacyAccountID,
+ )
+
+ // Best effort: backfill v2 key so subsequent reads avoid legacy fallback.
+ backfillCtx, backfillCancel := withOpenAIWSStateStoreRedisTimeout(ctx)
+ _ = s.cache.SetSessionAccountID(backfillCtx, groupID, cacheKey, legacyAccountID, openAIWSStateStoreHotCacheTTL)
+ backfillCancel()
+
+ return legacyAccountID, nil
}
func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
@@ -161,7 +284,15 @@ func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, g
}
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
- return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
+ primaryKey := openAIWSResponseAccountCacheKey(id)
+ if err := s.cache.DeleteSessionAccountID(cacheCtx, groupID, primaryKey); err != nil {
+ return err
+ }
+ legacyKey := openAIWSResponseAccountLegacyCacheKey(id)
+ if legacyKey == "" || legacyKey == primaryKey {
+ return nil
+ }
+ return s.cache.DeleteSessionAccountID(cacheCtx, groupID, legacyKey)
}
func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
@@ -173,13 +304,14 @@ func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string,
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
- s.responseToConnMu.Lock()
- ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
- s.responseToConn[id] = openAIWSConnBinding{
+ shard := s.connShard(id)
+ shard.mu.Lock()
+ ensureBindingCapacity(shard.m, id, openAIWSStateStoreMaxEntriesPerMap/openAIWSStateStoreConnShards)
+ shard.m[id] = openAIWSConnBinding{
connID: conn,
expiresAt: time.Now().Add(ttl),
}
- s.responseToConnMu.Unlock()
+ shard.mu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
@@ -187,13 +319,13 @@ func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string,
if id == "" {
return "", false
}
- s.maybeCleanup()
now := time.Now()
- s.responseToConnMu.RLock()
- binding, ok := s.responseToConn[id]
- s.responseToConnMu.RUnlock()
- if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
+ shard := s.connShard(id)
+ shard.mu.RLock()
+ binding, ok := shard.m[id]
+ shard.mu.RUnlock()
+ if !ok || now.After(binding.expiresAt) || binding.connID == "" {
return "", false
}
return binding.connID, true
@@ -204,9 +336,115 @@ func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
if id == "" {
return
}
- s.responseToConnMu.Lock()
- delete(s.responseToConn, id)
- s.responseToConnMu.Unlock()
+ shard := s.connShard(id)
+ shard.mu.Lock()
+ delete(shard.m, id)
+ shard.mu.Unlock()
+}
+
+func (s *defaultOpenAIWSStateStore) BindResponsePendingToolCalls(groupID int64, responseID string, callIDs []string, ttl time.Duration) {
+ id := normalizeOpenAIWSResponseID(responseID)
+ normalizedCallIDs := normalizeOpenAIWSPendingToolCallIDs(callIDs)
+ if id == "" || len(normalizedCallIDs) == 0 {
+ return
+ }
+ key := openAIWSResponsePendingToolCallsBindingKey(groupID, id)
+ if key == "" {
+ return
+ }
+ ttl = normalizeOpenAIWSTTL(ttl)
+ s.maybeCleanup()
+
+ s.responsePendingToolMu.Lock()
+ ensureBindingCapacity(s.responsePendingTool, key, openAIWSStateStoreMaxEntriesPerMap)
+ s.responsePendingTool[key] = openAIWSResponsePendingToolCallsBinding{
+ callIDs: append([]string(nil), normalizedCallIDs...),
+ expiresAt: time.Now().Add(ttl),
+ }
+ s.responsePendingToolMu.Unlock()
+
+ if cache := s.responsePendingToolCallsCache(); cache != nil {
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ if redisErr := cache.SetOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id, normalizedCallIDs, ttl); redisErr != nil {
+ logOpenAIWSModeInfo(
+ "state_store_bind_response_pending_tool_calls_redis_fail group_id=%d response_id=%s call_count=%d cause=%s",
+ groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), len(normalizedCallIDs), truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+ }
+}
+
+func (s *defaultOpenAIWSStateStore) GetResponsePendingToolCalls(groupID int64, responseID string) ([]string, bool) {
+ id := normalizeOpenAIWSResponseID(responseID)
+ if id == "" {
+ return nil, false
+ }
+ key := openAIWSResponsePendingToolCallsBindingKey(groupID, id)
+ if key == "" {
+ return nil, false
+ }
+
+ now := time.Now()
+ s.responsePendingToolMu.RLock()
+ binding, ok := s.responsePendingTool[key]
+ s.responsePendingToolMu.RUnlock()
+ if !ok || now.After(binding.expiresAt) || len(binding.callIDs) == 0 {
+ cache := s.responsePendingToolCallsCache()
+ if cache == nil {
+ return nil, false
+ }
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ callIDs, err := cache.GetOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id)
+ normalizedCallIDs := normalizeOpenAIWSPendingToolCallIDs(callIDs)
+ if err != nil || len(normalizedCallIDs) == 0 {
+ if err != nil {
+ logOpenAIWSModeInfo(
+ "state_store_get_response_pending_tool_calls_redis_fail group_id=%d response_id=%s cause=%s",
+ groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+ return nil, false
+ }
+
+ logOpenAIWSModeInfo(
+ "state_store_get_response_pending_tool_calls_redis_hit group_id=%d response_id=%s call_count=%d",
+ groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), len(normalizedCallIDs),
+ )
+
+ // Redis 命中后回填本地热缓存,降低后续访问开销。
+ s.responsePendingToolMu.Lock()
+ ensureBindingCapacity(s.responsePendingTool, key, openAIWSStateStoreMaxEntriesPerMap)
+ s.responsePendingTool[key] = openAIWSResponsePendingToolCallsBinding{
+ callIDs: append([]string(nil), normalizedCallIDs...),
+ expiresAt: time.Now().Add(openAIWSStateStoreHotCacheTTL),
+ }
+ s.responsePendingToolMu.Unlock()
+ return normalizedCallIDs, true
+ }
+ // binding.callIDs was already copied at bind time; return directly (callers are read-only).
+ return binding.callIDs, true
+}
+
+func (s *defaultOpenAIWSStateStore) DeleteResponsePendingToolCalls(groupID int64, responseID string) {
+ id := normalizeOpenAIWSResponseID(responseID)
+ if id == "" {
+ return
+ }
+ key := openAIWSResponsePendingToolCallsBindingKey(groupID, id)
+ if key == "" {
+ return
+ }
+ s.responsePendingToolMu.Lock()
+ delete(s.responsePendingTool, key)
+ s.responsePendingToolMu.Unlock()
+
+ if cache := s.responsePendingToolCallsCache(); cache != nil {
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ _ = cache.DeleteOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id)
+ }
}
func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
@@ -232,7 +470,6 @@ func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHa
if key == "" {
return "", false
}
- s.maybeCleanup()
now := time.Now()
s.sessionToTurnStateMu.RLock()
@@ -254,6 +491,99 @@ func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessio
s.sessionToTurnStateMu.Unlock()
}
+func (s *defaultOpenAIWSStateStore) BindSessionLastResponseID(groupID int64, sessionHash, responseID string, ttl time.Duration) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ id := normalizeOpenAIWSResponseID(responseID)
+ if key == "" || id == "" {
+ return
+ }
+ ttl = normalizeOpenAIWSTTL(ttl)
+ s.maybeCleanup()
+
+ s.sessionToLastRespMu.Lock()
+ ensureBindingCapacity(s.sessionToLastResp, key, openAIWSStateStoreMaxEntriesPerMap)
+ s.sessionToLastResp[key] = openAIWSSessionLastResponseBinding{
+ responseID: id,
+ expiresAt: time.Now().Add(ttl),
+ }
+ s.sessionToLastRespMu.Unlock()
+
+ if cache := s.sessionLastResponseCache(); cache != nil {
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ if redisErr := cache.SetOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash), id, ttl); redisErr != nil {
+ logOpenAIWSModeInfo(
+ "state_store_bind_session_last_response_redis_fail group_id=%d session_hash=%s response_id=%s cause=%s",
+ groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+ }
+}
+
+func (s *defaultOpenAIWSStateStore) GetSessionLastResponseID(groupID int64, sessionHash string) (string, bool) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ if key == "" {
+ return "", false
+ }
+
+ now := time.Now()
+ s.sessionToLastRespMu.RLock()
+ binding, ok := s.sessionToLastResp[key]
+ s.sessionToLastRespMu.RUnlock()
+ if ok && now.Before(binding.expiresAt) && strings.TrimSpace(binding.responseID) != "" {
+ return binding.responseID, true
+ }
+
+ cache := s.sessionLastResponseCache()
+ if cache == nil {
+ return "", false
+ }
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ responseID, err := cache.GetOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash))
+ responseID = normalizeOpenAIWSResponseID(responseID)
+ if err != nil || responseID == "" {
+ if err != nil {
+ logOpenAIWSModeInfo(
+ "state_store_get_session_last_response_redis_fail group_id=%d session_hash=%s cause=%s",
+ groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+ return "", false
+ }
+
+ logOpenAIWSModeInfo(
+ "state_store_get_session_last_response_redis_hit group_id=%d session_hash=%s response_id=%s",
+ groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ )
+
+ // Redis 命中后回填本地热缓存,降低后续访问开销。
+ s.sessionToLastRespMu.Lock()
+ ensureBindingCapacity(s.sessionToLastResp, key, openAIWSStateStoreMaxEntriesPerMap)
+ s.sessionToLastResp[key] = openAIWSSessionLastResponseBinding{
+ responseID: responseID,
+ expiresAt: time.Now().Add(openAIWSStateStoreHotCacheTTL),
+ }
+ s.sessionToLastRespMu.Unlock()
+ return responseID, true
+}
+
+func (s *defaultOpenAIWSStateStore) DeleteSessionLastResponseID(groupID int64, sessionHash string) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ if key == "" {
+ return
+ }
+ s.sessionToLastRespMu.Lock()
+ delete(s.sessionToLastResp, key)
+ s.sessionToLastRespMu.Unlock()
+
+ if cache := s.sessionLastResponseCache(); cache != nil {
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ _ = cache.DeleteOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash))
+ }
+}
+
func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
conn := strings.TrimSpace(connID)
@@ -277,7 +607,6 @@ func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash st
if key == "" {
return "", false
}
- s.maybeCleanup()
now := time.Now()
s.sessionToConnMu.RLock()
@@ -317,14 +646,29 @@ func (s *defaultOpenAIWSStateStore) maybeCleanup() {
cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
s.responseToAccountMu.Unlock()
- s.responseToConnMu.Lock()
- cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
- s.responseToConnMu.Unlock()
+ perShardLimit := openAIWSStateStoreCleanupMaxPerMap / openAIWSStateStoreConnShards
+ if perShardLimit < 32 {
+ perShardLimit = 32
+ }
+ for i := range s.responseToConnShards {
+ shard := &s.responseToConnShards[i]
+ shard.mu.Lock()
+ cleanupExpiredConnBindings(shard.m, now, perShardLimit)
+ shard.mu.Unlock()
+ }
+
+ s.responsePendingToolMu.Lock()
+ cleanupExpiredResponsePendingToolCallsBindings(s.responsePendingTool, now, openAIWSStateStoreCleanupMaxPerMap)
+ s.responsePendingToolMu.Unlock()
s.sessionToTurnStateMu.Lock()
cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToTurnStateMu.Unlock()
+ s.sessionToLastRespMu.Lock()
+ cleanupExpiredSessionLastResponseBindings(s.sessionToLastResp, now, openAIWSStateStoreCleanupMaxPerMap)
+ s.sessionToLastRespMu.Unlock()
+
s.sessionToConnMu.Lock()
cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToConnMu.Unlock()
@@ -362,6 +706,22 @@ func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now tim
}
}
+func cleanupExpiredResponsePendingToolCallsBindings(bindings map[string]openAIWSResponsePendingToolCallsBinding, now time.Time, maxScan int) {
+ if len(bindings) == 0 || maxScan <= 0 {
+ return
+ }
+ scanned := 0
+ for key, binding := range bindings {
+ if now.After(binding.expiresAt) {
+ delete(bindings, key)
+ }
+ scanned++
+ if scanned >= maxScan {
+ break
+ }
+ }
+}
+
func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
@@ -378,6 +738,22 @@ func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBindin
}
}
+func cleanupExpiredSessionLastResponseBindings(bindings map[string]openAIWSSessionLastResponseBinding, now time.Time, maxScan int) {
+ if len(bindings) == 0 || maxScan <= 0 {
+ return
+ }
+ scanned := 0
+ for key, binding := range bindings {
+ if now.After(binding.expiresAt) {
+ delete(bindings, key)
+ }
+ scanned++
+ if scanned >= maxScan {
+ break
+ }
+ }
+}
+
func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
@@ -394,17 +770,56 @@ func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBi
}
}
-func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
+type expiringBinding interface {
+ getExpiresAt() time.Time
+}
+
+func (b openAIWSAccountBinding) getExpiresAt() time.Time { return b.expiresAt }
+func (b openAIWSConnBinding) getExpiresAt() time.Time { return b.expiresAt }
+func (b openAIWSResponsePendingToolCallsBinding) getExpiresAt() time.Time { return b.expiresAt }
+func (b openAIWSTurnStateBinding) getExpiresAt() time.Time { return b.expiresAt }
+func (b openAIWSSessionConnBinding) getExpiresAt() time.Time { return b.expiresAt }
+func (b openAIWSSessionLastResponseBinding) getExpiresAt() time.Time { return b.expiresAt }
+
+func ensureBindingCapacity[T expiringBinding](bindings map[string]T, incomingKey string, maxEntries int) {
if len(bindings) < maxEntries || maxEntries <= 0 {
return
}
if _, exists := bindings[incomingKey]; exists {
return
}
- // 固定上限保护:淘汰任意一项,优先保证内存有界。
- for key := range bindings {
- delete(bindings, key)
- return
+ // 优先驱逐已过期条目;若不存在过期项,则按 expiresAt 最早驱逐,避免随机删除活跃绑定。
+ now := time.Now()
+ for key, val := range bindings {
+ if !val.getExpiresAt().IsZero() && now.After(val.getExpiresAt()) {
+ delete(bindings, key)
+ return
+ }
+ }
+ var (
+ evictKey string
+ evictExpireAt time.Time
+ hasCandidate bool
+ )
+ for key, val := range bindings {
+ expiresAt := val.getExpiresAt()
+ if !hasCandidate {
+ evictKey = key
+ evictExpireAt = expiresAt
+ hasCandidate = true
+ continue
+ }
+ switch {
+ case expiresAt.IsZero() && !evictExpireAt.IsZero():
+ evictKey = key
+ evictExpireAt = expiresAt
+ case !expiresAt.IsZero() && !evictExpireAt.IsZero() && expiresAt.Before(evictExpireAt):
+ evictKey = key
+ evictExpireAt = expiresAt
+ }
+ }
+ if hasCandidate {
+ delete(bindings, evictKey)
}
}
@@ -412,7 +827,46 @@ func normalizeOpenAIWSResponseID(responseID string) string {
return strings.TrimSpace(responseID)
}
+func normalizeOpenAIWSPendingToolCallIDs(callIDs []string) []string {
+ if len(callIDs) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(callIDs))
+ normalized := make([]string, 0, len(callIDs))
+ for _, callID := range callIDs {
+ id := strings.TrimSpace(callID)
+ if id == "" {
+ continue
+ }
+ if _, ok := seen[id]; ok {
+ continue
+ }
+ seen[id] = struct{}{}
+ normalized = append(normalized, id)
+ }
+ return normalized
+}
+
+func openAIWSResponsePendingToolCallsBindingKey(groupID int64, responseID string) string {
+ id := normalizeOpenAIWSResponseID(responseID)
+ if id == "" {
+ return ""
+ }
+ return strconv.FormatInt(groupID, 10) + ":" + id
+}
+
func openAIWSResponseAccountCacheKey(responseID string) string {
+ h := xxhash.Sum64String(responseID)
+ // Pad to 16 hex chars for consistent key length.
+ hex := strconv.FormatUint(h, 16)
+ const pad = "0000000000000000"
+ if len(hex) < 16 {
+ hex = pad[:16-len(hex)] + hex
+ }
+ return openAIWSResponseAccountCachePrefix + "v2:" + hex
+}
+
+func openAIWSResponseAccountLegacyCacheKey(responseID string) string {
sum := sha256.Sum256([]byte(responseID))
return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
}
@@ -424,12 +878,42 @@ func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
return ttl
}
+func openAIWSStateStoreLocalHotTTL(ttl time.Duration) time.Duration {
+ ttl = normalizeOpenAIWSTTL(ttl)
+ if ttl > openAIWSStateStoreHotCacheTTL {
+ return openAIWSStateStoreHotCacheTTL
+ }
+ return ttl
+}
+
+func (s *defaultOpenAIWSStateStore) sessionLastResponseCache() openAIWSStateStoreSessionLastResponseCache {
+ if s == nil || s.cache == nil {
+ return nil
+ }
+ cache, ok := s.cache.(openAIWSStateStoreSessionLastResponseCache)
+ if !ok {
+ return nil
+ }
+ return cache
+}
+
+func (s *defaultOpenAIWSStateStore) responsePendingToolCallsCache() openAIWSStateStoreResponsePendingToolCallsCache {
+ if s == nil || s.cache == nil {
+ return nil
+ }
+ cache, ok := s.cache.(openAIWSStateStoreResponsePendingToolCallsCache)
+ if !ok {
+ return nil
+ }
+ return cache
+}
+
func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
hash := strings.TrimSpace(sessionHash)
if hash == "" {
return ""
}
- return fmt.Sprintf("%d:%s", groupID, hash)
+ return strconv.FormatInt(groupID, 10) + ":" + hash
}
func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
diff --git a/backend/internal/service/openai_ws_state_store_test.go b/backend/internal/service/openai_ws_state_store_test.go
index 235d42331..f11a05ec9 100644
--- a/backend/internal/service/openai_ws_state_store_test.go
+++ b/backend/internal/service/openai_ws_state_store_test.go
@@ -41,6 +41,27 @@ func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) {
require.False(t, ok)
}
+func TestOpenAIWSStateStore_ResponsePendingToolCallsTTL(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+ groupID := int64(9)
+ store.BindResponsePendingToolCalls(groupID, "resp_pending_tool_1", []string{"call_1", "call_2", "call_1", " "}, 30*time.Millisecond)
+
+ callIDs, ok := store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_1")
+ require.True(t, ok)
+ require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs)
+ _, ok = store.GetResponsePendingToolCalls(groupID+1, "resp_pending_tool_1")
+ require.False(t, ok, "pending tool calls should be group-isolated")
+
+ store.DeleteResponsePendingToolCalls(groupID, "resp_pending_tool_1")
+ _, ok = store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_1")
+ require.False(t, ok)
+
+ store.BindResponsePendingToolCalls(groupID, "resp_pending_tool_2", []string{"call_3"}, 30*time.Millisecond)
+ time.Sleep(60 * time.Millisecond)
+ _, ok = store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_2")
+ require.False(t, ok)
+}
+
func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) {
store := NewOpenAIWSStateStore(nil)
store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond)
@@ -75,6 +96,178 @@ func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) {
require.False(t, ok)
}
+func TestOpenAIWSStateStore_SessionLastResponseIDTTL(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+ store.BindSessionLastResponseID(9, "session_hash_resp_1", "resp_1", 30*time.Millisecond)
+
+ responseID, ok := store.GetSessionLastResponseID(9, "session_hash_resp_1")
+ require.True(t, ok)
+ require.Equal(t, "resp_1", responseID)
+
+ // group 隔离
+ _, ok = store.GetSessionLastResponseID(10, "session_hash_resp_1")
+ require.False(t, ok)
+
+ time.Sleep(60 * time.Millisecond)
+ _, ok = store.GetSessionLastResponseID(9, "session_hash_resp_1")
+ require.False(t, ok)
+}
+
+type openAIWSSessionLastResponseProbeCache struct {
+ sessionData map[string]string
+ setCalled bool
+ getCalled bool
+ delCalled bool
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) GetSessionAccountID(context.Context, int64, string) (int64, error) {
+ return 0, nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) SetSessionAccountID(context.Context, int64, string, int64, time.Duration) error {
+ return nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) DeleteSessionAccountID(context.Context, int64, string) error {
+ return nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) SetOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash, responseID string, _ time.Duration) error {
+ if c.sessionData == nil {
+ c.sessionData = make(map[string]string)
+ }
+ c.setCalled = true
+ c.sessionData[fmt.Sprintf("%d:%s", groupID, sessionHash)] = responseID
+ return nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) GetOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash string) (string, error) {
+ c.getCalled = true
+ return c.sessionData[fmt.Sprintf("%d:%s", groupID, sessionHash)], nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) DeleteOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash string) error {
+ c.delCalled = true
+ delete(c.sessionData, fmt.Sprintf("%d:%s", groupID, sessionHash))
+ return nil
+}
+
+func TestOpenAIWSStateStore_SessionLastResponseID_UsesOptionalCacheFallback(t *testing.T) {
+ probe := &openAIWSSessionLastResponseProbeCache{sessionData: make(map[string]string)}
+ raw := NewOpenAIWSStateStore(probe)
+ store, ok := raw.(*defaultOpenAIWSStateStore)
+ require.True(t, ok)
+
+ groupID := int64(9)
+ sessionHash := "session_hash_resp_cache_1"
+ responseID := "resp_cache_1"
+ store.BindSessionLastResponseID(groupID, sessionHash, responseID, time.Minute)
+ require.True(t, probe.setCalled, "绑定 session last_response_id 时应写入可选缓存")
+
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ store.sessionToLastRespMu.Lock()
+ delete(store.sessionToLastResp, key)
+ store.sessionToLastRespMu.Unlock()
+
+ gotResponseID, found := store.GetSessionLastResponseID(groupID, sessionHash)
+ require.True(t, found, "本地缓存缺失时应降级读取可选缓存")
+ require.Equal(t, responseID, gotResponseID)
+ require.True(t, probe.getCalled)
+
+ store.DeleteSessionLastResponseID(groupID, sessionHash)
+ require.True(t, probe.delCalled, "删除 session last_response_id 时应同步删除可选缓存")
+ _, found = store.GetSessionLastResponseID(groupID, sessionHash)
+ require.False(t, found)
+}
+
+type openAIWSResponsePendingToolCallsProbeCache struct {
+ pendingData map[string][]string
+ setCalls int
+ getCalls int
+ delCalls int
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) GetSessionAccountID(context.Context, int64, string) (int64, error) {
+ return 0, nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) SetSessionAccountID(context.Context, int64, string, int64, time.Duration) error {
+ return nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) DeleteSessionAccountID(context.Context, int64, string) error {
+ return nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) SetOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string, callIDs []string, _ time.Duration) error {
+ if c.pendingData == nil {
+ c.pendingData = make(map[string][]string)
+ }
+ key := fmt.Sprintf("%d:%s", groupID, responseID)
+ normalized := normalizeOpenAIWSPendingToolCallIDs(callIDs)
+ if len(normalized) == 0 {
+ delete(c.pendingData, key)
+ } else {
+ c.pendingData[key] = append([]string(nil), normalized...)
+ }
+ c.setCalls++
+ return nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) GetOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string) ([]string, error) {
+ c.getCalls++
+ callIDs := c.pendingData[fmt.Sprintf("%d:%s", groupID, responseID)]
+ return append([]string(nil), callIDs...), nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) DeleteOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string) error {
+ c.delCalls++
+ delete(c.pendingData, fmt.Sprintf("%d:%s", groupID, responseID))
+ return nil
+}
+
+func TestOpenAIWSStateStore_ResponsePendingToolCalls_UsesOptionalCacheFallback(t *testing.T) {
+ probe := &openAIWSResponsePendingToolCallsProbeCache{pendingData: make(map[string][]string)}
+ raw := NewOpenAIWSStateStore(probe)
+ store, ok := raw.(*defaultOpenAIWSStateStore)
+ require.True(t, ok)
+
+ groupID := int64(11)
+ responseID := "resp_pending_tool_cache_1"
+ store.BindResponsePendingToolCalls(groupID, responseID, []string{"call_1", "call_2", "call_1"}, time.Minute)
+ require.Equal(t, 1, probe.setCalls, "绑定 pending_tool_calls 时应写入可选缓存")
+
+ store.responsePendingToolMu.Lock()
+ delete(store.responsePendingTool, openAIWSResponsePendingToolCallsBindingKey(groupID, responseID))
+ store.responsePendingToolMu.Unlock()
+
+ callIDs, found := store.GetResponsePendingToolCalls(groupID, responseID)
+ require.True(t, found, "本地缓存缺失时应降级读取可选缓存")
+ require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs)
+ require.Equal(t, 1, probe.getCalls)
+
+ // 回填后再次读取应命中本地缓存,不再触发 Redis 回源。
+ callIDs, found = store.GetResponsePendingToolCalls(groupID, responseID)
+ require.True(t, found)
+ require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs)
+ require.Equal(t, 1, probe.getCalls)
+ _, found = store.GetResponsePendingToolCalls(groupID+1, responseID)
+ require.False(t, found, "optional cache fallback should remain group-isolated")
+
+ store.DeleteResponsePendingToolCalls(groupID, responseID)
+ require.Equal(t, 1, probe.delCalls, "删除 pending_tool_calls 时应同步删除可选缓存")
+ _, found = store.GetResponsePendingToolCalls(groupID, responseID)
+ require.False(t, found)
+}
+
func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) {
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
store := NewOpenAIWSStateStore(cache)
@@ -94,6 +287,42 @@ func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.
require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射")
}
+func TestOpenAIWSStateStore_GetResponseAccount_LegacyKeyFallback(t *testing.T) {
+ cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
+ store := NewOpenAIWSStateStore(cache)
+ ctx := context.Background()
+ groupID := int64(18)
+ responseID := "resp_cache_legacy_fallback"
+
+ legacyKey := openAIWSResponseAccountLegacyCacheKey(responseID)
+ v2Key := openAIWSResponseAccountCacheKey(responseID)
+ cache.sessionBindings[legacyKey] = 601
+
+ accountID, err := store.GetResponseAccount(ctx, groupID, responseID)
+ require.NoError(t, err)
+ require.Equal(t, int64(601), accountID, "应支持 legacy cache key 回读")
+ require.Equal(t, int64(601), cache.sessionBindings[v2Key], "legacy 回读后应回填 v2 cache key")
+}
+
+func TestOpenAIWSStateStore_DeleteResponseAccount_DeletesLegacyAndV2Keys(t *testing.T) {
+ cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
+ store := NewOpenAIWSStateStore(cache)
+ ctx := context.Background()
+ groupID := int64(19)
+ responseID := "resp_cache_delete_both_keys"
+
+ legacyKey := openAIWSResponseAccountLegacyCacheKey(responseID)
+ v2Key := openAIWSResponseAccountCacheKey(responseID)
+ cache.sessionBindings[legacyKey] = 701
+ cache.sessionBindings[v2Key] = 701
+
+ require.NoError(t, store.DeleteResponseAccount(ctx, groupID, responseID))
+ _, legacyExists := cache.sessionBindings[legacyKey]
+ _, v2Exists := cache.sessionBindings[v2Key]
+ require.False(t, legacyExists, "删除 response account 绑定时应清理 legacy key")
+ require.False(t, v2Exists, "删除 response account 绑定时应清理 v2 key")
+}
+
func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) {
raw := NewOpenAIWSStateStore(nil)
store, ok := raw.(*defaultOpenAIWSStateStore)
@@ -101,21 +330,27 @@ func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T
expiredAt := time.Now().Add(-time.Minute)
total := 2048
- store.responseToConnMu.Lock()
for i := 0; i < total; i++ {
- store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{
+ key := fmt.Sprintf("resp_%d", i)
+ shard := store.connShard(key)
+ shard.mu.Lock()
+ shard.m[key] = openAIWSConnBinding{
connID: "conn_incremental",
expiresAt: expiredAt,
}
+ shard.mu.Unlock()
}
- store.responseToConnMu.Unlock()
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
store.maybeCleanup()
- store.responseToConnMu.RLock()
- remainingAfterFirst := len(store.responseToConn)
- store.responseToConnMu.RUnlock()
+ remainingAfterFirst := 0
+ for i := range store.responseToConnShards {
+ shard := &store.responseToConnShards[i]
+ shard.mu.RLock()
+ remainingAfterFirst += len(shard.m)
+ shard.mu.RUnlock()
+ }
require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展")
require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键")
@@ -124,36 +359,99 @@ func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T
store.maybeCleanup()
}
- store.responseToConnMu.RLock()
- remaining := len(store.responseToConn)
- store.responseToConnMu.RUnlock()
+ remaining := 0
+ for i := range store.responseToConnShards {
+ shard := &store.responseToConnShards[i]
+ shard.mu.RLock()
+ remaining += len(shard.m)
+ shard.mu.RUnlock()
+ }
require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键")
}
+func TestOpenAIWSStateStore_BackgroundCleanupRemovesExpiredWithoutNewWrites(t *testing.T) {
+ store := newOpenAIWSStateStore(nil, 20*time.Millisecond)
+ defer store.Close()
+
+ expiredAt := time.Now().Add(-time.Minute)
+ store.responseToAccountMu.Lock()
+ for i := 0; i < 64; i++ {
+ key := fmt.Sprintf("bg_cleanup_resp_%d", i)
+ store.responseToAccount[key] = openAIWSAccountBinding{
+ accountID: int64(i + 1),
+ expiresAt: expiredAt,
+ }
+ }
+ store.responseToAccountMu.Unlock()
+
+ // Backdate cleanup watermark so the worker can run immediately on next tick.
+ store.lastCleanupUnixNano.Store(time.Now().Add(-time.Minute).UnixNano())
+
+ require.Eventually(t, func() bool {
+ store.responseToAccountMu.RLock()
+ remaining := len(store.responseToAccount)
+ store.responseToAccountMu.RUnlock()
+ return remaining == 0
+ }, 600*time.Millisecond, 10*time.Millisecond, "后台 cleanup 应在无新写入时清理过期项")
+}
+
func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) {
- bindings := map[string]int{
- "a": 1,
- "b": 2,
+ bindings := map[string]openAIWSAccountBinding{
+ "a": {accountID: 1, expiresAt: time.Now().Add(time.Hour)},
+ "b": {accountID: 2, expiresAt: time.Now().Add(time.Hour)},
}
ensureBindingCapacity(bindings, "c", 2)
- bindings["c"] = 3
+ bindings["c"] = openAIWSAccountBinding{accountID: 3, expiresAt: time.Now().Add(time.Hour)}
require.Len(t, bindings, 2)
- require.Equal(t, 3, bindings["c"])
+ require.Equal(t, int64(3), bindings["c"].accountID)
}
func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) {
- bindings := map[string]int{
- "a": 1,
- "b": 2,
+ bindings := map[string]openAIWSAccountBinding{
+ "a": {accountID: 1, expiresAt: time.Now().Add(time.Hour)},
+ "b": {accountID: 2, expiresAt: time.Now().Add(time.Hour)},
}
ensureBindingCapacity(bindings, "a", 2)
- bindings["a"] = 9
+ bindings["a"] = openAIWSAccountBinding{accountID: 9, expiresAt: time.Now().Add(time.Hour)}
+
+ require.Len(t, bindings, 2)
+ require.Equal(t, int64(9), bindings["a"].accountID)
+}
+
+func TestEnsureBindingCapacity_PrefersExpiredEntry(t *testing.T) {
+ bindings := map[string]openAIWSAccountBinding{
+ "expired": {accountID: 1, expiresAt: time.Now().Add(-time.Hour)},
+ "active": {accountID: 2, expiresAt: time.Now().Add(time.Hour)},
+ }
+
+ ensureBindingCapacity(bindings, "c", 2)
+ bindings["c"] = openAIWSAccountBinding{accountID: 3, expiresAt: time.Now().Add(time.Hour)}
require.Len(t, bindings, 2)
- require.Equal(t, 9, bindings["a"])
+ _, hasExpired := bindings["expired"]
+ require.False(t, hasExpired, "expired entry should have been evicted")
+ require.Equal(t, int64(2), bindings["active"].accountID)
+ require.Equal(t, int64(3), bindings["c"].accountID)
+}
+
+func TestEnsureBindingCapacity_EvictsEarliestExpiryWhenNoExpired(t *testing.T) {
+ now := time.Now()
+ bindings := map[string]openAIWSAccountBinding{
+ "soon": {accountID: 1, expiresAt: now.Add(30 * time.Second)},
+ "later": {accountID: 2, expiresAt: now.Add(5 * time.Minute)},
+ }
+
+ ensureBindingCapacity(bindings, "new", 2)
+ bindings["new"] = openAIWSAccountBinding{accountID: 3, expiresAt: now.Add(10 * time.Minute)}
+
+ require.Len(t, bindings, 2)
+ _, hasSoon := bindings["soon"]
+ require.False(t, hasSoon, "entry with earliest expiresAt should be evicted")
+ require.Equal(t, int64(2), bindings["later"].accountID)
+ require.Equal(t, int64(3), bindings["new"].accountID)
}
type openAIWSStateStoreTimeoutProbeCache struct {
@@ -204,13 +502,14 @@ func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) {
accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe")
require.NoError(t, getErr)
- require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号")
+ require.Equal(t, int64(11), accountID, "Redis Set 失败时本地缓存仍应保留作为降级保障")
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe"))
require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文")
require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文")
- require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取")
+ // 本地缓存作为降级保障保留,Get 直接命中本地缓存不会穿透到 Redis
+ require.False(t, probe.getHasDeadline, "本地缓存命中时不应穿透到 Redis 读取")
require.Greater(t, probe.setDeadlineDelta, 2*time.Second)
require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second)
require.Greater(t, probe.delDeadlineDelta, 2*time.Second)
@@ -226,6 +525,26 @@ func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) {
require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second)
}
+func TestOpenAIWSStateStore_BindResponseAccount_UsesShortLocalHotTTL(t *testing.T) {
+ cache := &stubGatewayCache{}
+ raw := NewOpenAIWSStateStore(cache)
+ store, ok := raw.(*defaultOpenAIWSStateStore)
+ require.True(t, ok)
+
+ groupID := int64(23)
+ responseID := "resp_local_hot_ttl"
+ require.NoError(t, store.BindResponseAccount(context.Background(), groupID, responseID, 902, 24*time.Hour))
+
+ id := normalizeOpenAIWSResponseID(responseID)
+ require.NotEmpty(t, id)
+ store.responseToAccountMu.RLock()
+ binding, exists := store.responseToAccount[id]
+ store.responseToAccountMu.RUnlock()
+ require.True(t, exists)
+ require.Equal(t, int64(902), binding.accountID)
+ require.WithinDuration(t, time.Now().Add(openAIWSStateStoreHotCacheTTL), binding.expiresAt, 1500*time.Millisecond)
+}
+
func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) {
ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
defer cancel()
diff --git a/backend/internal/service/openai_ws_test_helpers_test.go b/backend/internal/service/openai_ws_test_helpers_test.go
new file mode 100644
index 000000000..d719c0eaf
--- /dev/null
+++ b/backend/internal/service/openai_ws_test_helpers_test.go
@@ -0,0 +1,277 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "io"
+ "net/http"
+ "sync"
+ "time"
+)
+
+type openAIWSQueueDialer struct {
+ mu sync.Mutex
+ conns []openAIWSClientConn
+ dialCount int
+}
+
+func (d *openAIWSQueueDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = ctx
+ _ = wsURL
+ _ = headers
+ _ = proxyURL
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.dialCount++
+ if len(d.conns) == 0 {
+ return nil, 503, nil, errors.New("no test conn")
+ }
+ conn := d.conns[0]
+ if len(d.conns) > 1 {
+ d.conns = d.conns[1:]
+ }
+ return conn, 0, nil, nil
+}
+
+func (d *openAIWSQueueDialer) DialCount() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dialCount
+}
+
+type openAIWSCaptureConn struct {
+ mu sync.Mutex
+ readDelays []time.Duration
+ events [][]byte
+ writes []map[string]any
+ closed bool
+}
+
+func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error {
+ _ = ctx
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return errOpenAIWSConnClosed
+ }
+ switch payload := value.(type) {
+ case map[string]any:
+ c.writes = append(c.writes, cloneMapStringAny(payload))
+ case json.RawMessage:
+ var parsed map[string]any
+ if err := json.Unmarshal(payload, &parsed); err == nil {
+ c.writes = append(c.writes, cloneMapStringAny(parsed))
+ }
+ case []byte:
+ var parsed map[string]any
+ if err := json.Unmarshal(payload, &parsed); err == nil {
+ c.writes = append(c.writes, cloneMapStringAny(parsed))
+ }
+ }
+ return nil
+}
+
+func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ c.mu.Lock()
+ if c.closed {
+ c.mu.Unlock()
+ return nil, errOpenAIWSConnClosed
+ }
+ if len(c.events) == 0 {
+ c.mu.Unlock()
+ return nil, io.EOF
+ }
+ delay := time.Duration(0)
+ if len(c.readDelays) > 0 {
+ delay = c.readDelays[0]
+ c.readDelays = c.readDelays[1:]
+ }
+ event := c.events[0]
+ c.events = c.events[1:]
+ c.mu.Unlock()
+ if delay > 0 {
+ timer := time.NewTimer(delay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-timer.C:
+ }
+ }
+ return event, nil
+}
+
+func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
+ _ = ctx
+ return nil
+}
+
+func (c *openAIWSCaptureConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.closed = true
+ return nil
+}
+
+func (c *openAIWSCaptureConn) Closed() bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.closed
+}
+
+func cloneMapStringAny(src map[string]any) map[string]any {
+ if src == nil {
+ return nil
+ }
+ dst := make(map[string]any, len(src))
+ for k, v := range src {
+ dst[k] = v
+ }
+ return dst
+}
+
+type openAIWSAlwaysFailDialer struct {
+ mu sync.Mutex
+ dialCount int
+}
+
+func (d *openAIWSAlwaysFailDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = ctx
+ _ = wsURL
+ _ = headers
+ _ = proxyURL
+ d.mu.Lock()
+ d.dialCount++
+ d.mu.Unlock()
+ return nil, 503, nil, errors.New("dial failed")
+}
+
+func (d *openAIWSAlwaysFailDialer) DialCount() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dialCount
+}
+
+type openAIWSFakeConn struct {
+ mu sync.Mutex
+ closed bool
+ payload [][]byte
+}
+
+func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error {
+ _ = ctx
+ _ = value
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return errors.New("closed")
+ }
+ c.payload = append(c.payload, []byte("ok"))
+ return nil
+}
+
+func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) {
+ _ = ctx
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return nil, errors.New("closed")
+ }
+ return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil
+}
+
+func (c *openAIWSFakeConn) Ping(ctx context.Context) error {
+ _ = ctx
+ return nil
+}
+
+func (c *openAIWSFakeConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.closed = true
+ return nil
+}
+
+type openAIWSPingFailConn struct{}
+
+func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error {
+ return nil
+}
+
+func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) {
+ return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil
+}
+
+func (c *openAIWSPingFailConn) Ping(context.Context) error {
+ return errors.New("ping failed")
+}
+
+func (c *openAIWSPingFailConn) Close() error {
+ return nil
+}
+
+// openAIWSDelayedPingFailConn 是带可控延迟的 Ping 失败连接,
+// 用于模拟"Ping 执行期间连接被重建"的竞态场景。
+type openAIWSDelayedPingFailConn struct {
+ delay time.Duration
+ pingDone chan struct{} // Ping 开始执行时关闭,通知测试可以进行下一步
+ mu sync.Mutex
+ closed bool
+}
+
+func newOpenAIWSDelayedPingFailConn(delay time.Duration) *openAIWSDelayedPingFailConn {
+ return &openAIWSDelayedPingFailConn{
+ delay: delay,
+ pingDone: make(chan struct{}),
+ }
+}
+
+func (c *openAIWSDelayedPingFailConn) WriteJSON(context.Context, any) error { return nil }
+func (c *openAIWSDelayedPingFailConn) ReadMessage(context.Context) ([]byte, error) {
+ return []byte(`{"type":"response.completed","response":{"id":"resp_delayed_ping"}}`), nil
+}
+
+func (c *openAIWSDelayedPingFailConn) Ping(ctx context.Context) error {
+ // 通知测试 Ping 已开始
+ select {
+ case <-c.pingDone:
+ default:
+ close(c.pingDone)
+ }
+ // 等待延迟或上下文取消
+ timer := time.NewTimer(c.delay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-timer.C:
+ }
+ return errors.New("ping failed after delay")
+}
+
+func (c *openAIWSDelayedPingFailConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.closed = true
+ return nil
+}
+
+func (c *openAIWSDelayedPingFailConn) Closed() bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.closed
+}
diff --git a/backend/internal/service/openai_ws_turn.go b/backend/internal/service/openai_ws_turn.go
new file mode 100644
index 000000000..d1854b095
--- /dev/null
+++ b/backend/internal/service/openai_ws_turn.go
@@ -0,0 +1,1601 @@
+package service
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "sort"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/gin-gonic/gin"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.uber.org/zap"
+)
+
+// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
+type OpenAIWSIngressHooks struct {
+ BeforeTurn func(turn int) error
+ AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
+}
+
+const (
+ openAIWSCtxFirstModelKey = "openai_ws_first_model"
+ openAIWSCtxFirstPreviousResponseIDKey = "openai_ws_first_previous_response_id"
+ openAIWSCtxFirstPreviousResponseIDKindKey = "openai_ws_first_previous_response_id_kind"
+)
+
+func SetOpenAIWSFirstMessageMeta(c *gin.Context, model, previousResponseID, previousResponseIDKind string) {
+ if c == nil {
+ return
+ }
+ c.Set(openAIWSCtxFirstModelKey, strings.TrimSpace(model))
+ c.Set(openAIWSCtxFirstPreviousResponseIDKey, strings.TrimSpace(previousResponseID))
+ c.Set(openAIWSCtxFirstPreviousResponseIDKindKey, strings.TrimSpace(previousResponseIDKind))
+}
+
+func ResolveOpenAIWSFirstMessageMeta(
+ c *gin.Context,
+ firstClientMessage []byte,
+) (model string, previousResponseID string, previousResponseIDKind string) {
+ if c != nil {
+ if v, ok := c.Get(openAIWSCtxFirstModelKey); ok {
+ if text, okText := v.(string); okText {
+ model = strings.TrimSpace(text)
+ }
+ }
+ if v, ok := c.Get(openAIWSCtxFirstPreviousResponseIDKey); ok {
+ if text, okText := v.(string); okText {
+ previousResponseID = strings.TrimSpace(text)
+ }
+ }
+ if v, ok := c.Get(openAIWSCtxFirstPreviousResponseIDKindKey); ok {
+ if text, okText := v.(string); okText {
+ previousResponseIDKind = strings.TrimSpace(text)
+ }
+ }
+ }
+ if model == "" || previousResponseID == "" {
+ values := gjson.GetManyBytes(firstClientMessage, "model", "previous_response_id")
+ if model == "" {
+ model = strings.TrimSpace(values[0].String())
+ }
+ if previousResponseID == "" {
+ previousResponseID = strings.TrimSpace(values[1].String())
+ }
+ }
+ if previousResponseIDKind == "" {
+ previousResponseIDKind = ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
+ }
+ return model, previousResponseID, previousResponseIDKind
+}
+
+func normalizeOpenAIWSLogValue(value string) string {
+ trimmed := strings.TrimSpace(value)
+ if trimmed == "" {
+ return "-"
+ }
+ return openAIWSLogValueReplacer.Replace(trimmed)
+}
+
+func truncateOpenAIWSLogValue(value string, maxLen int) string {
+ normalized := normalizeOpenAIWSLogValue(value)
+ if normalized == "-" || maxLen <= 0 {
+ return normalized
+ }
+ if len(normalized) <= maxLen {
+ return normalized
+ }
+ return normalized[:maxLen] + "..."
+}
+
+func openAIWSHeaderValueForLog(headers http.Header, key string) string {
+ if headers == nil {
+ return "-"
+ }
+ return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen)
+}
+
+func hasOpenAIWSHeader(headers http.Header, key string) bool {
+ if headers == nil {
+ return false
+ }
+ return strings.TrimSpace(headers.Get(key)) != ""
+}
+
+type openAIWSSessionHeaderResolution struct {
+ SessionID string
+ ConversationID string
+ SessionSource string
+ ConversationSource string
+}
+
+func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution {
+ resolution := openAIWSSessionHeaderResolution{
+ SessionSource: "none",
+ ConversationSource: "none",
+ }
+ if c != nil && c.Request != nil {
+ if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" {
+ resolution.SessionID = sessionID
+ resolution.SessionSource = "header_session_id"
+ }
+ if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" {
+ resolution.ConversationID = conversationID
+ resolution.ConversationSource = "header_conversation_id"
+ if resolution.SessionID == "" {
+ resolution.SessionID = conversationID
+ resolution.SessionSource = "header_conversation_id"
+ }
+ }
+ }
+
+ cacheKey := strings.TrimSpace(promptCacheKey)
+ if cacheKey != "" {
+ if resolution.SessionID == "" {
+ resolution.SessionID = cacheKey
+ resolution.SessionSource = "prompt_cache_key"
+ }
+ }
+ return resolution
+}
+
+func openAIWSIngressSessionScopeFromContext(c *gin.Context) string {
+ if c == nil {
+ return ""
+ }
+ value, exists := c.Get("api_key")
+ if !exists || value == nil {
+ return ""
+ }
+ apiKey, ok := value.(*APIKey)
+ if !ok || apiKey == nil {
+ return ""
+ }
+ userID := apiKey.UserID
+ if userID <= 0 && apiKey.User != nil {
+ userID = apiKey.User.ID
+ }
+ apiKeyID := apiKey.ID
+ if userID <= 0 && apiKeyID <= 0 {
+ return ""
+ }
+ return fmt.Sprintf("u%d:k%d", userID, apiKeyID)
+}
+
+func openAIWSApplySessionScope(sessionHash, scope string) string {
+ hash := strings.TrimSpace(sessionHash)
+ if hash == "" {
+ return ""
+ }
+ scope = strings.TrimSpace(scope)
+ if scope == "" {
+ return hash
+ }
+ return scope + "|" + hash
+}
+
+func openAIWSHostPathForLogFromURL(rawURL string) (host string, path string) {
+ trimmed := strings.TrimSpace(rawURL)
+ if trimmed == "" {
+ return "-", "-"
+ }
+
+ withoutScheme := trimmed
+ if schemeIdx := strings.Index(withoutScheme, "://"); schemeIdx >= 0 {
+ withoutScheme = withoutScheme[schemeIdx+3:]
+ }
+ withoutScheme = strings.TrimPrefix(withoutScheme, "//")
+
+ rawHost := withoutScheme
+ rawPath := ""
+ if slashIdx := strings.IndexByte(withoutScheme, '/'); slashIdx >= 0 {
+ rawHost = withoutScheme[:slashIdx]
+ rawPath = withoutScheme[slashIdx:]
+ }
+ if queryIdx := strings.IndexByte(rawPath, '?'); queryIdx >= 0 {
+ rawPath = rawPath[:queryIdx]
+ }
+
+ host = normalizeOpenAIWSLogValue(rawHost)
+ path = normalizeOpenAIWSLogValue(rawPath)
+ return host, path
+}
+
+func shouldLogOpenAIWSEvent(idx int, eventType string) bool {
+ if idx <= openAIWSEventLogHeadLimit {
+ return true
+ }
+ if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 {
+ return true
+ }
+ if eventType == "error" || isOpenAIWSTerminalEvent(eventType) {
+ return true
+ }
+ return false
+}
+
+func shouldLogOpenAIWSBufferedEvent(idx int) bool {
+ if idx <= openAIWSBufferLogHeadLimit {
+ return true
+ }
+ if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 {
+ return true
+ }
+ return false
+}
+
+func openAIWSEventMayContainModel(eventType string) bool {
+ switch eventType {
+ case "response.created",
+ "response.in_progress",
+ "response.completed",
+ "response.done",
+ "response.failed",
+ "response.incomplete",
+ "response.cancelled",
+ "response.canceled":
+ return true
+ default:
+ trimmed := strings.TrimSpace(eventType)
+ if trimmed == eventType {
+ return false
+ }
+ switch trimmed {
+ case "response.created",
+ "response.in_progress",
+ "response.completed",
+ "response.done",
+ "response.failed",
+ "response.incomplete",
+ "response.cancelled",
+ "response.canceled":
+ return true
+ default:
+ return false
+ }
+ }
+}
+
+func openAIWSEventMayContainToolCalls(eventType string) bool {
+ if eventType == "" {
+ return false
+ }
+ if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") {
+ return true
+ }
+ switch eventType {
+ case "response.output_item.added", "response.output_item.done", "response.completed", "response.done":
+ return true
+ default:
+ return false
+ }
+}
+
+// openAIWSEventShouldParseUsage 判断是否应解析 usage。
+// 调用方需确保 eventType 已经过 TrimSpace(如 parseOpenAIWSEventType 的返回值)。
+func openAIWSEventShouldParseUsage(eventType string) bool {
+ switch eventType {
+ case "response.completed", "response.done", "response.failed":
+ return true
+ default:
+ return false
+ }
+}
+
+// parseOpenAIWSEventType extracts only the event type and response ID from a WS message.
+// Use this lightweight version on hot paths where the full response body is not needed.
+func parseOpenAIWSEventType(message []byte) (eventType string, responseID string) {
+ if len(message) == 0 {
+ return "", ""
+ }
+ values := gjson.GetManyBytes(message, "type", "response.id", "id")
+ eventType = strings.TrimSpace(values[0].String())
+ if id := strings.TrimSpace(values[1].String()); id != "" {
+ responseID = id
+ } else {
+ responseID = strings.TrimSpace(values[2].String())
+ }
+ return eventType, responseID
+}
+
+func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
+ if len(message) == 0 {
+ return "", "", gjson.Result{}
+ }
+ values := gjson.GetManyBytes(message, "type", "response.id", "id", "response")
+ eventType = strings.TrimSpace(values[0].String())
+ if id := strings.TrimSpace(values[1].String()); id != "" {
+ responseID = id
+ } else {
+ responseID = strings.TrimSpace(values[2].String())
+ }
+ return eventType, responseID, values[3]
+}
+
+func openAIWSMessageLikelyContainsToolCalls(message []byte) bool {
+ if len(message) == 0 {
+ return false
+ }
+ return bytes.Contains(message, []byte(`"tool_calls"`)) ||
+ bytes.Contains(message, []byte(`"tool_call"`)) ||
+ bytes.Contains(message, []byte(`"function_call"`))
+}
+
+func openAIWSCollectPendingFunctionCallIDsFromJSONResult(result gjson.Result, callIDSet map[string]struct{}, depth int) {
+ if !result.Exists() || callIDSet == nil || depth > 8 || result.Type != gjson.JSON {
+ return
+ }
+ itemType := strings.TrimSpace(result.Get("type").String())
+ if itemType == "function_call" || itemType == "tool_call" {
+ callID := strings.TrimSpace(result.Get("call_id").String())
+ if callID == "" {
+ fallbackID := strings.TrimSpace(result.Get("id").String())
+ if strings.HasPrefix(fallbackID, "call_") {
+ callID = fallbackID
+ }
+ }
+ if callID != "" {
+ callIDSet[callID] = struct{}{}
+ }
+ }
+ result.ForEach(func(_, child gjson.Result) bool {
+ openAIWSCollectPendingFunctionCallIDsFromJSONResult(child, callIDSet, depth+1)
+ return true
+ })
+}
+
+func openAIWSExtractPendingFunctionCallIDsFromEvent(message []byte) []string {
+ if len(message) == 0 {
+ return nil
+ }
+ callIDSet := make(map[string]struct{}, 4)
+ openAIWSCollectPendingFunctionCallIDsFromJSONResult(gjson.ParseBytes(message), callIDSet, 0)
+ if len(callIDSet) == 0 {
+ return nil
+ }
+ callIDs := make([]string, 0, len(callIDSet))
+ for callID := range callIDSet {
+ callIDs = append(callIDs, callID)
+ }
+ sort.Strings(callIDs)
+ return callIDs
+}
+
+func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) {
+ if usage == nil || len(message) == 0 {
+ return
+ }
+ values := gjson.GetManyBytes(
+ message,
+ "response.usage.input_tokens",
+ "response.usage.output_tokens",
+ "response.usage.input_tokens_details.cached_tokens",
+ )
+ usage.InputTokens = int(values[0].Int())
+ usage.OutputTokens = int(values[1].Int())
+ usage.CacheReadInputTokens = int(values[2].Int())
+}
+
+func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
+ if len(message) == 0 {
+ return "", "", ""
+ }
+ values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message")
+ return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String())
+}
+
+func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) {
+ code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen)
+ errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen)
+ errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen)
+ return code, errType, errMessage
+}
+
+func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
+ if len(message) == 0 {
+ return "-", "-", "-"
+ }
+ return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message))
+}
+
+func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string {
+ if len(payload) == 0 {
+ return "-"
+ }
+ type keySize struct {
+ Key string
+ Size int
+ }
+ sizes := make([]keySize, 0, len(payload))
+ for key, value := range payload {
+ size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth)
+ sizes = append(sizes, keySize{Key: key, Size: size})
+ }
+ sort.Slice(sizes, func(i, j int) bool {
+ if sizes[i].Size == sizes[j].Size {
+ return sizes[i].Key < sizes[j].Key
+ }
+ return sizes[i].Size > sizes[j].Size
+ })
+
+ if topN <= 0 || topN > len(sizes) {
+ topN = len(sizes)
+ }
+ parts := make([]string, 0, topN)
+ for idx := 0; idx < topN; idx++ {
+ item := sizes[idx]
+ parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size))
+ }
+ return strings.Join(parts, ",")
+}
+
+func estimateOpenAIWSPayloadValueSize(value any, depth int) int {
+ if depth <= 0 {
+ return -1
+ }
+ switch v := value.(type) {
+ case nil:
+ return 0
+ case string:
+ return len(v)
+ case []byte:
+ return len(v)
+ case bool:
+ return 1
+ case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
+ return 8
+ case float32, float64:
+ return 8
+ case map[string]any:
+ if len(v) == 0 {
+ return 2
+ }
+ total := 2
+ count := 0
+ for key, item := range v {
+ count++
+ if count > openAIWSPayloadSizeEstimateMaxItems {
+ return -1
+ }
+ itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1)
+ if itemSize < 0 {
+ return -1
+ }
+ total += len(key) + itemSize + 3
+ if total > openAIWSPayloadSizeEstimateMaxBytes {
+ return -1
+ }
+ }
+ return total
+ case []any:
+ if len(v) == 0 {
+ return 2
+ }
+ total := 2
+ limit := len(v)
+ if limit > openAIWSPayloadSizeEstimateMaxItems {
+ return -1
+ }
+ for i := 0; i < limit; i++ {
+ itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1)
+ if itemSize < 0 {
+ return -1
+ }
+ total += itemSize + 1
+ if total > openAIWSPayloadSizeEstimateMaxBytes {
+ return -1
+ }
+ }
+ return total
+ default:
+ raw, err := json.Marshal(v)
+ if err != nil {
+ return -1
+ }
+ if len(raw) > openAIWSPayloadSizeEstimateMaxBytes {
+ return -1
+ }
+ return len(raw)
+ }
+}
+
+func openAIWSPayloadString(payload map[string]any, key string) string {
+ if len(payload) == 0 {
+ return ""
+ }
+ raw, ok := payload[key]
+ if !ok {
+ return ""
+ }
+ switch v := raw.(type) {
+ case nil:
+ return ""
+ case string:
+ return strings.TrimSpace(v)
+ case []byte:
+ return strings.TrimSpace(string(v))
+ default:
+ return ""
+ }
+}
+
+func openAIWSPayloadStringFromRaw(payload []byte, key string) string {
+ if len(payload) == 0 || strings.TrimSpace(key) == "" {
+ return ""
+ }
+ return strings.TrimSpace(gjson.GetBytes(payload, key).String())
+}
+
+func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool {
+ if len(payload) == 0 || strings.TrimSpace(key) == "" {
+ return defaultValue
+ }
+ value := gjson.GetBytes(payload, key)
+ if !value.Exists() {
+ return defaultValue
+ }
+ if value.Type != gjson.True && value.Type != gjson.False {
+ return defaultValue
+ }
+ return value.Bool()
+}
+
+func openAIWSSessionHashesFromID(sessionID string) (string, string) {
+ return deriveOpenAISessionHashes(sessionID)
+}
+
+func extractOpenAIWSImageURL(value any) string {
+ switch v := value.(type) {
+ case string:
+ return strings.TrimSpace(v)
+ case map[string]any:
+ if raw, ok := v["url"].(string); ok {
+ return strings.TrimSpace(raw)
+ }
+ }
+ return ""
+}
+
+func summarizeOpenAIWSInput(input any) string {
+ items, ok := input.([]any)
+ if !ok || len(items) == 0 {
+ return "-"
+ }
+
+ itemCount := len(items)
+ textChars := 0
+ imageDataURLs := 0
+ imageDataURLChars := 0
+ imageRemoteURLs := 0
+
+ handleContentItem := func(contentItem map[string]any) {
+ contentType, _ := contentItem["type"].(string)
+ switch strings.TrimSpace(contentType) {
+ case "input_text", "output_text", "text":
+ if text, ok := contentItem["text"].(string); ok {
+ textChars += len(text)
+ }
+ case "input_image":
+ imageURL := extractOpenAIWSImageURL(contentItem["image_url"])
+ if imageURL == "" {
+ return
+ }
+ if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
+ imageDataURLs++
+ imageDataURLChars += len(imageURL)
+ return
+ }
+ imageRemoteURLs++
+ }
+ }
+
+ handleInputItem := func(inputItem map[string]any) {
+ if content, ok := inputItem["content"].([]any); ok {
+ for _, rawContent := range content {
+ contentItem, ok := rawContent.(map[string]any)
+ if !ok {
+ continue
+ }
+ handleContentItem(contentItem)
+ }
+ return
+ }
+
+ itemType, _ := inputItem["type"].(string)
+ switch strings.TrimSpace(itemType) {
+ case "input_text", "output_text", "text":
+ if text, ok := inputItem["text"].(string); ok {
+ textChars += len(text)
+ }
+ case "input_image":
+ imageURL := extractOpenAIWSImageURL(inputItem["image_url"])
+ if imageURL == "" {
+ return
+ }
+ if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
+ imageDataURLs++
+ imageDataURLChars += len(imageURL)
+ return
+ }
+ imageRemoteURLs++
+ }
+ }
+
+ for _, rawItem := range items {
+ inputItem, ok := rawItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ handleInputItem(inputItem)
+ }
+
+ return fmt.Sprintf(
+ "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d",
+ itemCount,
+ textChars,
+ imageDataURLs,
+ imageDataURLChars,
+ imageRemoteURLs,
+ )
+}
+
+func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) {
+ if len(payload) == 0 || strings.TrimSpace(key) == "" {
+ return
+ }
+ if _, exists := payload[key]; !exists {
+ return
+ }
+ delete(payload, key)
+ *removed = append(*removed, key)
+}
+
+// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段,
+// 避免重试成功却改变原始请求语义。
+// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。
+func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) {
+ if len(payload) == 0 {
+ return "empty", nil
+ }
+ if attempt <= 1 {
+ return "full", nil
+ }
+
+ removed := make([]string, 0, 2)
+ if attempt >= 2 {
+ dropOpenAIWSPayloadKey(payload, "include", &removed)
+ }
+
+ if len(removed) == 0 {
+ return "full", nil
+ }
+ sort.Strings(removed)
+ return "trim_optional_fields", removed
+}
+
+func logOpenAIWSModeInfo(format string, args ...any) {
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
+}
+
+func isOpenAIWSModeDebugEnabled() bool {
+ return logger.L().Core().Enabled(zap.DebugLevel)
+}
+
+func logOpenAIWSModeDebug(format string, args ...any) {
+ if !isOpenAIWSModeDebugEnabled() {
+ return
+ }
+ logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
+}
+
+func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) {
+ if err == nil {
+ return
+ }
+ logger.L().Warn(
+ "openai.ws_bind_response_account_failed",
+ zap.Int64("group_id", groupID),
+ zap.Int64("account_id", accountID),
+ zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)),
+ zap.Error(err),
+ )
+}
+
+func logOpenAIWSIngressTurnAbort(
+ accountID int64,
+ turn int,
+ connID string,
+ reason openAIWSIngressTurnAbortReason,
+ expected bool,
+ cause error,
+) {
+ causeValue := "-"
+ if cause != nil {
+ causeValue = truncateOpenAIWSLogValue(cause.Error(), openAIWSLogValueMaxLen)
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_aborted account_id=%d turn=%d conn_id=%s reason=%s expected=%v cause=%s",
+ accountID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(string(reason)),
+ expected,
+ causeValue,
+ )
+}
+
+func sortedKeys(m map[string]any) []string {
+ if len(m) == 0 {
+ return nil
+ }
+ keys := make([]string, 0, len(m))
+ for k := range m {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ return keys
+}
+
+func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) {
+ return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes)
+}
+
+func dropPreviousResponseIDFromRawPayloadWithDeleteFn(
+ payload []byte,
+ deleteFn func([]byte, string) ([]byte, error),
+) ([]byte, bool, error) {
+ if len(payload) == 0 {
+ return payload, false, nil
+ }
+ if !gjson.GetBytes(payload, "previous_response_id").Exists() {
+ return payload, false, nil
+ }
+ if deleteFn == nil {
+ deleteFn = sjson.DeleteBytes
+ }
+
+ updated := payload
+ for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses &&
+ gjson.GetBytes(updated, "previous_response_id").Exists(); i++ {
+ next, err := deleteFn(updated, "previous_response_id")
+ if err != nil {
+ return payload, false, err
+ }
+ updated = next
+ }
+ return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil
+}
+
+func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) {
+ normalizedPrevID := strings.TrimSpace(previousResponseID)
+ if len(payload) == 0 || normalizedPrevID == "" {
+ return payload, nil
+ }
+ if current := openAIWSPayloadStringFromRaw(payload, "previous_response_id"); current == normalizedPrevID {
+ return payload, nil
+ }
+ updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID)
+ if err == nil {
+ return updated, nil
+ }
+
+ var reqBody map[string]any
+ if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil {
+ return nil, err
+ }
+ reqBody["previous_response_id"] = normalizedPrevID
+ rebuilt, marshalErr := json.Marshal(reqBody)
+ if marshalErr != nil {
+ return nil, marshalErr
+ }
+ return rebuilt, nil
+}
+
+func shouldInferIngressFunctionCallOutputPreviousResponseID(
+ storeDisabled bool,
+ turn int,
+ hasFunctionCallOutput bool,
+ currentPreviousResponseID string,
+ expectedPreviousResponseID string,
+) bool {
+ if !storeDisabled || turn <= 0 || !hasFunctionCallOutput {
+ return false
+ }
+ if strings.TrimSpace(currentPreviousResponseID) != "" {
+ return false
+ }
+ return strings.TrimSpace(expectedPreviousResponseID) != ""
+}
+
+func alignStoreDisabledPreviousResponseID(
+ payload []byte,
+ expectedPreviousResponseID string,
+) ([]byte, bool, error) {
+ if len(payload) == 0 {
+ return payload, false, nil
+ }
+ expected := strings.TrimSpace(expectedPreviousResponseID)
+ if expected == "" {
+ return payload, false, nil
+ }
+ current := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ if current == "" || current == expected {
+ return payload, false, nil
+ }
+
+ // 常见路径(无重复 key)直接 set,避免先 delete 再 set 的双遍处理。
+ // 仅在检测到重复 key 时走 drop+set 慢路径,确保最终语义一致。
+ if bytes.Count(payload, []byte(`"previous_response_id"`)) <= 1 {
+ updated, setErr := setPreviousResponseIDToRawPayload(payload, expected)
+ if setErr != nil {
+ return payload, false, setErr
+ }
+ return updated, !bytes.Equal(updated, payload), nil
+ }
+
+ withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ if dropErr != nil {
+ return payload, false, dropErr
+ }
+ if !removed {
+ return payload, false, nil
+ }
+ updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected)
+ if setErr != nil {
+ return payload, false, setErr
+ }
+ return updated, true, nil
+}
+
+func cloneOpenAIWSPayloadBytes(payload []byte) []byte {
+ if len(payload) == 0 {
+ return nil
+ }
+ cloned := make([]byte, len(payload))
+ copy(cloned, payload)
+ return cloned
+}
+
+func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage {
+ if items == nil {
+ return nil
+ }
+ cloned := make([]json.RawMessage, 0, len(items))
+ for idx := range items {
+ cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx])))
+ }
+ return cloned
+}
+
+func cloneOpenAIWSJSONRawString(raw string) []byte {
+ if strings.TrimSpace(raw) == "" {
+ return nil
+ }
+ cloned := make([]byte, len(raw))
+ copy(cloned, raw)
+ return cloned
+}
+
+func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) {
+ trimmed := bytes.TrimSpace(raw)
+ if len(trimmed) == 0 {
+ return nil, errors.New("json is empty")
+ }
+ var decoded any
+ if err := json.Unmarshal(trimmed, &decoded); err != nil {
+ return nil, err
+ }
+ return json.Marshal(decoded)
+}
+
+func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte {
+ normalized, err := normalizeOpenAIWSJSONForCompare(raw)
+ if err != nil {
+ return bytes.TrimSpace(raw)
+ }
+ return normalized
+}
+
+func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) {
+ if len(payload) == 0 {
+ return nil, errors.New("payload is empty")
+ }
+ var decoded map[string]any
+ if err := json.Unmarshal(payload, &decoded); err != nil {
+ return nil, err
+ }
+ delete(decoded, "input")
+ delete(decoded, "previous_response_id")
+ return json.Marshal(decoded)
+}
+
+func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) {
+ if len(payload) == 0 {
+ return nil, false, nil
+ }
+ inputValue := gjson.GetBytes(payload, "input")
+ if !inputValue.Exists() {
+ return nil, false, nil
+ }
+ if inputValue.Type == gjson.JSON {
+ raw := strings.TrimSpace(inputValue.Raw)
+ if strings.HasPrefix(raw, "[") {
+ var items []json.RawMessage
+ if err := json.Unmarshal([]byte(raw), &items); err != nil {
+ return nil, true, err
+ }
+ return items, true, nil
+ }
+ return []json.RawMessage{json.RawMessage(raw)}, true, nil
+ }
+ if inputValue.Type == gjson.String {
+ encoded, _ := json.Marshal(inputValue.String())
+ return []json.RawMessage{encoded}, true, nil
+ }
+ return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil
+}
+
+func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) {
+ previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload)
+ if prevErr != nil {
+ return false, prevErr
+ }
+ currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
+ if currentErr != nil {
+ return false, currentErr
+ }
+ if !previousExists && !currentExists {
+ return true, nil
+ }
+ if !previousExists {
+ return len(currentItems) == 0, nil
+ }
+ if !currentExists {
+ return len(previousItems) == 0, nil
+ }
+ if len(currentItems) < len(previousItems) {
+ return false, nil
+ }
+
+ for idx := range previousItems {
+ previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx])
+ currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx])
+ if !bytes.Equal(previousNormalized, currentNormalized) {
+ return false, nil
+ }
+ }
+ return true, nil
+}
+
+func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool {
+ if len(prefix) == 0 {
+ return true
+ }
+ if len(items) < len(prefix) {
+ return false
+ }
+ for idx := range prefix {
+ previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx])
+ currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx])
+ if !bytes.Equal(previousNormalized, currentNormalized) {
+ return false
+ }
+ }
+ return true
+}
+
+func limitOpenAIWSReplayInputSequenceByBytes(items []json.RawMessage, maxBytes int) []json.RawMessage {
+ if len(items) == 0 {
+ return nil
+ }
+ if maxBytes <= 0 {
+ return cloneOpenAIWSRawMessages(items)
+ }
+
+ start := len(items)
+ total := 2 // "[]"
+ for idx := len(items) - 1; idx >= 0; idx-- {
+ itemBytes := len(items[idx])
+ if start != len(items) {
+ itemBytes++ // comma
+ }
+ if total+itemBytes > maxBytes {
+ // Keep at least the newest item to avoid creating an empty replay input.
+ if start == len(items) {
+ start = idx
+ }
+ break
+ }
+ total += itemBytes
+ start = idx
+ }
+ if start < 0 || start > len(items) {
+ start = len(items) - 1
+ }
+ return cloneOpenAIWSRawMessages(items[start:])
+}
+
+func buildOpenAIWSReplayInputSequence(
+ previousFullInput []json.RawMessage,
+ previousFullInputExists bool,
+ currentPayload []byte,
+ hasPreviousResponseID bool,
+) ([]json.RawMessage, bool, error) {
+ currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
+ if currentErr != nil {
+ return nil, false, currentErr
+ }
+ candidate := []json.RawMessage(nil)
+ exists := false
+ if !hasPreviousResponseID {
+ candidate = cloneOpenAIWSRawMessages(currentItems)
+ exists = currentExists
+ if !exists {
+ return candidate, false, nil
+ }
+ return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), true, nil
+ }
+ if !previousFullInputExists {
+ candidate = cloneOpenAIWSRawMessages(currentItems)
+ exists = currentExists
+ if !exists {
+ return candidate, false, nil
+ }
+ return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), true, nil
+ }
+ if !currentExists || len(currentItems) == 0 {
+ candidate = cloneOpenAIWSRawMessages(previousFullInput)
+ exists = true
+ return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil
+ }
+ if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) {
+ candidate = cloneOpenAIWSRawMessages(currentItems)
+ exists = true
+ return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil
+ }
+ merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems))
+ merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...)
+ merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...)
+ candidate = merged
+ exists = true
+ return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil
+}
+
+func openAIWSInputAppearsEditedFromPreviousFullInput(
+ previousFullInput []json.RawMessage,
+ previousFullInputExists bool,
+ currentPayload []byte,
+ hasPreviousResponseID bool,
+) (bool, error) {
+ if !hasPreviousResponseID || !previousFullInputExists {
+ return false, nil
+ }
+ currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
+ if currentErr != nil {
+ return false, currentErr
+ }
+ if !currentExists || len(currentItems) == 0 {
+ return false, nil
+ }
+ if len(previousFullInput) < 2 {
+ // Single-item turns are ambiguous (could be a normal incremental replace), avoid false positives.
+ return false, nil
+ }
+ if len(currentItems) < len(previousFullInput) {
+ // Most delta appends only send the latest one/few items.
+ return false, nil
+ }
+ if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) {
+ // Full snapshot append or unchanged snapshot.
+ return false, nil
+ }
+ return true, nil
+}
+
+func setOpenAIWSPayloadInputSequence(
+ payload []byte,
+ fullInput []json.RawMessage,
+ fullInputExists bool,
+) ([]byte, error) {
+ if !fullInputExists {
+ return payload, nil
+ }
+ // Preserve [] vs null semantics when input exists but is empty.
+ inputForMarshal := fullInput
+ if inputForMarshal == nil {
+ inputForMarshal = []json.RawMessage{}
+ }
+ inputRaw, marshalErr := json.Marshal(inputForMarshal)
+ if marshalErr != nil {
+ return nil, marshalErr
+ }
+ return sjson.SetRawBytes(payload, "input", inputRaw)
+}
+
+func openAIWSNormalizeCallIDs(callIDs []string) []string {
+ if len(callIDs) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(callIDs))
+ normalized := make([]string, 0, len(callIDs))
+ for _, callID := range callIDs {
+ id := strings.TrimSpace(callID)
+ if id == "" {
+ continue
+ }
+ if _, exists := seen[id]; exists {
+ continue
+ }
+ seen[id] = struct{}{}
+ normalized = append(normalized, id)
+ }
+ sort.Strings(normalized)
+ return normalized
+}
+
+func openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload []byte) []string {
+ if len(payload) == 0 {
+ return nil
+ }
+ input := gjson.GetBytes(payload, "input")
+ if !input.Exists() {
+ return nil
+ }
+ callIDSet := make(map[string]struct{}, 4)
+ collect := func(item gjson.Result) {
+ if item.Type != gjson.JSON {
+ return
+ }
+ if strings.TrimSpace(item.Get("type").String()) != "function_call_output" {
+ return
+ }
+ callID := strings.TrimSpace(item.Get("call_id").String())
+ if callID == "" {
+ return
+ }
+ callIDSet[callID] = struct{}{}
+ }
+ if input.IsArray() {
+ input.ForEach(func(_, item gjson.Result) bool {
+ collect(item)
+ return true
+ })
+ } else {
+ collect(input)
+ }
+ if len(callIDSet) == 0 {
+ return nil
+ }
+ callIDs := make([]string, 0, len(callIDSet))
+ for callID := range callIDSet {
+ callIDs = append(callIDs, callID)
+ }
+ sort.Strings(callIDs)
+ return callIDs
+}
+
+func openAIWSHasToolCallContextInPayload(payload []byte) bool {
+ if len(payload) == 0 {
+ return false
+ }
+ input := gjson.GetBytes(payload, "input")
+ if !input.Exists() {
+ return false
+ }
+
+ hasContext := false
+ collect := func(item gjson.Result) {
+ if hasContext || item.Type != gjson.JSON {
+ return
+ }
+ itemType := strings.TrimSpace(item.Get("type").String())
+ if itemType != "tool_call" && itemType != "function_call" {
+ return
+ }
+ if strings.TrimSpace(item.Get("call_id").String()) == "" {
+ return
+ }
+ hasContext = true
+ }
+ if input.IsArray() {
+ input.ForEach(func(_, item gjson.Result) bool {
+ collect(item)
+ return !hasContext
+ })
+ return hasContext
+ }
+ collect(input)
+ return hasContext
+}
+
+func openAIWSHasItemReferenceForAllFunctionCallOutputsInPayload(payload []byte, functionCallOutputCallIDs []string) bool {
+ requiredCallIDs := openAIWSNormalizeCallIDs(functionCallOutputCallIDs)
+ if len(payload) == 0 || len(requiredCallIDs) == 0 {
+ return false
+ }
+ input := gjson.GetBytes(payload, "input")
+ if !input.Exists() {
+ return false
+ }
+
+ referenceIDSet := make(map[string]struct{}, len(requiredCallIDs))
+ collect := func(item gjson.Result) {
+ if item.Type != gjson.JSON {
+ return
+ }
+ if strings.TrimSpace(item.Get("type").String()) != "item_reference" {
+ return
+ }
+ referenceID := strings.TrimSpace(item.Get("id").String())
+ if referenceID == "" {
+ return
+ }
+ referenceIDSet[referenceID] = struct{}{}
+ }
+ if input.IsArray() {
+ input.ForEach(func(_, item gjson.Result) bool {
+ collect(item)
+ return true
+ })
+ } else {
+ collect(input)
+ }
+
+ if len(referenceIDSet) == 0 {
+ return false
+ }
+ for _, callID := range requiredCallIDs {
+ if _, ok := referenceIDSet[callID]; !ok {
+ return false
+ }
+ }
+ return true
+}
+
+func shouldProactivelyRejectIngressToolOutputWithoutPreviousResponseID(
+ storeDisabled bool,
+ hasFunctionCallOutput bool,
+ previousResponseID string,
+ hasToolOutputContext bool,
+) bool {
+ if !storeDisabled || !hasFunctionCallOutput {
+ return false
+ }
+ if strings.TrimSpace(previousResponseID) != "" {
+ return false
+ }
+ return !hasToolOutputContext
+}
+
+func openAIWSFindMissingCallIDs(requiredCallIDs []string, actualCallIDs []string) []string {
+ required := openAIWSNormalizeCallIDs(requiredCallIDs)
+ if len(required) == 0 {
+ return nil
+ }
+ actualSet := make(map[string]struct{}, len(actualCallIDs))
+ for _, callID := range actualCallIDs {
+ id := strings.TrimSpace(callID)
+ if id == "" {
+ continue
+ }
+ actualSet[id] = struct{}{}
+ }
+ missing := make([]string, 0, len(required))
+ for _, callID := range required {
+ if _, ok := actualSet[callID]; ok {
+ continue
+ }
+ missing = append(missing, callID)
+ }
+ return missing
+}
+
+func openAIWSInjectFunctionCallOutputItems(payload []byte, callIDs []string, outputValue string) ([]byte, int, error) {
+ normalizedCallIDs := openAIWSNormalizeCallIDs(callIDs)
+ if len(normalizedCallIDs) == 0 {
+ return payload, 0, nil
+ }
+ inputItems, inputExists, inputErr := openAIWSExtractNormalizedInputSequence(payload)
+ if inputErr != nil {
+ return nil, 0, inputErr
+ }
+ if !inputExists {
+ inputItems = []json.RawMessage{}
+ }
+ updatedInput := make([]json.RawMessage, 0, len(inputItems)+len(normalizedCallIDs))
+ updatedInput = append(updatedInput, cloneOpenAIWSRawMessages(inputItems)...)
+ for _, callID := range normalizedCallIDs {
+ rawItem, marshalErr := json.Marshal(map[string]any{
+ "type": "function_call_output",
+ "call_id": callID,
+ "output": outputValue,
+ })
+ if marshalErr != nil {
+ return nil, 0, marshalErr
+ }
+ updatedInput = append(updatedInput, json.RawMessage(rawItem))
+ }
+ updatedPayload, setErr := setOpenAIWSPayloadInputSequence(payload, updatedInput, true)
+ if setErr != nil {
+ return nil, 0, setErr
+ }
+ return updatedPayload, len(normalizedCallIDs), nil
+}
+
+func shouldKeepIngressPreviousResponseID(
+ previousPayload []byte,
+ currentPayload []byte,
+ lastTurnResponseID string,
+ hasFunctionCallOutput bool,
+ expectedPendingCallIDs []string,
+ functionCallOutputCallIDs []string,
+) (bool, string, error) {
+ if hasFunctionCallOutput {
+ if len(expectedPendingCallIDs) == 0 {
+ return true, "has_function_call_output", nil
+ }
+ if len(openAIWSFindMissingCallIDs(expectedPendingCallIDs, functionCallOutputCallIDs)) > 0 {
+ return false, "function_call_output_call_id_mismatch", nil
+ }
+ return true, "function_call_output_call_id_match", nil
+ }
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
+ if currentPreviousResponseID == "" {
+ return false, "missing_previous_response_id", nil
+ }
+ expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
+ if expectedPreviousResponseID == "" {
+ return false, "missing_last_turn_response_id", nil
+ }
+ if currentPreviousResponseID != expectedPreviousResponseID {
+ return false, "previous_response_id_mismatch", nil
+ }
+ if len(previousPayload) == 0 {
+ return false, "missing_previous_turn_payload", nil
+ }
+
+ previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload)
+ if previousComparableErr != nil {
+ return false, "non_input_compare_error", previousComparableErr
+ }
+ currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
+ if currentComparableErr != nil {
+ return false, "non_input_compare_error", currentComparableErr
+ }
+ if !bytes.Equal(previousComparable, currentComparable) {
+ return false, "non_input_changed", nil
+ }
+ return true, "strict_incremental_ok", nil
+}
+
+type openAIWSIngressPreviousTurnStrictState struct {
+ nonInputComparable []byte
+}
+
+func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) {
+ if len(payload) == 0 {
+ return nil, nil
+ }
+ nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload)
+ if nonInputErr != nil {
+ return nil, nonInputErr
+ }
+ return &openAIWSIngressPreviousTurnStrictState{
+ nonInputComparable: nonInputComparable,
+ }, nil
+}
+
+func shouldKeepIngressPreviousResponseIDWithStrictState(
+ previousState *openAIWSIngressPreviousTurnStrictState,
+ currentPayload []byte,
+ lastTurnResponseID string,
+ hasFunctionCallOutput bool,
+ expectedPendingCallIDs []string,
+ functionCallOutputCallIDs []string,
+) (bool, string, error) {
+ if hasFunctionCallOutput {
+ if len(expectedPendingCallIDs) == 0 {
+ return true, "has_function_call_output", nil
+ }
+ if len(openAIWSFindMissingCallIDs(expectedPendingCallIDs, functionCallOutputCallIDs)) > 0 {
+ return false, "function_call_output_call_id_mismatch", nil
+ }
+ return true, "function_call_output_call_id_match", nil
+ }
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
+ if currentPreviousResponseID == "" {
+ return false, "missing_previous_response_id", nil
+ }
+ expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
+ if expectedPreviousResponseID == "" {
+ return false, "missing_last_turn_response_id", nil
+ }
+ if currentPreviousResponseID != expectedPreviousResponseID {
+ return false, "previous_response_id_mismatch", nil
+ }
+ if previousState == nil {
+ return false, "missing_previous_turn_payload", nil
+ }
+
+ currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
+ if currentComparableErr != nil {
+ return false, "non_input_compare_error", currentComparableErr
+ }
+ if !bytes.Equal(previousState.nonInputComparable, currentComparable) {
+ return false, "non_input_changed", nil
+ }
+ return true, "strict_incremental_ok", nil
+}
+
+func payloadAsJSON(payload map[string]any) string {
+ return string(payloadAsJSONBytes(payload))
+}
+
+func normalizeOpenAIWSPreferredConnID(connID string) (string, bool) {
+ trimmed := strings.TrimSpace(connID)
+ if trimmed == "" {
+ return "", false
+ }
+ if strings.HasPrefix(trimmed, openAIWSConnIDPrefixCtx) {
+ return trimmed, true
+ }
+ if strings.HasPrefix(trimmed, openAIWSConnIDPrefixLegacy) {
+ return trimmed, true
+ }
+ return "", false
+}
+
+func openAIWSPreferredConnIDFromResponse(stateStore OpenAIWSStateStore, responseID string) string {
+ if stateStore == nil {
+ return ""
+ }
+ normalizedResponseID := strings.TrimSpace(responseID)
+ if normalizedResponseID == "" {
+ return ""
+ }
+ connID, ok := stateStore.GetResponseConn(normalizedResponseID)
+ if !ok {
+ return ""
+ }
+ normalizedConnID, ok := normalizeOpenAIWSPreferredConnID(connID)
+ if !ok {
+ return ""
+ }
+ return normalizedConnID
+}
+
+func payloadAsJSONBytes(payload map[string]any) []byte {
+ if len(payload) == 0 {
+ return []byte("{}")
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return []byte("{}")
+ }
+ return body
+}
+
+func isOpenAIWSTerminalEvent(eventType string) bool {
+ switch eventType {
+ case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
+ return true
+ default:
+ return false
+ }
+}
+
+func shouldPersistOpenAIWSLastResponseID(terminalEventType string) bool {
+ switch terminalEventType {
+ case "response.completed", "response.done":
+ return true
+ default:
+ return false
+ }
+}
+
+func isOpenAIWSTokenEvent(eventType string) bool {
+ if eventType == "" {
+ return false
+ }
+ switch eventType {
+ case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
+ return false
+ }
+ if strings.Contains(eventType, ".delta") {
+ return true
+ }
+ if strings.HasPrefix(eventType, "response.output_text") {
+ return true
+ }
+ if strings.HasPrefix(eventType, "response.output") {
+ return true
+ }
+ return eventType == "response.completed" || eventType == "response.done"
+}
+
+func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte {
+ if len(message) == 0 {
+ return message
+ }
+ if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel {
+ return message
+ }
+ if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) {
+ return message
+ }
+ modelValues := gjson.GetManyBytes(message, "model", "response.model")
+ replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel
+ replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel
+ if !replaceModel && !replaceResponseModel {
+ return message
+ }
+ updated := message
+ if replaceModel {
+ if next, err := sjson.SetBytes(updated, "model", toModel); err == nil {
+ updated = next
+ }
+ }
+ if replaceResponseModel {
+ if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil {
+ updated = next
+ }
+ }
+ return updated
+}
+
+func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) {
+ if usage == nil || len(body) == 0 {
+ return
+ }
+ values := gjson.GetManyBytes(
+ body,
+ "usage.input_tokens",
+ "usage.output_tokens",
+ "usage.input_tokens_details.cached_tokens",
+ )
+ usage.InputTokens = int(values[0].Int())
+ usage.OutputTokens = int(values[1].Int())
+ usage.CacheReadInputTokens = int(values[2].Int())
+}
+
+func getOpenAIGroupIDFromContext(c *gin.Context) int64 {
+ if c == nil {
+ return 0
+ }
+ value, exists := c.Get("api_key")
+ if !exists {
+ return 0
+ }
+ apiKey, ok := value.(*APIKey)
+ if !ok || apiKey == nil || apiKey.GroupID == nil {
+ return 0
+ }
+ return *apiKey.GroupID
+}
+
+func openAIWSIngressFallbackSessionSeedFromContext(c *gin.Context) string {
+ if c == nil {
+ return ""
+ }
+ value, exists := c.Get("api_key")
+ if !exists {
+ return ""
+ }
+ apiKey, ok := value.(*APIKey)
+ if !ok || apiKey == nil {
+ return ""
+ }
+ gid := int64(0)
+ if apiKey.GroupID != nil {
+ gid = *apiKey.GroupID
+ }
+ userID := int64(0)
+ if apiKey.User != nil {
+ userID = apiKey.User.ID
+ }
+ return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKey.ID)
+}
diff --git a/backend/internal/service/openai_ws_upstream_pump_test.go b/backend/internal/service/openai_ws_upstream_pump_test.go
new file mode 100644
index 000000000..26a7439dc
--- /dev/null
+++ b/backend/internal/service/openai_ws_upstream_pump_test.go
@@ -0,0 +1,1894 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "io"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// 辅助:构造测试用上游事件 JSON
+// ---------------------------------------------------------------------------
+
+func pumpTestEvent(eventType string) []byte {
+ m := map[string]any{"type": eventType}
+ b, _ := json.Marshal(m)
+ return b
+}
+
+func pumpTestEventWithResponseID(eventType, responseID string) []byte {
+ m := map[string]any{"type": eventType, "response": map[string]any{"id": responseID}}
+ b, _ := json.Marshal(m)
+ return b
+}
+
+// ---------------------------------------------------------------------------
+// 辅助:模拟上游连接(支持按序返回事件、延迟、错误注入)
+// ---------------------------------------------------------------------------
+
+type pumpTestConn struct {
+ mu sync.Mutex
+ events []pumpTestConnEvent
+ readCount int
+ closed bool
+ closedCh chan struct{}
+ ignoreCtx bool
+ pingErr error
+ writeErr error
+ writeCount int
+}
+
+type pumpTestConnEvent struct {
+ data []byte
+ err error
+ delay time.Duration
+}
+
+func newPumpTestConn(events ...pumpTestConnEvent) *pumpTestConn {
+ return &pumpTestConn{
+ events: events,
+ closedCh: make(chan struct{}),
+ }
+}
+
+func (c *pumpTestConn) WriteJSON(_ context.Context, _ any) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.writeCount++
+ return c.writeErr
+}
+
+func (c *pumpTestConn) ReadMessage(ctx context.Context) ([]byte, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ c.mu.Lock()
+ if c.closed {
+ c.mu.Unlock()
+ return nil, errOpenAIWSConnClosed
+ }
+ if len(c.events) == 0 {
+ c.mu.Unlock()
+ if c.ignoreCtx {
+ <-c.closedCh
+ return nil, io.EOF
+ }
+ // 阻塞直到上下文取消,模拟上游无更多事件
+ <-ctx.Done()
+ return nil, ctx.Err()
+ }
+ evt := c.events[0]
+ c.events = c.events[1:]
+ c.readCount++
+ c.mu.Unlock()
+
+ if evt.delay > 0 {
+ timer := time.NewTimer(evt.delay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-timer.C:
+ }
+ }
+ return evt.data, evt.err
+}
+
+func (c *pumpTestConn) Ping(_ context.Context) error { return c.pingErr }
+
+func (c *pumpTestConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if !c.closed {
+ c.closed = true
+ close(c.closedCh)
+ }
+ return nil
+}
+
+// ---------------------------------------------------------------------------
+// 辅助:模拟 lease 接口(仅泵测试所需的读写方法)
+// ---------------------------------------------------------------------------
+
+type pumpTestLease struct {
+ conn *pumpTestConn
+ broken atomic.Bool
+}
+
+func (l *pumpTestLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) {
+ readCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ return l.conn.ReadMessage(readCtx)
+}
+
+func (l *pumpTestLease) MarkBroken() {
+ l.broken.Store(true)
+ if l.conn != nil {
+ _ = l.conn.Close()
+ }
+}
+
+func (l *pumpTestLease) IsBroken() bool { return l.broken.Load() }
+
+// ---------------------------------------------------------------------------
+// 辅助:运行泵 goroutine 并收集所有产出的事件
+// ---------------------------------------------------------------------------
+
+// startPump 模拟 sendAndRelay 中的泵 goroutine,返回事件 channel 和取消函数。
+func startPump(ctx context.Context, lease *pumpTestLease, readTimeout time.Duration) (chan openAIWSUpstreamPumpEvent, context.CancelFunc) {
+ pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize)
+ pumpCtx, pumpCancel := context.WithCancel(ctx)
+ go func() {
+ defer close(pumpEventCh)
+ for {
+ msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, readTimeout)
+ select {
+ case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}:
+ case <-pumpCtx.Done():
+ return
+ }
+ if readErr != nil {
+ return
+ }
+ evtType, _ := parseOpenAIWSEventType(msg)
+ if isOpenAIWSTerminalEvent(evtType) || evtType == "error" {
+ return
+ }
+ }
+ }()
+ return pumpEventCh, pumpCancel
+}
+
+// collectAll 从 channel 读取所有事件直到关闭。
+func collectAll(ch chan openAIWSUpstreamPumpEvent) []openAIWSUpstreamPumpEvent {
+ var result []openAIWSUpstreamPumpEvent
+ for evt := range ch {
+ result = append(result, evt)
+ }
+ return result
+}
+
+// ---------------------------------------------------------------------------
+// 测试:openAIWSUpstreamPumpEvent 结构体
+// ---------------------------------------------------------------------------
+
+func TestOpenAIWSUpstreamPumpEvent_Fields(t *testing.T) {
+ t.Parallel()
+
+ t.Run("message_only", func(t *testing.T) {
+ evt := openAIWSUpstreamPumpEvent{message: []byte("hello")}
+ assert.Equal(t, []byte("hello"), evt.message)
+ assert.NoError(t, evt.err)
+ })
+
+ t.Run("error_only", func(t *testing.T) {
+ evt := openAIWSUpstreamPumpEvent{err: io.EOF}
+ assert.Nil(t, evt.message)
+ assert.ErrorIs(t, evt.err, io.EOF)
+ })
+
+ t.Run("both_fields", func(t *testing.T) {
+ evt := openAIWSUpstreamPumpEvent{message: []byte("partial"), err: io.ErrUnexpectedEOF}
+ assert.Equal(t, []byte("partial"), evt.message)
+ assert.ErrorIs(t, evt.err, io.ErrUnexpectedEOF)
+ })
+}
+
+func TestOpenAIWSUpstreamPumpBufferSize(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, 16, openAIWSUpstreamPumpBufferSize, "缓冲大小应为 16")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵 goroutine 正常事件流
+// ---------------------------------------------------------------------------
+
+func TestPump_NormalEventFlow(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+
+ require.Len(t, events, 4)
+ for _, evt := range events {
+ assert.NoError(t, evt.err)
+ assert.NotEmpty(t, evt.message)
+ }
+ // 验证最后一个是终端事件
+ lastType, _ := parseOpenAIWSEventType(events[3].message)
+ assert.True(t, isOpenAIWSTerminalEvent(lastType))
+}
+
+func TestPump_TerminalEventStopsPump(t *testing.T) {
+ t.Parallel()
+ terminalTypes := []string{
+ "response.completed",
+ "response.done",
+ "response.failed",
+ "response.incomplete",
+ "response.cancelled",
+ "response.canceled",
+ }
+ for _, tt := range terminalTypes {
+ tt := tt
+ t.Run(tt, func(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent(tt)},
+ // 以下事件不应该被读取
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2, "终端事件 %s 后泵应停止", tt)
+ assert.NoError(t, events[0].err)
+ assert.NoError(t, events[1].err)
+ })
+ }
+}
+
+func TestPump_ErrorEventStopsPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("error")},
+ // 不应被读取
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2, "error 事件后泵应停止")
+ evtType, _ := parseOpenAIWSEventType(events[1].message)
+ assert.Equal(t, "error", evtType)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵 goroutine 读取错误传播
+// ---------------------------------------------------------------------------
+
+func TestPump_ReadErrorPropagated(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{err: io.ErrUnexpectedEOF},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2)
+ assert.NoError(t, events[0].err)
+ assert.ErrorIs(t, events[1].err, io.ErrUnexpectedEOF)
+}
+
+func TestPump_ReadErrorOnFirstEvent(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{err: errors.New("connection refused")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 1)
+ assert.Error(t, events[0].err)
+ assert.Contains(t, events[0].err.Error(), "connection refused")
+}
+
+func TestPump_EOFError(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{err: io.EOF},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 3)
+ assert.ErrorIs(t, events[2].err, io.EOF)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:上下文取消终止泵
+// ---------------------------------------------------------------------------
+
+func TestPump_ContextCancellationStopsPump(t *testing.T) {
+ t.Parallel()
+ // 连接永远阻塞在第二次读取
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 无更多事件,ReadMessage 将阻塞直到 ctx 取消
+ )
+ lease := &pumpTestLease{conn: conn}
+ ctx, ctxCancel := context.WithCancel(context.Background())
+ ch, pumpCancel := startPump(ctx, lease, 30*time.Second)
+ defer pumpCancel()
+
+ // 读取第一个事件
+ evt := <-ch
+ assert.NoError(t, evt.err)
+
+ // 取消上下文
+ ctxCancel()
+
+ // 泵应该退出,channel 应该关闭
+ events := collectAll(ch)
+ // 可能收到一个 context.Canceled 错误事件
+ for _, e := range events {
+ assert.Error(t, e.err)
+ }
+}
+
+func TestPump_PumpCancelStopsPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second)
+
+ evt := <-ch
+ assert.NoError(t, evt.err)
+
+ // 调用 pumpCancel 应终止泵
+ pumpCancel()
+
+ // channel 应被关闭
+ events := collectAll(ch)
+ for _, e := range events {
+ assert.Error(t, e.err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:缓冲行为
+// ---------------------------------------------------------------------------
+
+func TestPump_BufferAllowsConcurrentReadWrite(t *testing.T) {
+ t.Parallel()
+ // 生成超过缓冲大小的事件,验证不会死锁
+ numEvents := openAIWSUpstreamPumpBufferSize + 5
+ connEvents := make([]pumpTestConnEvent, 0, numEvents)
+ for i := 0; i < numEvents-1; i++ {
+ connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")})
+ }
+ connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.completed")})
+
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, numEvents)
+ for _, evt := range events {
+ assert.NoError(t, evt.err)
+ }
+}
+
+func TestPump_SlowConsumerDoesNotBlock(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ // 模拟慢消费者
+ var events []openAIWSUpstreamPumpEvent
+ for evt := range ch {
+ events = append(events, evt)
+ time.Sleep(10 * time.Millisecond) // 慢消费
+ }
+ require.Len(t, events, 4)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水定时器机制
+// ---------------------------------------------------------------------------
+
+func TestPump_DrainTimerCancelsPump(t *testing.T) {
+ t.Parallel()
+ // 模拟:客户端断连后,排水定时器到期取消泵
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 第二次读取会阻塞(模拟上游仍在生成但还没发出事件)
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second)
+
+ // 读取第一个事件
+ evt := <-ch
+ assert.NoError(t, evt.err)
+
+ // 模拟排水定时器:50ms 后取消泵(正式代码中是 5 秒)
+ drainTimer := time.AfterFunc(50*time.Millisecond, pumpCancel)
+ defer drainTimer.Stop()
+
+ // 等待 channel 关闭
+ start := time.Now()
+ remaining := collectAll(ch)
+ elapsed := time.Since(start)
+
+ // 应在 50ms 附近退出,而非 30 秒
+ assert.Less(t, elapsed, 2*time.Second, "排水定时器应在约 50ms 后终止泵")
+
+ // 可能收到 context.Canceled 错误事件
+ for _, e := range remaining {
+ assert.Error(t, e.err)
+ }
+}
+
+func TestPump_DrainDeadlineCheckInMainLoop(t *testing.T) {
+ t.Parallel()
+ // 模拟主循环中的排水超时检查逻辑
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 加延迟模拟上游慢响应
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 80 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ var eventsBeforeDrain []openAIWSUpstreamPumpEvent
+ drainTriggered := false
+
+ for evt := range ch {
+ // 检查排水超时
+ if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ pumpCancel()
+ drainTriggered = true
+ break
+ }
+ if evt.err != nil {
+ break
+ }
+ eventsBeforeDrain = append(eventsBeforeDrain, evt)
+
+ // 模拟:第一个事件后客户端断连,设置极短的排水截止时间
+ if !clientDisconnected && len(eventsBeforeDrain) == 1 {
+ clientDisconnected = true
+ drainDeadline = time.Now().Add(30 * time.Millisecond)
+ }
+ }
+
+ // 排水截止时间为 30ms,第二个事件延迟 80ms,所以应该触发排水超时
+ assert.True(t, drainTriggered, "排水超时应被触发")
+ assert.Len(t, eventsBeforeDrain, 1, "排水前应只有 1 个事件")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:与上游事件延迟的并发行为
+// ---------------------------------------------------------------------------
+
+func TestPump_ReadDelayDoesNotBlockPreviousEvents(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 100 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ // 第一个事件应该立即可用
+ start := time.Now()
+ evt := <-ch
+ assert.NoError(t, evt.err)
+ assert.Less(t, time.Since(start), 50*time.Millisecond, "第一个事件应立即到达")
+
+ events := collectAll(ch)
+ require.Len(t, events, 2)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:空事件流
+// ---------------------------------------------------------------------------
+
+func TestPump_EmptyStreamContextCancel(t *testing.T) {
+ t.Parallel()
+ // 没有任何事件,连接阻塞,靠 context 取消
+ conn := newPumpTestConn() // 无事件
+ lease := &pumpTestLease{conn: conn}
+ ctx, ctxCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ defer ctxCancel()
+ ch, pumpCancel := startPump(ctx, lease, 30*time.Second)
+ defer pumpCancel()
+
+ events := collectAll(ch)
+ // context 取消后,泵的 select 可能选择 pumpCtx.Done() 分支直接退出(0 个事件),
+ // 也可能先将错误事件发送到 channel 后退出(1 个事件),两种行为都正确。
+ assert.LessOrEqual(t, len(events), 1, "最多应收到 1 个事件")
+ for _, evt := range events {
+ assert.Error(t, evt.err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:非终端/非错误事件不终止泵
+// ---------------------------------------------------------------------------
+
+func TestPump_NonTerminalEventsDoNotStopPump(t *testing.T) {
+ t.Parallel()
+ nonTerminalTypes := []string{
+ "response.created",
+ "response.in_progress",
+ "response.output_text.delta",
+ "response.content_part.added",
+ "response.output_item.added",
+ "response.reasoning_summary_text.delta",
+ }
+ connEvents := make([]pumpTestConnEvent, 0, len(nonTerminalTypes)+1)
+ for _, et := range nonTerminalTypes {
+ connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent(et)})
+ }
+ connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.completed")})
+
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, len(nonTerminalTypes)+1, "所有非终端事件 + 终端事件都应被传递")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:多次 pumpCancel 调用安全(幂等)
+// ---------------------------------------------------------------------------
+
+func TestPump_MultipleCancelSafe(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ events := collectAll(ch)
+ require.Len(t, events, 1)
+
+ // 多次调用 pumpCancel 不应 panic
+ assert.NotPanics(t, func() {
+ pumpCancel()
+ pumpCancel()
+ pumpCancel()
+ })
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵与主循环集成——模拟完整的 relay 消费模式
+// ---------------------------------------------------------------------------
+
+func TestPump_IntegrationRelayPattern(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_abc123")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEventWithResponseID("response.completed", "resp_abc123")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ // 模拟主循环处理
+ var responseID string
+ eventCount := 0
+ tokenEventCount := 0
+ var terminalEventType string
+ clientWriteCount := 0
+
+ for evt := range ch {
+ if evt.err != nil {
+ t.Fatalf("unexpected error: %v", evt.err)
+ }
+ eventType, evtRespID := parseOpenAIWSEventType(evt.message)
+ if responseID == "" && evtRespID != "" {
+ responseID = evtRespID
+ }
+ eventCount++
+ if isOpenAIWSTokenEvent(eventType) {
+ tokenEventCount++
+ }
+ // 模拟写客户端
+ clientWriteCount++
+
+ if isOpenAIWSTerminalEvent(eventType) {
+ terminalEventType = eventType
+ break
+ }
+ }
+
+ assert.Equal(t, "resp_abc123", responseID)
+ assert.Equal(t, 5, eventCount)
+ assert.GreaterOrEqual(t, tokenEventCount, 3, "至少 3 个 delta 事件应被计为 token 事件")
+ assert.Equal(t, 5, clientWriteCount)
+ assert.Equal(t, "response.completed", terminalEventType)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵 goroutine 在 channel 满时 + context 取消的行为
+// ---------------------------------------------------------------------------
+
+func TestPump_ChannelFullThenCancel(t *testing.T) {
+ t.Parallel()
+ // 生成大量事件但不消费,验证 pumpCancel 仍然能终止泵
+ numEvents := openAIWSUpstreamPumpBufferSize * 3
+ connEvents := make([]pumpTestConnEvent, numEvents)
+ for i := range connEvents {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}
+ }
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ // 等待缓冲区被填满
+ time.Sleep(50 * time.Millisecond)
+
+ // 取消泵
+ pumpCancel()
+
+ // 清空 channel
+ events := collectAll(ch)
+ // 应收到 bufferSize 到 bufferSize+1 个事件(泵在 channel 满时可能阻塞在 select)
+ assert.LessOrEqual(t, len(events), numEvents, "不应收到超过总事件数的事件")
+ assert.GreaterOrEqual(t, len(events), 1, "至少应收到一些事件")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:读取超时机制
+// ---------------------------------------------------------------------------
+
+func TestPump_ReadTimeoutTriggersError(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 第二次读取延迟超过超时
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 500 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ // 读取超时设为 50ms,远小于 500ms 延迟
+ ch, cancel := startPump(context.Background(), lease, 50*time.Millisecond)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2)
+ assert.NoError(t, events[0].err)
+ assert.Error(t, events[1].err, "第二次读取应超时")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵在 response.done 事件后停止(另一种终端事件)
+// ---------------------------------------------------------------------------
+
+func TestPump_ResponseDoneStopsPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.done")},
+ pumpTestConnEvent{data: pumpTestEvent("should_not_reach")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 3)
+ lastType, _ := parseOpenAIWSEventType(events[2].message)
+ assert.Equal(t, "response.done", lastType)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵在读取到 error event 后不继续读取更多事件
+// ---------------------------------------------------------------------------
+
+func TestPump_ErrorEventStopsReading(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("error")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, // 不应被读取
+ )
+ // 重写以追踪读取次数
+ origEvents := conn.events
+ conn.events = nil
+ var wrappedConn pumpTestConn
+ wrappedConn.closedCh = make(chan struct{})
+ wrappedConn.events = origEvents
+ wrappedLease := &pumpTestLease{conn: &wrappedConn}
+
+ ch, cancel := startPump(context.Background(), wrappedLease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 1, "error 事件后不应再读取更多事件")
+ evtType, _ := parseOpenAIWSEventType(events[0].message)
+ assert.Equal(t, "error", evtType)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:验证事件顺序保持不变
+// ---------------------------------------------------------------------------
+
+func TestPump_EventOrderPreserved(t *testing.T) {
+ t.Parallel()
+ expectedTypes := []string{
+ "response.created",
+ "response.in_progress",
+ "response.output_item.added",
+ "response.content_part.added",
+ "response.output_text.delta",
+ "response.output_text.delta",
+ "response.output_text.delta",
+ "response.output_text.done",
+ "response.content_part.done",
+ "response.output_item.done",
+ "response.completed",
+ }
+ connEvents := make([]pumpTestConnEvent, len(expectedTypes))
+ for i, et := range expectedTypes {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent(et)}
+ }
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, len(expectedTypes))
+ for i, evt := range events {
+ evtType, _ := parseOpenAIWSEventType(evt.message)
+ assert.Equal(t, expectedTypes[i], evtType, "事件 %d 类型不匹配", i)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:无效 JSON 消息不影响泵运行
+// ---------------------------------------------------------------------------
+
+func TestPump_InvalidJSONDoesNotStopPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: []byte("not json")},
+ pumpTestConnEvent{data: []byte("{invalid")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 3, "无效 JSON 不应终止泵")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:并发安全——多个消费者不会 panic
+// ---------------------------------------------------------------------------
+
+func TestPump_ConcurrentConsumeAndCancel(t *testing.T) {
+ t.Parallel()
+ connEvents := make([]pumpTestConnEvent, 100)
+ for i := range connEvents {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: time.Millisecond}
+ }
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ // 同时消费和取消,不应 panic
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ for range ch {
+ // 消费
+ }
+ }()
+ go func() {
+ defer wg.Done()
+ time.Sleep(20 * time.Millisecond)
+ pumpCancel()
+ }()
+
+ done := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ // 成功
+ case <-time.After(5 * time.Second):
+ t.Fatal("超时:并发消费和取消场景死锁")
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水定时器与正常终端事件的竞争
+// ---------------------------------------------------------------------------
+
+func TestPump_DrainTimerRaceWithTerminalEvent(t *testing.T) {
+ t.Parallel()
+ // 终端事件在排水定时器到期前到达
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 10 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ // 设置较长的排水定时器(200ms),终端事件应在 10ms 后到达
+ drainTimer := time.AfterFunc(200*time.Millisecond, pumpCancel)
+ defer drainTimer.Stop()
+
+ events := collectAll(ch)
+ // 终端事件应先到达
+ require.Len(t, events, 2)
+ lastType, _ := parseOpenAIWSEventType(events[1].message)
+ assert.Equal(t, "response.completed", lastType)
+ assert.NoError(t, events[1].err)
+
+ pumpCancel() // 清理
+}
+
+// ---------------------------------------------------------------------------
+// 测试:大量事件的吞吐量(确保泵不引入异常开销)
+// ---------------------------------------------------------------------------
+
+func TestPump_HighThroughput(t *testing.T) {
+ t.Parallel()
+ numEvents := 1000
+ connEvents := make([]pumpTestConnEvent, numEvents)
+ for i := range connEvents {
+ if i == numEvents-1 {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.completed")}
+ } else {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}
+ }
+ }
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ start := time.Now()
+ events := collectAll(ch)
+ elapsed := time.Since(start)
+
+ require.Len(t, events, numEvents)
+ assert.Less(t, elapsed, 2*time.Second, "1000 个事件应在 2 秒内完成")
+ for _, evt := range events {
+ assert.NoError(t, evt.err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:空消息(零字节)不终止泵
+// ---------------------------------------------------------------------------
+
+func TestPump_EmptyMessageDoesNotStopPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: []byte{}},
+ pumpTestConnEvent{data: nil},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 3, "空消息不应终止泵")
+}
+
+// ===========================================================================
+// 以下为消息泵模式新增代码路径的补充测试
+// ===========================================================================
+
+// ---------------------------------------------------------------------------
+// 测试:泵 channel 关闭但无终端事件(上游异常断连)
+// ---------------------------------------------------------------------------
+
+func TestPump_UnexpectedCloseDetectedByConsumer(t *testing.T) {
+ t.Parallel()
+ // 模拟:上游只发了非终端事件就断连(ReadMessage 返回 EOF)
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{err: io.EOF}, // 上游断连
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ // 模拟主循环消费:检查是否收到了终端事件
+ receivedTerminal := false
+ var lastErr error
+ for evt := range ch {
+ if evt.err != nil {
+ lastErr = evt.err
+ break
+ }
+ evtType, _ := parseOpenAIWSEventType(evt.message)
+ if isOpenAIWSTerminalEvent(evtType) {
+ receivedTerminal = true
+ break
+ }
+ }
+ // 未收到终端事件,但收到了 EOF 错误——消费者应识别为上游异常断连
+ assert.False(t, receivedTerminal, "不应收到终端事件")
+ assert.ErrorIs(t, lastErr, io.EOF, "应收到 EOF 错误标识上游断连")
+}
+
+func TestPump_ChannelCloseWithoutTerminalOrError(t *testing.T) {
+ t.Parallel()
+ // 极端情况:泵被外部取消(pumpCancel),channel 关闭但既无终端事件也无错误事件。
+ // 模拟中间事件后泵被取消。
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ // 无更多事件,ReadMessage 将阻塞
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second)
+
+ // 消费前两个事件
+ evt1 := <-ch
+ assert.NoError(t, evt1.err)
+ evt2 := <-ch
+ assert.NoError(t, evt2.err)
+
+ // 外部取消泵
+ pumpCancel()
+
+ // for-range 应退出,模拟 "泵 channel 关闭但未收到终端事件" 场景
+ receivedTerminal := false
+ for evt := range ch {
+ if evt.err == nil {
+ evtType, _ := parseOpenAIWSEventType(evt.message)
+ if isOpenAIWSTerminalEvent(evtType) {
+ receivedTerminal = true
+ }
+ }
+ }
+ assert.False(t, receivedTerminal, "泵被取消后不应再收到终端事件")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:lease.MarkBroken 场景验证
+// ---------------------------------------------------------------------------
+
+func TestPump_LeaseMarkedBrokenOnUnexpectedClose(t *testing.T) {
+ t.Parallel()
+ // 模拟主循环:泵关闭但无终端事件时应标记 lease 为 broken
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{err: io.ErrUnexpectedEOF},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ receivedTerminal := false
+ for evt := range ch {
+ if evt.err != nil {
+ // 模拟正式代码中的错误处理路径
+ lease.MarkBroken()
+ break
+ }
+ evtType, _ := parseOpenAIWSEventType(evt.message)
+ if isOpenAIWSTerminalEvent(evtType) {
+ receivedTerminal = true
+ break
+ }
+ }
+ // 如果 for-range 正常退出且未收到终端事件,也标记 broken
+ if !receivedTerminal {
+ lease.MarkBroken()
+ }
+
+ assert.True(t, lease.IsBroken(), "上游异常断连应标记 lease 为 broken")
+}
+
+func TestPump_LeaseNotBrokenOnNormalTerminal(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ for evt := range ch {
+ if evt.err != nil {
+ lease.MarkBroken()
+ break
+ }
+ }
+
+ assert.False(t, lease.IsBroken(), "正常终端事件不应标记 lease 为 broken")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水定时器只创建一次
+// ---------------------------------------------------------------------------
+
+func TestPump_DrainTimerCreatedOnlyOnce(t *testing.T) {
+ t.Parallel()
+ // 模拟多次"客户端断连"信号,验证排水定时器只创建一次
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 10 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ drainTimerCount := 0
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ var drainTimer *time.Timer
+
+ for evt := range ch {
+ if evt.err != nil {
+ break
+ }
+ // 每个事件后都"检测到客户端断连"
+ if !clientDisconnected {
+ clientDisconnected = true
+ }
+ // 排水定时器只在第一次断连时创建
+ if clientDisconnected && drainDeadline.IsZero() {
+ drainDeadline = time.Now().Add(500 * time.Millisecond)
+ drainTimer = time.AfterFunc(500*time.Millisecond, pumpCancel)
+ drainTimerCount++
+ }
+ }
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+
+ assert.Equal(t, 1, drainTimerCount, "排水定时器应只创建一次")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水定时器在正常完成前被 Stop
+// ---------------------------------------------------------------------------
+
+func TestPump_DrainTimerStoppedOnNormalCompletion(t *testing.T) {
+ t.Parallel()
+ // 终端事件在排水定时器到期前到达,验证定时器被正确停止
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ // 创建长时间排水定时器
+ drainTimer := time.AfterFunc(10*time.Second, pumpCancel)
+
+ var events []openAIWSUpstreamPumpEvent
+ for evt := range ch {
+ events = append(events, evt)
+ }
+
+ // 正常完成后停止排水定时器(模拟 defer drainTimer.Stop())
+ stopped := drainTimer.Stop()
+ pumpCancel() // 清理
+
+ assert.True(t, stopped, "定时器应尚未触发,Stop() 返回 true")
+ require.Len(t, events, 2)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水期间读取错误处理
+// ---------------------------------------------------------------------------
+
+func TestPump_ReadErrorDuringDrainTreatedAsDrainTimeout(t *testing.T) {
+ t.Parallel()
+ // 新代码:客户端已断连时任何读取错误都按排水超时处理(不仅限 DeadlineExceeded)
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{err: io.ErrUnexpectedEOF, delay: 20 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ clientDisconnected := false
+ var drainError error
+
+ for evt := range ch {
+ if !clientDisconnected {
+ // 第一个事件后模拟客户端断连
+ clientDisconnected = true
+ continue
+ }
+ if evt.err != nil && clientDisconnected {
+ // 新代码路径:排水期间收到读取错误
+ drainError = evt.err
+ break
+ }
+ }
+
+ assert.Error(t, drainError, "排水期间应收到读取错误")
+ assert.ErrorIs(t, drainError, io.ErrUnexpectedEOF)
+}
+
+func TestPump_ReadErrorDuringDrain_EOF(t *testing.T) {
+ t.Parallel()
+ // EOF 在排水期间等同于上游关闭
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{err: io.EOF},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ clientDisconnected := false
+ drainErrorCount := 0
+
+ for evt := range ch {
+ if evt.err != nil {
+ if clientDisconnected {
+ drainErrorCount++
+ }
+ break
+ }
+ // 第一个事件后模拟客户端断连
+ if !clientDisconnected {
+ clientDisconnected = true
+ }
+ }
+
+ assert.Equal(t, 1, drainErrorCount, "排水期间 EOF 应被计为一次排水错误")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水截止时间检查——在事件间隙中过期
+// ---------------------------------------------------------------------------
+
+func TestPump_DrainDeadlineExpiresBetweenEvents(t *testing.T) {
+ t.Parallel()
+ // 排水截止时间在两个上游事件之间到期
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 60 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 60 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ drainExpired := false
+ eventsProcessed := 0
+
+ for evt := range ch {
+ // 排水超时检查(在处理事件前,模拟正式代码)
+ if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ pumpCancel()
+ drainExpired = true
+ break
+ }
+ if evt.err != nil {
+ break
+ }
+ eventsProcessed++
+
+ // 第一个事件后断连,排水截止时间设为 30ms
+ if !clientDisconnected {
+ clientDisconnected = true
+ drainDeadline = time.Now().Add(30 * time.Millisecond)
+ }
+ }
+
+ // 第二个事件延迟 60ms > 排水截止 30ms,应触发排水超时
+ assert.True(t, drainExpired, "排水截止时间应在事件间隙中过期")
+ assert.Equal(t, 1, eventsProcessed, "过期前应只处理了 1 个事件")
+}
+
+func TestPump_DrainDeadlineNotYetExpiredAllowsProcessing(t *testing.T) {
+ t.Parallel()
+ // 排水截止时间足够长,允许处理所有事件
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 5 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ drainExpired := false
+ eventsProcessed := 0
+
+ for evt := range ch {
+ if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ pumpCancel()
+ drainExpired = true
+ break
+ }
+ if evt.err != nil {
+ break
+ }
+ eventsProcessed++
+ if !clientDisconnected {
+ clientDisconnected = true
+ drainDeadline = time.Now().Add(500 * time.Millisecond) // 足够长
+ }
+ }
+
+ assert.False(t, drainExpired, "排水截止时间未过期,不应触发排水超时")
+ assert.Equal(t, 3, eventsProcessed, "所有事件都应被处理")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:goroutine 清理和资源释放
+// ---------------------------------------------------------------------------
+
+func TestPump_DeferPumpCancelAndDrainTimerCleanup(t *testing.T) {
+ t.Parallel()
+ // 模拟正式代码的完整 defer 清理路径
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize)
+ pumpCtx, pumpCancel := context.WithCancel(context.Background())
+ // 模拟 defer pumpCancel()
+ defer pumpCancel()
+
+ go func() {
+ defer close(pumpEventCh)
+ for {
+ msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, 5*time.Second)
+ select {
+ case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}:
+ case <-pumpCtx.Done():
+ return
+ }
+ if readErr != nil {
+ return
+ }
+ evtType, _ := parseOpenAIWSEventType(msg)
+ if isOpenAIWSTerminalEvent(evtType) || evtType == "error" {
+ return
+ }
+ }
+ }()
+
+ // 模拟排水定时器
+ var drainTimer *time.Timer
+ defer func() {
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+ }()
+ drainTimer = time.AfterFunc(10*time.Second, pumpCancel)
+
+ events := collectAll(pumpEventCh)
+ require.Len(t, events, 2)
+
+ // defer 清理后不应 panic
+ assert.NotPanics(t, func() {
+ pumpCancel()
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+ })
+}
+
+// ---------------------------------------------------------------------------
+// 测试:连接在泵运行期间被关闭
+// ---------------------------------------------------------------------------
+
+func TestPump_ConnectionClosedDuringPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 后续事件阻塞
+ )
+ lease := &pumpTestLease{conn: conn}
+ // 使用较短的读超时,因为 conn.Close() 不会解除阻塞的 ReadMessage(它等待 <-ctx.Done())
+ ch, pumpCancel := startPump(context.Background(), lease, 100*time.Millisecond)
+ defer pumpCancel()
+
+ // 读取第一个事件
+ evt := <-ch
+ assert.NoError(t, evt.err)
+
+ // 关闭连接——注意:ReadMessage 仍在等待 ctx.Done(),
+ // 但读超时为 100ms 会触发 context.DeadlineExceeded。
+ // 下次 ReadMessage 调用时会检测到 closed 状态。
+ _ = conn.Close()
+
+ // 泵应在读超时后检测到连接关闭
+ events := collectAll(ch)
+ require.GreaterOrEqual(t, len(events), 1, "应收到错误")
+ // 至少有一个事件包含错误
+ hasError := false
+ for _, e := range events {
+ if e.err != nil {
+ hasError = true
+ }
+ }
+ assert.True(t, hasError, "应收到连接关闭或超时错误")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:大消息(KB 级别 JSON)不影响泵传递
+// ---------------------------------------------------------------------------
+
+func TestPump_LargeMessages(t *testing.T) {
+ t.Parallel()
+ // 构造 ~10KB 的消息
+ largeContent := make([]byte, 10*1024)
+ for i := range largeContent {
+ largeContent[i] = 'x'
+ }
+ largeMsg := []byte(`{"type":"response.output_text.delta","delta":"` + string(largeContent) + `"}`)
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: largeMsg},
+ pumpTestConnEvent{data: largeMsg},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 4)
+ // 验证大消息完整传递
+ assert.Len(t, events[1].message, len(largeMsg))
+ assert.Len(t, events[2].message, len(largeMsg))
+}
+
+// ---------------------------------------------------------------------------
+// 测试:多轮泵会话(同一 lease 上依次创建多个泵)
+// ---------------------------------------------------------------------------
+
+func TestPump_SequentialSessions(t *testing.T) {
+ t.Parallel()
+ // 第一轮
+ conn1 := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease1 := &pumpTestLease{conn: conn1}
+ ch1, cancel1 := startPump(context.Background(), lease1, 5*time.Second)
+ events1 := collectAll(ch1)
+ cancel1()
+ require.Len(t, events1, 2)
+
+ // 第二轮(新连接、新泵)
+ conn2 := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease2 := &pumpTestLease{conn: conn2}
+ ch2, cancel2 := startPump(context.Background(), lease2, 5*time.Second)
+ events2 := collectAll(ch2)
+ cancel2()
+ require.Len(t, events2, 3)
+
+ // 两轮之间互不影响
+ assert.False(t, lease1.IsBroken())
+ assert.False(t, lease2.IsBroken())
+}
+
+// ---------------------------------------------------------------------------
+// 测试:完整 relay 模式集成——包含客户端断连和排水
+// ---------------------------------------------------------------------------
+
+func TestPump_IntegrationRelayWithClientDisconnectAndDrain(t *testing.T) {
+ t.Parallel()
+ // 模拟完整场景:上游慢速响应,客户端在中途断连,排水定时器到期终止
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_drain1")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond},
+ // 上游后续事件延迟大于排水超时
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 200 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 200 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ var drainTimer *time.Timer
+ defer func() {
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+ }()
+
+ eventsProcessed := 0
+ drainTriggered := false
+
+ for evt := range ch {
+ // 排水超时检查
+ if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ pumpCancel()
+ drainTriggered = true
+ lease.MarkBroken()
+ break
+ }
+ if evt.err != nil {
+ if clientDisconnected {
+ // 排水期间读取错误(pumpCancel 导致 context.Canceled)
+ drainTriggered = true
+ lease.MarkBroken()
+ }
+ break
+ }
+ eventsProcessed++
+
+ // 模拟:第 2 个事件后客户端断连
+ if eventsProcessed == 2 && !clientDisconnected {
+ clientDisconnected = true
+ drainDeadline = time.Now().Add(50 * time.Millisecond) // 50ms 排水超时
+ drainTimer = time.AfterFunc(50*time.Millisecond, pumpCancel)
+ }
+ }
+ // for-range 退出后,如果 channel 因 pumpCancel 关闭且排水截止已过期,
+ // 也视为排水超时触发(泵的 select 可能选择 pumpCtx.Done() 而不发送错误事件)。
+ if !drainTriggered && clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ drainTriggered = true
+ lease.MarkBroken()
+ }
+
+ pumpCancel() // 最终清理
+
+ // 排水超时应触发(50ms 排水 vs 200ms 后续事件延迟)
+ assert.True(t, drainTriggered, "排水超时应被触发")
+ assert.GreaterOrEqual(t, eventsProcessed, 2, "至少应处理 2 个事件")
+ assert.LessOrEqual(t, eventsProcessed, 4, "不应处理所有 5 个事件")
+}
+
+func TestPump_IntegrationRelayWithSuccessfulDrain(t *testing.T) {
+ t.Parallel()
+ // 客户端断连后上游快速完成,在排水超时前正常结束
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_drain2")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 5 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ var drainTimer *time.Timer
+ defer func() {
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+ }()
+
+ eventsProcessed := 0
+ receivedTerminal := false
+ drainTriggered := false
+
+ for evt := range ch {
+ if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ pumpCancel()
+ drainTriggered = true
+ break
+ }
+ if evt.err != nil {
+ break
+ }
+ eventsProcessed++
+
+ evtType, _ := parseOpenAIWSEventType(evt.message)
+ if isOpenAIWSTerminalEvent(evtType) {
+ receivedTerminal = true
+ break
+ }
+
+ // 第一个事件后客户端断连
+ if eventsProcessed == 1 && !clientDisconnected {
+ clientDisconnected = true
+ drainDeadline = time.Now().Add(500 * time.Millisecond) // 足够长的排水超时
+ drainTimer = time.AfterFunc(500*time.Millisecond, pumpCancel)
+ }
+ }
+
+ pumpCancel()
+
+ assert.False(t, drainTriggered, "排水超时不应触发")
+ assert.True(t, receivedTerminal, "应正常收到终端事件")
+ assert.Equal(t, 4, eventsProcessed, "所有 4 个事件都应被处理")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵事件错误携带部分消息数据
+// ---------------------------------------------------------------------------
+
+func TestPump_ErrorEventWithPartialMessage(t *testing.T) {
+ t.Parallel()
+ // 模拟上游返回部分数据和错误
+ partialData := []byte(`{"type":"response.output_text`)
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: partialData, err: io.ErrUnexpectedEOF},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2)
+ // 第二个事件同时携带 message 和 error
+ assert.Equal(t, partialData, events[1].message)
+ assert.ErrorIs(t, events[1].err, io.ErrUnexpectedEOF)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:零超时读取
+// ---------------------------------------------------------------------------
+
+func TestPump_ZeroReadTimeout(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 第二次读取需要时间,但超时为 0
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ // 使用极短超时(1 纳秒 ~ 立即超时)
+ ch, cancel := startPump(context.Background(), lease, time.Nanosecond)
+ defer cancel()
+
+ events := collectAll(ch)
+ // 至少第一个事件成功读取(无延迟),第二个大概率超时
+ require.GreaterOrEqual(t, len(events), 1)
+ // 查找是否有超时错误
+ hasTimeout := false
+ for _, evt := range events {
+ if evt.err != nil && errors.Is(evt.err, context.DeadlineExceeded) {
+ hasTimeout = true
+ }
+ }
+ assert.True(t, hasTimeout, "极短超时应产生 DeadlineExceeded 错误")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:并发多个排水定时器取消(防止重复调用 pumpCancel)
+// ---------------------------------------------------------------------------
+
+func TestPump_ConcurrentDrainTimerAndExternalCancel(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 阻塞
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second)
+
+ // 读取第一个事件
+ evt := <-ch
+ assert.NoError(t, evt.err)
+
+ // 同时设置排水定时器和外部取消
+ done := make(chan struct{})
+ drainTimer := time.AfterFunc(30*time.Millisecond, pumpCancel)
+ defer drainTimer.Stop()
+
+ go func() {
+ time.Sleep(20 * time.Millisecond)
+ pumpCancel() // 外部取消稍早于定时器
+ close(done)
+ }()
+
+ // 不应死锁或 panic
+ events := collectAll(ch)
+ <-done
+
+ // 验证泵已终止
+ for _, e := range events {
+ if e.err != nil {
+ assert.Error(t, e.err)
+ }
+ }
+}
+
+func TestPump_DrainTimerMarkBrokenUnblocksIgnoreContextRead(t *testing.T) {
+ t.Parallel()
+
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ )
+ conn.ignoreCtx = true
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second)
+ defer pumpCancel()
+
+ first := <-ch
+ require.NoError(t, first.err)
+ require.NotEmpty(t, first.message)
+
+ done := make(chan struct{})
+ drainTimer := time.AfterFunc(30*time.Millisecond, func() {
+ lease.MarkBroken()
+ pumpCancel()
+ })
+ defer drainTimer.Stop()
+
+ go func() {
+ _ = collectAll(ch)
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("pump should stop quickly when drain timer marks lease broken")
+ }
+
+ assert.True(t, lease.IsBroken(), "lease should be marked broken by drain timer")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:快速连续事件(突发模式)
+// ---------------------------------------------------------------------------
+
+func TestPump_BurstEvents(t *testing.T) {
+ t.Parallel()
+ // 50 个事件无延迟突发
+ numBurst := 50
+ connEvents := make([]pumpTestConnEvent, numBurst+1)
+ for i := 0; i < numBurst; i++ {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}
+ }
+ connEvents[numBurst] = pumpTestConnEvent{data: pumpTestEvent("response.completed")}
+
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, numBurst+1, "突发事件应全部被传递")
+
+ // 验证所有事件无错误
+ for i, evt := range events {
+ assert.NoError(t, evt.err, "事件 %d 不应有错误", i)
+ }
+ lastType, _ := parseOpenAIWSEventType(events[numBurst].message)
+ assert.True(t, isOpenAIWSTerminalEvent(lastType))
+}
+
+// ---------------------------------------------------------------------------
+// 测试:事件类型解析边界情况
+// ---------------------------------------------------------------------------
+
+func TestPump_EventTypeParsingEdgeCases(t *testing.T) {
+ t.Parallel()
+ // 各种边缘 JSON 格式
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: []byte(`{"type": " response.created "}`)}, // 带空格
+ pumpTestConnEvent{data: []byte(`{"type":"response.output_text.delta"}`)}, // 无空格
+ pumpTestConnEvent{data: []byte(`{"type":"","other":"field"}`)}, // 空类型
+ pumpTestConnEvent{data: []byte(`{"no_type_field": true}`)}, // 无 type 字段
+ pumpTestConnEvent{data: []byte(`{"type":"response.completed"}`)}, // 终端
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 5, "所有格式的事件都应被传递")
+ for _, evt := range events {
+ assert.NoError(t, evt.err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:function_call_output 等非标准事件类型不终止泵
+// ---------------------------------------------------------------------------
+
+func TestPump_FunctionCallOutputDoesNotStopPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.function_call_arguments.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.function_call_arguments.done")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_item.done")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 5, "function_call 相关事件不应终止泵")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:pumpCancel 在 channel 已关闭后调用不 panic
+// ---------------------------------------------------------------------------
+
+func TestPump_CancelAfterChannelClosed(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ // 等待 channel 关闭
+ events := collectAll(ch)
+ require.Len(t, events, 1)
+
+ // channel 已关闭后再取消不应 panic
+ assert.NotPanics(t, func() {
+ pumpCancel()
+ })
+
+ // 再次从已关闭 channel 读取应返回零值
+ evt, ok := <-ch
+ assert.False(t, ok, "channel 应已关闭")
+ assert.Nil(t, evt.message)
+ assert.NoError(t, evt.err)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:混合事件大小(小消息和大消息交替)
+// ---------------------------------------------------------------------------
+
+func TestPump_MixedMessageSizes(t *testing.T) {
+ t.Parallel()
+ smallMsg := pumpTestEvent("response.output_text.delta")
+ largeContent := make([]byte, 64*1024) // 64KB
+ for i := range largeContent {
+ largeContent[i] = 'A'
+ }
+ largeMsg := []byte(`{"type":"response.output_text.delta","delta":"` + string(largeContent) + `"}`)
+
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: smallMsg},
+ pumpTestConnEvent{data: largeMsg},
+ pumpTestConnEvent{data: smallMsg},
+ pumpTestConnEvent{data: largeMsg},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 6)
+ assert.Len(t, events[2].message, len(largeMsg), "大消息应完整传递")
+ assert.Len(t, events[4].message, len(largeMsg), "大消息应完整传递")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵在 errOpenAIWSConnClosed 错误后停止
+// ---------------------------------------------------------------------------
+
+func TestPump_ConnClosedErrorStopsPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{err: errOpenAIWSConnClosed},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2)
+ assert.ErrorIs(t, events[1].err, errOpenAIWSConnClosed)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:同时读取和写入——验证读写解耦
+// ---------------------------------------------------------------------------
+
+func TestPump_ReadWriteDecoupling(t *testing.T) {
+ t.Parallel()
+ // 模拟:上游事件到达时,客户端写入有延迟(通过 channel 消费延迟模拟)
+ numEvents := 10
+ connEvents := make([]pumpTestConnEvent, numEvents)
+ for i := 0; i < numEvents-1; i++ {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}
+ }
+ connEvents[numEvents-1] = pumpTestConnEvent{data: pumpTestEvent("response.completed")}
+
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ // 模拟慢写入:每个事件处理需要 5ms
+ start := time.Now()
+ var events []openAIWSUpstreamPumpEvent
+ for evt := range ch {
+ events = append(events, evt)
+ time.Sleep(5 * time.Millisecond) // 模拟写入延迟
+ }
+ elapsed := time.Since(start)
+
+ require.Len(t, events, numEvents)
+ // 如果没有并发(串行读写),总时间 >= numEvents * 5ms = 50ms
+ // 有缓冲并发时,上游读取可以提前完成,总时间 < 串行预估
+ // 此处验证所有事件都被传递即可
+ t.Logf("处理 %d 个事件耗时: %v (慢消费模式)", numEvents, elapsed)
+}
diff --git a/backend/internal/service/openai_ws_v2/caddy_adapter.go b/backend/internal/service/openai_ws_v2/caddy_adapter.go
new file mode 100644
index 000000000..1fecc231d
--- /dev/null
+++ b/backend/internal/service/openai_ws_v2/caddy_adapter.go
@@ -0,0 +1,24 @@
+package openai_ws_v2
+
+import (
+ "context"
+)
+
+// runCaddyStyleRelay 采用 Caddy reverseproxy 的双向隧道思想:
+// 连接建立后并发复制两个方向,任一方向退出触发收敛关闭。
+//
+// Reference:
+// - Project: caddyserver/caddy (Apache-2.0)
+// - Commit: f283062d37c50627d53ca682ebae2ce219b35515
+// - Files:
+// - modules/caddyhttp/reverseproxy/streaming.go
+// - modules/caddyhttp/reverseproxy/reverseproxy.go
+func runCaddyStyleRelay(
+ ctx context.Context,
+ clientConn FrameConn,
+ upstreamConn FrameConn,
+ firstClientMessage []byte,
+ options RelayOptions,
+) (RelayResult, *RelayExit) {
+ return Relay(ctx, clientConn, upstreamConn, firstClientMessage, options)
+}
diff --git a/backend/internal/service/openai_ws_v2/entry.go b/backend/internal/service/openai_ws_v2/entry.go
new file mode 100644
index 000000000..176298fe9
--- /dev/null
+++ b/backend/internal/service/openai_ws_v2/entry.go
@@ -0,0 +1,23 @@
+package openai_ws_v2
+
+import "context"
+
+// EntryInput 是 passthrough v2 数据面的入口参数。
+type EntryInput struct {
+ Ctx context.Context
+ ClientConn FrameConn
+ UpstreamConn FrameConn
+ FirstClientMessage []byte
+ Options RelayOptions
+}
+
+// RunEntry 是 openai_ws_v2 包对外的统一入口。
+func RunEntry(input EntryInput) (RelayResult, *RelayExit) {
+ return runCaddyStyleRelay(
+ input.Ctx,
+ input.ClientConn,
+ input.UpstreamConn,
+ input.FirstClientMessage,
+ input.Options,
+ )
+}
diff --git a/backend/internal/service/openai_ws_v2/metrics.go b/backend/internal/service/openai_ws_v2/metrics.go
new file mode 100644
index 000000000..3708befdb
--- /dev/null
+++ b/backend/internal/service/openai_ws_v2/metrics.go
@@ -0,0 +1,29 @@
+package openai_ws_v2
+
+import (
+ "sync/atomic"
+)
+
+// MetricsSnapshot 是 OpenAI WS v2 passthrough 路径的轻量运行时指标快照。
+type MetricsSnapshot struct {
+ SemanticMutationTotal int64 `json:"semantic_mutation_total"`
+ UsageParseFailureTotal int64 `json:"usage_parse_failure_total"`
+}
+
+var (
+ // passthrough 路径默认不会做语义改写,该计数通常应保持为 0(保留用于未来防御性校验)。
+ passthroughSemanticMutationTotal atomic.Int64
+ passthroughUsageParseFailureTotal atomic.Int64
+)
+
+func recordUsageParseFailure() {
+ passthroughUsageParseFailureTotal.Add(1)
+}
+
+// SnapshotMetrics 返回当前 passthrough 指标快照。
+func SnapshotMetrics() MetricsSnapshot {
+ return MetricsSnapshot{
+ SemanticMutationTotal: passthroughSemanticMutationTotal.Load(),
+ UsageParseFailureTotal: passthroughUsageParseFailureTotal.Load(),
+ }
+}
diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay.go b/backend/internal/service/openai_ws_v2/passthrough_relay.go
new file mode 100644
index 000000000..d147a0e18
--- /dev/null
+++ b/backend/internal/service/openai_ws_v2/passthrough_relay.go
@@ -0,0 +1,869 @@
+package openai_ws_v2
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ coderws "github.com/coder/websocket"
+ "github.com/tidwall/gjson"
+)
+
+type FrameConn interface {
+ ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error)
+ WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error
+ Close() error
+}
+
+type Usage struct {
+ InputTokens int
+ OutputTokens int
+ CacheCreationInputTokens int
+ CacheReadInputTokens int
+}
+
+type RelayResult struct {
+ RequestModel string
+ Usage Usage
+ RequestID string
+ TerminalEventType string
+ FirstTokenMs *int
+ Duration time.Duration
+ ClientToUpstreamFrames int64
+ UpstreamToClientFrames int64
+ DroppedDownstreamFrames int64
+}
+
+type RelayTurnResult struct {
+ RequestModel string
+ Usage Usage
+ RequestID string
+ TerminalEventType string
+ Duration time.Duration
+ FirstTokenMs *int
+}
+
+type RelayExit struct {
+ Stage string
+ Err error
+ WroteDownstream bool
+}
+
+type RelayOptions struct {
+ WriteTimeout time.Duration
+ IdleTimeout time.Duration
+ UpstreamDrainTimeout time.Duration
+ FirstMessageType coderws.MessageType
+ InitialRequestModel string
+ DisableWriteTimeout bool
+ OnUsageParseFailure func(eventType string, usageRaw string)
+ OnTurnComplete func(turn RelayTurnResult)
+ OnTrace func(event RelayTraceEvent)
+ Now func() time.Time
+}
+
+type RelayTraceEvent struct {
+ Stage string
+ Direction string
+ MessageType string
+ PayloadBytes int
+ Graceful bool
+ WroteDownstream bool
+ Error string
+}
+
+type relayState struct {
+ usage Usage
+ requestModel string
+ lastResponseID string
+ terminalEventType string
+ firstTokenMs *int
+ currentTurnStart time.Time
+ currentTurnToken *int
+ turnTimingByID map[string]*relayTurnTiming
+}
+
+type relayExitSignal struct {
+ stage string
+ err error
+ graceful bool
+ wroteDownstream bool
+}
+
+type observedUpstreamEvent struct {
+ terminal bool
+ eventType string
+ responseID string
+ usage Usage
+ duration time.Duration
+ firstToken *int
+}
+
+type relayTurnTiming struct {
+ startAt time.Time
+ firstTokenMs *int
+}
+
+// ErrRelayIdleTimeout indicates relay inactivity timeout in passthrough mode.
+var ErrRelayIdleTimeout = errors.New("openai ws v2 passthrough relay idle timeout")
+
+func Relay(
+ ctx context.Context,
+ clientConn FrameConn,
+ upstreamConn FrameConn,
+ firstClientMessage []byte,
+ options RelayOptions,
+) (RelayResult, *RelayExit) {
+ initialRequestModel := strings.TrimSpace(options.InitialRequestModel)
+ if initialRequestModel == "" {
+ initialRequestModel = strings.TrimSpace(gjson.GetBytes(firstClientMessage, "model").String())
+ }
+ result := RelayResult{RequestModel: initialRequestModel}
+ if clientConn == nil || upstreamConn == nil {
+ return result, &RelayExit{Stage: "relay_init", Err: errors.New("relay connection is nil")}
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+
+ nowFn := options.Now
+ if nowFn == nil {
+ nowFn = time.Now
+ }
+ writeTimeout := options.WriteTimeout
+ if !options.DisableWriteTimeout && writeTimeout <= 0 {
+ writeTimeout = 2 * time.Minute
+ }
+ useWriteTimeout := !options.DisableWriteTimeout && writeTimeout > 0
+ drainTimeout := options.UpstreamDrainTimeout
+ if drainTimeout <= 0 {
+ drainTimeout = 1200 * time.Millisecond
+ }
+ firstMessageType := options.FirstMessageType
+ if firstMessageType != coderws.MessageBinary {
+ firstMessageType = coderws.MessageText
+ }
+ startAt := nowFn()
+ state := &relayState{requestModel: result.RequestModel}
+ onTrace := options.OnTrace
+
+ relayCtx, relayCancel := context.WithCancel(ctx)
+ defer relayCancel()
+
+ lastActivity := atomic.Int64{}
+ lastActivity.Store(nowFn().UnixNano())
+ markActivity := func() {
+ lastActivity.Store(nowFn().UnixNano())
+ }
+
+ writeUpstream := func(msgType coderws.MessageType, payload []byte) error {
+ if !useWriteTimeout {
+ return upstreamConn.WriteFrame(relayCtx, msgType, payload)
+ }
+ writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
+ defer cancel()
+ return upstreamConn.WriteFrame(writeCtx, msgType, payload)
+ }
+ writeClient := func(msgType coderws.MessageType, payload []byte) error {
+ if !useWriteTimeout {
+ return clientConn.WriteFrame(relayCtx, msgType, payload)
+ }
+ writeCtx, cancel := context.WithTimeout(relayCtx, writeTimeout)
+ defer cancel()
+ return clientConn.WriteFrame(writeCtx, msgType, payload)
+ }
+
+ clientToUpstreamFrames := &atomic.Int64{}
+ upstreamToClientFrames := &atomic.Int64{}
+ droppedDownstreamFrames := &atomic.Int64{}
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "relay_start",
+ PayloadBytes: len(firstClientMessage),
+ MessageType: relayMessageTypeString(firstMessageType),
+ })
+
+ if err := writeUpstream(firstMessageType, firstClientMessage); err != nil {
+ result.Duration = nowFn().Sub(startAt)
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "write_first_message_failed",
+ Direction: "client_to_upstream",
+ MessageType: relayMessageTypeString(firstMessageType),
+ PayloadBytes: len(firstClientMessage),
+ Error: err.Error(),
+ })
+ return result, &RelayExit{Stage: "write_upstream", Err: err}
+ }
+ clientToUpstreamFrames.Add(1)
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "write_first_message_ok",
+ Direction: "client_to_upstream",
+ MessageType: relayMessageTypeString(firstMessageType),
+ PayloadBytes: len(firstClientMessage),
+ })
+ markActivity()
+
+ exitCh := make(chan relayExitSignal, 3)
+ dropDownstreamWrites := atomic.Bool{}
+ go runClientToUpstream(relayCtx, clientConn, writeUpstream, markActivity, clientToUpstreamFrames, onTrace, exitCh)
+ go runUpstreamToClient(
+ relayCtx,
+ upstreamConn,
+ writeClient,
+ startAt,
+ nowFn,
+ state,
+ options.OnUsageParseFailure,
+ options.OnTurnComplete,
+ &dropDownstreamWrites,
+ upstreamToClientFrames,
+ droppedDownstreamFrames,
+ markActivity,
+ onTrace,
+ exitCh,
+ )
+ go runIdleWatchdog(relayCtx, nowFn, options.IdleTimeout, &lastActivity, onTrace, exitCh)
+
+ firstExit := <-exitCh
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "first_exit",
+ Direction: relayDirectionFromStage(firstExit.stage),
+ Graceful: firstExit.graceful,
+ WroteDownstream: firstExit.wroteDownstream,
+ Error: relayErrorString(firstExit.err),
+ })
+ combinedWroteDownstream := firstExit.wroteDownstream
+ secondExit := relayExitSignal{graceful: true}
+ hasSecondExit := false
+
+ // 客户端断开后尽力继续读取上游短窗口,捕获延迟 usage/terminal 事件用于计费。
+ upstreamClosed := false
+ closeUpstream := func() {
+ if upstreamClosed {
+ return
+ }
+ upstreamClosed = true
+ _ = upstreamConn.Close()
+ }
+ if firstExit.stage == "read_client" && firstExit.graceful {
+ dropDownstreamWrites.Store(true)
+ secondExit, hasSecondExit = waitRelayExit(exitCh, drainTimeout)
+ } else {
+ relayCancel()
+ closeUpstream()
+ secondExit, hasSecondExit = waitRelayExit(exitCh, 200*time.Millisecond)
+ }
+ if hasSecondExit {
+ combinedWroteDownstream = combinedWroteDownstream || secondExit.wroteDownstream
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "second_exit",
+ Direction: relayDirectionFromStage(secondExit.stage),
+ Graceful: secondExit.graceful,
+ WroteDownstream: secondExit.wroteDownstream,
+ Error: relayErrorString(secondExit.err),
+ })
+ }
+
+ relayCancel()
+ closeUpstream()
+
+ enrichResult(&result, state, nowFn().Sub(startAt))
+ result.ClientToUpstreamFrames = clientToUpstreamFrames.Load()
+ result.UpstreamToClientFrames = upstreamToClientFrames.Load()
+ result.DroppedDownstreamFrames = droppedDownstreamFrames.Load()
+ if firstExit.stage == "read_client" && firstExit.graceful {
+ stage := "client_disconnected"
+ exitErr := firstExit.err
+ if hasSecondExit && !secondExit.graceful {
+ stage = secondExit.stage
+ exitErr = secondExit.err
+ }
+ if exitErr == nil {
+ exitErr = io.EOF
+ }
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "relay_exit",
+ Direction: relayDirectionFromStage(stage),
+ Graceful: false,
+ WroteDownstream: combinedWroteDownstream,
+ Error: relayErrorString(exitErr),
+ })
+ return result, &RelayExit{
+ Stage: stage,
+ Err: exitErr,
+ WroteDownstream: combinedWroteDownstream,
+ }
+ }
+ if firstExit.graceful && (!hasSecondExit || secondExit.graceful) {
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "relay_complete",
+ Graceful: true,
+ WroteDownstream: combinedWroteDownstream,
+ })
+ _ = clientConn.Close()
+ return result, nil
+ }
+ if !firstExit.graceful {
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "relay_exit",
+ Direction: relayDirectionFromStage(firstExit.stage),
+ Graceful: false,
+ WroteDownstream: combinedWroteDownstream,
+ Error: relayErrorString(firstExit.err),
+ })
+ return result, &RelayExit{
+ Stage: firstExit.stage,
+ Err: firstExit.err,
+ WroteDownstream: combinedWroteDownstream,
+ }
+ }
+ if hasSecondExit && !secondExit.graceful {
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "relay_exit",
+ Direction: relayDirectionFromStage(secondExit.stage),
+ Graceful: false,
+ WroteDownstream: combinedWroteDownstream,
+ Error: relayErrorString(secondExit.err),
+ })
+ return result, &RelayExit{
+ Stage: secondExit.stage,
+ Err: secondExit.err,
+ WroteDownstream: combinedWroteDownstream,
+ }
+ }
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "relay_complete",
+ Graceful: true,
+ WroteDownstream: combinedWroteDownstream,
+ })
+ _ = clientConn.Close()
+ return result, nil
+}
+
+func runClientToUpstream(
+ ctx context.Context,
+ clientConn FrameConn,
+ writeUpstream func(msgType coderws.MessageType, payload []byte) error,
+ markActivity func(),
+ forwardedFrames *atomic.Int64,
+ onTrace func(event RelayTraceEvent),
+ exitCh chan<- relayExitSignal,
+) {
+ for {
+ msgType, payload, err := clientConn.ReadFrame(ctx)
+ if err != nil {
+ graceful := isDisconnectError(err)
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "read_client_failed",
+ Direction: "client_to_upstream",
+ Error: err.Error(),
+ Graceful: graceful,
+ })
+ exitCh <- relayExitSignal{stage: "read_client", err: err, graceful: graceful}
+ return
+ }
+ if err := writeUpstream(msgType, payload); err != nil {
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "write_upstream_failed",
+ Direction: "client_to_upstream",
+ MessageType: relayMessageTypeString(msgType),
+ PayloadBytes: len(payload),
+ Error: err.Error(),
+ })
+ exitCh <- relayExitSignal{stage: "write_upstream", err: err}
+ return
+ }
+ if forwardedFrames != nil {
+ forwardedFrames.Add(1)
+ }
+ markActivity()
+ }
+}
+
+func runUpstreamToClient(
+ ctx context.Context,
+ upstreamConn FrameConn,
+ writeClient func(msgType coderws.MessageType, payload []byte) error,
+ startAt time.Time,
+ nowFn func() time.Time,
+ state *relayState,
+ onUsageParseFailure func(eventType string, usageRaw string),
+ onTurnComplete func(turn RelayTurnResult),
+ dropDownstreamWrites *atomic.Bool,
+ forwardedFrames *atomic.Int64,
+ droppedFrames *atomic.Int64,
+ markActivity func(),
+ onTrace func(event RelayTraceEvent),
+ exitCh chan<- relayExitSignal,
+) {
+ wroteDownstream := false
+ droppedSinceDisconnect := int64(0)
+ for {
+ msgType, payload, err := upstreamConn.ReadFrame(ctx)
+ if err != nil {
+ graceful := isDisconnectError(err)
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "read_upstream_failed",
+ Direction: "upstream_to_client",
+ Error: err.Error(),
+ Graceful: graceful,
+ WroteDownstream: wroteDownstream,
+ })
+ exitCh <- relayExitSignal{
+ stage: "read_upstream",
+ err: err,
+ graceful: graceful,
+ wroteDownstream: wroteDownstream,
+ }
+ return
+ }
+ observedEvent := observedUpstreamEvent{}
+ switch msgType {
+ case coderws.MessageText:
+ observedEvent = observeUpstreamMessage(state, payload, nowFn, onUsageParseFailure)
+ case coderws.MessageBinary:
+ // binary frame 直接透传,不进入 JSON 观测路径(避免无效解析开销)。
+ }
+ emitTurnComplete(onTurnComplete, state, observedEvent)
+ if dropDownstreamWrites != nil && dropDownstreamWrites.Load() {
+ droppedSinceDisconnect++
+ if droppedFrames != nil {
+ droppedFrames.Add(1)
+ }
+ if shouldTraceDroppedDownstreamFrame(droppedSinceDisconnect, observedEvent.terminal) {
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "drop_downstream_frame",
+ Direction: "upstream_to_client",
+ MessageType: relayMessageTypeString(msgType),
+ PayloadBytes: len(payload),
+ WroteDownstream: wroteDownstream,
+ })
+ }
+ if observedEvent.terminal {
+ exitCh <- relayExitSignal{
+ stage: "drain_terminal",
+ graceful: true,
+ wroteDownstream: wroteDownstream,
+ }
+ return
+ }
+ markActivity()
+ continue
+ }
+ if err := writeClient(msgType, payload); err != nil {
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "write_client_failed",
+ Direction: "upstream_to_client",
+ MessageType: relayMessageTypeString(msgType),
+ PayloadBytes: len(payload),
+ WroteDownstream: wroteDownstream,
+ Error: err.Error(),
+ })
+ exitCh <- relayExitSignal{stage: "write_client", err: err, wroteDownstream: wroteDownstream}
+ return
+ }
+ wroteDownstream = true
+ if forwardedFrames != nil {
+ forwardedFrames.Add(1)
+ }
+ markActivity()
+ }
+}
+
+func runIdleWatchdog(
+ ctx context.Context,
+ nowFn func() time.Time,
+ idleTimeout time.Duration,
+ lastActivity *atomic.Int64,
+ onTrace func(event RelayTraceEvent),
+ exitCh chan<- relayExitSignal,
+) {
+ if idleTimeout <= 0 {
+ return
+ }
+ checkInterval := minDuration(idleTimeout/4, 5*time.Second)
+ if checkInterval < time.Second {
+ checkInterval = time.Second
+ }
+ ticker := time.NewTicker(checkInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ case <-ticker.C:
+ last := time.Unix(0, lastActivity.Load())
+ if nowFn().Sub(last) < idleTimeout {
+ continue
+ }
+ emitRelayTrace(onTrace, RelayTraceEvent{
+ Stage: "idle_timeout_triggered",
+ Direction: "watchdog",
+ Error: ErrRelayIdleTimeout.Error(),
+ })
+ exitCh <- relayExitSignal{stage: "idle_timeout", err: ErrRelayIdleTimeout}
+ return
+ }
+ }
+}
+
+func emitRelayTrace(onTrace func(event RelayTraceEvent), event RelayTraceEvent) {
+ if onTrace == nil {
+ return
+ }
+ onTrace(event)
+}
+
+func relayMessageTypeString(msgType coderws.MessageType) string {
+ switch msgType {
+ case coderws.MessageText:
+ return "text"
+ case coderws.MessageBinary:
+ return "binary"
+ default:
+ return "unknown(" + strconv.Itoa(int(msgType)) + ")"
+ }
+}
+
+func relayDirectionFromStage(stage string) string {
+ switch stage {
+ case "read_client", "write_upstream":
+ return "client_to_upstream"
+ case "read_upstream", "write_client", "drain_terminal":
+ return "upstream_to_client"
+ case "idle_timeout":
+ return "watchdog"
+ default:
+ return ""
+ }
+}
+
+func relayErrorString(err error) string {
+ if err == nil {
+ return ""
+ }
+ return err.Error()
+}
+
+func observeUpstreamMessage(
+ state *relayState,
+ message []byte,
+ nowFn func() time.Time,
+ onUsageParseFailure func(eventType string, usageRaw string),
+) observedUpstreamEvent {
+ if state == nil || len(message) == 0 {
+ return observedUpstreamEvent{}
+ }
+ values := gjson.GetManyBytes(message, "type", "response.id", "response_id", "id")
+ eventType := strings.TrimSpace(values[0].String())
+ if eventType == "" {
+ return observedUpstreamEvent{}
+ }
+ responseID := strings.TrimSpace(values[1].String())
+ if responseID == "" {
+ responseID = strings.TrimSpace(values[2].String())
+ }
+ // 仅 terminal 事件兜底读取顶层 id,避免把 event_id 当成 response_id 关联到 turn。
+ if responseID == "" && isTerminalEvent(eventType) {
+ responseID = strings.TrimSpace(values[3].String())
+ }
+ now := nowFn()
+ if state.currentTurnStart.IsZero() {
+ state.currentTurnStart = now
+ state.currentTurnToken = nil
+ }
+
+ if state.currentTurnToken == nil && isTokenEvent(eventType) {
+ ms := int(now.Sub(state.currentTurnStart).Milliseconds())
+ if ms >= 0 {
+ state.currentTurnToken = &ms
+ }
+ }
+ parsedUsage := parseUsageAndAccumulate(state, message, eventType, onUsageParseFailure)
+ observed := observedUpstreamEvent{
+ eventType: eventType,
+ responseID: responseID,
+ usage: parsedUsage,
+ }
+ if responseID != "" {
+ turnTiming := openAIWSRelayGetOrInitTurnTiming(state, responseID, now)
+ if turnTiming != nil && turnTiming.firstTokenMs == nil && isTokenEvent(eventType) {
+ ms := int(now.Sub(turnTiming.startAt).Milliseconds())
+ if ms >= 0 {
+ turnTiming.firstTokenMs = &ms
+ }
+ }
+ }
+ if !isTerminalEvent(eventType) {
+ return observed
+ }
+ observed.terminal = true
+ state.terminalEventType = eventType
+ if responseID != "" {
+ state.lastResponseID = responseID
+ if turnTiming, ok := openAIWSRelayDeleteTurnTiming(state, responseID); ok {
+ duration := now.Sub(turnTiming.startAt)
+ if duration < 0 {
+ duration = 0
+ }
+ observed.duration = duration
+ observed.firstToken = openAIWSRelayCloneIntPtr(turnTiming.firstTokenMs)
+ }
+ }
+ if observed.firstToken == nil {
+ observed.firstToken = openAIWSRelayCloneIntPtr(state.currentTurnToken)
+ }
+ if observed.duration <= 0 && !state.currentTurnStart.IsZero() {
+ duration := now.Sub(state.currentTurnStart)
+ if duration < 0 {
+ duration = 0
+ }
+ observed.duration = duration
+ }
+ state.firstTokenMs = openAIWSRelayCloneIntPtr(observed.firstToken)
+ state.currentTurnStart = time.Time{}
+ state.currentTurnToken = nil
+ return observed
+}
+
+func emitTurnComplete(
+ onTurnComplete func(turn RelayTurnResult),
+ state *relayState,
+ observed observedUpstreamEvent,
+) {
+ if onTurnComplete == nil || !observed.terminal {
+ return
+ }
+ responseID := strings.TrimSpace(observed.responseID)
+ if responseID == "" {
+ return
+ }
+ requestModel := ""
+ if state != nil {
+ requestModel = state.requestModel
+ }
+ onTurnComplete(RelayTurnResult{
+ RequestModel: requestModel,
+ Usage: observed.usage,
+ RequestID: responseID,
+ TerminalEventType: observed.eventType,
+ Duration: observed.duration,
+ FirstTokenMs: openAIWSRelayCloneIntPtr(observed.firstToken),
+ })
+}
+
+func openAIWSRelayGetOrInitTurnTiming(state *relayState, responseID string, now time.Time) *relayTurnTiming {
+ if state == nil {
+ return nil
+ }
+ if state.turnTimingByID == nil {
+ state.turnTimingByID = make(map[string]*relayTurnTiming, 8)
+ }
+ timing, ok := state.turnTimingByID[responseID]
+ if !ok || timing == nil || timing.startAt.IsZero() {
+ startAt := now
+ if !state.currentTurnStart.IsZero() {
+ startAt = state.currentTurnStart
+ }
+ timing = &relayTurnTiming{startAt: startAt}
+ state.turnTimingByID[responseID] = timing
+ return timing
+ }
+ return timing
+}
+
+func openAIWSRelayDeleteTurnTiming(state *relayState, responseID string) (relayTurnTiming, bool) {
+ if state == nil || state.turnTimingByID == nil {
+ return relayTurnTiming{}, false
+ }
+ timing, ok := state.turnTimingByID[responseID]
+ if !ok || timing == nil {
+ return relayTurnTiming{}, false
+ }
+ delete(state.turnTimingByID, responseID)
+ return *timing, true
+}
+
+func openAIWSRelayCloneIntPtr(v *int) *int {
+ if v == nil {
+ return nil
+ }
+ cloned := *v
+ return &cloned
+}
+
+func parseUsageAndAccumulate(
+ state *relayState,
+ message []byte,
+ eventType string,
+ onParseFailure func(eventType string, usageRaw string),
+) Usage {
+ if state == nil || len(message) == 0 || !shouldParseUsage(eventType) {
+ return Usage{}
+ }
+ usageResult := gjson.GetBytes(message, "response.usage")
+ if !usageResult.Exists() {
+ return Usage{}
+ }
+ usageRaw := strings.TrimSpace(usageResult.Raw)
+ if usageRaw == "" || !strings.HasPrefix(usageRaw, "{") {
+ recordUsageParseFailure()
+ if onParseFailure != nil {
+ onParseFailure(eventType, usageRaw)
+ }
+ return Usage{}
+ }
+
+ inputResult := usageResult.Get("input_tokens")
+ outputResult := usageResult.Get("output_tokens")
+ cachedResult := usageResult.Get("input_tokens_details.cached_tokens")
+
+ inputTokens, inputOK := parseUsageIntField(inputResult, true)
+ outputTokens, outputOK := parseUsageIntField(outputResult, true)
+ cachedTokens, cachedOK := parseUsageIntField(cachedResult, false)
+ if !inputOK || !outputOK || !cachedOK {
+ recordUsageParseFailure()
+ if onParseFailure != nil {
+ onParseFailure(eventType, usageRaw)
+ }
+ // 解析失败时不做部分字段累加,避免计费 usage 出现“半有效”状态。
+ return Usage{}
+ }
+ parsedUsage := Usage{
+ InputTokens: inputTokens,
+ OutputTokens: outputTokens,
+ CacheReadInputTokens: cachedTokens,
+ }
+
+ state.usage.InputTokens += parsedUsage.InputTokens
+ state.usage.OutputTokens += parsedUsage.OutputTokens
+ state.usage.CacheReadInputTokens += parsedUsage.CacheReadInputTokens
+ return parsedUsage
+}
+
+func parseUsageIntField(value gjson.Result, required bool) (int, bool) {
+ if !value.Exists() {
+ return 0, !required
+ }
+ if value.Type != gjson.Number {
+ return 0, false
+ }
+ return int(value.Int()), true
+}
+
+func enrichResult(result *RelayResult, state *relayState, duration time.Duration) {
+ if result == nil {
+ return
+ }
+ result.Duration = duration
+ if state == nil {
+ return
+ }
+ result.RequestModel = state.requestModel
+ result.Usage = state.usage
+ result.RequestID = state.lastResponseID
+ result.TerminalEventType = state.terminalEventType
+ result.FirstTokenMs = state.firstTokenMs
+}
+
+func isDisconnectError(err error) bool {
+ if err == nil {
+ return false
+ }
+ if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
+ return true
+ }
+ switch coderws.CloseStatus(err) {
+ case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
+ return true
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ if message == "" {
+ return false
+ }
+ return strings.Contains(message, "failed to read frame header: eof") ||
+ strings.Contains(message, "unexpected eof") ||
+ strings.Contains(message, "use of closed network connection") ||
+ strings.Contains(message, "connection reset by peer") ||
+ strings.Contains(message, "broken pipe")
+}
+
+func isTerminalEvent(eventType string) bool {
+ switch eventType {
+ case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
+ return true
+ default:
+ return false
+ }
+}
+
+func shouldParseUsage(eventType string) bool {
+ switch eventType {
+ case "response.completed", "response.done", "response.failed":
+ return true
+ default:
+ return false
+ }
+}
+
+func isTokenEvent(eventType string) bool {
+ if eventType == "" {
+ return false
+ }
+ switch eventType {
+ case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
+ return false
+ }
+ if strings.Contains(eventType, ".delta") {
+ return true
+ }
+ if strings.HasPrefix(eventType, "response.output_text") {
+ return true
+ }
+ if strings.HasPrefix(eventType, "response.output") {
+ return true
+ }
+ return false
+}
+
+func minDuration(a, b time.Duration) time.Duration {
+ if a <= 0 {
+ return b
+ }
+ if b <= 0 {
+ return a
+ }
+ if a < b {
+ return a
+ }
+ return b
+}
+
+func shouldTraceDroppedDownstreamFrame(dropCount int64, terminal bool) bool {
+ if terminal {
+ return true
+ }
+ if dropCount <= 3 {
+ return true
+ }
+ return dropCount%128 == 0
+}
+
+func waitRelayExit(exitCh <-chan relayExitSignal, timeout time.Duration) (relayExitSignal, bool) {
+ if timeout <= 0 {
+ timeout = 200 * time.Millisecond
+ }
+ timer := time.NewTimer(timeout)
+ defer timer.Stop()
+ select {
+ case sig := <-exitCh:
+ return sig, true
+ case <-timer.C:
+ return relayExitSignal{}, false
+ }
+}
diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go
new file mode 100644
index 000000000..d0b38442e
--- /dev/null
+++ b/backend/internal/service/openai_ws_v2/passthrough_relay_internal_test.go
@@ -0,0 +1,504 @@
+package openai_ws_v2
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ coderws "github.com/coder/websocket"
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+func TestRunEntry_DelegatesRelay(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_entry","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }, true)
+
+ result, relayExit := RunEntry(EntryInput{
+ Ctx: context.Background(),
+ ClientConn: clientConn,
+ UpstreamConn: upstreamConn,
+ FirstClientMessage: []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`),
+ })
+ require.Nil(t, relayExit)
+ require.Equal(t, "resp_entry", result.RequestID)
+}
+
+func TestRunClientToUpstream_ErrorPaths(t *testing.T) {
+ t.Parallel()
+
+ t.Run("read client eof", func(t *testing.T) {
+ t.Parallel()
+
+ exitCh := make(chan relayExitSignal, 1)
+ runClientToUpstream(
+ context.Background(),
+ newPassthroughTestFrameConn(nil, true),
+ func(_ coderws.MessageType, _ []byte) error { return nil },
+ func() {},
+ nil,
+ nil,
+ exitCh,
+ )
+ sig := <-exitCh
+ require.Equal(t, "read_client", sig.stage)
+ require.True(t, sig.graceful)
+ })
+
+ t.Run("write upstream failed", func(t *testing.T) {
+ t.Parallel()
+
+ exitCh := make(chan relayExitSignal, 1)
+ runClientToUpstream(
+ context.Background(),
+ newPassthroughTestFrameConn([]passthroughTestFrame{
+ {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
+ }, true),
+ func(_ coderws.MessageType, _ []byte) error { return errors.New("boom") },
+ func() {},
+ nil,
+ nil,
+ exitCh,
+ )
+ sig := <-exitCh
+ require.Equal(t, "write_upstream", sig.stage)
+ require.False(t, sig.graceful)
+ })
+
+ t.Run("forwarded counter and trace callback", func(t *testing.T) {
+ t.Parallel()
+
+ exitCh := make(chan relayExitSignal, 1)
+ forwarded := &atomic.Int64{}
+ traces := make([]RelayTraceEvent, 0, 2)
+ runClientToUpstream(
+ context.Background(),
+ newPassthroughTestFrameConn([]passthroughTestFrame{
+ {msgType: coderws.MessageText, payload: []byte(`{"x":1}`)},
+ }, true),
+ func(_ coderws.MessageType, _ []byte) error { return nil },
+ func() {},
+ forwarded,
+ func(event RelayTraceEvent) {
+ traces = append(traces, event)
+ },
+ exitCh,
+ )
+ sig := <-exitCh
+ require.Equal(t, "read_client", sig.stage)
+ require.Equal(t, int64(1), forwarded.Load())
+ require.NotEmpty(t, traces)
+ })
+}
+
+func TestRunUpstreamToClient_ErrorAndDropPaths(t *testing.T) {
+ t.Parallel()
+
+ t.Run("read upstream eof", func(t *testing.T) {
+ t.Parallel()
+
+ exitCh := make(chan relayExitSignal, 1)
+ drop := &atomic.Bool{}
+ drop.Store(false)
+ runUpstreamToClient(
+ context.Background(),
+ newPassthroughTestFrameConn(nil, true),
+ func(_ coderws.MessageType, _ []byte) error { return nil },
+ time.Now(),
+ time.Now,
+ &relayState{},
+ nil,
+ nil,
+ drop,
+ nil,
+ nil,
+ func() {},
+ nil,
+ exitCh,
+ )
+ sig := <-exitCh
+ require.Equal(t, "read_upstream", sig.stage)
+ require.True(t, sig.graceful)
+ })
+
+ t.Run("write client failed", func(t *testing.T) {
+ t.Parallel()
+
+ exitCh := make(chan relayExitSignal, 1)
+ drop := &atomic.Bool{}
+ drop.Store(false)
+ runUpstreamToClient(
+ context.Background(),
+ newPassthroughTestFrameConn([]passthroughTestFrame{
+ {msgType: coderws.MessageText, payload: []byte(`{"type":"response.output_text.delta","delta":"x"}`)},
+ }, true),
+ func(_ coderws.MessageType, _ []byte) error { return errors.New("write failed") },
+ time.Now(),
+ time.Now,
+ &relayState{},
+ nil,
+ nil,
+ drop,
+ nil,
+ nil,
+ func() {},
+ nil,
+ exitCh,
+ )
+ sig := <-exitCh
+ require.Equal(t, "write_client", sig.stage)
+ })
+
+ t.Run("drop downstream and stop on terminal", func(t *testing.T) {
+ t.Parallel()
+
+ exitCh := make(chan relayExitSignal, 1)
+ drop := &atomic.Bool{}
+ drop.Store(true)
+ dropped := &atomic.Int64{}
+ runUpstreamToClient(
+ context.Background(),
+ newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_drop","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }, true),
+ func(_ coderws.MessageType, _ []byte) error { return nil },
+ time.Now(),
+ time.Now,
+ &relayState{},
+ nil,
+ nil,
+ drop,
+ nil,
+ dropped,
+ func() {},
+ nil,
+ exitCh,
+ )
+ sig := <-exitCh
+ require.Equal(t, "drain_terminal", sig.stage)
+ require.True(t, sig.graceful)
+ require.Equal(t, int64(1), dropped.Load())
+ })
+}
+
+func TestRunIdleWatchdog_NoTimeoutWhenDisabled(t *testing.T) {
+ t.Parallel()
+
+ exitCh := make(chan relayExitSignal, 1)
+ lastActivity := &atomic.Int64{}
+ lastActivity.Store(time.Now().UnixNano())
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ go runIdleWatchdog(ctx, time.Now, 0, lastActivity, nil, exitCh)
+ select {
+ case <-exitCh:
+ t.Fatal("unexpected idle timeout signal")
+ case <-time.After(200 * time.Millisecond):
+ }
+}
+
+func TestHelperFunctionsCoverage(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, "text", relayMessageTypeString(coderws.MessageText))
+ require.Equal(t, "binary", relayMessageTypeString(coderws.MessageBinary))
+ require.Contains(t, relayMessageTypeString(coderws.MessageType(99)), "unknown(")
+
+ require.Equal(t, "", relayErrorString(nil))
+ require.Equal(t, "x", relayErrorString(errors.New("x")))
+
+ require.True(t, isDisconnectError(io.EOF))
+ require.True(t, isDisconnectError(net.ErrClosed))
+ require.True(t, isDisconnectError(context.Canceled))
+ require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusGoingAway}))
+ require.True(t, isDisconnectError(errors.New("broken pipe")))
+ require.False(t, isDisconnectError(errors.New("unrelated")))
+
+ require.True(t, isTokenEvent("response.output_text.delta"))
+ require.True(t, isTokenEvent("response.output_audio.delta"))
+ require.False(t, isTokenEvent("response.completed"))
+ require.False(t, isTokenEvent(""))
+ require.False(t, isTokenEvent("response.created"))
+
+ require.Equal(t, 2*time.Second, minDuration(2*time.Second, 5*time.Second))
+ require.Equal(t, 2*time.Second, minDuration(5*time.Second, 2*time.Second))
+ require.Equal(t, 5*time.Second, minDuration(0, 5*time.Second))
+ require.Equal(t, 2*time.Second, minDuration(2*time.Second, 0))
+
+ ch := make(chan relayExitSignal, 1)
+ ch <- relayExitSignal{stage: "ok"}
+ sig, ok := waitRelayExit(ch, 10*time.Millisecond)
+ require.True(t, ok)
+ require.Equal(t, "ok", sig.stage)
+ ch <- relayExitSignal{stage: "ok2"}
+ sig, ok = waitRelayExit(ch, 0)
+ require.True(t, ok)
+ require.Equal(t, "ok2", sig.stage)
+ _, ok = waitRelayExit(ch, 10*time.Millisecond)
+ require.False(t, ok)
+
+ n, ok := parseUsageIntField(gjson.Get(`{"n":3}`, "n"), true)
+ require.True(t, ok)
+ require.Equal(t, 3, n)
+ _, ok = parseUsageIntField(gjson.Get(`{"n":"x"}`, "n"), true)
+ require.False(t, ok)
+ n, ok = parseUsageIntField(gjson.Result{}, false)
+ require.True(t, ok)
+ require.Equal(t, 0, n)
+ _, ok = parseUsageIntField(gjson.Result{}, true)
+ require.False(t, ok)
+}
+
+func TestParseUsageAndEnrichCoverage(t *testing.T) {
+ t.Parallel()
+
+ state := &relayState{}
+ parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":"bad"}}}`), "response.completed", nil)
+ require.Equal(t, 0, state.usage.InputTokens)
+
+ parseUsageAndAccumulate(
+ state,
+ []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":9,"output_tokens":"bad","input_tokens_details":{"cached_tokens":2}}}}`),
+ "response.completed",
+ nil,
+ )
+ require.Equal(t, 0, state.usage.InputTokens, "部分字段解析失败时不应累加 usage")
+ require.Equal(t, 0, state.usage.OutputTokens)
+ require.Equal(t, 0, state.usage.CacheReadInputTokens)
+
+ parseUsageAndAccumulate(
+ state,
+ []byte(`{"type":"response.completed","response":{"usage":{"input_tokens_details":{"cached_tokens":2}}}}`),
+ "response.completed",
+ nil,
+ )
+ require.Equal(t, 0, state.usage.InputTokens, "必填 usage 字段缺失时不应累加 usage")
+ require.Equal(t, 0, state.usage.OutputTokens)
+ require.Equal(t, 0, state.usage.CacheReadInputTokens)
+
+ parseUsageAndAccumulate(state, []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":2,"output_tokens":1,"input_tokens_details":{"cached_tokens":1}}}}`), "response.completed", nil)
+ require.Equal(t, 2, state.usage.InputTokens)
+ require.Equal(t, 1, state.usage.OutputTokens)
+ require.Equal(t, 1, state.usage.CacheReadInputTokens)
+
+ result := &RelayResult{}
+ enrichResult(result, state, 5*time.Millisecond)
+ require.Equal(t, state.usage.InputTokens, result.Usage.InputTokens)
+ require.Equal(t, 5*time.Millisecond, result.Duration)
+ parseUsageAndAccumulate(state, []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":9}}}`), "response.in_progress", nil)
+ require.Equal(t, 2, state.usage.InputTokens)
+ enrichResult(nil, state, 0)
+}
+
+func TestEmitTurnCompleteCoverage(t *testing.T) {
+ t.Parallel()
+
+ // 非 terminal 事件不应触发。
+ called := 0
+ emitTurnComplete(func(turn RelayTurnResult) {
+ called++
+ }, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
+ terminal: false,
+ eventType: "response.output_text.delta",
+ responseID: "resp_ignored",
+ usage: Usage{InputTokens: 1},
+ })
+ require.Equal(t, 0, called)
+
+ // 缺少 response_id 时不应触发。
+ emitTurnComplete(func(turn RelayTurnResult) {
+ called++
+ }, &relayState{requestModel: "gpt-5"}, observedUpstreamEvent{
+ terminal: true,
+ eventType: "response.completed",
+ })
+ require.Equal(t, 0, called)
+
+ // terminal 且 response_id 存在,应该触发;state=nil 时 model 为空串。
+ var got RelayTurnResult
+ emitTurnComplete(func(turn RelayTurnResult) {
+ called++
+ got = turn
+ }, nil, observedUpstreamEvent{
+ terminal: true,
+ eventType: "response.completed",
+ responseID: "resp_emit",
+ usage: Usage{InputTokens: 2, OutputTokens: 3},
+ })
+ require.Equal(t, 1, called)
+ require.Equal(t, "resp_emit", got.RequestID)
+ require.Equal(t, "response.completed", got.TerminalEventType)
+ require.Equal(t, 2, got.Usage.InputTokens)
+ require.Equal(t, 3, got.Usage.OutputTokens)
+ require.Equal(t, "", got.RequestModel)
+}
+
+func TestIsDisconnectErrorCoverage_CloseStatusesAndMessageBranches(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNormalClosure}))
+ require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusNoStatusRcvd}))
+ require.True(t, isDisconnectError(coderws.CloseError{Code: coderws.StatusAbnormalClosure}))
+ require.True(t, isDisconnectError(errors.New("connection reset by peer")))
+ require.False(t, isDisconnectError(errors.New(" ")))
+}
+
+func TestIsTokenEventCoverageBranches(t *testing.T) {
+ t.Parallel()
+
+ require.False(t, isTokenEvent("response.in_progress"))
+ require.False(t, isTokenEvent("response.output_item.added"))
+ require.True(t, isTokenEvent("response.output_audio.delta"))
+ require.True(t, isTokenEvent("response.output"))
+ require.False(t, isTokenEvent("response.done"))
+ require.False(t, isTokenEvent("response.completed"))
+}
+
+func TestRelayTurnTimingHelpersCoverage(t *testing.T) {
+ t.Parallel()
+
+ now := time.Unix(100, 0)
+ // nil state
+ require.Nil(t, openAIWSRelayGetOrInitTurnTiming(nil, "resp_nil", now))
+ _, ok := openAIWSRelayDeleteTurnTiming(nil, "resp_nil")
+ require.False(t, ok)
+
+ state := &relayState{}
+ timing := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now)
+ require.NotNil(t, timing)
+ require.Equal(t, now, timing.startAt)
+
+ // 再次获取返回同一条 timing
+ timing2 := openAIWSRelayGetOrInitTurnTiming(state, "resp_a", now.Add(5*time.Second))
+ require.NotNil(t, timing2)
+ require.Equal(t, now, timing2.startAt)
+
+ // 删除存在键
+ deleted, ok := openAIWSRelayDeleteTurnTiming(state, "resp_a")
+ require.True(t, ok)
+ require.Equal(t, now, deleted.startAt)
+
+ // 删除不存在键
+ _, ok = openAIWSRelayDeleteTurnTiming(state, "resp_a")
+ require.False(t, ok)
+}
+
+func TestObserveUpstreamMessage_ResponseIDFallbackPolicy(t *testing.T) {
+ t.Parallel()
+
+ state := &relayState{requestModel: "gpt-5"}
+ startAt := time.Unix(0, 0)
+ now := startAt
+ nowFn := func() time.Time {
+ now = now.Add(5 * time.Millisecond)
+ return now
+ }
+
+ // 非 terminal:仅有顶层 id,不应把 event id 当成 response_id。
+ observed := observeUpstreamMessage(
+ state,
+ []byte(`{"type":"response.output_text.delta","id":"evt_123","delta":"hi"}`),
+ nowFn,
+ nil,
+ )
+ require.False(t, observed.terminal)
+ require.Equal(t, "", observed.responseID)
+
+ // terminal:允许兜底用顶层 id(用于兼容少数字段变体)。
+ observed = observeUpstreamMessage(
+ state,
+ []byte(`{"type":"response.completed","id":"resp_fallback","response":{"usage":{"input_tokens":1,"output_tokens":1}}}`),
+ nowFn,
+ nil,
+ )
+ require.True(t, observed.terminal)
+ require.Equal(t, "resp_fallback", observed.responseID)
+}
+
+type writeCtxCaptureFrameConn struct {
+ writeHasDeadline atomic.Bool
+}
+
+func (c *writeCtxCaptureFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ <-ctx.Done()
+ return coderws.MessageText, nil, ctx.Err()
+}
+
+func (c *writeCtxCaptureFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ _, hasDeadline := ctx.Deadline()
+ c.writeHasDeadline.Store(hasDeadline)
+ return nil
+}
+
+func (c *writeCtxCaptureFrameConn) Close() error {
+ return nil
+}
+
+func TestRelay_WriteTimeoutDeadlineToggle(t *testing.T) {
+ t.Parallel()
+
+ run := func(disable bool) bool {
+ clientConn := newPassthroughTestFrameConn(nil, true)
+ upstreamConn := &writeCtxCaptureFrameConn{}
+ _, _ = Relay(context.Background(), clientConn, upstreamConn, []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`), RelayOptions{
+ WriteTimeout: 30 * time.Second,
+ DisableWriteTimeout: disable,
+ UpstreamDrainTimeout: 20 * time.Millisecond,
+ })
+ return upstreamConn.writeHasDeadline.Load()
+ }
+
+ require.False(t, run(true))
+ require.True(t, run(false))
+}
+
+func TestRelay_InitialRequestModelOverride(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_override","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }, true)
+
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, []byte(`{"type":"response.create","input":[]}`), RelayOptions{
+ InitialRequestModel: "gpt-5.3-codex",
+ })
+ require.Nil(t, relayExit)
+ require.Equal(t, "gpt-5.3-codex", result.RequestModel)
+}
+
+func TestShouldTraceDroppedDownstreamFrame(t *testing.T) {
+ t.Parallel()
+
+ require.True(t, shouldTraceDroppedDownstreamFrame(1, false))
+ require.True(t, shouldTraceDroppedDownstreamFrame(3, false))
+ require.False(t, shouldTraceDroppedDownstreamFrame(4, false))
+ require.True(t, shouldTraceDroppedDownstreamFrame(128, false))
+ require.True(t, shouldTraceDroppedDownstreamFrame(999, true))
+}
diff --git a/backend/internal/service/openai_ws_v2/passthrough_relay_test.go b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go
new file mode 100644
index 000000000..6503b3d0f
--- /dev/null
+++ b/backend/internal/service/openai_ws_v2/passthrough_relay_test.go
@@ -0,0 +1,804 @@
+package openai_ws_v2
+
+import (
+ "context"
+ "errors"
+ "io"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ coderws "github.com/coder/websocket"
+ "github.com/stretchr/testify/require"
+)
+
+type passthroughTestFrame struct {
+ msgType coderws.MessageType
+ payload []byte
+}
+
+type passthroughTestFrameConn struct {
+ mu sync.Mutex
+ writes []passthroughTestFrame
+ readCh chan passthroughTestFrame
+ once sync.Once
+}
+
+type delayedReadFrameConn struct {
+ base FrameConn
+ firstDelay time.Duration
+ once sync.Once
+}
+
+type closeSpyFrameConn struct {
+ closeCalls atomic.Int32
+}
+
+func newPassthroughTestFrameConn(frames []passthroughTestFrame, autoClose bool) *passthroughTestFrameConn {
+ c := &passthroughTestFrameConn{
+ readCh: make(chan passthroughTestFrame, len(frames)+1),
+ }
+ for _, frame := range frames {
+ copied := passthroughTestFrame{msgType: frame.msgType, payload: append([]byte(nil), frame.payload...)}
+ c.readCh <- copied
+ }
+ if autoClose {
+ close(c.readCh)
+ }
+ return c
+}
+
+func (c *passthroughTestFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ select {
+ case <-ctx.Done():
+ return coderws.MessageText, nil, ctx.Err()
+ case frame, ok := <-c.readCh:
+ if !ok {
+ return coderws.MessageText, nil, io.EOF
+ }
+ return frame.msgType, append([]byte(nil), frame.payload...), nil
+ }
+}
+
+func (c *passthroughTestFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.writes = append(c.writes, passthroughTestFrame{msgType: msgType, payload: append([]byte(nil), payload...)})
+ return nil
+}
+
+func (c *passthroughTestFrameConn) Close() error {
+ c.once.Do(func() {
+ defer func() { _ = recover() }()
+ close(c.readCh)
+ })
+ return nil
+}
+
+func (c *passthroughTestFrameConn) Writes() []passthroughTestFrame {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ out := make([]passthroughTestFrame, len(c.writes))
+ copy(out, c.writes)
+ return out
+}
+
+func (c *delayedReadFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ if c == nil || c.base == nil {
+ return coderws.MessageText, nil, io.EOF
+ }
+ c.once.Do(func() {
+ if c.firstDelay > 0 {
+ timer := time.NewTimer(c.firstDelay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ case <-timer.C:
+ }
+ }
+ })
+ return c.base.ReadFrame(ctx)
+}
+
+func (c *delayedReadFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
+ if c == nil || c.base == nil {
+ return io.EOF
+ }
+ return c.base.WriteFrame(ctx, msgType, payload)
+}
+
+func (c *delayedReadFrameConn) Close() error {
+ if c == nil || c.base == nil {
+ return nil
+ }
+ return c.base.Close()
+}
+
+func (c *closeSpyFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ <-ctx.Done()
+ return coderws.MessageText, nil, ctx.Err()
+}
+
+func (c *closeSpyFrameConn) WriteFrame(ctx context.Context, _ coderws.MessageType, _ []byte) error {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ default:
+ return nil
+ }
+}
+
+func (c *closeSpyFrameConn) Close() error {
+ if c != nil {
+ c.closeCalls.Add(1)
+ }
+ return nil
+}
+
+func (c *closeSpyFrameConn) CloseCalls() int32 {
+ if c == nil {
+ return 0
+ }
+ return c.closeCalls.Load()
+}
+
+func TestRelay_BasicRelayAndUsage(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`),
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"input_text","text":"hello"}]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
+ require.Nil(t, relayExit)
+ require.Equal(t, "gpt-5.3-codex", result.RequestModel)
+ require.Equal(t, "resp_123", result.RequestID)
+ require.Equal(t, "response.completed", result.TerminalEventType)
+ require.Equal(t, 7, result.Usage.InputTokens)
+ require.Equal(t, 3, result.Usage.OutputTokens)
+ require.Equal(t, 2, result.Usage.CacheReadInputTokens)
+ require.Nil(t, result.FirstTokenMs)
+ require.Equal(t, int64(1), result.ClientToUpstreamFrames)
+ require.Equal(t, int64(1), result.UpstreamToClientFrames)
+ require.Equal(t, int64(0), result.DroppedDownstreamFrames)
+
+ upstreamWrites := upstreamConn.Writes()
+ require.Len(t, upstreamWrites, 1)
+ require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
+ require.JSONEq(t, string(firstPayload), string(upstreamWrites[0].payload))
+
+ clientWrites := clientConn.Writes()
+ require.Len(t, clientWrites, 1)
+ require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
+ require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_123","usage":{"input_tokens":7,"output_tokens":3,"input_tokens_details":{"cached_tokens":2}}}}`, string(clientWrites[0].payload))
+}
+
+func TestRelay_FunctionCallOutputBytesPreserved(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_func","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[{"type":"function_call_output","call_id":"call_abc123","output":"{\"ok\":true}"}]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
+ require.Nil(t, relayExit)
+
+ upstreamWrites := upstreamConn.Writes()
+ require.Len(t, upstreamWrites, 1)
+ require.Equal(t, coderws.MessageText, upstreamWrites[0].msgType)
+ require.Equal(t, firstPayload, upstreamWrites[0].payload)
+}
+
+func TestRelay_UpstreamDisconnect(t *testing.T) {
+ t.Parallel()
+
+ // 上游立即关闭(EOF),客户端不发送额外帧
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
+ // 上游 EOF 属于 disconnect,标记为 graceful
+ require.Nil(t, relayExit, "上游 EOF 应被视为 graceful disconnect")
+ require.Equal(t, "gpt-4o", result.RequestModel)
+}
+
+func TestRelay_ClientDisconnect(t *testing.T) {
+ t.Parallel()
+
+ // 客户端立即关闭(EOF),上游阻塞读取直到 context 取消
+ clientConn := newPassthroughTestFrameConn(nil, true) // 立即 close -> EOF
+ upstreamConn := newPassthroughTestFrameConn(nil, false)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
+ require.NotNil(t, relayExit, "客户端 EOF 应返回可观测的中断状态")
+ require.Equal(t, "client_disconnected", relayExit.Stage)
+ require.Equal(t, "gpt-4o", result.RequestModel)
+}
+
+func TestRelay_ClientDisconnect_DrainCapturesLateUsage(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, true)
+ upstreamBase := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_drain","usage":{"input_tokens":6,"output_tokens":4,"input_tokens_details":{"cached_tokens":1}}}}`),
+ },
+ }, true)
+ upstreamConn := &delayedReadFrameConn{
+ base: upstreamBase,
+ firstDelay: 80 * time.Millisecond,
+ }
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
+ UpstreamDrainTimeout: 400 * time.Millisecond,
+ })
+ require.NotNil(t, relayExit)
+ require.Equal(t, "client_disconnected", relayExit.Stage)
+ require.Equal(t, "resp_drain", result.RequestID)
+ require.Equal(t, "response.completed", result.TerminalEventType)
+ require.Equal(t, 6, result.Usage.InputTokens)
+ require.Equal(t, 4, result.Usage.OutputTokens)
+ require.Equal(t, 1, result.Usage.CacheReadInputTokens)
+ require.Equal(t, int64(1), result.ClientToUpstreamFrames)
+ require.Equal(t, int64(0), result.UpstreamToClientFrames)
+ require.Equal(t, int64(1), result.DroppedDownstreamFrames)
+}
+
+func TestRelay_IdleTimeout(t *testing.T) {
+ t.Parallel()
+
+ // 客户端和上游都不发送帧,idle timeout 应触发
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn(nil, false)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ // 使用快进时间来加速 idle timeout
+ now := time.Now()
+ callCount := 0
+ nowFn := func() time.Time {
+ callCount++
+ // 前几次调用返回正常时间(初始化阶段),之后快进
+ if callCount <= 5 {
+ return now
+ }
+ return now.Add(time.Hour) // 快进到超时
+ }
+
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
+ IdleTimeout: 2 * time.Second,
+ Now: nowFn,
+ })
+ require.NotNil(t, relayExit, "应因 idle timeout 退出")
+ require.Equal(t, "idle_timeout", relayExit.Stage)
+ require.ErrorIs(t, relayExit.Err, ErrRelayIdleTimeout)
+ require.NotErrorIs(t, relayExit.Err, context.DeadlineExceeded)
+ require.Equal(t, "gpt-4o", result.RequestModel)
+}
+
+func TestRelay_IdleTimeoutDoesNotCloseClientOnError(t *testing.T) {
+ t.Parallel()
+
+ clientConn := &closeSpyFrameConn{}
+ upstreamConn := &closeSpyFrameConn{}
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ now := time.Now()
+ callCount := 0
+ nowFn := func() time.Time {
+ callCount++
+ if callCount <= 5 {
+ return now
+ }
+ return now.Add(time.Hour)
+ }
+
+ _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
+ IdleTimeout: 2 * time.Second,
+ Now: nowFn,
+ })
+ require.NotNil(t, relayExit, "应因 idle timeout 退出")
+ require.Equal(t, "idle_timeout", relayExit.Stage)
+ require.ErrorIs(t, relayExit.Err, ErrRelayIdleTimeout)
+ require.NotErrorIs(t, relayExit.Err, context.DeadlineExceeded)
+ require.Zero(t, clientConn.CloseCalls(), "错误路径不应提前关闭客户端连接,交给上层决定 close code")
+ require.GreaterOrEqual(t, upstreamConn.CloseCalls(), int32(1))
+}
+
+func TestRelay_NilConnections(t *testing.T) {
+ t.Parallel()
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx := context.Background()
+
+ t.Run("nil client conn", func(t *testing.T) {
+ upstreamConn := newPassthroughTestFrameConn(nil, true)
+ _, relayExit := Relay(ctx, nil, upstreamConn, firstPayload, RelayOptions{})
+ require.NotNil(t, relayExit)
+ require.Equal(t, "relay_init", relayExit.Stage)
+ require.Contains(t, relayExit.Err.Error(), "nil")
+ })
+
+ t.Run("nil upstream conn", func(t *testing.T) {
+ clientConn := newPassthroughTestFrameConn(nil, true)
+ _, relayExit := Relay(ctx, clientConn, nil, firstPayload, RelayOptions{})
+ require.NotNil(t, relayExit)
+ require.Equal(t, "relay_init", relayExit.Stage)
+ require.Contains(t, relayExit.Err.Error(), "nil")
+ })
+}
+
+func TestRelay_MultipleUpstreamMessages(t *testing.T) {
+ t.Parallel()
+
+ // 上游发送多个事件(delta + completed),验证多帧中继和 usage 聚合
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.output_text.delta","delta":"Hello"}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.output_text.delta","delta":" world"}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_multi","usage":{"input_tokens":10,"output_tokens":5,"input_tokens_details":{"cached_tokens":3}}}}`),
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[{"type":"input_text","text":"hi"}]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
+ require.Nil(t, relayExit)
+ require.Equal(t, "resp_multi", result.RequestID)
+ require.Equal(t, "response.completed", result.TerminalEventType)
+ require.Equal(t, 10, result.Usage.InputTokens)
+ require.Equal(t, 5, result.Usage.OutputTokens)
+ require.Equal(t, 3, result.Usage.CacheReadInputTokens)
+ require.NotNil(t, result.FirstTokenMs)
+
+ // 验证所有 3 个上游帧都转发给了客户端
+ clientWrites := clientConn.Writes()
+ require.Len(t, clientWrites, 3)
+}
+
+func TestRelay_OnTurnComplete_PerTerminalEvent(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":2,"output_tokens":1}}}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":3,"output_tokens":4}}}`),
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ turns := make([]RelayTurnResult, 0, 2)
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
+ OnTurnComplete: func(turn RelayTurnResult) {
+ turns = append(turns, turn)
+ },
+ })
+ require.Nil(t, relayExit)
+ require.Len(t, turns, 2)
+ require.Equal(t, "resp_turn_1", turns[0].RequestID)
+ require.Equal(t, "response.completed", turns[0].TerminalEventType)
+ require.Equal(t, 2, turns[0].Usage.InputTokens)
+ require.Equal(t, 1, turns[0].Usage.OutputTokens)
+ require.Equal(t, "resp_turn_2", turns[1].RequestID)
+ require.Equal(t, "response.failed", turns[1].TerminalEventType)
+ require.Equal(t, 3, turns[1].Usage.InputTokens)
+ require.Equal(t, 4, turns[1].Usage.OutputTokens)
+ require.Equal(t, 5, result.Usage.InputTokens)
+ require.Equal(t, 5, result.Usage.OutputTokens)
+}
+
+func TestRelay_OnTurnComplete_ProvidesTurnMetrics(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metric","delta":"hi"}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_metric","usage":{"input_tokens":2,"output_tokens":1}}}`),
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ base := time.Unix(0, 0)
+ var nowTick atomic.Int64
+ nowFn := func() time.Time {
+ step := nowTick.Add(1)
+ return base.Add(time.Duration(step) * 5 * time.Millisecond)
+ }
+
+ var turn RelayTurnResult
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
+ Now: nowFn,
+ OnTurnComplete: func(current RelayTurnResult) {
+ turn = current
+ },
+ })
+ require.Nil(t, relayExit)
+ require.Equal(t, "resp_metric", turn.RequestID)
+ require.Equal(t, "response.completed", turn.TerminalEventType)
+ require.NotNil(t, turn.FirstTokenMs)
+ require.GreaterOrEqual(t, *turn.FirstTokenMs, 0)
+ require.Greater(t, turn.Duration.Milliseconds(), int64(0))
+ require.NotNil(t, result.FirstTokenMs)
+ require.Greater(t, result.Duration.Milliseconds(), int64(0))
+}
+
+func TestRelay_OnTurnComplete_FirstTokenFromTurnFallbackWhenDeltaHasNoResponseID(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.created"}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.output_text.delta","delta":"hello"}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_fallback_ttft","usage":{"input_tokens":2,"output_tokens":1}}}`),
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ base := time.Unix(0, 0)
+ var nowTick atomic.Int64
+ nowFn := func() time.Time {
+ step := nowTick.Add(1)
+ return base.Add(time.Duration(step) * 5 * time.Millisecond)
+ }
+
+ var turn RelayTurnResult
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
+ Now: nowFn,
+ OnTurnComplete: func(current RelayTurnResult) {
+ turn = current
+ },
+ })
+ require.Nil(t, relayExit)
+ require.Equal(t, "resp_fallback_ttft", turn.RequestID)
+ require.NotNil(t, turn.FirstTokenMs)
+ require.NotNil(t, result.FirstTokenMs)
+ require.GreaterOrEqual(t, *turn.FirstTokenMs, 0)
+ require.Equal(t, *turn.FirstTokenMs, *result.FirstTokenMs)
+ require.Greater(t, turn.Duration.Milliseconds(), int64(0))
+ require.Greater(t, turn.Duration.Milliseconds(), int64(*turn.FirstTokenMs))
+}
+
+func TestRelay_BinaryFramePassthrough(t *testing.T) {
+ t.Parallel()
+
+ // 验证 binary frame 被透传但不进行 usage 解析
+ binaryPayload := []byte{0x00, 0x01, 0x02, 0x03}
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageBinary,
+ payload: binaryPayload,
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
+ require.Nil(t, relayExit)
+ // binary frame 不解析 usage
+ require.Equal(t, 0, result.Usage.InputTokens)
+
+ clientWrites := clientConn.Writes()
+ require.Len(t, clientWrites, 1)
+ require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
+ require.Equal(t, binaryPayload, clientWrites[0].payload)
+}
+
+func TestRelay_BinaryJSONFrameSkipsObservation(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageBinary,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_binary","usage":{"input_tokens":7,"output_tokens":3}}}`),
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
+ require.Nil(t, relayExit)
+ require.Equal(t, 0, result.Usage.InputTokens)
+ require.Equal(t, "", result.RequestID)
+ require.Equal(t, "", result.TerminalEventType)
+
+ clientWrites := clientConn.Writes()
+ require.Len(t, clientWrites, 1)
+ require.Equal(t, coderws.MessageBinary, clientWrites[0].msgType)
+}
+
+func TestRelay_UpstreamErrorEventPassthroughRaw(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ errorEvent := []byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool call found"}}`)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: errorEvent,
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
+ require.Nil(t, relayExit)
+
+ clientWrites := clientConn.Writes()
+ require.Len(t, clientWrites, 1)
+ require.Equal(t, coderws.MessageText, clientWrites[0].msgType)
+ require.Equal(t, errorEvent, clientWrites[0].payload)
+}
+
+func TestRelay_PreservesFirstMessageType(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn(nil, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
+ FirstMessageType: coderws.MessageBinary,
+ })
+ require.Nil(t, relayExit)
+
+ upstreamWrites := upstreamConn.Writes()
+ require.Len(t, upstreamWrites, 1)
+ require.Equal(t, coderws.MessageBinary, upstreamWrites[0].msgType)
+ require.Equal(t, firstPayload, upstreamWrites[0].payload)
+}
+
+func TestRelay_UsageParseFailureDoesNotBlockRelay(t *testing.T) {
+ baseline := SnapshotMetrics().UsageParseFailureTotal
+
+ // 上游发送无效 JSON(非 usage 格式),不应影响透传
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_bad","usage":"not_an_object"}}`),
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ result, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
+ require.Nil(t, relayExit)
+ // usage 解析失败,值为 0 但不影响透传
+ require.Equal(t, 0, result.Usage.InputTokens)
+ require.Equal(t, "response.completed", result.TerminalEventType)
+
+ // 帧仍然被转发
+ clientWrites := clientConn.Writes()
+ require.Len(t, clientWrites, 1)
+ require.GreaterOrEqual(t, SnapshotMetrics().UsageParseFailureTotal, baseline+1)
+}
+
+func TestRelay_WriteUpstreamFirstMessageFails(t *testing.T) {
+ t.Parallel()
+
+ // 上游连接立即关闭,首包写入失败
+ upstreamConn := newPassthroughTestFrameConn(nil, true)
+ _ = upstreamConn.Close()
+
+ // 覆盖 WriteFrame 使其返回错误
+ errConn := &errorOnWriteFrameConn{}
+ clientConn := newPassthroughTestFrameConn(nil, false)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ _, relayExit := Relay(ctx, clientConn, errConn, firstPayload, RelayOptions{})
+ require.NotNil(t, relayExit)
+ require.Equal(t, "write_upstream", relayExit.Stage)
+}
+
+func TestRelay_ContextCanceled(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn(nil, false)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+
+ // 立即取消 context
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{})
+ // context 取消导致写首包失败
+ require.NotNil(t, relayExit)
+}
+
+func TestRelay_TraceEvents_ContainsLifecycleStages(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn([]passthroughTestFrame{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_trace","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ }, true)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+
+ stages := make([]string, 0, 8)
+ var stagesMu sync.Mutex
+ _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
+ OnTrace: func(event RelayTraceEvent) {
+ stagesMu.Lock()
+ stages = append(stages, event.Stage)
+ stagesMu.Unlock()
+ },
+ })
+ require.Nil(t, relayExit)
+ stagesMu.Lock()
+ capturedStages := append([]string(nil), stages...)
+ stagesMu.Unlock()
+ require.Contains(t, capturedStages, "relay_start")
+ require.Contains(t, capturedStages, "write_first_message_ok")
+ require.Contains(t, capturedStages, "first_exit")
+ require.Contains(t, capturedStages, "relay_complete")
+}
+
+func TestRelay_TraceEvents_IdleTimeout(t *testing.T) {
+ t.Parallel()
+
+ clientConn := newPassthroughTestFrameConn(nil, false)
+ upstreamConn := newPassthroughTestFrameConn(nil, false)
+
+ firstPayload := []byte(`{"type":"response.create","model":"gpt-4o","input":[]}`)
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ now := time.Now()
+ callCount := 0
+ nowFn := func() time.Time {
+ callCount++
+ if callCount <= 5 {
+ return now
+ }
+ return now.Add(time.Hour)
+ }
+
+ stages := make([]string, 0, 8)
+ var stagesMu sync.Mutex
+ _, relayExit := Relay(ctx, clientConn, upstreamConn, firstPayload, RelayOptions{
+ IdleTimeout: 2 * time.Second,
+ Now: nowFn,
+ OnTrace: func(event RelayTraceEvent) {
+ stagesMu.Lock()
+ stages = append(stages, event.Stage)
+ stagesMu.Unlock()
+ },
+ })
+ require.NotNil(t, relayExit)
+ require.Equal(t, "idle_timeout", relayExit.Stage)
+ require.ErrorIs(t, relayExit.Err, ErrRelayIdleTimeout)
+ stagesMu.Lock()
+ capturedStages := append([]string(nil), stages...)
+ stagesMu.Unlock()
+ require.Contains(t, capturedStages, "idle_timeout_triggered")
+ require.Contains(t, capturedStages, "relay_exit")
+}
+
+// errorOnWriteFrameConn 是一个写入总是失败的 FrameConn 实现,用于测试首包写入失败。
+type errorOnWriteFrameConn struct{}
+
+func (c *errorOnWriteFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ <-ctx.Done()
+ return coderws.MessageText, nil, ctx.Err()
+}
+
+func (c *errorOnWriteFrameConn) WriteFrame(_ context.Context, _ coderws.MessageType, _ []byte) error {
+ return errors.New("write failed: connection refused")
+}
+
+func (c *errorOnWriteFrameConn) Close() error {
+ return nil
+}
diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter.go b/backend/internal/service/openai_ws_v2_passthrough_adapter.go
new file mode 100644
index 000000000..828c3692f
--- /dev/null
+++ b/backend/internal/service/openai_ws_v2_passthrough_adapter.go
@@ -0,0 +1,397 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "strings"
+ "sync/atomic"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
+ openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
+ coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
+)
+
+type openAIWSClientFrameConn struct {
+ conn *coderws.Conn
+}
+
+const openaiWSV2PassthroughModeFields = "ws_mode=passthrough ws_router=v2"
+
+var _ openaiwsv2.FrameConn = (*openAIWSClientFrameConn)(nil)
+
+func (c *openAIWSClientFrameConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ if c == nil || c.conn == nil {
+ return coderws.MessageText, nil, errOpenAIWSConnClosed
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return c.conn.Read(ctx)
+}
+
+func (c *openAIWSClientFrameConn) WriteFrame(ctx context.Context, msgType coderws.MessageType, payload []byte) error {
+ if c == nil || c.conn == nil {
+ return errOpenAIWSConnClosed
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ return c.conn.Write(ctx, msgType, payload)
+}
+
+func (c *openAIWSClientFrameConn) Close() error {
+ if c == nil || c.conn == nil {
+ return nil
+ }
+ _ = c.conn.Close(coderws.StatusNormalClosure, "")
+ _ = c.conn.CloseNow()
+ return nil
+}
+
+func (s *OpenAIGatewayService) proxyResponsesWebSocketV2Passthrough(
+ ctx context.Context,
+ c *gin.Context,
+ clientConn *coderws.Conn,
+ account *Account,
+ token string,
+ firstClientMessageType coderws.MessageType,
+ firstClientMessage []byte,
+ hooks *OpenAIWSIngressHooks,
+ wsDecision OpenAIWSProtocolDecision,
+) error {
+ if s == nil {
+ return errors.New("service is nil")
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ if clientConn == nil {
+ return errors.New("client websocket is nil")
+ }
+ if account == nil {
+ return errors.New("account is nil")
+ }
+ if strings.TrimSpace(token) == "" {
+ return errors.New("token is empty")
+ }
+ requestModel, requestPreviousResponseID, _ := ResolveOpenAIWSFirstMessageMeta(c, firstClientMessage)
+ logOpenAIWSV2Passthrough(
+ "relay_start account_id=%d model=%s previous_response_id=%s first_message_type=%s first_message_bytes=%d",
+ account.ID,
+ truncateOpenAIWSLogValue(requestModel, openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(requestPreviousResponseID, openAIWSIDValueMaxLen),
+ openaiwsv2RelayMessageTypeName(firstClientMessageType),
+ len(firstClientMessage),
+ )
+
+ wsURL, err := s.buildOpenAIResponsesWSURL(account)
+ if err != nil {
+ return fmt.Errorf("build ws url: %w", err)
+ }
+ wsHost, wsPath := openAIWSHostPathForLogFromURL(wsURL)
+ logOpenAIWSV2Passthrough(
+ "relay_dial_start account_id=%d ws_host=%s ws_path=%s proxy_enabled=%v",
+ account.ID,
+ wsHost,
+ wsPath,
+ account.ProxyID != nil && account.Proxy != nil,
+ )
+
+ isCodexCLI := false
+ if c != nil {
+ isCodexCLI = openai.IsCodexCLIRequest(c.GetHeader("User-Agent"))
+ }
+ if s.cfg != nil && s.cfg.Gateway.ForceCodexCLI {
+ isCodexCLI = true
+ }
+ headers, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, "", "", "")
+ proxyURL := ""
+ if account.ProxyID != nil && account.Proxy != nil {
+ proxyURL = account.Proxy.URL()
+ }
+
+ dialer := s.getOpenAIWSPassthroughDialer()
+ if dialer == nil {
+ return errors.New("openai ws passthrough dialer is nil")
+ }
+
+ dialCtx, cancelDial := context.WithTimeout(ctx, s.openAIWSDialTimeout())
+ defer cancelDial()
+ upstreamConn, statusCode, handshakeHeaders, err := dialer.Dial(dialCtx, wsURL, headers, proxyURL)
+ if err != nil {
+ logOpenAIWSV2Passthrough(
+ "relay_dial_failed account_id=%d status_code=%d err=%s",
+ account.ID,
+ statusCode,
+ truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ )
+ return s.mapOpenAIWSPassthroughDialError(err, statusCode, handshakeHeaders)
+ }
+ defer func() {
+ _ = upstreamConn.Close()
+ }()
+ logOpenAIWSV2Passthrough(
+ "relay_dial_ok account_id=%d status_code=%d upstream_request_id=%s",
+ account.ID,
+ statusCode,
+ openAIWSHeaderValueForLog(handshakeHeaders, "x-request-id"),
+ )
+
+ upstreamFrameConn, ok := upstreamConn.(openaiwsv2.FrameConn)
+ if !ok {
+ return errors.New("openai ws passthrough upstream connection does not support frame relay")
+ }
+
+ completedTurns := atomic.Int32{}
+ var onTrace func(event openaiwsv2.RelayTraceEvent)
+ if isOpenAIWSModeDebugEnabled() {
+ onTrace = func(event openaiwsv2.RelayTraceEvent) {
+ logOpenAIWSV2PassthroughDebug(
+ "relay_trace account_id=%d stage=%s direction=%s msg_type=%s bytes=%d graceful=%v wrote_downstream=%v err=%s",
+ account.ID,
+ truncateOpenAIWSLogValue(event.Stage, openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(event.Direction, openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(event.MessageType, openAIWSLogValueMaxLen),
+ event.PayloadBytes,
+ event.Graceful,
+ event.WroteDownstream,
+ truncateOpenAIWSLogValue(event.Error, openAIWSLogValueMaxLen),
+ )
+ }
+ }
+ relayResult, relayExit := openaiwsv2.RunEntry(openaiwsv2.EntryInput{
+ Ctx: ctx,
+ ClientConn: &openAIWSClientFrameConn{conn: clientConn},
+ UpstreamConn: upstreamFrameConn,
+ FirstClientMessage: firstClientMessage,
+ Options: openaiwsv2.RelayOptions{
+ WriteTimeout: s.openAIWSWriteTimeout(),
+ IdleTimeout: s.openAIWSPassthroughIdleTimeout(),
+ FirstMessageType: firstClientMessageType,
+ InitialRequestModel: requestModel,
+ // passthrough 链路走大帧高频转发,避免每帧创建超时 context/timer。
+ // 由 relayCtx 取消 + idle watchdog 兜底释放。
+ DisableWriteTimeout: true,
+ OnUsageParseFailure: func(eventType string, usageRaw string) {
+ logOpenAIWSV2Passthrough(
+ "usage_parse_failed event_type=%s usage_raw=%s",
+ truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(usageRaw, openAIWSLogValueMaxLen),
+ )
+ },
+ OnTurnComplete: func(turn openaiwsv2.RelayTurnResult) {
+ turnNo := int(completedTurns.Add(1))
+ turnResult := &OpenAIForwardResult{
+ RequestID: turn.RequestID,
+ Usage: OpenAIUsage{
+ InputTokens: turn.Usage.InputTokens,
+ OutputTokens: turn.Usage.OutputTokens,
+ CacheCreationInputTokens: turn.Usage.CacheCreationInputTokens,
+ CacheReadInputTokens: turn.Usage.CacheReadInputTokens,
+ },
+ Model: turn.RequestModel,
+ Stream: true,
+ OpenAIWSMode: true,
+ WSIngressMode: OpenAIWSIngressModePassthrough,
+ Duration: turn.Duration,
+ FirstTokenMs: turn.FirstTokenMs,
+ TerminalEventType: turn.TerminalEventType,
+ }
+ logOpenAIWSV2Passthrough(
+ "relay_turn_completed account_id=%d turn=%d request_id=%s terminal_event=%s duration_ms=%d first_token_ms=%d input_tokens=%d output_tokens=%d cache_read_tokens=%d",
+ account.ID,
+ turnNo,
+ truncateOpenAIWSLogValue(turnResult.RequestID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(turnResult.TerminalEventType, openAIWSLogValueMaxLen),
+ turnResult.Duration.Milliseconds(),
+ openAIWSFirstTokenMsForLog(turnResult.FirstTokenMs),
+ turnResult.Usage.InputTokens,
+ turnResult.Usage.OutputTokens,
+ turnResult.Usage.CacheReadInputTokens,
+ )
+ if hooks != nil && hooks.AfterTurn != nil {
+ hooks.AfterTurn(turnNo, turnResult, nil)
+ }
+ },
+ OnTrace: onTrace,
+ },
+ })
+
+ result := &OpenAIForwardResult{
+ RequestID: relayResult.RequestID,
+ Usage: OpenAIUsage{
+ InputTokens: relayResult.Usage.InputTokens,
+ OutputTokens: relayResult.Usage.OutputTokens,
+ CacheCreationInputTokens: relayResult.Usage.CacheCreationInputTokens,
+ CacheReadInputTokens: relayResult.Usage.CacheReadInputTokens,
+ },
+ Model: relayResult.RequestModel,
+ Stream: true,
+ OpenAIWSMode: true,
+ WSIngressMode: OpenAIWSIngressModePassthrough,
+ Duration: relayResult.Duration,
+ FirstTokenMs: relayResult.FirstTokenMs,
+ TerminalEventType: relayResult.TerminalEventType,
+ }
+
+ turnCount := int(completedTurns.Load())
+ if relayExit == nil {
+ logOpenAIWSV2Passthrough(
+ "relay_completed account_id=%d request_id=%s terminal_event=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
+ account.ID,
+ truncateOpenAIWSLogValue(result.RequestID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(result.TerminalEventType, openAIWSLogValueMaxLen),
+ result.Duration.Milliseconds(),
+ relayResult.ClientToUpstreamFrames,
+ relayResult.UpstreamToClientFrames,
+ relayResult.DroppedDownstreamFrames,
+ turnCount,
+ )
+ // 正常路径按 terminal 事件逐 turn 已回调;仅在零 turn 场景兜底回调一次。
+ if turnCount == 0 && hooks != nil && hooks.AfterTurn != nil {
+ hooks.AfterTurn(1, result, nil)
+ }
+ return nil
+ }
+ logOpenAIWSV2Passthrough(
+ "relay_failed account_id=%d stage=%s wrote_downstream=%v err=%s duration_ms=%d c2u_frames=%d u2c_frames=%d dropped_frames=%d turns=%d",
+ account.ID,
+ truncateOpenAIWSLogValue(relayExit.Stage, openAIWSLogValueMaxLen),
+ relayExit.WroteDownstream,
+ truncateOpenAIWSLogValue(relayErrorText(relayExit.Err), openAIWSLogValueMaxLen),
+ result.Duration.Milliseconds(),
+ relayResult.ClientToUpstreamFrames,
+ relayResult.UpstreamToClientFrames,
+ relayResult.DroppedDownstreamFrames,
+ turnCount,
+ )
+
+ relayErr := relayExit.Err
+ if relayExit.Stage == "idle_timeout" {
+ relayErr = NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "client websocket idle timeout",
+ relayErr,
+ )
+ }
+ turnErr := error(nil)
+ if turnCount == 0 {
+ turnErr = wrapOpenAIWSIngressTurnErrorWithPartial(
+ relayExit.Stage,
+ relayErr,
+ relayExit.WroteDownstream,
+ result,
+ )
+ } else {
+ // 已按 turn 回调 usage 时,错误路径不再附带 partial,避免重复记账同一 request_id。
+ turnErr = wrapOpenAIWSIngressTurnError(
+ relayExit.Stage,
+ relayErr,
+ relayExit.WroteDownstream,
+ )
+ }
+ if hooks != nil && hooks.AfterTurn != nil {
+ hooks.AfterTurn(turnCount+1, nil, turnErr)
+ }
+ return turnErr
+}
+
+func (s *OpenAIGatewayService) mapOpenAIWSPassthroughDialError(
+ err error,
+ statusCode int,
+ handshakeHeaders http.Header,
+) error {
+ if err == nil {
+ return nil
+ }
+ wrappedErr := err
+ var dialErr *openAIWSDialError
+ if !errors.As(err, &dialErr) {
+ wrappedErr = &openAIWSDialError{
+ StatusCode: statusCode,
+ ResponseHeaders: cloneHeader(handshakeHeaders),
+ Err: err,
+ }
+ }
+
+ if errors.Is(err, context.Canceled) {
+ return err
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusTryAgainLater,
+ "upstream websocket connect timeout",
+ wrappedErr,
+ )
+ }
+ if statusCode == http.StatusTooManyRequests {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusTryAgainLater,
+ "upstream websocket is busy, please retry later",
+ wrappedErr,
+ )
+ }
+ if statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "upstream websocket authentication failed",
+ wrappedErr,
+ )
+ }
+ if statusCode >= http.StatusBadRequest && statusCode < http.StatusInternalServerError {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "upstream websocket handshake rejected",
+ wrappedErr,
+ )
+ }
+ return fmt.Errorf("openai ws passthrough dial: %w", wrappedErr)
+}
+
+func openaiwsv2RelayMessageTypeName(msgType coderws.MessageType) string {
+ switch msgType {
+ case coderws.MessageText:
+ return "text"
+ case coderws.MessageBinary:
+ return "binary"
+ default:
+ return fmt.Sprintf("unknown(%d)", msgType)
+ }
+}
+
+func relayErrorText(err error) string {
+ if err == nil {
+ return ""
+ }
+ return err.Error()
+}
+
+func openAIWSFirstTokenMsForLog(firstTokenMs *int) int {
+ if firstTokenMs == nil {
+ return -1
+ }
+ return *firstTokenMs
+}
+
+func logOpenAIWSV2Passthrough(format string, args ...any) {
+ logger.LegacyPrintf(
+ "service.openai_ws_v2",
+ "[OpenAI WS v2 passthrough] "+openaiWSV2PassthroughModeFields+" "+format,
+ args...,
+ )
+}
+
+func logOpenAIWSV2PassthroughDebug(format string, args ...any) {
+ if !isOpenAIWSModeDebugEnabled() {
+ return
+ }
+ logger.LegacyPrintf(
+ "service.openai_ws_v2",
+ "[debug] [OpenAI WS v2 passthrough] "+openaiWSV2PassthroughModeFields+" "+format,
+ args...,
+ )
+}
diff --git a/backend/internal/service/openai_ws_v2_passthrough_adapter_test.go b/backend/internal/service/openai_ws_v2_passthrough_adapter_test.go
new file mode 100644
index 000000000..3557356c0
--- /dev/null
+++ b/backend/internal/service/openai_ws_v2_passthrough_adapter_test.go
@@ -0,0 +1,1194 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ openaiwsv2 "github.com/Wei-Shaw/sub2api/internal/service/openai_ws_v2"
+ coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+var openAIWSModeDebugLogTestMu sync.Mutex
+
+func TestOpenAIWSClientFrameConn_NilGuards(t *testing.T) {
+ t.Parallel()
+
+ var nilReceiver *openAIWSClientFrameConn
+ msgType, payload, err := nilReceiver.ReadFrame(context.Background())
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ require.Equal(t, coderws.MessageText, msgType)
+ require.Nil(t, payload)
+
+ err = nilReceiver.WriteFrame(context.Background(), coderws.MessageText, []byte("x"))
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ require.NoError(t, nilReceiver.Close())
+
+ empty := &openAIWSClientFrameConn{}
+ var nilCtx context.Context
+ _, _, err = empty.ReadFrame(nilCtx)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ err = empty.WriteFrame(nilCtx, coderws.MessageText, []byte("x"))
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+ require.NoError(t, empty.Close())
+}
+
+func TestOpenAIWSClientFrameConn_NilContextWithLiveConn(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, nil)
+ require.NoError(t, err)
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, payload, err := conn.Read(readCtx)
+ cancelRead()
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ writeCtx, cancelWrite := context.WithTimeout(r.Context(), 3*time.Second)
+ err = conn.Write(writeCtx, msgType, payload)
+ cancelWrite()
+ serverErrCh <- err
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(
+ dialCtx,
+ "ws"+strings.TrimPrefix(wsServer.URL, "http"),
+ nil,
+ )
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ frameConn := &openAIWSClientFrameConn{conn: clientConn}
+ payload := []byte(`{"type":"response.create","model":"gpt-5.3-codex"}`)
+ var nilCtx context.Context
+
+ require.NoError(t, frameConn.WriteFrame(nilCtx, coderws.MessageText, payload))
+ msgType, gotPayload, err := frameConn.ReadFrame(nilCtx)
+ require.NoError(t, err)
+ require.Equal(t, coderws.MessageText, msgType)
+ require.Equal(t, payload, gotPayload)
+ require.NoError(t, <-serverErrCh)
+}
+
+func TestOpenAIWSV2PassthroughHelpers(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, "text", openaiwsv2RelayMessageTypeName(coderws.MessageText))
+ require.Equal(t, "binary", openaiwsv2RelayMessageTypeName(coderws.MessageBinary))
+ require.Contains(t, openaiwsv2RelayMessageTypeName(coderws.MessageType(99)), "unknown(")
+
+ require.Equal(t, "", relayErrorText(nil))
+ require.Equal(t, "boom", relayErrorText(errors.New("boom")))
+ require.Equal(t, -1, openAIWSFirstTokenMsForLog(nil))
+ ms := 12
+ require.Equal(t, 12, openAIWSFirstTokenMsForLog(&ms))
+
+ require.NotPanics(t, func() {
+ logOpenAIWSV2Passthrough("helper_test account_id=%d", 1)
+ })
+}
+
+func TestLogOpenAIWSV2PassthroughDebug_Disabled(t *testing.T) {
+ openAIWSModeDebugLogTestMu.Lock()
+ defer openAIWSModeDebugLogTestMu.Unlock()
+
+ logger.InitBootstrap()
+ require.NoError(t, logger.SetLevel("info"))
+ require.False(t, isOpenAIWSModeDebugEnabled())
+
+ require.NotPanics(t, func() {
+ logOpenAIWSV2PassthroughDebug("helper_test_debug account_id=%d", 1)
+ })
+}
+
+func TestLogOpenAIWSV2PassthroughDebug_Enabled(t *testing.T) {
+ openAIWSModeDebugLogTestMu.Lock()
+ defer openAIWSModeDebugLogTestMu.Unlock()
+
+ logger.InitBootstrap()
+ require.NoError(t, logger.SetLevel("debug"))
+ require.True(t, isOpenAIWSModeDebugEnabled())
+ t.Cleanup(func() {
+ _ = logger.SetLevel("info")
+ })
+
+ require.NotPanics(t, func() {
+ logOpenAIWSV2PassthroughDebug("helper_test_debug_enabled account_id=%d", 1)
+ })
+}
+
+func TestProxyResponsesWebSocketV2Passthrough_InvalidInputs(t *testing.T) {
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+ firstMessage := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
+ var nilSvc *OpenAIGatewayService
+ err := nilSvc.proxyResponsesWebSocketV2Passthrough(
+ context.Background(),
+ nil,
+ nil,
+ account,
+ "sk-test",
+ coderws.MessageText,
+ firstMessage,
+ nil,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "service is nil")
+
+ cfg := buildIngressPolicyTestConfig()
+ svc := buildIngressPolicyTestService(cfg)
+ dummyClient := &coderws.Conn{}
+
+ err = svc.proxyResponsesWebSocketV2Passthrough(
+ context.Background(),
+ nil,
+ nil,
+ account,
+ "sk-test",
+ coderws.MessageText,
+ firstMessage,
+ nil,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "client websocket is nil")
+
+ err = svc.proxyResponsesWebSocketV2Passthrough(
+ context.Background(),
+ nil,
+ dummyClient,
+ nil,
+ "sk-test",
+ coderws.MessageText,
+ firstMessage,
+ nil,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "account is nil")
+
+ err = svc.proxyResponsesWebSocketV2Passthrough(
+ context.Background(),
+ nil,
+ dummyClient,
+ account,
+ " ",
+ coderws.MessageText,
+ firstMessage,
+ nil,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "token is empty")
+}
+
+func TestProxyResponsesWebSocketV2Passthrough_DialFailure(t *testing.T) {
+ cfg := buildIngressPolicyTestConfig()
+ svc := buildIngressPolicyTestService(cfg)
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+ dummyClient := &coderws.Conn{}
+ firstMessage := []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`)
+
+ svc.openaiWSPassthroughDialer = &passthroughDialerStub{
+ err: errors.New("dial failed"),
+ statusCode: 503,
+ }
+ err := svc.proxyResponsesWebSocketV2Passthrough(
+ context.Background(),
+ nil,
+ dummyClient,
+ account,
+ "sk-test",
+ coderws.MessageText,
+ firstMessage,
+ nil,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "openai ws passthrough dial")
+}
+
+type passthroughDialerStub struct {
+ conn openAIWSClientConn
+ statusCode int
+ headers http.Header
+ err error
+ dialCount atomic.Int32
+}
+
+func (d *passthroughDialerStub) Dial(
+ _ context.Context,
+ _ string,
+ _ http.Header,
+ _ string,
+) (openAIWSClientConn, int, http.Header, error) {
+ d.dialCount.Add(1)
+ return d.conn, d.statusCode, d.headers, d.err
+}
+
+type passthroughUpstreamConn struct {
+ mu sync.Mutex
+ reads []struct {
+ msgType coderws.MessageType
+ payload []byte
+ }
+ readDelay time.Duration
+ readErr error
+ writes [][]byte
+ closed bool
+}
+
+func (c *passthroughUpstreamConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ if c.readDelay > 0 {
+ timer := time.NewTimer(c.readDelay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return coderws.MessageText, nil, ctx.Err()
+ case <-timer.C:
+ }
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return coderws.MessageText, nil, io.EOF
+ }
+ if len(c.reads) == 0 {
+ if c.readErr != nil {
+ return coderws.MessageText, nil, c.readErr
+ }
+ return coderws.MessageText, nil, io.EOF
+ }
+ item := c.reads[0]
+ c.reads = c.reads[1:]
+ return item.msgType, append([]byte(nil), item.payload...), nil
+}
+
+func (c *passthroughUpstreamConn) WriteFrame(_ context.Context, _ coderws.MessageType, payload []byte) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.writes = append(c.writes, append([]byte(nil), payload...))
+ return nil
+}
+
+func (c *passthroughUpstreamConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.closed = true
+ return nil
+}
+
+func (c *passthroughUpstreamConn) WriteJSON(context.Context, any) error { return nil }
+func (c *passthroughUpstreamConn) ReadMessage(context.Context) ([]byte, error) {
+ return nil, io.EOF
+}
+func (c *passthroughUpstreamConn) Ping(context.Context) error { return nil }
+
+type passthroughBlockingUpstreamConn struct{}
+
+func (c *passthroughBlockingUpstreamConn) ReadFrame(ctx context.Context) (coderws.MessageType, []byte, error) {
+ <-ctx.Done()
+ return coderws.MessageText, nil, ctx.Err()
+}
+func (c *passthroughBlockingUpstreamConn) WriteFrame(context.Context, coderws.MessageType, []byte) error {
+ return nil
+}
+func (c *passthroughBlockingUpstreamConn) Close() error { return nil }
+func (c *passthroughBlockingUpstreamConn) WriteJSON(context.Context, any) error {
+ return nil
+}
+func (c *passthroughBlockingUpstreamConn) ReadMessage(context.Context) ([]byte, error) {
+ return nil, io.EOF
+}
+func (c *passthroughBlockingUpstreamConn) Ping(context.Context) error { return nil }
+
+func TestProxyResponsesWebSocketV2Passthrough_Success(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ upstream := &passthroughUpstreamConn{
+ reads: []struct {
+ msgType coderws.MessageType
+ payload []byte
+ }{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_passthrough","usage":{"input_tokens":3,"output_tokens":2}}}`),
+ },
+ },
+ }
+ cfg := buildIngressPolicyTestConfig()
+ svc := buildIngressPolicyTestService(cfg)
+ svc.openaiWSPassthroughDialer = &passthroughDialerStub{
+ conn: upstream,
+ statusCode: 101,
+ headers: http.Header{
+ "X-Request-ID": []string{"req-passthrough"},
+ },
+ }
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ var (
+ afterTurnCalled bool
+ afterTurnErr error
+ afterTurnResult *OpenAIForwardResult
+ )
+ hooks := &OpenAIWSIngressHooks{
+ AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
+ afterTurnCalled = true
+ afterTurnErr = turnErr
+ afterTurnResult = result
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, nil)
+ require.NoError(t, err)
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, err := conn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+
+ ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder())
+ ginCtx.Request = r
+ serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough(
+ r.Context(),
+ ginCtx,
+ conn,
+ account,
+ "sk-test",
+ msgType,
+ firstMessage,
+ hooks,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(
+ dialCtx,
+ "ws"+strings.TrimPrefix(wsServer.URL, "http"),
+ nil,
+ )
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, payload, err := clientConn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+ require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_passthrough","usage":{"input_tokens":3,"output_tokens":2}}}`, string(payload))
+
+ require.NoError(t, <-serverErrCh)
+ require.True(t, afterTurnCalled)
+ require.NoError(t, afterTurnErr)
+ require.NotNil(t, afterTurnResult)
+ require.Equal(t, "resp_passthrough", afterTurnResult.RequestID)
+}
+
+func TestProxyResponsesWebSocketV2Passthrough_NilContext(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ upstream := &passthroughUpstreamConn{
+ reads: []struct {
+ msgType coderws.MessageType
+ payload []byte
+ }{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_nil_ctx","usage":{"input_tokens":1,"output_tokens":1}}}`),
+ },
+ },
+ }
+ cfg := buildIngressPolicyTestConfig()
+ svc := buildIngressPolicyTestService(cfg)
+ svc.openaiWSPassthroughDialer = &passthroughDialerStub{
+ conn: upstream,
+ statusCode: 101,
+ }
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, nil)
+ require.NoError(t, err)
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, err := conn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+
+ ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder())
+ ginCtx.Request = r
+ serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough(
+ context.TODO(),
+ ginCtx,
+ conn,
+ account,
+ "sk-test",
+ msgType,
+ firstMessage,
+ nil,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(
+ dialCtx,
+ "ws"+strings.TrimPrefix(wsServer.URL, "http"),
+ nil,
+ )
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, payload, err := clientConn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+ require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_nil_ctx","usage":{"input_tokens":1,"output_tokens":1}}}`, string(payload))
+
+ require.NoError(t, <-serverErrCh)
+}
+
+func TestProxyResponsesWebSocketV2Passthrough_ZeroTurnFallbackCallback(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ upstream := &passthroughUpstreamConn{
+ reads: []struct {
+ msgType coderws.MessageType
+ payload []byte
+ }{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.output_text.delta","delta":"hello"}`),
+ },
+ },
+ }
+ cfg := buildIngressPolicyTestConfig()
+ svc := buildIngressPolicyTestService(cfg)
+ svc.openaiWSPassthroughDialer = &passthroughDialerStub{
+ conn: upstream,
+ statusCode: 101,
+ }
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ var (
+ afterTurnCalled bool
+ afterTurnErr error
+ afterTurnResult *OpenAIForwardResult
+ )
+ hooks := &OpenAIWSIngressHooks{
+ AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
+ afterTurnCalled = true
+ afterTurnErr = turnErr
+ afterTurnResult = result
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, nil)
+ require.NoError(t, err)
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, err := conn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+
+ ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder())
+ ginCtx.Request = r
+ serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough(
+ r.Context(),
+ ginCtx,
+ conn,
+ account,
+ "sk-test",
+ msgType,
+ firstMessage,
+ hooks,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(
+ dialCtx,
+ "ws"+strings.TrimPrefix(wsServer.URL, "http"),
+ nil,
+ )
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ msgType, payload, err := clientConn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+ require.Equal(t, coderws.MessageText, msgType)
+ require.JSONEq(t, `{"type":"response.output_text.delta","delta":"hello"}`, string(payload))
+
+ require.NoError(t, <-serverErrCh)
+ require.True(t, afterTurnCalled)
+ require.NoError(t, afterTurnErr)
+ require.NotNil(t, afterTurnResult)
+ require.Equal(t, "", afterTurnResult.RequestID)
+ require.Equal(t, 0, afterTurnResult.Usage.InputTokens)
+ require.Equal(t, "", afterTurnResult.TerminalEventType)
+}
+
+func TestProxyResponsesWebSocketV2Passthrough_TurnMetricsPropagated(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ upstream := &passthroughUpstreamConn{
+ readDelay: 15 * time.Millisecond,
+ reads: []struct {
+ msgType coderws.MessageType
+ payload []byte
+ }{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.output_text.delta","response_id":"resp_metrics","delta":"hello"}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_metrics","usage":{"input_tokens":4,"output_tokens":2}}}`),
+ },
+ },
+ }
+ cfg := buildIngressPolicyTestConfig()
+ svc := buildIngressPolicyTestService(cfg)
+ svc.openaiWSPassthroughDialer = &passthroughDialerStub{
+ conn: upstream,
+ statusCode: 101,
+ headers: http.Header{
+ "X-Request-ID": []string{"req-passthrough"},
+ },
+ }
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ var (
+ afterTurnCalled bool
+ afterTurnResult *OpenAIForwardResult
+ afterTurnErr error
+ )
+ hooks := &OpenAIWSIngressHooks{
+ AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
+ afterTurnCalled = true
+ afterTurnResult = result
+ afterTurnErr = turnErr
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, nil)
+ require.NoError(t, err)
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, err := conn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+
+ ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder())
+ ginCtx.Request = r
+ serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough(
+ r.Context(),
+ ginCtx,
+ conn,
+ account,
+ "sk-test",
+ msgType,
+ firstMessage,
+ hooks,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(
+ dialCtx,
+ "ws"+strings.TrimPrefix(wsServer.URL, "http"),
+ nil,
+ )
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`))
+ cancelWrite()
+ require.NoError(t, err)
+
+ // 读取透传下发事件直到 terminal 到达。
+ readCtx1, cancelRead1 := context.WithTimeout(context.Background(), 3*time.Second)
+ _, _, err = clientConn.Read(readCtx1)
+ cancelRead1()
+ require.NoError(t, err)
+ readCtx2, cancelRead2 := context.WithTimeout(context.Background(), 3*time.Second)
+ _, payload2, err := clientConn.Read(readCtx2)
+ cancelRead2()
+ require.NoError(t, err)
+ require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_metrics","usage":{"input_tokens":4,"output_tokens":2}}}`, string(payload2))
+
+ require.NoError(t, <-serverErrCh)
+ require.True(t, afterTurnCalled)
+ require.NoError(t, afterTurnErr)
+ require.NotNil(t, afterTurnResult)
+ require.NotNil(t, afterTurnResult.FirstTokenMs)
+ require.GreaterOrEqual(t, *afterTurnResult.FirstTokenMs, 0)
+ require.Greater(t, afterTurnResult.Duration.Milliseconds(), int64(0))
+}
+
+func TestProxyResponsesWebSocketV2Passthrough_MultiTurnAfterTurnOnTerminal(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ upstream := &passthroughUpstreamConn{
+ reads: []struct {
+ msgType coderws.MessageType
+ payload []byte
+ }{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":3,"output_tokens":2}}}`),
+ },
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":1,"output_tokens":4}}}`),
+ },
+ },
+ }
+ cfg := buildIngressPolicyTestConfig()
+ svc := buildIngressPolicyTestService(cfg)
+ svc.openaiWSPassthroughDialer = &passthroughDialerStub{
+ conn: upstream,
+ statusCode: 101,
+ headers: http.Header{
+ "X-Request-ID": []string{"req-passthrough"},
+ },
+ }
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ type afterTurnCall struct {
+ turn int
+ result *OpenAIForwardResult
+ err error
+ }
+ var (
+ callsMu sync.Mutex
+ calls = make([]afterTurnCall, 0, 3)
+ )
+ hooks := &OpenAIWSIngressHooks{
+ AfterTurn: func(turn int, result *OpenAIForwardResult, turnErr error) {
+ callsMu.Lock()
+ defer callsMu.Unlock()
+ calls = append(calls, afterTurnCall{
+ turn: turn,
+ result: result,
+ err: turnErr,
+ })
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, nil)
+ require.NoError(t, err)
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, err := conn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+
+ ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder())
+ ginCtx.Request = r
+ serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough(
+ r.Context(),
+ ginCtx,
+ conn,
+ account,
+ "sk-test",
+ msgType,
+ firstMessage,
+ hooks,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(
+ dialCtx,
+ "ws"+strings.TrimPrefix(wsServer.URL, "http"),
+ nil,
+ )
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx1, cancelRead1 := context.WithTimeout(context.Background(), 3*time.Second)
+ _, payload1, err := clientConn.Read(readCtx1)
+ cancelRead1()
+ require.NoError(t, err)
+ require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_turn_1","usage":{"input_tokens":3,"output_tokens":2}}}`, string(payload1))
+
+ readCtx2, cancelRead2 := context.WithTimeout(context.Background(), 3*time.Second)
+ _, payload2, err := clientConn.Read(readCtx2)
+ cancelRead2()
+ require.NoError(t, err)
+ require.JSONEq(t, `{"type":"response.failed","response":{"id":"resp_turn_2","usage":{"input_tokens":1,"output_tokens":4}}}`, string(payload2))
+
+ require.NoError(t, <-serverErrCh)
+
+ callsMu.Lock()
+ gotCalls := append([]afterTurnCall(nil), calls...)
+ callsMu.Unlock()
+ require.Len(t, gotCalls, 2)
+ require.Equal(t, 1, gotCalls[0].turn)
+ require.NoError(t, gotCalls[0].err)
+ require.NotNil(t, gotCalls[0].result)
+ require.Equal(t, "resp_turn_1", gotCalls[0].result.RequestID)
+ require.Equal(t, 3, gotCalls[0].result.Usage.InputTokens)
+ require.Equal(t, 2, gotCalls[0].result.Usage.OutputTokens)
+ require.Equal(t, 2, gotCalls[1].turn)
+ require.NoError(t, gotCalls[1].err)
+ require.NotNil(t, gotCalls[1].result)
+ require.Equal(t, "resp_turn_2", gotCalls[1].result.RequestID)
+ require.Equal(t, 1, gotCalls[1].result.Usage.InputTokens)
+ require.Equal(t, 4, gotCalls[1].result.Usage.OutputTokens)
+}
+
+func TestProxyResponsesWebSocketV2Passthrough_ErrorAfterTurn_NoPartialDuplicate(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ upstream := &passthroughUpstreamConn{
+ reads: []struct {
+ msgType coderws.MessageType
+ payload []byte
+ }{
+ {
+ msgType: coderws.MessageText,
+ payload: []byte(`{"type":"response.completed","response":{"id":"resp_turn_ok","usage":{"input_tokens":2,"output_tokens":1}}}`),
+ },
+ },
+ readErr: errors.New("upstream stream broken"),
+ }
+ cfg := buildIngressPolicyTestConfig()
+ svc := buildIngressPolicyTestService(cfg)
+ svc.openaiWSPassthroughDialer = &passthroughDialerStub{
+ conn: upstream,
+ statusCode: 101,
+ headers: http.Header{
+ "X-Request-ID": []string{"req-passthrough"},
+ },
+ }
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ type afterTurnCall struct {
+ turn int
+ result *OpenAIForwardResult
+ err error
+ }
+ var (
+ callsMu sync.Mutex
+ calls = make([]afterTurnCall, 0, 3)
+ )
+ hooks := &OpenAIWSIngressHooks{
+ AfterTurn: func(turn int, result *OpenAIForwardResult, turnErr error) {
+ callsMu.Lock()
+ defer callsMu.Unlock()
+ calls = append(calls, afterTurnCall{
+ turn: turn,
+ result: result,
+ err: turnErr,
+ })
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, nil)
+ require.NoError(t, err)
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, err := conn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+
+ ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder())
+ ginCtx.Request = r
+ serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough(
+ r.Context(),
+ ginCtx,
+ conn,
+ account,
+ "sk-test",
+ msgType,
+ firstMessage,
+ hooks,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(
+ dialCtx,
+ "ws"+strings.TrimPrefix(wsServer.URL, "http"),
+ nil,
+ )
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`))
+ cancelWrite()
+ require.NoError(t, err)
+
+ readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
+ _, payload, err := clientConn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+ require.JSONEq(t, `{"type":"response.completed","response":{"id":"resp_turn_ok","usage":{"input_tokens":2,"output_tokens":1}}}`, string(payload))
+
+ serverErr := <-serverErrCh
+ require.Error(t, serverErr)
+ require.Contains(t, serverErr.Error(), "upstream stream broken")
+
+ callsMu.Lock()
+ gotCalls := append([]afterTurnCall(nil), calls...)
+ callsMu.Unlock()
+ require.Len(t, gotCalls, 2)
+ require.Equal(t, 1, gotCalls[0].turn)
+ require.NoError(t, gotCalls[0].err)
+ require.NotNil(t, gotCalls[0].result)
+ require.Equal(t, "resp_turn_ok", gotCalls[0].result.RequestID)
+ require.Equal(t, 2, gotCalls[1].turn)
+ require.Error(t, gotCalls[1].err)
+ require.Nil(t, gotCalls[1].result)
+
+ partial, ok := OpenAIWSIngressTurnPartialResult(gotCalls[1].err)
+ require.False(t, ok)
+ require.Nil(t, partial)
+}
+
+func TestProxyResponsesWebSocketV2Passthrough_UpstreamWithoutFrameConn(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ svc := buildIngressPolicyTestService(cfg)
+ svc.openaiWSPassthroughDialer = &passthroughDialerStub{
+ conn: &openAIWSFakeConn{},
+ }
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, nil)
+ require.NoError(t, err)
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+ readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, err := conn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+
+ ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder())
+ ginCtx.Request = r
+ err = svc.proxyResponsesWebSocketV2Passthrough(
+ r.Context(),
+ ginCtx,
+ conn,
+ account,
+ "sk-test",
+ msgType,
+ firstMessage,
+ nil,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "does not support frame relay")
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(
+ dialCtx,
+ "ws"+strings.TrimPrefix(wsServer.URL, "http"),
+ nil,
+ )
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`))
+ cancelWrite()
+ require.NoError(t, err)
+}
+
+func TestProxyResponsesWebSocketV2Passthrough_BuildWSURLError(t *testing.T) {
+ cfg := buildIngressPolicyTestConfig()
+ svc := buildIngressPolicyTestService(cfg)
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+ account.Credentials["base_url"] = "http://[::1"
+ err := svc.proxyResponsesWebSocketV2Passthrough(
+ context.Background(),
+ nil,
+ &coderws.Conn{},
+ account,
+ "sk-test",
+ coderws.MessageText,
+ []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`),
+ nil,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "build ws url")
+}
+
+func TestProxyResponsesWebSocketV2Passthrough_IdleTimeoutErrorPath(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 1
+ svc := buildIngressPolicyTestService(cfg)
+ svc.openaiWSPassthroughDialer = &passthroughDialerStub{
+ conn: &passthroughBlockingUpstreamConn{},
+ }
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
+ })
+
+ var (
+ afterTurnCalled bool
+ afterTurnErr error
+ )
+ hooks := &OpenAIWSIngressHooks{
+ AfterTurn: func(_ int, _ *OpenAIForwardResult, turnErr error) {
+ afterTurnCalled = true
+ afterTurnErr = turnErr
+ },
+ }
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, nil)
+ require.NoError(t, err)
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ readCtx, cancelRead := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, firstMessage, err := conn.Read(readCtx)
+ cancelRead()
+ require.NoError(t, err)
+
+ ginCtx, _ := gin.CreateTestContext(httptest.NewRecorder())
+ ginCtx.Request = r
+ serverErrCh <- svc.proxyResponsesWebSocketV2Passthrough(
+ r.Context(),
+ ginCtx,
+ conn,
+ account,
+ "sk-test",
+ msgType,
+ firstMessage,
+ hooks,
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ )
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(
+ dialCtx,
+ "ws"+strings.TrimPrefix(wsServer.URL, "http"),
+ nil,
+ )
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.3-codex","input":[]}`))
+ cancelWrite()
+ require.NoError(t, err)
+
+ serverErr := <-serverErrCh
+ require.Error(t, serverErr)
+ require.Contains(t, serverErr.Error(), "idle timeout")
+ require.True(t, afterTurnCalled)
+ require.Error(t, afterTurnErr)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, afterTurnErr, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ require.Contains(t, closeErr.Reason(), "idle timeout")
+ require.ErrorIs(t, afterTurnErr, openaiwsv2.ErrRelayIdleTimeout)
+ require.NotErrorIs(t, afterTurnErr, context.DeadlineExceeded)
+}
+
+func TestMapOpenAIWSPassthroughDialError_StatusMapping(t *testing.T) {
+ t.Parallel()
+
+ svc := &OpenAIGatewayService{}
+
+ t.Run("nil error returns nil", func(t *testing.T) {
+ err := svc.mapOpenAIWSPassthroughDialError(nil, 0, nil)
+ require.NoError(t, err)
+ })
+
+ t.Run("401 maps to policy violation", func(t *testing.T) {
+ err := svc.mapOpenAIWSPassthroughDialError(errors.New("unauthorized"), 401, nil)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, err, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ })
+
+ t.Run("403 maps to policy violation", func(t *testing.T) {
+ err := svc.mapOpenAIWSPassthroughDialError(errors.New("forbidden"), 403, nil)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, err, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ })
+
+ t.Run("429 maps to try again later", func(t *testing.T) {
+ err := svc.mapOpenAIWSPassthroughDialError(errors.New("rate limited"), 429, nil)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, err, &closeErr)
+ require.Equal(t, coderws.StatusTryAgainLater, closeErr.StatusCode())
+ })
+
+ t.Run("4xx generic maps to policy violation", func(t *testing.T) {
+ err := svc.mapOpenAIWSPassthroughDialError(errors.New("bad request"), 400, nil)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, err, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ })
+
+ t.Run("5xx wraps as generic error without CloseError", func(t *testing.T) {
+ err := svc.mapOpenAIWSPassthroughDialError(errors.New("server error"), 500, nil)
+ require.Error(t, err)
+ var closeErr *OpenAIWSClientCloseError
+ require.False(t, errors.As(err, &closeErr), "5xx 不应封装为 CloseError")
+ require.Contains(t, err.Error(), "openai ws passthrough dial")
+ })
+
+ t.Run("deadline exceeded maps to try again later", func(t *testing.T) {
+ err := svc.mapOpenAIWSPassthroughDialError(context.DeadlineExceeded, 0, nil)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, err, &closeErr)
+ require.Equal(t, coderws.StatusTryAgainLater, closeErr.StatusCode())
+ })
+
+ t.Run("context canceled returns original error", func(t *testing.T) {
+ err := svc.mapOpenAIWSPassthroughDialError(context.Canceled, 0, nil)
+ require.ErrorIs(t, err, context.Canceled)
+ var closeErr *OpenAIWSClientCloseError
+ require.False(t, errors.As(err, &closeErr), "context.Canceled 不应封装为 CloseError")
+ })
+}
diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go
index b22da7522..225477d52 100644
--- a/backend/internal/service/redeem_service.go
+++ b/backend/internal/service/redeem_service.go
@@ -15,11 +15,12 @@ import (
)
var (
- ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found")
- ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used")
- ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance")
- ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later")
- ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again")
+ ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found")
+ ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used")
+ ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance")
+ ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later")
+ ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again")
+ ErrBalanceCacheNotFound = errors.New("balance cache key not found")
)
const (
diff --git a/backend/internal/service/setting_bulk_edit_template.go b/backend/internal/service/setting_bulk_edit_template.go
new file mode 100644
index 000000000..dd28e1633
--- /dev/null
+++ b/backend/internal/service/setting_bulk_edit_template.go
@@ -0,0 +1,770 @@
+package service
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "sort"
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+const (
+ BulkEditTemplateShareScopePrivate = "private"
+ BulkEditTemplateShareScopeTeam = "team"
+ BulkEditTemplateShareScopeGroups = "groups"
+)
+
+var (
+ ErrBulkEditTemplateNotFound = infraerrors.NotFound("BULK_EDIT_TEMPLATE_NOT_FOUND", "bulk edit template not found")
+ ErrBulkEditTemplateVersionNotFound = infraerrors.NotFound(
+ "BULK_EDIT_TEMPLATE_VERSION_NOT_FOUND",
+ "bulk edit template version not found",
+ )
+ ErrBulkEditTemplateForbidden = infraerrors.Forbidden(
+ "BULK_EDIT_TEMPLATE_FORBIDDEN",
+ "no permission to modify this bulk edit template",
+ )
+ bulkEditTemplateRandRead = rand.Read
+)
+
+type BulkEditTemplate struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ ScopePlatform string `json:"scope_platform"`
+ ScopeType string `json:"scope_type"`
+ ShareScope string `json:"share_scope"`
+ GroupIDs []int64 `json:"group_ids"`
+ State map[string]any `json:"state"`
+ CreatedBy int64 `json:"created_by"`
+ UpdatedBy int64 `json:"updated_by"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+}
+
+type BulkEditTemplateQuery struct {
+ ScopePlatform string
+ ScopeType string
+ ScopeGroupIDs []int64
+ RequesterUserID int64
+}
+
+type BulkEditTemplateVersion struct {
+ VersionID string `json:"version_id"`
+ ShareScope string `json:"share_scope"`
+ GroupIDs []int64 `json:"group_ids"`
+ State map[string]any `json:"state"`
+ UpdatedBy int64 `json:"updated_by"`
+ UpdatedAt int64 `json:"updated_at"`
+}
+
+type BulkEditTemplateVersionQuery struct {
+ TemplateID string
+ ScopeGroupIDs []int64
+ RequesterUserID int64
+}
+
+type BulkEditTemplateUpsertInput struct {
+ ID string
+ Name string
+ ScopePlatform string
+ ScopeType string
+ ShareScope string
+ GroupIDs []int64
+ State map[string]any
+ RequesterUserID int64
+}
+
+type BulkEditTemplateRollbackInput struct {
+ TemplateID string
+ VersionID string
+ ScopeGroupIDs []int64
+ RequesterUserID int64
+}
+
+type bulkEditTemplateLibraryStore struct {
+ Items []bulkEditTemplateStoreItem `json:"items"`
+}
+
+type bulkEditTemplateVersionStoreItem struct {
+ VersionID string `json:"version_id"`
+ ShareScope string `json:"share_scope"`
+ GroupIDs []int64 `json:"group_ids"`
+ State json.RawMessage `json:"state"`
+ UpdatedBy int64 `json:"updated_by"`
+ UpdatedAt int64 `json:"updated_at"`
+}
+
+type bulkEditTemplateStoreItem struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ ScopePlatform string `json:"scope_platform"`
+ ScopeType string `json:"scope_type"`
+ ShareScope string `json:"share_scope"`
+ GroupIDs []int64 `json:"group_ids"`
+ State json.RawMessage `json:"state"`
+ Versions []bulkEditTemplateVersionStoreItem `json:"versions"`
+ CreatedBy int64 `json:"created_by"`
+ UpdatedBy int64 `json:"updated_by"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+}
+
+func (s *SettingService) ListBulkEditTemplates(ctx context.Context, query BulkEditTemplateQuery) ([]BulkEditTemplate, error) {
+ store, err := s.loadBulkEditTemplateLibrary(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ scopePlatform := strings.TrimSpace(strings.ToLower(query.ScopePlatform))
+ scopeType := strings.TrimSpace(strings.ToLower(query.ScopeType))
+ scopeGroupIDs := normalizeBulkEditTemplateGroupIDs(query.ScopeGroupIDs)
+ scopeGroupSet := make(map[int64]struct{}, len(scopeGroupIDs))
+ for _, groupID := range scopeGroupIDs {
+ scopeGroupSet[groupID] = struct{}{}
+ }
+
+ out := make([]BulkEditTemplate, 0, len(store.Items))
+ for idx := range store.Items {
+ item := store.Items[idx]
+ if scopePlatform != "" && item.ScopePlatform != scopePlatform {
+ continue
+ }
+ if scopeType != "" && item.ScopeType != scopeType {
+ continue
+ }
+ if !isBulkEditTemplateVisible(item, query.RequesterUserID, scopeGroupSet) {
+ continue
+ }
+ out = append(out, toBulkEditTemplate(item))
+ }
+
+ sort.Slice(out, func(i, j int) bool {
+ if out[i].UpdatedAt == out[j].UpdatedAt {
+ return out[i].ID < out[j].ID
+ }
+ return out[i].UpdatedAt > out[j].UpdatedAt
+ })
+
+ return out, nil
+}
+
+func (s *SettingService) UpsertBulkEditTemplate(ctx context.Context, input BulkEditTemplateUpsertInput) (*BulkEditTemplate, error) {
+ name := strings.TrimSpace(input.Name)
+ if name == "" {
+ return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template name is required")
+ }
+ if input.RequesterUserID <= 0 {
+ return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized")
+ }
+
+ scopePlatform := strings.TrimSpace(strings.ToLower(input.ScopePlatform))
+ scopeType := strings.TrimSpace(strings.ToLower(input.ScopeType))
+ if scopePlatform == "" || scopeType == "" {
+ return nil, infraerrors.BadRequest(
+ "BULK_EDIT_TEMPLATE_INVALID_INPUT",
+ "scope_platform and scope_type are required",
+ )
+ }
+
+ shareScope, shareScopeErr := validateBulkEditTemplateShareScope(input.ShareScope)
+ if shareScopeErr != nil {
+ return nil, shareScopeErr
+ }
+
+ groupIDs := normalizeBulkEditTemplateGroupIDs(input.GroupIDs)
+ if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 {
+ return nil, infraerrors.BadRequest(
+ "BULK_EDIT_TEMPLATE_INVALID_INPUT",
+ "group_ids is required when share_scope=groups",
+ )
+ }
+
+ stateRaw, err := json.Marshal(input.State)
+ if err != nil {
+ return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "invalid template state")
+ }
+ if len(stateRaw) == 0 || string(stateRaw) == "null" {
+ stateRaw = json.RawMessage("{}")
+ }
+
+ store, err := s.loadBulkEditTemplateLibrary(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ templateID := strings.TrimSpace(input.ID)
+ matchIndex := -1
+ if templateID != "" {
+ for idx := range store.Items {
+ if store.Items[idx].ID == templateID {
+ matchIndex = idx
+ break
+ }
+ }
+ if matchIndex < 0 {
+ return nil, ErrBulkEditTemplateNotFound
+ }
+ }
+
+ if matchIndex < 0 && templateID == "" {
+ for idx := range store.Items {
+ item := store.Items[idx]
+ if item.ScopePlatform != scopePlatform || item.ScopeType != scopeType {
+ continue
+ }
+ if !strings.EqualFold(strings.TrimSpace(item.Name), name) {
+ continue
+ }
+ if !canModifyBulkEditTemplate(item, input.RequesterUserID) {
+ continue
+ }
+ matchIndex = idx
+ break
+ }
+ }
+
+ nowMS := time.Now().UnixMilli()
+ if matchIndex >= 0 {
+ item := store.Items[matchIndex]
+ if !canModifyBulkEditTemplate(item, input.RequesterUserID) {
+ return nil, ErrBulkEditTemplateForbidden
+ }
+
+ previousVersion := snapshotBulkEditTemplateVersion(item)
+ item.Versions = append(item.Versions, previousVersion)
+ item.Name = name
+ item.ScopePlatform = scopePlatform
+ item.ScopeType = scopeType
+ item.ShareScope = shareScope
+ item.GroupIDs = groupIDs
+ item.State = cloneBulkEditTemplateStateRaw(stateRaw)
+ if item.CreatedBy <= 0 {
+ item.CreatedBy = input.RequesterUserID
+ }
+ if item.CreatedAt <= 0 {
+ item.CreatedAt = nowMS
+ }
+ item.UpdatedBy = input.RequesterUserID
+ item.UpdatedAt = nowMS
+ store.Items[matchIndex] = item
+
+ if err := s.persistBulkEditTemplateLibrary(ctx, store); err != nil {
+ return nil, err
+ }
+ output := toBulkEditTemplate(item)
+ return &output, nil
+ }
+
+ if templateID == "" {
+ templateID = generateBulkEditTemplateID()
+ }
+
+ created := bulkEditTemplateStoreItem{
+ ID: templateID,
+ Name: name,
+ ScopePlatform: scopePlatform,
+ ScopeType: scopeType,
+ ShareScope: shareScope,
+ GroupIDs: groupIDs,
+ State: cloneBulkEditTemplateStateRaw(stateRaw),
+ Versions: []bulkEditTemplateVersionStoreItem{},
+ CreatedBy: input.RequesterUserID,
+ UpdatedBy: input.RequesterUserID,
+ CreatedAt: nowMS,
+ UpdatedAt: nowMS,
+ }
+ store.Items = append(store.Items, created)
+
+ if err := s.persistBulkEditTemplateLibrary(ctx, store); err != nil {
+ return nil, err
+ }
+
+ output := toBulkEditTemplate(created)
+ return &output, nil
+}
+
+func (s *SettingService) DeleteBulkEditTemplate(ctx context.Context, templateID string, requesterUserID int64) error {
+ id := strings.TrimSpace(templateID)
+ if id == "" {
+ return infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template id is required")
+ }
+ if requesterUserID <= 0 {
+ return infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized")
+ }
+
+ store, err := s.loadBulkEditTemplateLibrary(ctx)
+ if err != nil {
+ return err
+ }
+
+ idx := -1
+ for index := range store.Items {
+ if store.Items[index].ID == id {
+ idx = index
+ break
+ }
+ }
+ if idx < 0 {
+ return ErrBulkEditTemplateNotFound
+ }
+
+ target := store.Items[idx]
+ if target.ShareScope == BulkEditTemplateShareScopePrivate && target.CreatedBy > 0 && target.CreatedBy != requesterUserID {
+ return ErrBulkEditTemplateForbidden
+ }
+
+ store.Items = append(store.Items[:idx], store.Items[idx+1:]...)
+ return s.persistBulkEditTemplateLibrary(ctx, store)
+}
+
+func (s *SettingService) ListBulkEditTemplateVersions(
+ ctx context.Context,
+ query BulkEditTemplateVersionQuery,
+) ([]BulkEditTemplateVersion, error) {
+ templateID := strings.TrimSpace(query.TemplateID)
+ if templateID == "" {
+ return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template id is required")
+ }
+ if query.RequesterUserID <= 0 {
+ return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized")
+ }
+
+ store, err := s.loadBulkEditTemplateLibrary(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ scopeGroupSet := toBulkEditTemplateScopeGroupSet(query.ScopeGroupIDs)
+ target := findBulkEditTemplateStoreItemByID(store.Items, templateID)
+ if target == nil {
+ return nil, ErrBulkEditTemplateNotFound
+ }
+ if !isBulkEditTemplateVisible(*target, query.RequesterUserID, scopeGroupSet) {
+ return nil, ErrBulkEditTemplateForbidden
+ }
+
+ versions := make([]BulkEditTemplateVersion, 0, len(target.Versions))
+ for idx := range target.Versions {
+ versions = append(versions, toBulkEditTemplateVersion(target.Versions[idx]))
+ }
+
+ sort.Slice(versions, func(i, j int) bool {
+ if versions[i].UpdatedAt == versions[j].UpdatedAt {
+ return versions[i].VersionID < versions[j].VersionID
+ }
+ return versions[i].UpdatedAt > versions[j].UpdatedAt
+ })
+
+ return versions, nil
+}
+
+func (s *SettingService) RollbackBulkEditTemplate(
+ ctx context.Context,
+ input BulkEditTemplateRollbackInput,
+) (*BulkEditTemplate, error) {
+ templateID := strings.TrimSpace(input.TemplateID)
+ versionID := strings.TrimSpace(input.VersionID)
+ if templateID == "" || versionID == "" {
+ return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template_id and version_id are required")
+ }
+ if input.RequesterUserID <= 0 {
+ return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized")
+ }
+
+ store, err := s.loadBulkEditTemplateLibrary(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ scopeGroupSet := toBulkEditTemplateScopeGroupSet(input.ScopeGroupIDs)
+ templateIndex := findBulkEditTemplateStoreItemIndexByID(store.Items, templateID)
+ if templateIndex < 0 {
+ return nil, ErrBulkEditTemplateNotFound
+ }
+
+ item := store.Items[templateIndex]
+ if !isBulkEditTemplateVisible(item, input.RequesterUserID, scopeGroupSet) {
+ return nil, ErrBulkEditTemplateForbidden
+ }
+
+ versionIndex := findBulkEditTemplateVersionIndexByID(item.Versions, versionID)
+ if versionIndex < 0 {
+ return nil, ErrBulkEditTemplateVersionNotFound
+ }
+
+ targetVersion := item.Versions[versionIndex]
+ previousVersion := snapshotBulkEditTemplateVersion(item)
+ item.Versions = append(item.Versions, previousVersion)
+ item.ShareScope = targetVersion.ShareScope
+ item.GroupIDs = append([]int64(nil), targetVersion.GroupIDs...)
+ item.State = cloneBulkEditTemplateStateRaw(targetVersion.State)
+ item.UpdatedBy = input.RequesterUserID
+ item.UpdatedAt = time.Now().UnixMilli()
+
+ store.Items[templateIndex] = item
+ if persistErr := s.persistBulkEditTemplateLibrary(ctx, store); persistErr != nil {
+ return nil, persistErr
+ }
+
+ output := toBulkEditTemplate(item)
+ return &output, nil
+}
+
+func (s *SettingService) loadBulkEditTemplateLibrary(ctx context.Context) (*bulkEditTemplateLibraryStore, error) {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyBulkEditTemplateLibrary)
+ if err != nil {
+ if errors.Is(err, ErrSettingNotFound) {
+ return &bulkEditTemplateLibraryStore{}, nil
+ }
+ return nil, fmt.Errorf("get bulk edit template library: %w", err)
+ }
+
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return &bulkEditTemplateLibraryStore{}, nil
+ }
+
+ store := bulkEditTemplateLibraryStore{}
+ if err := json.Unmarshal([]byte(raw), &store); err != nil {
+ return nil, fmt.Errorf("parse bulk edit template library: %w", err)
+ }
+
+ normalized := normalizeBulkEditTemplateLibraryStore(store)
+ return &normalized, nil
+}
+
+func (s *SettingService) persistBulkEditTemplateLibrary(ctx context.Context, store *bulkEditTemplateLibraryStore) error {
+ if store == nil {
+ return infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template library cannot be nil")
+ }
+
+ normalized := normalizeBulkEditTemplateLibraryStore(*store)
+ data, err := json.Marshal(normalized)
+ if err != nil {
+ return fmt.Errorf("marshal bulk edit template library: %w", err)
+ }
+
+ return s.settingRepo.Set(ctx, SettingKeyBulkEditTemplateLibrary, string(data))
+}
+
+func validateBulkEditTemplateShareScope(scope string) (string, error) {
+ normalized := strings.TrimSpace(strings.ToLower(scope))
+ if normalized == "" {
+ return BulkEditTemplateShareScopePrivate, nil
+ }
+ switch normalized {
+ case BulkEditTemplateShareScopePrivate,
+ BulkEditTemplateShareScopeTeam,
+ BulkEditTemplateShareScopeGroups:
+ return normalized, nil
+ default:
+ return "", infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "invalid share_scope")
+ }
+}
+
+func normalizeBulkEditTemplateLibraryStore(store bulkEditTemplateLibraryStore) bulkEditTemplateLibraryStore {
+ if len(store.Items) == 0 {
+ return bulkEditTemplateLibraryStore{Items: []bulkEditTemplateStoreItem{}}
+ }
+
+ nowMS := time.Now().UnixMilli()
+ items := make([]bulkEditTemplateStoreItem, 0, len(store.Items))
+ seenID := make(map[string]struct{}, len(store.Items))
+
+ for _, raw := range store.Items {
+ name := strings.TrimSpace(raw.Name)
+ scopePlatform := strings.TrimSpace(strings.ToLower(raw.ScopePlatform))
+ scopeType := strings.TrimSpace(strings.ToLower(raw.ScopeType))
+ if name == "" || scopePlatform == "" || scopeType == "" {
+ continue
+ }
+
+ shareScope := normalizeBulkEditTemplateShareScopeOrDefault(raw.ShareScope)
+ groupIDs := normalizeBulkEditTemplateGroupIDs(raw.GroupIDs)
+ if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 {
+ shareScope = BulkEditTemplateShareScopePrivate
+ }
+
+ templateID := strings.TrimSpace(raw.ID)
+ if templateID == "" {
+ templateID = generateBulkEditTemplateID()
+ }
+ if _, exists := seenID[templateID]; exists {
+ continue
+ }
+ seenID[templateID] = struct{}{}
+
+ state := raw.State
+ if len(state) == 0 || string(state) == "null" {
+ state = json.RawMessage("{}")
+ }
+
+ createdAt := raw.CreatedAt
+ if createdAt <= 0 {
+ createdAt = nowMS
+ }
+ updatedAt := raw.UpdatedAt
+ if updatedAt <= 0 {
+ updatedAt = createdAt
+ }
+
+ items = append(items, bulkEditTemplateStoreItem{
+ ID: templateID,
+ Name: name,
+ ScopePlatform: scopePlatform,
+ ScopeType: scopeType,
+ ShareScope: shareScope,
+ GroupIDs: groupIDs,
+ State: cloneBulkEditTemplateStateRaw(state),
+ Versions: normalizeBulkEditTemplateVersionStoreItems(raw.Versions),
+ CreatedBy: raw.CreatedBy,
+ UpdatedBy: raw.UpdatedBy,
+ CreatedAt: createdAt,
+ UpdatedAt: updatedAt,
+ })
+ }
+
+ return bulkEditTemplateLibraryStore{Items: items}
+}
+
+func toBulkEditTemplate(item bulkEditTemplateStoreItem) BulkEditTemplate {
+ state := map[string]any{}
+ if err := json.Unmarshal(item.State, &state); err != nil || state == nil {
+ state = map[string]any{}
+ }
+
+ return BulkEditTemplate{
+ ID: item.ID,
+ Name: item.Name,
+ ScopePlatform: item.ScopePlatform,
+ ScopeType: item.ScopeType,
+ ShareScope: item.ShareScope,
+ GroupIDs: append([]int64(nil), item.GroupIDs...),
+ State: state,
+ CreatedBy: item.CreatedBy,
+ UpdatedBy: item.UpdatedBy,
+ CreatedAt: item.CreatedAt,
+ UpdatedAt: item.UpdatedAt,
+ }
+}
+
+func toBulkEditTemplateVersion(item bulkEditTemplateVersionStoreItem) BulkEditTemplateVersion {
+ state := map[string]any{}
+ if err := json.Unmarshal(item.State, &state); err != nil || state == nil {
+ state = map[string]any{}
+ }
+
+ return BulkEditTemplateVersion{
+ VersionID: item.VersionID,
+ ShareScope: item.ShareScope,
+ GroupIDs: append([]int64(nil), item.GroupIDs...),
+ State: state,
+ UpdatedBy: item.UpdatedBy,
+ UpdatedAt: item.UpdatedAt,
+ }
+}
+
+func normalizeBulkEditTemplateVersionStoreItems(
+ rawVersions []bulkEditTemplateVersionStoreItem,
+) []bulkEditTemplateVersionStoreItem {
+ if len(rawVersions) == 0 {
+ return []bulkEditTemplateVersionStoreItem{}
+ }
+
+ nowMS := time.Now().UnixMilli()
+ seen := make(map[string]struct{}, len(rawVersions))
+ out := make([]bulkEditTemplateVersionStoreItem, 0, len(rawVersions))
+ for _, raw := range rawVersions {
+ versionID := strings.TrimSpace(raw.VersionID)
+ if versionID == "" {
+ versionID = generateBulkEditTemplateVersionID()
+ }
+ if _, exists := seen[versionID]; exists {
+ continue
+ }
+ seen[versionID] = struct{}{}
+
+ shareScope := normalizeBulkEditTemplateShareScopeOrDefault(raw.ShareScope)
+ groupIDs := normalizeBulkEditTemplateGroupIDs(raw.GroupIDs)
+ if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 {
+ shareScope = BulkEditTemplateShareScopePrivate
+ }
+
+ updatedAt := raw.UpdatedAt
+ if updatedAt <= 0 {
+ updatedAt = nowMS
+ }
+
+ out = append(out, bulkEditTemplateVersionStoreItem{
+ VersionID: versionID,
+ ShareScope: shareScope,
+ GroupIDs: groupIDs,
+ State: cloneBulkEditTemplateStateRaw(raw.State),
+ UpdatedBy: raw.UpdatedBy,
+ UpdatedAt: updatedAt,
+ })
+ }
+
+ sort.Slice(out, func(i, j int) bool {
+ if out[i].UpdatedAt == out[j].UpdatedAt {
+ return out[i].VersionID < out[j].VersionID
+ }
+ return out[i].UpdatedAt > out[j].UpdatedAt
+ })
+ return out
+}
+
+func snapshotBulkEditTemplateVersion(item bulkEditTemplateStoreItem) bulkEditTemplateVersionStoreItem {
+ updatedAt := item.UpdatedAt
+ if updatedAt <= 0 {
+ updatedAt = time.Now().UnixMilli()
+ }
+ return bulkEditTemplateVersionStoreItem{
+ VersionID: generateBulkEditTemplateVersionID(),
+ ShareScope: normalizeBulkEditTemplateShareScopeOrDefault(item.ShareScope),
+ GroupIDs: normalizeBulkEditTemplateGroupIDs(item.GroupIDs),
+ State: cloneBulkEditTemplateStateRaw(item.State),
+ UpdatedBy: item.UpdatedBy,
+ UpdatedAt: updatedAt,
+ }
+}
+
+func cloneBulkEditTemplateStateRaw(raw json.RawMessage) json.RawMessage {
+ if len(raw) == 0 || string(raw) == "null" {
+ return json.RawMessage("{}")
+ }
+ cloned := make(json.RawMessage, len(raw))
+ copy(cloned, raw)
+ return cloned
+}
+
+func toBulkEditTemplateScopeGroupSet(raw []int64) map[int64]struct{} {
+ groupIDs := normalizeBulkEditTemplateGroupIDs(raw)
+ scopeGroupSet := make(map[int64]struct{}, len(groupIDs))
+ for _, groupID := range groupIDs {
+ scopeGroupSet[groupID] = struct{}{}
+ }
+ return scopeGroupSet
+}
+
+func findBulkEditTemplateStoreItemByID(
+ items []bulkEditTemplateStoreItem,
+ templateID string,
+) *bulkEditTemplateStoreItem {
+ for idx := range items {
+ if items[idx].ID == templateID {
+ return &items[idx]
+ }
+ }
+ return nil
+}
+
+func findBulkEditTemplateStoreItemIndexByID(items []bulkEditTemplateStoreItem, templateID string) int {
+ for idx := range items {
+ if items[idx].ID == templateID {
+ return idx
+ }
+ }
+ return -1
+}
+
+func findBulkEditTemplateVersionIndexByID(
+ versions []bulkEditTemplateVersionStoreItem,
+ versionID string,
+) int {
+ for idx := range versions {
+ if versions[idx].VersionID == versionID {
+ return idx
+ }
+ }
+ return -1
+}
+
+func canModifyBulkEditTemplate(item bulkEditTemplateStoreItem, requesterUserID int64) bool {
+ if requesterUserID <= 0 {
+ return false
+ }
+ if item.ShareScope != BulkEditTemplateShareScopePrivate {
+ return true
+ }
+ if item.CreatedBy <= 0 {
+ return true
+ }
+ return item.CreatedBy == requesterUserID
+}
+
+func isBulkEditTemplateVisible(
+ item bulkEditTemplateStoreItem,
+ requesterUserID int64,
+ scopeGroupSet map[int64]struct{},
+) bool {
+ switch item.ShareScope {
+ case BulkEditTemplateShareScopeTeam:
+ return true
+ case BulkEditTemplateShareScopeGroups:
+ if len(scopeGroupSet) == 0 || len(item.GroupIDs) == 0 {
+ return false
+ }
+ for _, groupID := range item.GroupIDs {
+ if _, ok := scopeGroupSet[groupID]; ok {
+ return true
+ }
+ }
+ return false
+ default:
+ return requesterUserID > 0 && item.CreatedBy == requesterUserID
+ }
+}
+
+func normalizeBulkEditTemplateShareScopeOrDefault(scope string) string {
+ normalized, err := validateBulkEditTemplateShareScope(scope)
+ if err != nil {
+ return BulkEditTemplateShareScopePrivate
+ }
+ return normalized
+}
+
+func normalizeBulkEditTemplateGroupIDs(raw []int64) []int64 {
+ if len(raw) == 0 {
+ return []int64{}
+ }
+
+ seen := make(map[int64]struct{}, len(raw))
+ groupIDs := make([]int64, 0, len(raw))
+ for _, groupID := range raw {
+ if groupID <= 0 {
+ continue
+ }
+ if _, exists := seen[groupID]; exists {
+ continue
+ }
+ seen[groupID] = struct{}{}
+ groupIDs = append(groupIDs, groupID)
+ }
+ sort.Slice(groupIDs, func(i, j int) bool {
+ return groupIDs[i] < groupIDs[j]
+ })
+ return groupIDs
+}
+
+func generateBulkEditTemplateID() string {
+ buf := make([]byte, 12)
+ if _, err := bulkEditTemplateRandRead(buf); err == nil {
+ return "btpl-" + hex.EncodeToString(buf)
+ }
+ return fmt.Sprintf("btpl-%d", time.Now().UnixNano())
+}
+
+func generateBulkEditTemplateVersionID() string {
+ buf := make([]byte, 12)
+ if _, err := bulkEditTemplateRandRead(buf); err == nil {
+ return "btplv-" + hex.EncodeToString(buf)
+ }
+ return fmt.Sprintf("btplv-%d", time.Now().UnixNano())
+}
diff --git a/backend/internal/service/setting_bulk_edit_template_test.go b/backend/internal/service/setting_bulk_edit_template_test.go
new file mode 100644
index 000000000..c3a0351fa
--- /dev/null
+++ b/backend/internal/service/setting_bulk_edit_template_test.go
@@ -0,0 +1,860 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "testing"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+type bulkTemplateSettingRepoStub struct {
+ values map[string]string
+}
+
+func newBulkTemplateSettingRepoStub() *bulkTemplateSettingRepoStub {
+ return &bulkTemplateSettingRepoStub{values: map[string]string{}}
+}
+
+func (s *bulkTemplateSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ value, err := s.GetValue(ctx, key)
+ if err != nil {
+ return nil, err
+ }
+ return &Setting{Key: key, Value: value}, nil
+}
+
+func (s *bulkTemplateSettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *bulkTemplateSettingRepoStub) Set(ctx context.Context, key, value string) error {
+ s.values[key] = value
+ return nil
+}
+
+func (s *bulkTemplateSettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *bulkTemplateSettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ for key, value := range settings {
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *bulkTemplateSettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *bulkTemplateSettingRepoStub) Delete(ctx context.Context, key string) error {
+ delete(s.values, key)
+ return nil
+}
+
+type bulkTemplateFailingRepoStub struct{}
+
+func (s *bulkTemplateFailingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ return nil, errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ return "", errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) Set(ctx context.Context, key, value string) error {
+ return errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ return nil, errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ return errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ return nil, errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) Delete(ctx context.Context, key string) error {
+ return errors.New("boom")
+}
+
+func TestSettingServiceBulkEditTemplate_UpsertAndPrivateVisibility(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "OpenAI OAuth Baseline",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"enableOpenAIWSMode": true},
+ RequesterUserID: 11,
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, created.ID)
+ require.Equal(t, BulkEditTemplateShareScopePrivate, created.ShareScope)
+
+ listByOwner, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 11,
+ })
+ require.NoError(t, err)
+ require.Len(t, listByOwner, 1)
+ require.Equal(t, created.ID, listByOwner[0].ID)
+ require.Equal(t, true, listByOwner[0].State["enableOpenAIWSMode"])
+
+ listByOther, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 22,
+ })
+ require.NoError(t, err)
+ require.Len(t, listByOther, 0)
+}
+
+func TestSettingServiceBulkEditTemplate_GroupsVisibilityByScopeGroupIDs(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Shared By Group",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeGroups,
+ GroupIDs: []int64{10, 20},
+ State: map[string]any{"enableBaseUrl": true},
+ RequesterUserID: 9,
+ })
+ require.NoError(t, err)
+
+ invisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ScopeGroupIDs: []int64{99},
+ RequesterUserID: 8,
+ })
+ require.NoError(t, err)
+ require.Len(t, invisible, 0)
+
+ visible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ScopeGroupIDs: []int64{20, 100},
+ RequesterUserID: 8,
+ })
+ require.NoError(t, err)
+ require.Len(t, visible, 1)
+ require.Equal(t, BulkEditTemplateShareScopeGroups, visible[0].ShareScope)
+ require.Equal(t, []int64{10, 20}, visible[0].GroupIDs)
+}
+
+func TestSettingServiceBulkEditTemplate_UpsertByNameReplacesSameScopeRecord(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ first, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Team Baseline",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: map[string]any{"priority": 1},
+ RequesterUserID: 7,
+ })
+ require.NoError(t, err)
+
+ second, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "team baseline",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: map[string]any{"priority": 9},
+ RequesterUserID: 7,
+ })
+ require.NoError(t, err)
+ require.Equal(t, first.ID, second.ID)
+
+ items, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 3,
+ })
+ require.NoError(t, err)
+ require.Len(t, items, 1)
+ require.EqualValues(t, 9, items[0].State["priority"])
+}
+
+func TestSettingServiceBulkEditTemplate_DeletePermissionAndNotFound(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Private Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"status": "active"},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+
+ err = svc.DeleteBulkEditTemplate(context.Background(), created.ID, 2)
+ require.Error(t, err)
+ require.True(t, infraerrors.IsForbidden(err))
+
+ err = svc.DeleteBulkEditTemplate(context.Background(), created.ID, 1)
+ require.NoError(t, err)
+
+ ownerList, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+ require.Len(t, ownerList, 0)
+
+ err = svc.DeleteBulkEditTemplate(context.Background(), "missing-id", 1)
+ require.Error(t, err)
+ require.True(t, infraerrors.IsNotFound(err))
+}
+
+func TestSettingServiceBulkEditTemplate_ValidatesInput(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Groups No IDs",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeGroups,
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Invalid Scope",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: "invalid",
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+}
+
+func TestSettingServiceBulkEditTemplate_CoversInternalHelpers(t *testing.T) {
+ store := normalizeBulkEditTemplateLibraryStore(bulkEditTemplateLibraryStore{
+ Items: []bulkEditTemplateStoreItem{
+ {
+ ID: "same-id",
+ Name: " One ",
+ ScopePlatform: "OPENAI",
+ ScopeType: "OAUTH",
+ ShareScope: BulkEditTemplateShareScopeGroups,
+ GroupIDs: []int64{},
+ State: nil,
+ },
+ {
+ ID: "same-id",
+ Name: "Duplicate ID",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ },
+ {
+ ID: "",
+ Name: "Two",
+ ScopePlatform: "openai",
+ ScopeType: "apikey",
+ ShareScope: "invalid",
+ GroupIDs: []int64{5, 5, 1},
+ State: []byte(`{"ok":true}`),
+ },
+ {
+ ID: "invalid-entry",
+ Name: "",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ },
+ },
+ })
+ require.Len(t, store.Items, 2)
+ require.Equal(t, "same-id", store.Items[0].ID)
+ require.Equal(t, BulkEditTemplateShareScopePrivate, store.Items[0].ShareScope)
+ require.Equal(t, []int64{1, 5}, store.Items[1].GroupIDs)
+ require.NotEmpty(t, store.Items[1].ID)
+
+ require.Equal(t, BulkEditTemplateShareScopeTeam, normalizeBulkEditTemplateShareScopeOrDefault("team"))
+ require.Equal(t, BulkEditTemplateShareScopePrivate, normalizeBulkEditTemplateShareScopeOrDefault("bad"))
+ require.Equal(t, []int64{}, normalizeBulkEditTemplateGroupIDs(nil))
+
+ scope, err := validateBulkEditTemplateShareScope("")
+ require.NoError(t, err)
+ require.Equal(t, BulkEditTemplateShareScopePrivate, scope)
+ _, err = validateBulkEditTemplateShareScope("bad")
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ require.True(t, isBulkEditTemplateVisible(
+ bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeTeam},
+ 1,
+ map[int64]struct{}{},
+ ))
+ require.False(t, isBulkEditTemplateVisible(
+ bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeGroups, GroupIDs: []int64{2}},
+ 1,
+ map[int64]struct{}{},
+ ))
+ require.True(t, isBulkEditTemplateVisible(
+ bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeGroups, GroupIDs: []int64{2}},
+ 1,
+ map[int64]struct{}{2: {}},
+ ))
+ require.True(t, isBulkEditTemplateVisible(
+ bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopePrivate, CreatedBy: 9},
+ 9,
+ nil,
+ ))
+ require.False(t, isBulkEditTemplateVisible(
+ bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopePrivate, CreatedBy: 9},
+ 1,
+ nil,
+ ))
+
+ converted := toBulkEditTemplate(bulkEditTemplateStoreItem{
+ ID: "id-1",
+ Name: "Demo",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ GroupIDs: []int64{3},
+ State: []byte(`invalid-json`),
+ })
+ require.Equal(t, map[string]any{}, converted.State)
+}
+
+func TestSettingServiceBulkEditTemplate_LoadPersistBranches(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ err := svc.persistBulkEditTemplateLibrary(context.Background(), nil)
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ repo.values[SettingKeyBulkEditTemplateLibrary] = "{bad-json"
+ store, err := svc.loadBulkEditTemplateLibrary(context.Background())
+ require.Error(t, err)
+ require.Nil(t, store)
+
+ delete(repo.values, SettingKeyBulkEditTemplateLibrary)
+ store, err = svc.loadBulkEditTemplateLibrary(context.Background())
+ require.NoError(t, err)
+ require.NotNil(t, store)
+ require.Empty(t, store.Items)
+}
+
+func TestSettingServiceBulkEditTemplate_UpsertByMismatchedID(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Scoped Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, created.ID)
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ ID: "another-id",
+ Name: "Scoped Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{},
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsNotFound(err))
+
+ err = svc.DeleteBulkEditTemplate(context.Background(), "", 1)
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+}
+
+func TestSettingServiceBulkEditTemplate_PrivateTemplateIsolationAcrossUsers(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ ownerTemplate, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Private Scoped Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"priority": 1},
+ RequesterUserID: 101,
+ })
+ require.NoError(t, err)
+
+ otherTemplate, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Private Scoped Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"priority": 9},
+ RequesterUserID: 202,
+ })
+ require.NoError(t, err)
+ require.NotEqual(t, ownerTemplate.ID, otherTemplate.ID, "不同用户的私有同名模板不应互相覆盖")
+
+ ownerVisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 101,
+ })
+ require.NoError(t, err)
+ require.Len(t, ownerVisible, 1)
+ require.Equal(t, ownerTemplate.ID, ownerVisible[0].ID)
+ require.EqualValues(t, 1, ownerVisible[0].State["priority"])
+
+ otherVisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 202,
+ })
+ require.NoError(t, err)
+ require.Len(t, otherVisible, 1)
+ require.Equal(t, otherTemplate.ID, otherVisible[0].ID)
+ require.EqualValues(t, 9, otherVisible[0].State["priority"])
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ ID: ownerTemplate.ID,
+ Name: ownerTemplate.Name,
+ ScopePlatform: ownerTemplate.ScopePlatform,
+ ScopeType: ownerTemplate.ScopeType,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"priority": 99},
+ RequesterUserID: 202,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsForbidden(err), "非 owner 不允许通过 template ID 修改私有模板")
+}
+
+func TestSettingServiceBulkEditTemplate_UpsertFailsWhenStoredLibraryCorrupted(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ repo.values[SettingKeyBulkEditTemplateLibrary] = "{bad-json"
+ svc := NewSettingService(repo, nil)
+
+ _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Should Fail",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"ok": true},
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+}
+
+func TestSettingServiceBulkEditTemplate_ListFilteringAndSorting(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ store := bulkEditTemplateLibraryStore{
+ Items: []bulkEditTemplateStoreItem{
+ {
+ ID: "b",
+ Name: "Second",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: []byte(`{}`),
+ CreatedBy: 1,
+ UpdatedBy: 1,
+ CreatedAt: 1,
+ UpdatedAt: 100,
+ },
+ {
+ ID: "a",
+ Name: "First",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: []byte(`{}`),
+ CreatedBy: 1,
+ UpdatedBy: 1,
+ CreatedAt: 1,
+ UpdatedAt: 100,
+ },
+ {
+ ID: "skip-type",
+ Name: "Skip Type",
+ ScopePlatform: "openai",
+ ScopeType: "apikey",
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: []byte(`{}`),
+ CreatedBy: 1,
+ UpdatedBy: 1,
+ CreatedAt: 1,
+ UpdatedAt: 999,
+ },
+ {
+ ID: "skip-private",
+ Name: "Skip Private",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: []byte(`{}`),
+ CreatedBy: 99,
+ UpdatedBy: 99,
+ CreatedAt: 1,
+ UpdatedAt: 1000,
+ },
+ },
+ }
+ require.NoError(t, svc.persistBulkEditTemplateLibrary(context.Background(), &store))
+
+ items, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+ require.Len(t, items, 2)
+ require.Equal(t, []string{"a", "b"}, []string{items[0].ID, items[1].ID})
+}
+
+func TestSettingServiceBulkEditTemplate_UpsertByIDAndMarshalError(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "By ID",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: map[string]any{"priority": 1},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+
+ updated, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ ID: created.ID,
+ Name: "By ID",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: map[string]any{"priority": 2},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+ require.Equal(t, created.ID, updated.ID)
+ require.EqualValues(t, 2, updated.State["priority"])
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Marshal Error",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"bad": make(chan int)},
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Unauthorized",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{},
+ RequesterUserID: 0,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsUnauthorized(err))
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Missing Scope",
+ ScopePlatform: "",
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{},
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+}
+
+func TestSettingServiceBulkEditTemplate_LoadErrorFromRepository(t *testing.T) {
+ svc := NewSettingService(&bulkTemplateFailingRepoStub{}, nil)
+ store, err := svc.loadBulkEditTemplateLibrary(context.Background())
+ require.Error(t, err)
+ require.Nil(t, store)
+}
+
+func TestGenerateBulkEditTemplateID_Fallback(t *testing.T) {
+ original := bulkEditTemplateRandRead
+ bulkEditTemplateRandRead = func(_ []byte) (int, error) {
+ return 0, errors.New("rand fail")
+ }
+ defer func() {
+ bulkEditTemplateRandRead = original
+ }()
+
+ id := generateBulkEditTemplateID()
+ require.NotEmpty(t, id)
+ require.Contains(t, id, "btpl-")
+
+ versionID := generateBulkEditTemplateVersionID()
+ require.NotEmpty(t, versionID)
+ require.Contains(t, versionID, "btplv-")
+}
+
+func TestSettingServiceBulkEditTemplate_VersionLifecycleAndRollback(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Versioned Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"priority": 1},
+ RequesterUserID: 88,
+ })
+ require.NoError(t, err)
+
+ updated, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ ID: created.ID,
+ Name: "Versioned Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: map[string]any{"priority": 9},
+ RequesterUserID: 88,
+ })
+ require.NoError(t, err)
+ require.Equal(t, BulkEditTemplateShareScopeTeam, updated.ShareScope)
+
+ versions, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: created.ID,
+ RequesterUserID: 88,
+ })
+ require.NoError(t, err)
+ require.Len(t, versions, 1)
+ require.Equal(t, BulkEditTemplateShareScopePrivate, versions[0].ShareScope)
+ require.EqualValues(t, 1, versions[0].State["priority"])
+
+ rollbacked, err := svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{
+ TemplateID: created.ID,
+ VersionID: versions[0].VersionID,
+ RequesterUserID: 88,
+ })
+ require.NoError(t, err)
+ require.Equal(t, BulkEditTemplateShareScopePrivate, rollbacked.ShareScope)
+ require.EqualValues(t, 1, rollbacked.State["priority"])
+
+ versionsAfterRollback, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: created.ID,
+ RequesterUserID: 88,
+ })
+ require.NoError(t, err)
+ require.Len(t, versionsAfterRollback, 2)
+}
+
+func TestSettingServiceBulkEditTemplate_VersionVisibilityAndErrors(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Group Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeGroups,
+ GroupIDs: []int64{7},
+ State: map[string]any{"enableBaseUrl": true},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ ID: created.ID,
+ Name: "Group Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeGroups,
+ GroupIDs: []int64{7, 9},
+ State: map[string]any{"enableBaseUrl": false},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+
+ _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: created.ID,
+ ScopeGroupIDs: []int64{8},
+ RequesterUserID: 2,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsForbidden(err))
+
+ visibleVersions, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: created.ID,
+ ScopeGroupIDs: []int64{7},
+ RequesterUserID: 2,
+ })
+ require.NoError(t, err)
+ require.Len(t, visibleVersions, 1)
+
+ _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: "missing",
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsNotFound(err))
+
+ _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: created.ID,
+ RequesterUserID: 0,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsUnauthorized(err))
+
+ _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{
+ TemplateID: created.ID,
+ VersionID: "missing-version",
+ ScopeGroupIDs: []int64{7},
+ RequesterUserID: 2,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsNotFound(err))
+
+ _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{
+ TemplateID: created.ID,
+ VersionID: visibleVersions[0].VersionID,
+ ScopeGroupIDs: []int64{8},
+ RequesterUserID: 2,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsForbidden(err))
+
+ _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: " ",
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{
+ TemplateID: created.ID,
+ VersionID: visibleVersions[0].VersionID,
+ RequesterUserID: 0,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsUnauthorized(err))
+}
+
+func TestSettingServiceBulkEditTemplate_VersionHelpers(t *testing.T) {
+ normalized := normalizeBulkEditTemplateVersionStoreItems([]bulkEditTemplateVersionStoreItem{
+ {
+ VersionID: "",
+ ShareScope: "groups",
+ GroupIDs: []int64{},
+ State: nil,
+ UpdatedBy: 1,
+ UpdatedAt: 0,
+ },
+ {
+ VersionID: "v-1",
+ ShareScope: "team",
+ GroupIDs: []int64{4, 4, 2},
+ State: []byte(`{"ok":true}`),
+ UpdatedBy: 2,
+ UpdatedAt: 20,
+ },
+ {
+ VersionID: "v-1",
+ ShareScope: "team",
+ State: []byte(`{}`),
+ UpdatedAt: 30,
+ },
+ })
+ require.Len(t, normalized, 2)
+ privateCount := 0
+ teamCount := 0
+ for _, item := range normalized {
+ if item.ShareScope == BulkEditTemplateShareScopePrivate {
+ privateCount++
+ }
+ if item.ShareScope == BulkEditTemplateShareScopeTeam {
+ teamCount++
+ require.Equal(t, []int64{2, 4}, item.GroupIDs)
+ }
+ }
+ require.Equal(t, 1, privateCount)
+ require.Equal(t, 1, teamCount)
+
+ item := bulkEditTemplateStoreItem{
+ ID: "tpl-1",
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ GroupIDs: []int64{3},
+ State: []byte(`{"priority":3}`),
+ UpdatedBy: 10,
+ UpdatedAt: 123,
+ }
+ version := snapshotBulkEditTemplateVersion(item)
+ require.NotEmpty(t, version.VersionID)
+ require.Equal(t, BulkEditTemplateShareScopeTeam, version.ShareScope)
+ require.EqualValues(t, 123, version.UpdatedAt)
+
+ versionDTO := toBulkEditTemplateVersion(bulkEditTemplateVersionStoreItem{
+ VersionID: "ver-1",
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ GroupIDs: []int64{9},
+ State: []byte(`invalid`),
+ UpdatedBy: 1,
+ UpdatedAt: 2,
+ })
+ require.Equal(t, map[string]any{}, versionDTO.State)
+
+ require.Equal(t, -1, findBulkEditTemplateVersionIndexByID(nil, "x"))
+ require.Equal(t, -1, findBulkEditTemplateStoreItemIndexByID(nil, "x"))
+ require.Nil(t, findBulkEditTemplateStoreItemByID(nil, "x"))
+
+ scopeSet := toBulkEditTemplateScopeGroupSet([]int64{4, 4, 2, -1})
+ _, has2 := scopeSet[2]
+ _, has4 := scopeSet[4]
+ require.True(t, has2)
+ require.True(t, has4)
+
+ cloned := cloneBulkEditTemplateStateRaw(json.RawMessage(`{"x":1}`))
+ require.Equal(t, `{"x":1}`, string(cloned))
+ require.Equal(t, `{}`, string(cloneBulkEditTemplateStateRaw(nil)))
+}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index f7e4fb6be..776cab85a 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -7,31 +7,22 @@ import (
"encoding/json"
"errors"
"fmt"
- "log/slog"
"net/url"
+ "sort"
"strconv"
"strings"
- "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
- "golang.org/x/sync/singleflight"
+ "github.com/tidwall/gjson"
)
var (
- ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
- ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
- ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
- ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
- ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
- "DEFAULT_SUBSCRIPTION_GROUP_INVALID",
- "default subscription group must exist and be subscription type",
- )
- ErrDefaultSubGroupDuplicate = infraerrors.BadRequest(
- "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE",
- "default subscription group cannot be duplicated",
- )
+ ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
+ ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
+ ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
+ ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
)
type SettingRepository interface {
@@ -44,40 +35,13 @@ type SettingRepository interface {
Delete(ctx context.Context, key string) error
}
-// cachedMinVersion 缓存最低 Claude Code 版本号(进程内缓存,60s TTL)
-type cachedMinVersion struct {
- value string // 空字符串 = 不检查
- expiresAt int64 // unix nano
-}
-
-// minVersionCache 最低版本号进程内缓存
-var minVersionCache atomic.Value // *cachedMinVersion
-
-// minVersionSF 防止缓存过期时 thundering herd
-var minVersionSF singleflight.Group
-
-// minVersionCacheTTL 缓存有效期
-const minVersionCacheTTL = 60 * time.Second
-
-// minVersionErrorTTL DB 错误时的短缓存,快速重试
-const minVersionErrorTTL = 5 * time.Second
-
-// minVersionDBTimeout singleflight 内 DB 查询超时,独立于请求 context
-const minVersionDBTimeout = 5 * time.Second
-
-// DefaultSubscriptionGroupReader validates group references used by default subscriptions.
-type DefaultSubscriptionGroupReader interface {
- GetByID(ctx context.Context, id int64) (*Group, error)
-}
-
// SettingService 系统设置服务
type SettingService struct {
- settingRepo SettingRepository
- defaultSubGroupReader DefaultSubscriptionGroupReader
- cfg *config.Config
- onUpdate func() // Callback when settings are updated (for cache invalidation)
- onS3Update func() // Callback when Sora S3 settings are updated
- version string // Application version
+ settingRepo SettingRepository
+ cfg *config.Config
+ onUpdate func() // Callback when settings are updated (for cache invalidation)
+ onS3Update func() // Callback when Sora S3 settings are updated
+ version string // Application version
}
// NewSettingService 创建系统设置服务实例
@@ -88,11 +52,6 @@ func NewSettingService(settingRepo SettingRepository, cfg *config.Config) *Setti
}
}
-// SetDefaultSubscriptionGroupReader injects an optional group reader for default subscription validation.
-func (s *SettingService) SetDefaultSubscriptionGroupReader(reader DefaultSubscriptionGroupReader) {
- s.defaultSubGroupReader = reader
-}
-
// GetAllSettings 获取所有系统设置
func (s *SettingService) GetAllSettings(ctx context.Context) (*SystemSettings, error) {
settings, err := s.settingRepo.GetAll(ctx)
@@ -125,7 +84,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeyPurchaseSubscriptionEnabled,
SettingKeyPurchaseSubscriptionURL,
SettingKeySoraClientEnabled,
- SettingKeyCustomMenuItems,
SettingKeyLinuxDoConnectEnabled,
}
@@ -165,7 +123,6 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
- CustomMenuItems: settings[SettingKeyCustomMenuItems],
LinuxDoOAuthEnabled: linuxDoEnabled,
}, nil
}
@@ -196,28 +153,27 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
// Return a struct that matches the frontend's expected format
return &struct {
- RegistrationEnabled bool `json:"registration_enabled"`
- EmailVerifyEnabled bool `json:"email_verify_enabled"`
- PromoCodeEnabled bool `json:"promo_code_enabled"`
- PasswordResetEnabled bool `json:"password_reset_enabled"`
- InvitationCodeEnabled bool `json:"invitation_code_enabled"`
- TotpEnabled bool `json:"totp_enabled"`
- TurnstileEnabled bool `json:"turnstile_enabled"`
- TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
- SiteName string `json:"site_name"`
- SiteLogo string `json:"site_logo,omitempty"`
- SiteSubtitle string `json:"site_subtitle,omitempty"`
- APIBaseURL string `json:"api_base_url,omitempty"`
- ContactInfo string `json:"contact_info,omitempty"`
- DocURL string `json:"doc_url,omitempty"`
- HomeContent string `json:"home_content,omitempty"`
- HideCcsImportButton bool `json:"hide_ccs_import_button"`
- PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
- PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
- SoraClientEnabled bool `json:"sora_client_enabled"`
- CustomMenuItems json.RawMessage `json:"custom_menu_items"`
- LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
- Version string `json:"version,omitempty"`
+ RegistrationEnabled bool `json:"registration_enabled"`
+ EmailVerifyEnabled bool `json:"email_verify_enabled"`
+ PromoCodeEnabled bool `json:"promo_code_enabled"`
+ PasswordResetEnabled bool `json:"password_reset_enabled"`
+ InvitationCodeEnabled bool `json:"invitation_code_enabled"`
+ TotpEnabled bool `json:"totp_enabled"`
+ TurnstileEnabled bool `json:"turnstile_enabled"`
+ TurnstileSiteKey string `json:"turnstile_site_key,omitempty"`
+ SiteName string `json:"site_name"`
+ SiteLogo string `json:"site_logo,omitempty"`
+ SiteSubtitle string `json:"site_subtitle,omitempty"`
+ APIBaseURL string `json:"api_base_url,omitempty"`
+ ContactInfo string `json:"contact_info,omitempty"`
+ DocURL string `json:"doc_url,omitempty"`
+ HomeContent string `json:"home_content,omitempty"`
+ HideCcsImportButton bool `json:"hide_ccs_import_button"`
+ PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"`
+ PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"`
+ SoraClientEnabled bool `json:"sora_client_enabled"`
+ LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
+ Version string `json:"version,omitempty"`
}{
RegistrationEnabled: settings.RegistrationEnabled,
EmailVerifyEnabled: settings.EmailVerifyEnabled,
@@ -238,125 +194,65 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any
PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled,
PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL,
SoraClientEnabled: settings.SoraClientEnabled,
- CustomMenuItems: filterUserVisibleMenuItems(settings.CustomMenuItems),
LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled,
Version: s.version,
}, nil
}
-// filterUserVisibleMenuItems filters out admin-only menu items from a raw JSON
-// array string, returning only items with visibility != "admin".
-func filterUserVisibleMenuItems(raw string) json.RawMessage {
- raw = strings.TrimSpace(raw)
- if raw == "" || raw == "[]" {
- return json.RawMessage("[]")
- }
- var items []struct {
- Visibility string `json:"visibility"`
- }
- if err := json.Unmarshal([]byte(raw), &items); err != nil {
- return json.RawMessage("[]")
- }
-
- // Parse full items to preserve all fields
- var fullItems []json.RawMessage
- if err := json.Unmarshal([]byte(raw), &fullItems); err != nil {
- return json.RawMessage("[]")
- }
-
- var filtered []json.RawMessage
- for i, item := range items {
- if item.Visibility != "admin" {
- filtered = append(filtered, fullItems[i])
- }
- }
- if len(filtered) == 0 {
- return json.RawMessage("[]")
+// GetFrameSrcOrigins 提取需要注入 CSP frame-src 的外部域名来源。
+func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) {
+ keys := []string{
+ SettingKeyPurchaseSubscriptionURL,
+ SettingKeyHomeContent,
+ SettingKeyCustomMenuItems,
}
- result, err := json.Marshal(filtered)
+ settings, err := s.settingRepo.GetMultiple(ctx, keys)
if err != nil {
- return json.RawMessage("[]")
+ return nil, fmt.Errorf("get frame src settings: %w", err)
}
- return result
-}
-// GetFrameSrcOrigins returns deduplicated http(s) origins from purchase_subscription_url
-// and all custom_menu_items URLs. Used by the router layer for CSP frame-src injection.
-func (s *SettingService) GetFrameSrcOrigins(ctx context.Context) ([]string, error) {
- settings, err := s.GetPublicSettings(ctx)
- if err != nil {
- return nil, err
+ originsSet := make(map[string]struct{}, 8)
+ addOrigin := func(raw string) {
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return
+ }
+ u, parseErr := url.Parse(raw)
+ if parseErr != nil || u == nil {
+ return
+ }
+ if u.Scheme != "http" && u.Scheme != "https" {
+ return
+ }
+ if u.Host == "" {
+ return
+ }
+ originsSet[u.Scheme+"://"+u.Host] = struct{}{}
}
- seen := make(map[string]struct{})
- var origins []string
+ addOrigin(settings[SettingKeyPurchaseSubscriptionURL])
+ addOrigin(settings[SettingKeyHomeContent])
- addOrigin := func(rawURL string) {
- if origin := extractOriginFromURL(rawURL); origin != "" {
- if _, ok := seen[origin]; !ok {
- seen[origin] = struct{}{}
- origins = append(origins, origin)
+ customMenuRaw := strings.TrimSpace(settings[SettingKeyCustomMenuItems])
+ if customMenuRaw != "" {
+ menuItems := gjson.Parse(customMenuRaw)
+ if menuItems.IsArray() {
+ for _, item := range menuItems.Array() {
+ addOrigin(item.Get("url").String())
}
}
}
- // purchase subscription URL
- if settings.PurchaseSubscriptionEnabled {
- addOrigin(settings.PurchaseSubscriptionURL)
- }
-
- // all custom menu items (including admin-only, since CSP must allow all iframes)
- for _, item := range parseCustomMenuItemURLs(settings.CustomMenuItems) {
- addOrigin(item)
+ origins := make([]string, 0, len(originsSet))
+ for origin := range originsSet {
+ origins = append(origins, origin)
}
-
+ sort.Strings(origins)
return origins, nil
}
-// extractOriginFromURL returns the scheme+host origin from rawURL.
-// Only http and https schemes are accepted.
-func extractOriginFromURL(rawURL string) string {
- rawURL = strings.TrimSpace(rawURL)
- if rawURL == "" {
- return ""
- }
- u, err := url.Parse(rawURL)
- if err != nil || u.Host == "" {
- return ""
- }
- if u.Scheme != "http" && u.Scheme != "https" {
- return ""
- }
- return u.Scheme + "://" + u.Host
-}
-
-// parseCustomMenuItemURLs extracts URLs from a raw JSON array of custom menu items.
-func parseCustomMenuItemURLs(raw string) []string {
- raw = strings.TrimSpace(raw)
- if raw == "" || raw == "[]" {
- return nil
- }
- var items []struct {
- URL string `json:"url"`
- }
- if err := json.Unmarshal([]byte(raw), &items); err != nil {
- return nil
- }
- urls := make([]string, 0, len(items))
- for _, item := range items {
- if item.URL != "" {
- urls = append(urls, item.URL)
- }
- }
- return urls
-}
-
// UpdateSettings 更新系统设置
func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSettings) error {
- if err := s.validateDefaultSubscriptionGroups(ctx, settings.DefaultSubscriptions); err != nil {
- return err
- }
-
updates := make(map[string]string)
// 注册设置
@@ -405,16 +301,10 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled)
updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL)
updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled)
- updates[SettingKeyCustomMenuItems] = settings.CustomMenuItems
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
- defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions)
- if err != nil {
- return fmt.Errorf("marshal default subscriptions: %w", err)
- }
- updates[SettingKeyDefaultSubscriptions] = string(defaultSubsJSON)
// Model fallback configuration
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
@@ -435,63 +325,13 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyOpsMetricsIntervalSeconds] = strconv.Itoa(settings.OpsMetricsIntervalSeconds)
}
- // Claude Code version check
- updates[SettingKeyMinClaudeCodeVersion] = settings.MinClaudeCodeVersion
-
- err = s.settingRepo.SetMultiple(ctx, updates)
- if err == nil {
- // 先使 inflight singleflight 失效,再刷新缓存,缩小旧值覆盖新值的竞态窗口
- minVersionSF.Forget("min_version")
- minVersionCache.Store(&cachedMinVersion{
- value: settings.MinClaudeCodeVersion,
- expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(),
- })
- if s.onUpdate != nil {
- s.onUpdate() // Invalidate cache after settings update
- }
+ err := s.settingRepo.SetMultiple(ctx, updates)
+ if err == nil && s.onUpdate != nil {
+ s.onUpdate() // Invalidate cache after settings update
}
return err
}
-func (s *SettingService) validateDefaultSubscriptionGroups(ctx context.Context, items []DefaultSubscriptionSetting) error {
- if len(items) == 0 {
- return nil
- }
-
- checked := make(map[int64]struct{}, len(items))
- for _, item := range items {
- if item.GroupID <= 0 {
- continue
- }
- if _, ok := checked[item.GroupID]; ok {
- return ErrDefaultSubGroupDuplicate.WithMetadata(map[string]string{
- "group_id": strconv.FormatInt(item.GroupID, 10),
- })
- }
- checked[item.GroupID] = struct{}{}
- if s.defaultSubGroupReader == nil {
- continue
- }
-
- group, err := s.defaultSubGroupReader.GetByID(ctx, item.GroupID)
- if err != nil {
- if errors.Is(err, ErrGroupNotFound) {
- return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{
- "group_id": strconv.FormatInt(item.GroupID, 10),
- })
- }
- return fmt.Errorf("get default subscription group %d: %w", item.GroupID, err)
- }
- if !group.IsSubscriptionType() {
- return ErrDefaultSubGroupInvalid.WithMetadata(map[string]string{
- "group_id": strconv.FormatInt(item.GroupID, 10),
- })
- }
- }
-
- return nil
-}
-
// IsRegistrationEnabled 检查是否开放注册
func (s *SettingService) IsRegistrationEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyRegistrationEnabled)
@@ -591,15 +431,6 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 {
return s.cfg.Default.UserBalance
}
-// GetDefaultSubscriptions 获取新用户默认订阅配置列表。
-func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting {
- value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions)
- if err != nil {
- return nil
- }
- return parseDefaultSubscriptions(value)
-}
-
// InitializeDefaultSettings 初始化默认设置
func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
// 检查是否已有设置
@@ -622,10 +453,8 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyPurchaseSubscriptionEnabled: "false",
SettingKeyPurchaseSubscriptionURL: "",
SettingKeySoraClientEnabled: "false",
- SettingKeyCustomMenuItems: "[]",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
- SettingKeyDefaultSubscriptions: "[]",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
// Model fallback defaults
@@ -643,9 +472,6 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeyOpsRealtimeMonitoringEnabled: "true",
SettingKeyOpsQueryModeDefault: "auto",
SettingKeyOpsMetricsIntervalSeconds: "60",
-
- // Claude Code version check (default: empty = disabled)
- SettingKeyMinClaudeCodeVersion: "",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -681,7 +507,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true",
PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]),
SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true",
- CustomMenuItems: settings[SettingKeyCustomMenuItems],
}
// 解析整数类型
@@ -703,7 +528,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
} else {
result.DefaultBalance = s.cfg.Default.UserBalance
}
- result.DefaultSubscriptions = parseDefaultSubscriptions(settings[SettingKeyDefaultSubscriptions])
// 敏感信息直接返回,方便测试连接时使用
result.SMTPPassword = settings[SettingKeySMTPPassword]
@@ -773,9 +597,6 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
}
- // Claude Code version check
- result.MinClaudeCodeVersion = settings[SettingKeyMinClaudeCodeVersion]
-
return result
}
@@ -788,31 +609,6 @@ func isFalseSettingValue(value string) bool {
}
}
-func parseDefaultSubscriptions(raw string) []DefaultSubscriptionSetting {
- raw = strings.TrimSpace(raw)
- if raw == "" {
- return nil
- }
-
- var items []DefaultSubscriptionSetting
- if err := json.Unmarshal([]byte(raw), &items); err != nil {
- return nil
- }
-
- normalized := make([]DefaultSubscriptionSetting, 0, len(items))
- for _, item := range items {
- if item.GroupID <= 0 || item.ValidityDays <= 0 {
- continue
- }
- if item.ValidityDays > MaxValidityDays {
- item.ValidityDays = MaxValidityDays
- }
- normalized = append(normalized, item)
- }
-
- return normalized
-}
-
// getStringOrDefault 获取字符串值或默认值
func (s *SettingService) getStringOrDefault(settings map[string]string, key, defaultValue string) string {
if value, ok := settings[key]; ok && value != "" {
@@ -1098,53 +894,6 @@ func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamT
return &settings, nil
}
-// GetMinClaudeCodeVersion 获取最低 Claude Code 版本号要求
-// 使用进程内 atomic.Value 缓存,60 秒 TTL,热路径零锁开销
-// singleflight 防止缓存过期时 thundering herd
-// 返回空字符串表示不做版本检查
-func (s *SettingService) GetMinClaudeCodeVersion(ctx context.Context) string {
- if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok {
- if time.Now().UnixNano() < cached.expiresAt {
- return cached.value
- }
- }
- // singleflight: 同一时刻只有一个 goroutine 查询 DB,其余复用结果
- result, err, _ := minVersionSF.Do("min_version", func() (any, error) {
- // 二次检查,避免排队的 goroutine 重复查询
- if cached, ok := minVersionCache.Load().(*cachedMinVersion); ok {
- if time.Now().UnixNano() < cached.expiresAt {
- return cached.value, nil
- }
- }
- // 使用独立 context:断开请求取消链,避免客户端断连导致空值被长期缓存
- dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), minVersionDBTimeout)
- defer cancel()
- value, err := s.settingRepo.GetValue(dbCtx, SettingKeyMinClaudeCodeVersion)
- if err != nil {
- // fail-open: DB 错误时不阻塞请求,但记录日志并使用短 TTL 快速重试
- slog.Warn("failed to get min claude code version setting, skipping version check", "error", err)
- minVersionCache.Store(&cachedMinVersion{
- value: "",
- expiresAt: time.Now().Add(minVersionErrorTTL).UnixNano(),
- })
- return "", nil
- }
- minVersionCache.Store(&cachedMinVersion{
- value: value,
- expiresAt: time.Now().Add(minVersionCacheTTL).UnixNano(),
- })
- return value, nil
- })
- if err != nil {
- return ""
- }
- ver, ok := result.(string)
- if !ok {
- return ""
- }
- return ver
-}
-
// SetStreamTimeoutSettings 设置流超时处理配置
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
if settings == nil {
diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go
deleted file mode 100644
index ec64511f2..000000000
--- a/backend/internal/service/setting_service_update_test.go
+++ /dev/null
@@ -1,182 +0,0 @@
-//go:build unit
-
-package service
-
-import (
- "context"
- "encoding/json"
- "testing"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
- "github.com/stretchr/testify/require"
-)
-
-type settingUpdateRepoStub struct {
- updates map[string]string
-}
-
-func (s *settingUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
- panic("unexpected Get call")
-}
-
-func (s *settingUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
- panic("unexpected GetValue call")
-}
-
-func (s *settingUpdateRepoStub) Set(ctx context.Context, key, value string) error {
- panic("unexpected Set call")
-}
-
-func (s *settingUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
- panic("unexpected GetMultiple call")
-}
-
-func (s *settingUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
- s.updates = make(map[string]string, len(settings))
- for k, v := range settings {
- s.updates[k] = v
- }
- return nil
-}
-
-func (s *settingUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
- panic("unexpected GetAll call")
-}
-
-func (s *settingUpdateRepoStub) Delete(ctx context.Context, key string) error {
- panic("unexpected Delete call")
-}
-
-type defaultSubGroupReaderStub struct {
- byID map[int64]*Group
- errBy map[int64]error
- calls []int64
-}
-
-func (s *defaultSubGroupReaderStub) GetByID(ctx context.Context, id int64) (*Group, error) {
- s.calls = append(s.calls, id)
- if err, ok := s.errBy[id]; ok {
- return nil, err
- }
- if g, ok := s.byID[id]; ok {
- return g, nil
- }
- return nil, ErrGroupNotFound
-}
-
-func TestSettingService_UpdateSettings_DefaultSubscriptions_ValidGroup(t *testing.T) {
- repo := &settingUpdateRepoStub{}
- groupReader := &defaultSubGroupReaderStub{
- byID: map[int64]*Group{
- 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription},
- },
- }
- svc := NewSettingService(repo, &config.Config{})
- svc.SetDefaultSubscriptionGroupReader(groupReader)
-
- err := svc.UpdateSettings(context.Background(), &SystemSettings{
- DefaultSubscriptions: []DefaultSubscriptionSetting{
- {GroupID: 11, ValidityDays: 30},
- },
- })
- require.NoError(t, err)
- require.Equal(t, []int64{11}, groupReader.calls)
-
- raw, ok := repo.updates[SettingKeyDefaultSubscriptions]
- require.True(t, ok)
-
- var got []DefaultSubscriptionSetting
- require.NoError(t, json.Unmarshal([]byte(raw), &got))
- require.Equal(t, []DefaultSubscriptionSetting{
- {GroupID: 11, ValidityDays: 30},
- }, got)
-}
-
-func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNonSubscriptionGroup(t *testing.T) {
- repo := &settingUpdateRepoStub{}
- groupReader := &defaultSubGroupReaderStub{
- byID: map[int64]*Group{
- 12: {ID: 12, SubscriptionType: SubscriptionTypeStandard},
- },
- }
- svc := NewSettingService(repo, &config.Config{})
- svc.SetDefaultSubscriptionGroupReader(groupReader)
-
- err := svc.UpdateSettings(context.Background(), &SystemSettings{
- DefaultSubscriptions: []DefaultSubscriptionSetting{
- {GroupID: 12, ValidityDays: 7},
- },
- })
- require.Error(t, err)
- require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err))
- require.Nil(t, repo.updates)
-}
-
-func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNotFoundGroup(t *testing.T) {
- repo := &settingUpdateRepoStub{}
- groupReader := &defaultSubGroupReaderStub{
- errBy: map[int64]error{
- 13: ErrGroupNotFound,
- },
- }
- svc := NewSettingService(repo, &config.Config{})
- svc.SetDefaultSubscriptionGroupReader(groupReader)
-
- err := svc.UpdateSettings(context.Background(), &SystemSettings{
- DefaultSubscriptions: []DefaultSubscriptionSetting{
- {GroupID: 13, ValidityDays: 7},
- },
- })
- require.Error(t, err)
- require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err))
- require.Equal(t, "13", infraerrors.FromError(err).Metadata["group_id"])
- require.Nil(t, repo.updates)
-}
-
-func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroup(t *testing.T) {
- repo := &settingUpdateRepoStub{}
- groupReader := &defaultSubGroupReaderStub{
- byID: map[int64]*Group{
- 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription},
- },
- }
- svc := NewSettingService(repo, &config.Config{})
- svc.SetDefaultSubscriptionGroupReader(groupReader)
-
- err := svc.UpdateSettings(context.Background(), &SystemSettings{
- DefaultSubscriptions: []DefaultSubscriptionSetting{
- {GroupID: 11, ValidityDays: 30},
- {GroupID: 11, ValidityDays: 60},
- },
- })
- require.Error(t, err)
- require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err))
- require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"])
- require.Nil(t, repo.updates)
-}
-
-func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroupWithoutGroupReader(t *testing.T) {
- repo := &settingUpdateRepoStub{}
- svc := NewSettingService(repo, &config.Config{})
-
- err := svc.UpdateSettings(context.Background(), &SystemSettings{
- DefaultSubscriptions: []DefaultSubscriptionSetting{
- {GroupID: 11, ValidityDays: 30},
- {GroupID: 11, ValidityDays: 60},
- },
- })
- require.Error(t, err)
- require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err))
- require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"])
- require.Nil(t, repo.updates)
-}
-
-func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) {
- got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`)
- require.Equal(t, []DefaultSubscriptionSetting{
- {GroupID: 11, ValidityDays: 30},
- {GroupID: 11, ValidityDays: 60},
- {GroupID: 12, ValidityDays: MaxValidityDays},
- }, got)
-}
diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go
index 9f0de6000..6611f9901 100644
--- a/backend/internal/service/settings_view.go
+++ b/backend/internal/service/settings_view.go
@@ -42,9 +42,8 @@ type SystemSettings struct {
SoraClientEnabled bool
CustomMenuItems string // JSON array of custom menu items
- DefaultConcurrency int
- DefaultBalance float64
- DefaultSubscriptions []DefaultSubscriptionSetting
+ DefaultConcurrency int
+ DefaultBalance float64
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
@@ -62,14 +61,6 @@ type SystemSettings struct {
OpsRealtimeMonitoringEnabled bool
OpsQueryModeDefault string
OpsMetricsIntervalSeconds int
-
- // Claude Code version check
- MinClaudeCodeVersion string
-}
-
-type DefaultSubscriptionSetting struct {
- GroupID int64 `json:"group_id"`
- ValidityDays int `json:"validity_days"`
}
type PublicSettings struct {
diff --git a/backend/internal/service/sora_generation_service_test.go b/backend/internal/service/sora_generation_service_test.go
index 46f322c82..a5c0c890d 100644
--- a/backend/internal/service/sora_generation_service_test.go
+++ b/backend/internal/service/sora_generation_service_test.go
@@ -162,12 +162,12 @@ func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, err
func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
-func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
-func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
-func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
return nil
}
+func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
+func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
diff --git a/backend/internal/service/token_refresh_parallel_test.go b/backend/internal/service/token_refresh_parallel_test.go
new file mode 100644
index 000000000..c844ef934
--- /dev/null
+++ b/backend/internal/service/token_refresh_parallel_test.go
@@ -0,0 +1,439 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+// --- 并行刷新专用 stub ---
+
+// concurrentTokenRefresherStub 记录并发度和调用次数
+type concurrentTokenRefresherStub struct {
+ canRefreshFn func(*Account) bool
+ needsRefreshFn func(*Account, time.Duration) bool
+ refreshDelay time.Duration
+ refreshErr error
+ credentials map[string]any
+ refreshCalls atomic.Int64
+ maxConcurrent atomic.Int64
+ currentActive atomic.Int64
+}
+
+func (r *concurrentTokenRefresherStub) CanRefresh(account *Account) bool {
+ if r.canRefreshFn != nil {
+ return r.canRefreshFn(account)
+ }
+ return true
+}
+
+func (r *concurrentTokenRefresherStub) NeedsRefresh(account *Account, window time.Duration) bool {
+ if r.needsRefreshFn != nil {
+ return r.needsRefreshFn(account, window)
+ }
+ return true
+}
+
+func (r *concurrentTokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
+ r.refreshCalls.Add(1)
+ active := r.currentActive.Add(1)
+ // 记录峰值并发
+ for {
+ old := r.maxConcurrent.Load()
+ if active <= old || r.maxConcurrent.CompareAndSwap(old, active) {
+ break
+ }
+ }
+ if r.refreshDelay > 0 {
+ time.Sleep(r.refreshDelay)
+ }
+ r.currentActive.Add(-1)
+ if r.refreshErr != nil {
+ return nil, r.refreshErr
+ }
+ // 每次返回新 map,避免多 goroutine 共享同一 map 实例引发竞态
+ creds := make(map[string]any, len(r.credentials))
+ for k, v := range r.credentials {
+ creds[k] = v
+ }
+ return creds, nil
+}
+
+// concurrentTokenRefreshAccountRepo 线程安全的 account repo stub
+type concurrentTokenRefreshAccountRepo struct {
+ mockAccountRepoForGemini
+ mu sync.Mutex
+ updateCalls int
+ setErrorCalls int
+ activeAccounts []Account
+ updateErr error
+}
+
+func (r *concurrentTokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.updateCalls++
+ return r.updateErr
+}
+
+func (r *concurrentTokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.setErrorCalls++
+ return nil
+}
+
+func (r *concurrentTokenRefreshAccountRepo) ListActive(ctx context.Context) ([]Account, error) {
+ out := make([]Account, len(r.activeAccounts))
+ copy(out, r.activeAccounts)
+ return out, nil
+}
+
+// --- 测试用例 ---
+
+func TestProcessRefresh_ParallelExecution(t *testing.T) {
+ accounts := make([]Account, 20)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(100 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ }
+ }
+
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ refreshDelay: 20 * time.Millisecond,
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ start := time.Now()
+ svc.processRefresh()
+ elapsed := time.Since(start)
+
+ // 20 个账号每个 20ms,串行至少 400ms;并行(maxConcurrency=10)约 40-60ms
+ require.Equal(t, int64(20), refresher.refreshCalls.Load())
+ require.Less(t, elapsed, 300*time.Millisecond, "并行刷新应显著快于串行")
+ require.Greater(t, refresher.maxConcurrent.Load(), int64(1), "应有多个账号并发刷新")
+ require.LessOrEqual(t, refresher.maxConcurrent.Load(), int64(10), "并发不应超过信号量限制")
+
+ repo.mu.Lock()
+ require.Equal(t, 20, repo.updateCalls)
+ repo.mu.Unlock()
+}
+
+func TestProcessRefresh_SemaphoreLimitsConcurrency(t *testing.T) {
+ accounts := make([]Account, 15)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(200 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ }
+ }
+
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ refreshDelay: 50 * time.Millisecond,
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ svc.processRefresh()
+
+ require.Equal(t, int64(15), refresher.refreshCalls.Load())
+ require.LessOrEqual(t, refresher.maxConcurrent.Load(), int64(10), "并发不应超过 maxConcurrency=10")
+}
+
+func TestProcessRefresh_StopInterruptsPhase2(t *testing.T) {
+ accounts := make([]Account, 30)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(300 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ }
+ }
+
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ refreshDelay: 100 * time.Millisecond,
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ done := make(chan struct{})
+ go func() {
+ svc.processRefresh()
+ close(done)
+ }()
+
+ // 短暂等待让部分 goroutine 启动
+ time.Sleep(30 * time.Millisecond)
+ svc.Stop()
+
+ select {
+ case <-done:
+ // ok
+ case <-time.After(3 * time.Second):
+ t.Fatal("processRefresh 应在收到 stop 信号后及时退出")
+ }
+
+ // 因中断,不应刷新全部 30 个账号
+ require.Less(t, refresher.refreshCalls.Load(), int64(30), "stop 应中断后续任务提交")
+}
+
+func TestProcessRefresh_EmptyAccounts(t *testing.T) {
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: nil}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg)
+ refresher := &concurrentTokenRefresherStub{}
+ svc.refreshers = []TokenRefresher{refresher}
+
+ // 不应 panic
+ require.NotPanics(t, func() {
+ svc.processRefresh()
+ })
+ require.Equal(t, int64(0), refresher.refreshCalls.Load())
+}
+
+func TestProcessRefresh_NoAccountsNeedRefresh(t *testing.T) {
+ accounts := []Account{
+ {ID: 401, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ {ID: 402, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ }
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ needsRefreshFn: func(a *Account, d time.Duration) bool { return false },
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ svc.processRefresh()
+
+ require.Equal(t, int64(0), refresher.refreshCalls.Load())
+}
+
+func TestProcessRefresh_MixedSuccessAndFailure(t *testing.T) {
+ accounts := []Account{
+ {ID: 501, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ {ID: 502, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ {ID: 503, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ {ID: 504, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ }
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+
+ // 偶数 ID 成功,奇数 ID 失败
+ refresher := &concurrentTokenRefresherStub{
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ failRefresher := &concurrentTokenRefresherStub{
+ refreshErr: errors.New("refresh failed"),
+ }
+
+ // 使用 selectiveRefresher 按 ID 分流
+ selectiveRefresher := &selectiveTokenRefresherStub{
+ successRefresher: refresher,
+ failRefresher: failRefresher,
+ failIDs: map[int64]bool{501: true, 503: true},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{selectiveRefresher}
+
+ svc.processRefresh()
+
+ totalCalls := refresher.refreshCalls.Load() + failRefresher.refreshCalls.Load()
+ require.Equal(t, int64(4), totalCalls)
+ require.Equal(t, int64(2), refresher.refreshCalls.Load())
+ require.Equal(t, int64(2), failRefresher.refreshCalls.Load())
+}
+
+// selectiveTokenRefresherStub 按账号 ID 分流到不同的 refresher
+type selectiveTokenRefresherStub struct {
+ successRefresher *concurrentTokenRefresherStub
+ failRefresher *concurrentTokenRefresherStub
+ failIDs map[int64]bool
+}
+
+func (r *selectiveTokenRefresherStub) CanRefresh(account *Account) bool {
+ return true
+}
+
+func (r *selectiveTokenRefresherStub) NeedsRefresh(account *Account, window time.Duration) bool {
+ return true
+}
+
+func (r *selectiveTokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
+ if r.failIDs[account.ID] {
+ return r.failRefresher.Refresh(ctx, account)
+ }
+ return r.successRefresher.Refresh(ctx, account)
+}
+
+func TestProcessRefresh_SingleAccount(t *testing.T) {
+ accounts := []Account{
+ {ID: 601, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ }
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ svc.processRefresh()
+
+ require.Equal(t, int64(1), refresher.refreshCalls.Load())
+ require.Equal(t, int64(1), refresher.maxConcurrent.Load())
+}
+
+func TestProcessRefresh_AllFailed(t *testing.T) {
+ accounts := make([]Account, 5)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(700 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ }
+ }
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ refreshErr: errors.New("all fail"),
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ // 不应 panic
+ require.NotPanics(t, func() {
+ svc.processRefresh()
+ })
+ require.Equal(t, int64(5), refresher.refreshCalls.Load())
+
+ repo.mu.Lock()
+ require.Equal(t, 5, repo.setErrorCalls)
+ require.Equal(t, 0, repo.updateCalls)
+ repo.mu.Unlock()
+}
+
+func TestProcessRefresh_CanRefreshFilters(t *testing.T) {
+ accounts := []Account{
+ {ID: 801, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ {ID: 802, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive},
+ {ID: 803, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ }
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ canRefreshFn: func(a *Account) bool { return a.Type == AccountTypeOAuth },
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ svc.processRefresh()
+
+ // 只有 OAuth 账号(ID 801, 803)应被刷新
+ require.Equal(t, int64(2), refresher.refreshCalls.Load())
+}
diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go
index a37e0d0ac..ae373a831 100644
--- a/backend/internal/service/token_refresh_service.go
+++ b/backend/internal/service/token_refresh_service.go
@@ -6,11 +6,19 @@ import (
"log/slog"
"strings"
"sync"
+ "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
+const (
+ tokenRefreshDistributedLockPlatform = "token_refresh"
+ tokenRefreshDistributedLockMinTTL = 30 * time.Second
+ tokenRefreshDistributedLockMaxTTL = 10 * time.Minute
+ tokenRefreshDistributedLockTimeout = 2 * time.Second
+)
+
// TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token
type TokenRefreshService struct {
@@ -20,8 +28,9 @@ type TokenRefreshService struct {
cacheInvalidator TokenCacheInvalidator
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
- stopCh chan struct{}
- wg sync.WaitGroup
+ stopCh chan struct{}
+ stopOnce sync.Once
+ wg sync.WaitGroup
}
// NewTokenRefreshService 创建token刷新服务
@@ -87,7 +96,12 @@ func (s *TokenRefreshService) Start() {
// Stop 停止刷新服务
func (s *TokenRefreshService) Stop() {
- close(s.stopCh)
+ if s == nil {
+ return
+ }
+ s.stopOnce.Do(func() {
+ close(s.stopCh)
+ })
s.wg.Wait()
slog.Info("token_refresh.service_stopped")
}
@@ -118,9 +132,24 @@ func (s *TokenRefreshService) refreshLoop() {
}
}
+// refreshTask 封装一个待刷新的账号及其对应的刷新器
+type refreshTask struct {
+ account *Account
+ refresher TokenRefresher
+}
+
// processRefresh 执行一次刷新检查
+// 分两阶段:先串行收集需刷新的账号,再并行执行刷新(信号量限制并发数)
func (s *TokenRefreshService) processRefresh() {
- ctx := context.Background()
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ go func() {
+ select {
+ case <-s.stopCh:
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
// 计算刷新窗口
refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour))
@@ -128,68 +157,188 @@ func (s *TokenRefreshService) processRefresh() {
// 获取所有active状态的账号
accounts, err := s.listActiveAccounts(ctx)
if err != nil {
+ if ctx.Err() != nil {
+ slog.Info("token_refresh.cycle_interrupted_by_stop")
+ return
+ }
slog.Error("token_refresh.list_accounts_failed", "error", err)
return
}
totalAccounts := len(accounts)
- oauthAccounts := 0 // 可刷新的OAuth账号数
- needsRefresh := 0 // 需要刷新的账号数
- refreshed, failed := 0, 0
+
+ // Phase 1: 收集需要刷新的账号(轻量操作,仍串行)
+ var tasks []refreshTask
+ oauthAccounts := 0
for i := range accounts {
account := &accounts[i]
-
- // 遍历所有刷新器,找到能处理此账号的
for _, refresher := range s.refreshers {
if !refresher.CanRefresh(account) {
continue
}
-
oauthAccounts++
-
- // 检查是否需要刷新
- if !refresher.NeedsRefresh(account, refreshWindow) {
- break // 不需要刷新,跳过
+ if refresher.NeedsRefresh(account, refreshWindow) {
+ if !s.tryAcquireDistributedRefreshLock(ctx, account) {
+ break
+ }
+ tasks = append(tasks, refreshTask{account: account, refresher: refresher})
}
+ break // 每个账号只由一个refresher处理
+ }
+ }
- needsRefresh++
+ needsRefresh := len(tasks)
+ if needsRefresh == 0 {
+ slog.Debug("token_refresh.cycle_completed",
+ "total", totalAccounts, "oauth", oauthAccounts,
+ "needs_refresh", 0, "refreshed", 0, "failed", 0)
+ return
+ }
+
+ // Phase 2: 并行刷新(带信号量限制)
+ const maxConcurrency = 10
+ sem := make(chan struct{}, maxConcurrency)
+ var refreshed, failed atomic.Int64
+ var wg sync.WaitGroup
+ interrupted := false
+
+submitLoop:
+ for _, task := range tasks {
+ // 检查停止信号
+ select {
+ case <-s.stopCh:
+ slog.Info("token_refresh.cycle_interrupted_by_stop")
+ interrupted = true
+ break submitLoop
+ default:
+ }
+
+ select {
+ case sem <- struct{}{}: // 获取信号量
+ case <-s.stopCh:
+ slog.Info("token_refresh.cycle_interrupted_by_stop")
+ interrupted = true
+ break submitLoop
+ case <-ctx.Done():
+ interrupted = true
+ break submitLoop
+ }
- // 执行刷新
- if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
+ wg.Add(1)
+ go func(t refreshTask) {
+ defer func() {
+ <-sem // 释放信号量
+ wg.Done()
+ }()
+
+ if err := s.refreshWithRetry(ctx, t.account, t.refresher); err != nil {
slog.Warn("token_refresh.account_refresh_failed",
- "account_id", account.ID,
- "account_name", account.Name,
+ "account_id", t.account.ID,
+ "account_name", t.account.Name,
"error", err,
)
- failed++
+ failed.Add(1)
} else {
slog.Info("token_refresh.account_refreshed",
- "account_id", account.ID,
- "account_name", account.Name,
+ "account_id", t.account.ID,
+ "account_name", t.account.Name,
)
- refreshed++
+ refreshed.Add(1)
}
-
- // 每个账号只由一个refresher处理
- break
- }
+ }(task)
}
- // 无刷新活动时降级为 Debug,有实际刷新活动时保持 Info
- if needsRefresh == 0 && failed == 0 {
+ wg.Wait()
+
+ r, f := int(refreshed.Load()), int(failed.Load())
+ if interrupted {
+ slog.Info("token_refresh.cycle_wait_completed_after_stop",
+ "needs_refresh", needsRefresh, "refreshed", r, "failed", f)
+ }
+ if needsRefresh == 0 && f == 0 {
slog.Debug("token_refresh.cycle_completed",
"total", totalAccounts, "oauth", oauthAccounts,
- "needs_refresh", needsRefresh, "refreshed", refreshed, "failed", failed)
+ "needs_refresh", needsRefresh, "refreshed", r, "failed", f)
} else {
slog.Info("token_refresh.cycle_completed",
- "total", totalAccounts,
- "oauth", oauthAccounts,
- "needs_refresh", needsRefresh,
- "refreshed", refreshed,
- "failed", failed,
+ "total", totalAccounts, "oauth", oauthAccounts,
+ "needs_refresh", needsRefresh, "refreshed", r, "failed", f)
+ }
+}
+
+func (s *TokenRefreshService) tryAcquireDistributedRefreshLock(ctx context.Context, account *Account) bool {
+ if s == nil || account == nil || account.ID <= 0 || s.schedulerCache == nil {
+ return true
+ }
+ lockTTL := s.tokenRefreshDistributedLockTTL()
+ if lockTTL <= 0 {
+ return true
+ }
+ lockCtx := ctx
+ if lockCtx == nil {
+ lockCtx = context.Background()
+ }
+ lockCtx, cancel := context.WithTimeout(lockCtx, tokenRefreshDistributedLockTimeout)
+ defer cancel()
+
+ lockBucket := SchedulerBucket{
+ GroupID: account.ID,
+ Platform: tokenRefreshDistributedLockPlatform,
+ Mode: normalizeTokenRefreshLockMode(account.Platform),
+ }
+ locked, err := s.schedulerCache.TryLockBucket(lockCtx, lockBucket, lockTTL)
+ if err != nil {
+ if ctx != nil && ctx.Err() != nil {
+ slog.Info("token_refresh.distributed_lock_canceled",
+ "account_id", account.ID,
+ "platform", account.Platform,
+ "error", err,
+ )
+ return false
+ }
+ slog.Warn("token_refresh.distributed_lock_failed",
+ "account_id", account.ID,
+ "platform", account.Platform,
+ "error", err,
+ "fail_open", true,
+ )
+ return true
+ }
+ if !locked {
+ slog.Debug("token_refresh.distributed_lock_held",
+ "account_id", account.ID,
+ "platform", account.Platform,
)
+ return false
+ }
+ return true
+}
+
+func normalizeTokenRefreshLockMode(platform string) string {
+ mode := strings.TrimSpace(platform)
+ if mode == "" {
+ return "unknown"
+ }
+ return mode
+}
+
+func (s *TokenRefreshService) tokenRefreshDistributedLockTTL() time.Duration {
+ if s == nil || s.cfg == nil {
+ return tokenRefreshDistributedLockMinTTL
}
+ checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute
+ if checkInterval <= 0 {
+ return tokenRefreshDistributedLockMinTTL
+ }
+ ttl := checkInterval / 2
+ if ttl < tokenRefreshDistributedLockMinTTL {
+ ttl = tokenRefreshDistributedLockMinTTL
+ }
+ if ttl > tokenRefreshDistributedLockMaxTTL {
+ ttl = tokenRefreshDistributedLockMaxTTL
+ }
+ return ttl
}
// listActiveAccounts 获取所有active状态的账号
@@ -281,7 +430,9 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if attempt < s.cfg.MaxRetries {
// 指数退避:2^(attempt-1) * baseSeconds
backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1))
- time.Sleep(backoff)
+ if err := s.waitRetryBackoff(ctx, backoff); err != nil {
+ return err
+ }
}
}
@@ -306,6 +457,23 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
return lastErr
}
+func (s *TokenRefreshService) waitRetryBackoff(ctx context.Context, backoff time.Duration) error {
+ if backoff <= 0 {
+ return nil
+ }
+ timer := time.NewTimer(backoff)
+ defer timer.Stop()
+
+ select {
+ case <-timer.C:
+ return nil
+ case <-s.stopCh:
+ return context.Canceled
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go
index 8e16c6f5d..142a5c169 100644
--- a/backend/internal/service/token_refresh_service_test.go
+++ b/backend/internal/service/token_refresh_service_test.go
@@ -14,10 +14,11 @@ import (
type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini
- updateCalls int
- setErrorCalls int
- lastAccount *Account
- updateErr error
+ updateCalls int
+ setErrorCalls int
+ lastAccount *Account
+ updateErr error
+ activeAccounts []Account
}
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
@@ -31,6 +32,15 @@ func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorM
return nil
}
+func (r *tokenRefreshAccountRepo) ListActive(ctx context.Context) ([]Account, error) {
+ if len(r.activeAccounts) == 0 {
+ return nil, nil
+ }
+ out := make([]Account, 0, len(r.activeAccounts))
+ out = append(out, r.activeAccounts...)
+ return out, nil
+}
+
type tokenCacheInvalidatorStub struct {
calls int
err error
@@ -42,8 +52,9 @@ func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account
}
type tokenRefresherStub struct {
- credentials map[string]any
- err error
+ credentials map[string]any
+ err error
+ refreshCalls int
}
func (r *tokenRefresherStub) CanRefresh(account *Account) bool {
@@ -55,12 +66,41 @@ func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuratio
}
func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
+ r.refreshCalls++
if r.err != nil {
return nil, r.err
}
return r.credentials, nil
}
+type tokenRefreshSchedulerLockStub struct {
+ SchedulerCache
+ lockByAccount map[int64]bool
+ err error
+ calls []SchedulerBucket
+}
+
+func (s *tokenRefreshSchedulerLockStub) TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error) {
+ _ = ctx
+ _ = ttl
+ s.calls = append(s.calls, bucket)
+ if s.err != nil {
+ return false, s.err
+ }
+ if s.lockByAccount != nil {
+ if ok, exists := s.lockByAccount[bucket.GroupID]; exists {
+ return ok, nil
+ }
+ }
+ return true, nil
+}
+
+func (s *tokenRefreshSchedulerLockStub) SetAccount(ctx context.Context, account *Account) error {
+ _ = ctx
+ _ = account
+ return nil
+}
+
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
@@ -89,6 +129,96 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
require.Equal(t, "new-token", account.GetCredential("access_token"))
}
+func TestTokenRefreshService_ProcessRefresh_SkipsWhenDistributedLockHeld(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{
+ activeAccounts: []Account{
+ {ID: 31, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true},
+ },
+ }
+ lockStub := &tokenRefreshSchedulerLockStub{
+ lockByAccount: map[int64]bool{31: false},
+ }
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 2,
+ RefreshBeforeExpiryHours: 1,
+ SyncLinkedSoraAccounts: false,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}}
+ service.refreshers = []TokenRefresher{refresher}
+
+ service.processRefresh()
+
+ require.Len(t, lockStub.calls, 1)
+ require.Equal(t, int64(31), lockStub.calls[0].GroupID)
+ require.Equal(t, tokenRefreshDistributedLockPlatform, lockStub.calls[0].Platform)
+ require.Equal(t, PlatformOpenAI, lockStub.calls[0].Mode)
+ require.Equal(t, 0, refresher.refreshCalls, "lock held by another instance should skip refresh")
+ require.Equal(t, 0, repo.updateCalls)
+}
+
+func TestTokenRefreshService_ProcessRefresh_RefreshesWhenDistributedLockAcquired(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{
+ activeAccounts: []Account{
+ {ID: 32, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true},
+ },
+ }
+ lockStub := &tokenRefreshSchedulerLockStub{
+ lockByAccount: map[int64]bool{32: true},
+ }
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 2,
+ RefreshBeforeExpiryHours: 1,
+ SyncLinkedSoraAccounts: false,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}}
+ service.refreshers = []TokenRefresher{refresher}
+
+ service.processRefresh()
+
+ require.Len(t, lockStub.calls, 1)
+ require.Equal(t, 1, refresher.refreshCalls)
+ require.Equal(t, 1, repo.updateCalls, "lock acquired should allow refresh")
+}
+
+func TestTokenRefreshService_ProcessRefresh_FailOpenWhenDistributedLockError(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{
+ activeAccounts: []Account{
+ {ID: 33, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true},
+ },
+ }
+ lockStub := &tokenRefreshSchedulerLockStub{
+ err: errors.New("redis unavailable"),
+ }
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 2,
+ RefreshBeforeExpiryHours: 1,
+ SyncLinkedSoraAccounts: false,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}}
+ service.refreshers = []TokenRefresher{refresher}
+
+ service.processRefresh()
+
+ require.Len(t, lockStub.calls, 1)
+ require.Equal(t, 1, refresher.refreshCalls, "lock backend error should fail-open and continue refresh")
+ require.Equal(t, 1, repo.updateCalls)
+}
+
func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")}
@@ -359,3 +489,56 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
})
}
}
+
+func TestTokenRefreshService_Stop_Idempotent(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg)
+
+ require.NotPanics(t, func() {
+ service.Stop()
+ service.Stop()
+ })
+}
+
+func TestTokenRefreshService_RefreshWithRetry_StopInterruptsBackoff(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 3,
+ RetryBackoffSeconds: 5,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg)
+ account := &Account{
+ ID: 21,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ }
+ refresher := &tokenRefresherStub{
+ err: errors.New("refresh failed"),
+ }
+
+ start := time.Now()
+ done := make(chan error, 1)
+ go func() {
+ done <- service.refreshWithRetry(context.Background(), account, refresher)
+ }()
+
+ time.Sleep(80 * time.Millisecond)
+ service.Stop()
+
+ select {
+ case err := <-done:
+ require.ErrorIs(t, err, context.Canceled)
+ case <-time.After(600 * time.Millisecond):
+ t.Fatal("refreshWithRetry should exit quickly after service stop")
+ }
+ require.Less(t, time.Since(start), time.Second)
+ require.Equal(t, 0, repo.setErrorCalls, "stop 中断时不应落错误状态")
+}
diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go
index 0dd3cf45d..a4fc1c70e 100644
--- a/backend/internal/service/token_refresher.go
+++ b/backend/internal/service/token_refresher.go
@@ -7,6 +7,18 @@ import (
"time"
)
+const openAISoraSyncConcurrencyLimit = 8
+
+// needsRefreshWithoutExpiry 在 expires_at 缺失时判断是否需要刷新。
+// 通过 Account.UpdatedAt 避免每轮刷新周期都发起无效刷新:
+// 如果账号在 refreshWindow 内曾被更新,说明最近可能已刷新过,跳过本轮。
+func needsRefreshWithoutExpiry(account *Account, refreshWindow time.Duration) bool {
+ if refreshWindow <= 0 {
+ return true
+ }
+ return time.Since(account.UpdatedAt) >= refreshWindow
+}
+
// TokenRefresher 定义平台特定的token刷新策略接口
// 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini)
type TokenRefresher interface {
@@ -46,7 +58,8 @@ func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool {
func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
- return false
+ // 无过期时间:如果账号近期已更新(可能刚刷新过),跳过本轮
+ return needsRefreshWithoutExpiry(account, refreshWindow)
}
return time.Until(*expiresAt) < refreshWindow
}
@@ -87,6 +100,10 @@ type OpenAITokenRefresher struct {
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
syncLinkedSora bool
+ syncLinkedSoraSem chan struct{}
+
+ // test hook: override sync execution target when needed.
+ syncLinkedSoraAccountsFn func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any)
}
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
@@ -94,6 +111,7 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService, accountRepo
return &OpenAITokenRefresher{
openaiOAuthService: openaiOAuthService,
accountRepo: accountRepo,
+ syncLinkedSoraSem: make(chan struct{}, openAISoraSyncConcurrencyLimit),
}
}
@@ -120,7 +138,8 @@ func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
- return false
+ // 无过期时间:如果账号近期已更新(可能刚刷新过),跳过本轮
+ return needsRefreshWithoutExpiry(account, refreshWindow)
}
return time.Until(*expiresAt) < refreshWindow
@@ -147,12 +166,58 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
// 异步同步关联的 Sora 账号(不阻塞主流程)
if r.accountRepo != nil && r.syncLinkedSora {
- go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
+ syncCredentials := copyCredentialsMap(newCredentials)
+ syncFn := r.syncLinkedSoraAccounts
+ if r.syncLinkedSoraAccountsFn != nil {
+ syncFn = r.syncLinkedSoraAccountsFn
+ }
+ if r.tryAcquireSyncLinkedSoraSlot() {
+ go func() {
+ defer r.releaseSyncLinkedSoraSlot()
+ syncFn(context.Background(), account.ID, syncCredentials)
+ }()
+ } else {
+ // 达到并发上限时回退为同步执行,避免 goroutine 无界堆积。
+ syncFn(ctx, account.ID, syncCredentials)
+ }
}
return newCredentials, nil
}
+func copyCredentialsMap(src map[string]any) map[string]any {
+ if len(src) == 0 {
+ return map[string]any{}
+ }
+ dst := make(map[string]any, len(src))
+ for k, v := range src {
+ dst[k] = v
+ }
+ return dst
+}
+
+func (r *OpenAITokenRefresher) tryAcquireSyncLinkedSoraSlot() bool {
+ if r == nil || r.syncLinkedSoraSem == nil {
+ return false
+ }
+ select {
+ case r.syncLinkedSoraSem <- struct{}{}:
+ return true
+ default:
+ return false
+ }
+}
+
+func (r *OpenAITokenRefresher) releaseSyncLinkedSoraSlot() {
+ if r == nil || r.syncLinkedSoraSem == nil {
+ return
+ }
+ select {
+ case <-r.syncLinkedSoraSem:
+ default:
+ }
+}
+
// syncLinkedSoraAccounts 同步关联的 Sora 账号的 token(双表同步)
// 该方法异步执行,失败只记录日志,不影响主流程
//
diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go
index 264d79125..27d786634 100644
--- a/backend/internal/service/token_refresher_test.go
+++ b/backend/internal/service/token_refresher_test.go
@@ -3,13 +3,38 @@
package service
import (
+ "context"
"strconv"
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
+type openAIOAuthClientStubForRefresher struct {
+ tokenResp *openai.TokenResponse
+ err error
+}
+
+func (s *openAIOAuthClientStubForRefresher) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
+ return nil, s.err
+}
+
+func (s *openAIOAuthClientStubForRefresher) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ return s.tokenResp, nil
+}
+
+func (s *openAIOAuthClientStubForRefresher) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ return s.tokenResp, nil
+}
+
func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) {
refresher := &ClaudeTokenRefresher{}
refreshWindow := 30 * time.Minute
@@ -64,26 +89,26 @@ func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) {
{
name: "expires_at missing",
credentials: map[string]any{},
- wantRefresh: false,
+ wantRefresh: true,
},
{
name: "expires_at is nil",
credentials: map[string]any{
"expires_at": nil,
},
- wantRefresh: false,
+ wantRefresh: true,
},
{
name: "expires_at is invalid string",
credentials: map[string]any{
"expires_at": "invalid",
},
- wantRefresh: false,
+ wantRefresh: true,
},
{
name: "credentials is nil",
credentials: nil,
- wantRefresh: false,
+ wantRefresh: true,
},
}
@@ -179,6 +204,36 @@ func TestClaudeTokenRefresher_NeedsRefresh_OutsideWindow(t *testing.T) {
}
}
+func TestNeedsRefreshWithoutExpiry_RecentlyUpdated(t *testing.T) {
+ refreshWindow := 30 * time.Minute
+
+ t.Run("recently_updated_skips_refresh", func(t *testing.T) {
+ // 账号近期更新过(5 分钟前),不需要刷新
+ account := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{},
+ UpdatedAt: time.Now().Add(-5 * time.Minute),
+ }
+ refresher := &ClaudeTokenRefresher{}
+ require.False(t, refresher.NeedsRefresh(account, refreshWindow),
+ "近期更新过的账号无 expires_at 时不应刷新")
+ })
+
+ t.Run("old_updated_needs_refresh", func(t *testing.T) {
+ // 账号很久没更新(2 小时前),需要刷新
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{},
+ UpdatedAt: time.Now().Add(-2 * time.Hour),
+ }
+ refresher := &OpenAITokenRefresher{}
+ require.True(t, refresher.NeedsRefresh(account, refreshWindow),
+ "长期未更新的账号无 expires_at 时应刷新")
+ })
+}
+
func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
refresher := &ClaudeTokenRefresher{}
@@ -266,3 +321,159 @@ func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
})
}
}
+
+func TestOpenAITokenRefresher_NeedsRefresh(t *testing.T) {
+ refresher := &OpenAITokenRefresher{}
+ refreshWindow := 30 * time.Minute
+
+ tests := []struct {
+ name string
+ credentials map[string]any
+ wantRefresh bool
+ }{
+ {
+ name: "expires_at missing",
+ credentials: map[string]any{
+ "access_token": "token",
+ },
+ wantRefresh: true,
+ },
+ {
+ name: "expires_at invalid",
+ credentials: map[string]any{
+ "expires_at": "invalid",
+ },
+ wantRefresh: true,
+ },
+ {
+ name: "expires_at expired",
+ credentials: map[string]any{
+ "expires_at": strconv.FormatInt(time.Now().Add(-time.Minute).Unix(), 10),
+ },
+ wantRefresh: true,
+ },
+ {
+ name: "expires_at far future",
+ credentials: map[string]any{
+ "expires_at": strconv.FormatInt(time.Now().Add(2*time.Hour).Unix(), 10),
+ },
+ wantRefresh: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: tt.credentials,
+ }
+ require.Equal(t, tt.wantRefresh, refresher.NeedsRefresh(account, refreshWindow))
+ })
+ }
+}
+
+func TestOpenAITokenRefresher_Refresh_AsyncSyncUsesCopiedCredentials(t *testing.T) {
+ oauthSvc := NewOpenAIOAuthService(nil, &openAIOAuthClientStubForRefresher{
+ tokenResp: &openai.TokenResponse{
+ AccessToken: "new_access_token",
+ RefreshToken: "new_refresh_token",
+ ExpiresIn: 3600,
+ },
+ })
+ refresher := NewOpenAITokenRefresher(oauthSvc, &mockAccountRepoForGemini{})
+ refresher.SetSyncLinkedSoraAccounts(true)
+ refresher.syncLinkedSoraSem = make(chan struct{}, 1)
+
+ readNow := make(chan struct{})
+ seenValue := make(chan string, 1)
+ refresher.syncLinkedSoraAccountsFn = func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) {
+ <-readNow
+ if v, ok := newCredentials["custom"].(string); ok {
+ seenValue <- v
+ return
+ }
+ seenValue <- ""
+ }
+
+ account := &Account{
+ ID: 1001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "refresh_token": "old_refresh_token",
+ "client_id": "test-client",
+ "custom": "original",
+ },
+ }
+
+ newCredentials, err := refresher.Refresh(context.Background(), account)
+ require.NoError(t, err)
+ require.NotNil(t, newCredentials)
+
+ newCredentials["custom"] = "mutated_after_return"
+ close(readNow)
+
+ select {
+ case got := <-seenValue:
+ require.Equal(t, "original", got, "异步同步应使用 credentials 副本,避免并发写污染")
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("timed out waiting for sync hook")
+ }
+}
+
+func TestOpenAITokenRefresher_Refresh_FallsBackToSyncWhenLimiterFull(t *testing.T) {
+ oauthSvc := NewOpenAIOAuthService(nil, &openAIOAuthClientStubForRefresher{
+ tokenResp: &openai.TokenResponse{
+ AccessToken: "new_access_token",
+ RefreshToken: "new_refresh_token",
+ ExpiresIn: 3600,
+ },
+ })
+ refresher := NewOpenAITokenRefresher(oauthSvc, &mockAccountRepoForGemini{})
+ refresher.SetSyncLinkedSoraAccounts(true)
+ refresher.syncLinkedSoraSem = make(chan struct{}, 1)
+ refresher.syncLinkedSoraSem <- struct{}{} // 填满 limiter,强制走同步降级路径
+
+ entered := make(chan struct{})
+ releaseSync := make(chan struct{})
+ refresher.syncLinkedSoraAccountsFn = func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) {
+ close(entered)
+ <-releaseSync
+ }
+
+ account := &Account{
+ ID: 1002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "refresh_token": "old_refresh_token",
+ "client_id": "test-client",
+ },
+ }
+
+ done := make(chan struct{})
+ go func() {
+ _, _ = refresher.Refresh(context.Background(), account)
+ close(done)
+ }()
+
+ select {
+ case <-entered:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("sync hook was not invoked")
+ }
+
+ select {
+ case <-done:
+ t.Fatal("Refresh should block when falling back to synchronous linked-sora sync")
+ default:
+ }
+
+ close(releaseSync)
+ select {
+ case <-done:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("Refresh did not finish after releasing synchronous sync hook")
+ }
+}
diff --git a/backend/internal/service/usage_billing_compensation_service.go b/backend/internal/service/usage_billing_compensation_service.go
new file mode 100644
index 000000000..47ea71618
--- /dev/null
+++ b/backend/internal/service/usage_billing_compensation_service.go
@@ -0,0 +1,256 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+const (
+ defaultUsageBillingCompensationInterval = 20 * time.Second
+ defaultUsageBillingCompensationBatchSize = 64
+ defaultUsageBillingCompensationTaskTimout = 8 * time.Second
+ defaultUsageBillingCompensationStaleAfter = 3 * time.Minute
+)
+
+// UsageBillingCompensationService retries pending usage charges in billing_usage_entries.
+// It only runs when usageLogRepo supports UsageBillingEntryStore.
+type UsageBillingCompensationService struct {
+ usageLogRepo UsageLogRepository
+ userRepo UserRepository
+ userSubRepo UserSubscriptionRepository
+ billingCache *BillingCacheService
+ cfg *config.Config
+
+ startOnce sync.Once
+ stopOnce sync.Once
+ stopCh chan struct{}
+}
+
+func NewUsageBillingCompensationService(
+ usageLogRepo UsageLogRepository,
+ userRepo UserRepository,
+ userSubRepo UserSubscriptionRepository,
+ billingCache *BillingCacheService,
+ cfg *config.Config,
+) *UsageBillingCompensationService {
+ return &UsageBillingCompensationService{
+ usageLogRepo: usageLogRepo,
+ userRepo: userRepo,
+ userSubRepo: userSubRepo,
+ billingCache: billingCache,
+ cfg: cfg,
+ stopCh: make(chan struct{}),
+ }
+}
+
+func (s *UsageBillingCompensationService) Start() {
+ if s == nil || s.store() == nil {
+ return
+ }
+ if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
+ return
+ }
+ s.startOnce.Do(func() {
+ slog.Info("usage_billing_compensation.started",
+ "interval", defaultUsageBillingCompensationInterval.String(),
+ "batch_size", defaultUsageBillingCompensationBatchSize,
+ )
+ go s.runLoop()
+ })
+}
+
+func (s *UsageBillingCompensationService) Stop() {
+ if s == nil {
+ return
+ }
+ s.stopOnce.Do(func() {
+ close(s.stopCh)
+ slog.Info("usage_billing_compensation.stopped")
+ })
+}
+
+func (s *UsageBillingCompensationService) runLoop() {
+ ticker := time.NewTicker(defaultUsageBillingCompensationInterval)
+ defer ticker.Stop()
+
+ s.processOnce()
+
+ for {
+ select {
+ case <-ticker.C:
+ s.processOnce()
+ case <-s.stopCh:
+ return
+ }
+ }
+}
+
+func (s *UsageBillingCompensationService) processOnce() {
+ store := s.store()
+ if store == nil {
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), defaultUsageBillingCompensationTaskTimout)
+ defer cancel()
+ go func() {
+ select {
+ case <-s.stopCh:
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
+
+ entries, err := store.ClaimUsageBillingEntries(ctx, defaultUsageBillingCompensationBatchSize, defaultUsageBillingCompensationStaleAfter)
+ if err != nil {
+ slog.Warn("usage_billing_compensation.claim_failed", "error", err)
+ return
+ }
+ for i := range entries {
+ if ctx.Err() != nil {
+ return
+ }
+ s.processEntry(ctx, entries[i])
+ }
+}
+
+func (s *UsageBillingCompensationService) processEntry(ctx context.Context, entry UsageBillingEntry) {
+ if entry.Applied || entry.DeltaUSD <= 0 {
+ s.markApplied(ctx, entry)
+ return
+ }
+
+ if err := s.applyEntry(ctx, entry); err != nil {
+ s.markRetry(ctx, entry, err)
+ return
+ }
+}
+
+func (s *UsageBillingCompensationService) applyEntry(ctx context.Context, entry UsageBillingEntry) error {
+ switch entry.BillingType {
+ case BillingTypeSubscription:
+ return s.applySubscriptionEntry(ctx, entry)
+ default:
+ return s.applyBalanceEntry(ctx, entry)
+ }
+}
+
+func (s *UsageBillingCompensationService) applyBalanceEntry(ctx context.Context, entry UsageBillingEntry) error {
+ if s.userRepo == nil {
+ return errors.New("user repository unavailable")
+ }
+
+ cacheDeducted := false
+ if s.billingCache != nil {
+ if err := s.billingCache.DeductBalanceCache(ctx, entry.UserID, entry.DeltaUSD); err != nil {
+ slog.Warn("usage_billing_compensation.balance_cache_deduct_failed",
+ "entry_id", entry.ID,
+ "user_id", entry.UserID,
+ "amount", entry.DeltaUSD,
+ "error", err,
+ )
+ _ = s.billingCache.InvalidateUserBalance(ctx, entry.UserID)
+ } else {
+ cacheDeducted = true
+ }
+ }
+
+ if err := s.runWithTx(ctx, func(txCtx context.Context) error {
+ if err := s.userRepo.DeductBalance(txCtx, entry.UserID, entry.DeltaUSD); err != nil {
+ return err
+ }
+ return s.store().MarkUsageBillingEntryApplied(txCtx, entry.ID)
+ }); err != nil {
+ if s.billingCache != nil && cacheDeducted {
+ _ = s.billingCache.InvalidateUserBalance(ctx, entry.UserID)
+ }
+ return err
+ }
+
+ return nil
+}
+
+func (s *UsageBillingCompensationService) applySubscriptionEntry(ctx context.Context, entry UsageBillingEntry) error {
+ if s.userSubRepo == nil {
+ return errors.New("subscription repository unavailable")
+ }
+ if entry.SubscriptionID == nil {
+ return errors.New("subscription_id is nil for subscription billing")
+ }
+
+ if err := s.runWithTx(ctx, func(txCtx context.Context) error {
+ if err := s.userSubRepo.IncrementUsage(txCtx, *entry.SubscriptionID, entry.DeltaUSD); err != nil {
+ return err
+ }
+ return s.store().MarkUsageBillingEntryApplied(txCtx, entry.ID)
+ }); err != nil {
+ return err
+ }
+
+ if s.billingCache != nil {
+ sub, err := s.userSubRepo.GetByID(ctx, *entry.SubscriptionID)
+ if err == nil && sub != nil {
+ _ = s.billingCache.InvalidateSubscription(ctx, entry.UserID, sub.GroupID)
+ }
+ }
+
+ return nil
+}
+
+func (s *UsageBillingCompensationService) markApplied(ctx context.Context, entry UsageBillingEntry) {
+ store := s.store()
+ if store == nil {
+ return
+ }
+ if err := store.MarkUsageBillingEntryApplied(ctx, entry.ID); err != nil {
+ slog.Warn("usage_billing_compensation.mark_applied_failed", "entry_id", entry.ID, "error", err)
+ }
+}
+
+func (s *UsageBillingCompensationService) markRetry(ctx context.Context, entry UsageBillingEntry, cause error) {
+ store := s.store()
+ if store == nil {
+ return
+ }
+ errMsg := strings.TrimSpace(cause.Error())
+ if len(errMsg) > 500 {
+ errMsg = errMsg[:500]
+ }
+ backoff := usageBillingRetryBackoff(entry.AttemptCount)
+ nextRetryAt := time.Now().Add(backoff)
+ if err := store.MarkUsageBillingEntryRetry(ctx, entry.ID, nextRetryAt, errMsg); err != nil {
+ slog.Warn("usage_billing_compensation.mark_retry_failed",
+ "entry_id", entry.ID,
+ "next_retry_at", nextRetryAt,
+ "error", err,
+ )
+ return
+ }
+ slog.Warn("usage_billing_compensation.requeued",
+ "entry_id", entry.ID,
+ "attempt", entry.AttemptCount,
+ "next_retry_at", nextRetryAt,
+ "error", errMsg,
+ )
+}
+
+func (s *UsageBillingCompensationService) runWithTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ if runner, ok := s.usageLogRepo.(UsageBillingTxRunner); ok && runner != nil {
+ return runner.WithUsageBillingTx(ctx, fn)
+ }
+ return fn(ctx)
+}
+
+func (s *UsageBillingCompensationService) store() UsageBillingEntryStore {
+ store, ok := s.usageLogRepo.(UsageBillingEntryStore)
+ if !ok {
+ return nil
+ }
+ return store
+}
diff --git a/backend/internal/service/usage_billing_compensation_service_test.go b/backend/internal/service/usage_billing_compensation_service_test.go
new file mode 100644
index 000000000..f45c78efe
--- /dev/null
+++ b/backend/internal/service/usage_billing_compensation_service_test.go
@@ -0,0 +1,230 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type usageBillingCompRepoStub struct {
+ UsageLogRepository
+
+ claimErr error
+ claims []UsageBillingEntry
+
+ markAppliedCalls int
+ markRetryCalls int
+ lastRetryID int64
+ lastRetryAt time.Time
+ lastRetryErr string
+ lastMarkAppliedCtx context.Context
+ lastTxCtx context.Context
+}
+
+func (s *usageBillingCompRepoStub) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error) {
+ return nil, ErrUsageBillingEntryNotFound
+}
+
+func (s *usageBillingCompRepoStub) UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error) {
+ return entry, true, nil
+}
+
+func (s *usageBillingCompRepoStub) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error {
+ s.markAppliedCalls++
+ s.lastMarkAppliedCtx = ctx
+ return nil
+}
+
+func (s *usageBillingCompRepoStub) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error {
+ s.markRetryCalls++
+ s.lastRetryID = entryID
+ s.lastRetryAt = nextRetryAt
+ s.lastRetryErr = lastError
+ _ = ctx
+ return nil
+}
+
+func (s *usageBillingCompRepoStub) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error) {
+ if s.claimErr != nil {
+ return nil, s.claimErr
+ }
+ out := make([]UsageBillingEntry, len(s.claims))
+ copy(out, s.claims)
+ s.claims = nil
+ return out, nil
+}
+
+func (s *usageBillingCompRepoStub) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ s.lastTxCtx = ctx
+ if fn == nil {
+ return nil
+ }
+ return fn(ctx)
+}
+
+type usageBillingCompUserRepoStub struct {
+ UserRepository
+
+ deductCalls int
+ deductErr error
+ lastDeductCtx context.Context
+}
+
+func (s *usageBillingCompUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
+ s.deductCalls++
+ s.lastDeductCtx = ctx
+ return s.deductErr
+}
+
+type usageBillingCompSubRepoStub struct {
+ UserSubscriptionRepository
+
+ incrementCalls int
+ incrementErr error
+ lastIncrementCtx context.Context
+}
+
+func (s *usageBillingCompSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
+ s.incrementCalls++
+ s.lastIncrementCtx = ctx
+ return s.incrementErr
+}
+
+type usageBillingCompCtxKey string
+
+func TestUsageBillingCompensationService_ProcessOnceBalanceSuccess(t *testing.T) {
+ repo := &usageBillingCompRepoStub{
+ claims: []UsageBillingEntry{
+ {
+ ID: 1,
+ UsageLogID: 1001,
+ UserID: 2001,
+ BillingType: BillingTypeBalance,
+ DeltaUSD: 1.23,
+ AttemptCount: 1,
+ },
+ },
+ }
+ userRepo := &usageBillingCompUserRepoStub{}
+ subRepo := &usageBillingCompSubRepoStub{}
+ svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{})
+
+ svc.processOnce()
+
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.Equal(t, 1, repo.markAppliedCalls)
+ require.Equal(t, 0, repo.markRetryCalls)
+}
+
+func TestUsageBillingCompensationService_ProcessOnceBalanceFailureRequeues(t *testing.T) {
+ repo := &usageBillingCompRepoStub{
+ claims: []UsageBillingEntry{
+ {
+ ID: 2,
+ UsageLogID: 1002,
+ UserID: 2002,
+ BillingType: BillingTypeBalance,
+ DeltaUSD: 2.34,
+ AttemptCount: 2,
+ },
+ },
+ }
+ userRepo := &usageBillingCompUserRepoStub{deductErr: errors.New("db down")}
+ subRepo := &usageBillingCompSubRepoStub{}
+ svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{})
+
+ svc.processOnce()
+
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.Equal(t, 0, repo.markAppliedCalls)
+ require.Equal(t, 1, repo.markRetryCalls)
+ require.Equal(t, int64(2), repo.lastRetryID)
+ require.NotZero(t, repo.lastRetryAt)
+ require.Contains(t, repo.lastRetryErr, "db down")
+}
+
+func TestUsageBillingCompensationService_ProcessOnceSubscriptionSuccess(t *testing.T) {
+ subID := int64(4003)
+ repo := &usageBillingCompRepoStub{
+ claims: []UsageBillingEntry{
+ {
+ ID: 3,
+ UsageLogID: 1003,
+ UserID: 2003,
+ SubscriptionID: &subID,
+ BillingType: BillingTypeSubscription,
+ DeltaUSD: 3.45,
+ AttemptCount: 1,
+ },
+ },
+ }
+ userRepo := &usageBillingCompUserRepoStub{}
+ subRepo := &usageBillingCompSubRepoStub{}
+ svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{})
+
+ svc.processOnce()
+
+ require.Equal(t, 1, subRepo.incrementCalls)
+ require.Equal(t, 1, repo.markAppliedCalls)
+ require.Equal(t, 0, repo.markRetryCalls)
+}
+
+func TestUsageBillingCompensationService_ApplyBalanceEntryPropagatesContext(t *testing.T) {
+ repo := &usageBillingCompRepoStub{}
+ userRepo := &usageBillingCompUserRepoStub{}
+ subRepo := &usageBillingCompSubRepoStub{}
+ svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{})
+
+ entry := UsageBillingEntry{
+ ID: 10,
+ UsageLogID: 1010,
+ UserID: 2010,
+ BillingType: BillingTypeBalance,
+ DeltaUSD: 1.11,
+ }
+ key := usageBillingCompCtxKey("trace")
+ ctx := context.WithValue(context.Background(), key, "balance")
+
+ err := svc.applyBalanceEntry(ctx, entry)
+ require.NoError(t, err)
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.NotNil(t, repo.lastTxCtx)
+ require.NotNil(t, userRepo.lastDeductCtx)
+ require.NotNil(t, repo.lastMarkAppliedCtx)
+ require.Equal(t, "balance", repo.lastTxCtx.Value(key))
+ require.Equal(t, "balance", userRepo.lastDeductCtx.Value(key))
+ require.Equal(t, "balance", repo.lastMarkAppliedCtx.Value(key))
+}
+
+func TestUsageBillingCompensationService_ApplySubscriptionEntryPropagatesContext(t *testing.T) {
+ subID := int64(4010)
+ repo := &usageBillingCompRepoStub{}
+ userRepo := &usageBillingCompUserRepoStub{}
+ subRepo := &usageBillingCompSubRepoStub{}
+ svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{})
+
+ entry := UsageBillingEntry{
+ ID: 11,
+ UsageLogID: 1011,
+ UserID: 2011,
+ SubscriptionID: &subID,
+ BillingType: BillingTypeSubscription,
+ DeltaUSD: 2.22,
+ }
+ key := usageBillingCompCtxKey("trace")
+ ctx := context.WithValue(context.Background(), key, "subscription")
+
+ err := svc.applySubscriptionEntry(ctx, entry)
+ require.NoError(t, err)
+ require.Equal(t, 1, subRepo.incrementCalls)
+ require.NotNil(t, repo.lastTxCtx)
+ require.NotNil(t, subRepo.lastIncrementCtx)
+ require.NotNil(t, repo.lastMarkAppliedCtx)
+ require.Equal(t, "subscription", repo.lastTxCtx.Value(key))
+ require.Equal(t, "subscription", subRepo.lastIncrementCtx.Value(key))
+ require.Equal(t, "subscription", repo.lastMarkAppliedCtx.Value(key))
+}
diff --git a/backend/internal/service/usage_billing_entry.go b/backend/internal/service/usage_billing_entry.go
new file mode 100644
index 000000000..a24714082
--- /dev/null
+++ b/backend/internal/service/usage_billing_entry.go
@@ -0,0 +1,60 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "time"
+)
+
+var ErrUsageBillingEntryNotFound = errors.New("usage billing entry not found")
+
+type UsageBillingEntryStatus int16
+
+const (
+ UsageBillingEntryStatusPending UsageBillingEntryStatus = 0
+ UsageBillingEntryStatusProcessing UsageBillingEntryStatus = 1
+ UsageBillingEntryStatusApplied UsageBillingEntryStatus = 2
+)
+
+type UsageBillingEntry struct {
+ ID int64
+ UsageLogID int64
+ UserID int64
+ APIKeyID int64
+ SubscriptionID *int64
+ BillingType int8
+ Applied bool
+ DeltaUSD float64
+ Status UsageBillingEntryStatus
+ AttemptCount int
+ NextRetryAt time.Time
+ UpdatedAt time.Time
+ CreatedAt time.Time
+ LastError *string
+}
+
+type UsageBillingEntryStore interface {
+ GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error)
+ UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error)
+ MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error
+ MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error
+ ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error)
+}
+
+type UsageBillingTxRunner interface {
+ WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error
+}
+
+func usageBillingRetryBackoff(attempt int) time.Duration {
+ if attempt <= 1 {
+ return 30 * time.Second
+ }
+ backoff := 30 * time.Second
+ for i := 1; i < attempt && backoff < 30*time.Minute; i++ {
+ backoff *= 2
+ }
+ if backoff > 30*time.Minute {
+ return 30 * time.Minute
+ }
+ return backoff
+}
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index 05fe50560..c8661c915 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -46,7 +46,7 @@ func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int
return 0, nil
}
func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
-func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
+func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index c71851901..2ada41adf 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -293,13 +293,6 @@ func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthC
return apiKeyService
}
-// ProvideSettingService wires SettingService with group reader for default subscription validation.
-func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupRepository, cfg *config.Config) *SettingService {
- svc := NewSettingService(settingRepo, cfg)
- svc.SetDefaultSubscriptionGroupReader(groupRepo)
- return svc
-}
-
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
@@ -342,7 +335,7 @@ var ProviderSet = wire.NewSet(
ProvideRateLimitService,
NewAccountUsageService,
NewAccountTestService,
- ProvideSettingService,
+ NewSettingService,
NewDataManagementService,
ProvideOpsSystemLogSink,
NewOpsService,
@@ -355,7 +348,6 @@ var ProviderSet = wire.NewSet(
ProvideEmailQueueService,
NewTurnstileService,
NewSubscriptionService,
- wire.Bind(new(DefaultSubscriptionAssigner), new(*SubscriptionService)),
ProvideConcurrencyService,
ProvideUserMessageQueueService,
NewUsageRecordWorkerPool,
diff --git a/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql
index d0ed5d6dd..93af0da7f 100644
--- a/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql
+++ b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql
@@ -1,46 +1,37 @@
--- Add gemini-3.1-flash-image and gemini-3.1-flash-image-preview to model_mapping
+-- Add gemini-3.1-flash-image mapping keys without wiping existing custom mappings.
--
-- Background:
--- Antigravity now supports gemini-3.1-flash-image as the latest image generation model,
--- replacing the previous gemini-3-pro-image.
+-- Antigravity now supports gemini-3.1-flash-image as the latest image generation model.
+-- Existing accounts may still contain gemini-3-pro-image aliases.
--
-- Strategy:
--- Directly overwrite the entire model_mapping with updated mappings
--- This ensures consistency with DefaultAntigravityModelMapping in constants.go
+-- Incrementally upsert only image-related keys in credentials.model_mapping:
+-- 1) add canonical 3.1 image keys
+-- 2) keep legacy 3-pro-image keys but remap them to 3.1 image for compatibility
+-- This preserves user custom mappings and avoids full mapping overwrite.
UPDATE accounts
SET credentials = jsonb_set(
- credentials,
- '{model_mapping}',
- '{
- "claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
- "claude-opus-4-6": "claude-opus-4-6-thinking",
- "claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
- "claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
- "claude-sonnet-4-6": "claude-sonnet-4-6",
- "claude-sonnet-4-5": "claude-sonnet-4-5",
- "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
- "claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
- "claude-haiku-4-5": "claude-sonnet-4-5",
- "claude-haiku-4-5-20251001": "claude-sonnet-4-5",
- "gemini-2.5-flash": "gemini-2.5-flash",
- "gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
- "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
- "gemini-2.5-pro": "gemini-2.5-pro",
- "gemini-3-flash": "gemini-3-flash",
- "gemini-3-pro-high": "gemini-3-pro-high",
- "gemini-3-pro-low": "gemini-3-pro-low",
- "gemini-3-flash-preview": "gemini-3-flash",
- "gemini-3-pro-preview": "gemini-3-pro-high",
- "gemini-3.1-pro-high": "gemini-3.1-pro-high",
- "gemini-3.1-pro-low": "gemini-3.1-pro-low",
- "gemini-3.1-pro-preview": "gemini-3.1-pro-high",
- "gemini-3.1-flash-image": "gemini-3.1-flash-image",
- "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
- "gpt-oss-120b-medium": "gpt-oss-120b-medium",
- "tab_flash_lite_preview": "tab_flash_lite_preview"
- }'::jsonb
+ jsonb_set(
+ jsonb_set(
+ jsonb_set(
+ credentials,
+ '{model_mapping,gemini-3.1-flash-image}',
+ '"gemini-3.1-flash-image"'::jsonb,
+ true
+ ),
+ '{model_mapping,gemini-3.1-flash-image-preview}',
+ '"gemini-3.1-flash-image"'::jsonb,
+ true
+ ),
+ '{model_mapping,gemini-3-pro-image}',
+ '"gemini-3.1-flash-image"'::jsonb,
+ true
+ ),
+ '{model_mapping,gemini-3-pro-image-preview}',
+ '"gemini-3.1-flash-image"'::jsonb,
+ true
)
WHERE platform = 'antigravity'
AND deleted_at IS NULL
- AND credentials->'model_mapping' IS NOT NULL;
\ No newline at end of file
+ AND credentials->'model_mapping' IS NOT NULL;
diff --git a/backend/migrations/064_add_billing_usage_entry_retry_fields.sql b/backend/migrations/064_add_billing_usage_entry_retry_fields.sql
new file mode 100644
index 000000000..aebb39295
--- /dev/null
+++ b/backend/migrations/064_add_billing_usage_entry_retry_fields.sql
@@ -0,0 +1,27 @@
+-- 064_add_billing_usage_entry_retry_fields.sql
+-- Add retry-state columns for billing_usage_entries compensation worker.
+
+ALTER TABLE billing_usage_entries
+ ADD COLUMN IF NOT EXISTS status SMALLINT NOT NULL DEFAULT 0,
+ ADD COLUMN IF NOT EXISTS attempt_count INTEGER NOT NULL DEFAULT 0,
+ ADD COLUMN IF NOT EXISTS next_retry_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ ADD COLUMN IF NOT EXISTS last_error TEXT;
+
+-- Keep legacy rows aligned with applied flag.
+UPDATE billing_usage_entries
+SET status = CASE WHEN applied THEN 2 ELSE 0 END
+WHERE status NOT IN (0, 1, 2)
+ OR (applied = TRUE AND status <> 2)
+ OR (applied = FALSE AND status = 2);
+
+ALTER TABLE billing_usage_entries
+ DROP CONSTRAINT IF EXISTS chk_billing_usage_entries_status;
+
+ALTER TABLE billing_usage_entries
+ ADD CONSTRAINT chk_billing_usage_entries_status
+ CHECK (status IN (0, 1, 2));
+
+CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_retry
+ ON billing_usage_entries (status, next_retry_at, updated_at)
+ WHERE applied = FALSE;
diff --git a/deploy/Caddyfile b/deploy/Caddyfile
index b643fe9b8..cbd762b1c 100644
--- a/deploy/Caddyfile
+++ b/deploy/Caddyfile
@@ -30,6 +30,36 @@ api.sub2api.com {
# =========================================================================
# 反向代理配置
# =========================================================================
+ # OpenAI Responses(含 WebSocket/SSE)单独代理策略:
+ # 1) flush_interval -1:尽快转发流式分片,降低中间层缓冲导致的断流概率
+ # 2) versions 1.1:确保上游走标准 HTTP/1.1 Upgrade,避免协议协商差异
+ # 3) stream_timeout/stream_close_delay:为长连接提供更宽松生命周期
+ @openai_responses {
+ path /openai/v1/responses*
+ }
+ reverse_proxy @openai_responses localhost:8080 {
+ # 长连接/流式场景建议关闭代理缓冲
+ flush_interval -1
+ # 长连接超时窗口(避免长会话被代理层过早回收)
+ stream_timeout 24h
+ # 配置热重载时,给现有流预留关闭缓冲期
+ stream_close_delay 5m
+
+ # 传递真实客户端信息
+ header_up X-Real-IP {remote_host}
+ header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP}
+
+ transport http {
+ # WebSocket Upgrade 对上游统一使用 HTTP/1.1,更稳妥
+ versions 1.1
+ keepalive 120s
+ keepalive_idle_conns 256
+ read_buffer 32KB
+ write_buffer 32KB
+ compression off
+ }
+ }
+
reverse_proxy localhost:8080 {
# 健康检查
health_uri /health
@@ -45,9 +75,6 @@ api.sub2api.com {
# 传递真实客户端信息
# 兼容 Cloudflare 和直连:后端应优先读取 CF-Connecting-IP,其次 X-Real-IP
header_up X-Real-IP {remote_host}
- header_up X-Forwarded-For {remote_host}
- header_up X-Forwarded-Proto {scheme}
- header_up X-Forwarded-Host {host}
# 保留 Cloudflare 原始头(如果存在)
# 后端获取 IP 的优先级建议: CF-Connecting-IP → X-Real-IP → X-Forwarded-For
header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP}
diff --git a/deploy/Dockerfile b/deploy/Dockerfile
index b33203009..c9fcf3017 100644
--- a/deploy/Dockerfile
+++ b/deploy/Dockerfile
@@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
-ARG GOLANG_IMAGE=golang:1.25.5-alpine
+ARG GOLANG_IMAGE=golang:1.25.7-alpine
ARG ALPINE_IMAGE=alpine:3.20
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index e2eb3130c..da1a54a4c 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -134,12 +134,6 @@ security:
# Allow skipping TLS verification for proxy probe (debug only)
# 允许代理探测时跳过 TLS 证书验证(仅用于调试)
insecure_skip_verify: false
- proxy_fallback:
- # Allow auxiliary services (update check, pricing data) to fallback to direct
- # connection when proxy initialization fails. Does NOT affect AI gateway connections.
- # 辅助服务(更新检查、定价数据拉取)代理初始化失败时是否允许回退直连。
- # 不影响 AI 账号网关连接。默认 false:fail-fast 防止 IP 泄露。
- allow_direct_on_error: false
# =============================================================================
# Gateway Configuration
@@ -207,10 +201,12 @@ gateway:
openai_passthrough_allow_timeout_headers: false
# OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
openai_ws:
- # 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
- mode_router_v2_enabled: false
- # ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效)
- ingress_mode_default: shared
+ # 新版 WS mode 路由(默认开启)。关闭时保持当前 legacy 实现行为。
+ mode_router_v2_enabled: true
+ # ingress 默认模式:off|ctx_pool|passthrough(仅 mode_router_v2_enabled=true 生效)
+ # 建议:常规场景选 ctx_pool;需要“原样透传 + 不走连接池”时选 passthrough;紧急回滚选 off。
+ # Recommendation: use ctx_pool for normal traffic, passthrough for raw relay without pool reuse, off for emergency rollback.
+ ingress_mode_default: ctx_pool
# 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
enabled: true
# 按账号类型细分开关
@@ -233,7 +229,7 @@ gateway:
# 协议 feature 开关,v2 优先于 v1
responses_websockets: false
responses_websockets_v2: true
- # 连接池参数(按账号池化复用)
+ # 连接池参数(按账号池化复用,仅 ctx_pool 模式生效;passthrough 模式不使用连接池)
max_conns_per_account: 128
min_idle_per_account: 4
max_idle_per_account: 12
@@ -248,6 +244,10 @@ gateway:
write_timeout_seconds: 120
pool_target_utilization: 0.7
queue_limit_per_conn: 64
+ # 上游 WebSocket 连接最大存活时间(秒)。
+ # OpenAI 在 60 分钟后强制断开连接,此参数控制主动轮换阈值。
+ # 默认 3300(55 分钟);设为 0 则禁用超龄轮换。
+ upstream_conn_max_age_seconds: 3300
# 流式写出批量 flush 参数
event_flush_batch_size: 1
event_flush_interval_ms: 10
@@ -265,7 +265,7 @@ gateway:
# payload_schema 日志采样率(0-1);降低热路径日志放大
payload_log_sample_rate: 0.2
# 调度与粘连参数
- lb_top_k: 7
+ lb_top_k: 999
sticky_session_ttl_seconds: 3600
# 会话哈希迁移兼容开关:新 key 未命中时回退读取旧 SHA-256 key
session_hash_read_old_fallback: true
@@ -282,6 +282,24 @@ gateway:
queue: 0.7
error_rate: 0.8
ttft: 0.5
+ # OpenAI HTTP upstream protocol strategy
+ # OpenAI HTTP 上游协议策略
+ openai_http2:
+ # Enable OpenAI HTTP/2 preference (default on)
+ # 启用 OpenAI HTTP/2 优先策略(默认开启)
+ enabled: true
+ # Allow fallback to HTTP/1.1 for incompatible HTTP proxies
+ # 当 HTTP 代理不兼容时允许回退到 HTTP/1.1
+ allow_proxy_fallback_to_http1: true
+ # Fallback triggers after N HTTP/2 compatibility errors within window
+ # 在窗口期内累计 N 次 HTTP/2 兼容错误后触发回退
+ fallback_error_threshold: 2
+ # Error counting window (seconds)
+ # 错误计数窗口(秒)
+ fallback_window_seconds: 60
+ # How long to stay in HTTP/1.1 fallback mode (seconds)
+ # 进入 HTTP/1.1 回退态后的持续时间(秒)
+ fallback_ttl_seconds: 600
# HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults)
# HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值)
# Max idle connections across all hosts
@@ -316,6 +334,15 @@ gateway:
# SSE max line size in bytes (default: 40MB)
# SSE 单行最大字节数(默认 40MB)
max_line_size: 41943040
+ # Usage record worker pool (bounded queue)
+ # 使用量记录异步池(有界队列)
+ usage_record:
+ # queue overflow policy: drop/sample/sync
+ # 队列溢出策略:drop/sample/sync
+ overflow_policy: sync
+ # only used when overflow_policy=sample
+ # 仅在 overflow_policy=sample 时生效
+ overflow_sample_percent: 10
# Log upstream error response body summary (safe/truncated; does not log request content)
# 记录上游错误响应体摘要(安全/截断;不记录请求内容)
log_upstream_error_body: true
diff --git a/docs/ADMIN_PAYMENT_INTEGRATION_API.md b/docs/ADMIN_PAYMENT_INTEGRATION_API.md
deleted file mode 100644
index 4cc215948..000000000
--- a/docs/ADMIN_PAYMENT_INTEGRATION_API.md
+++ /dev/null
@@ -1,241 +0,0 @@
-# ADMIN_PAYMENT_INTEGRATION_API
-
-> 单文件中英双语文档 / Single-file bilingual documentation (Chinese + English)
-
----
-
-## 中文
-
-### 目标
-本文档用于对接外部支付系统(如 `sub2apipay`)与 Sub2API 的 Admin API,覆盖:
-- 支付成功后充值
-- 用户查询
-- 人工余额修正
-- 前端购买页参数透传
-
-### 基础地址
-- 生产:`https://`
-- Beta:`http://:8084`
-
-### 认证
-推荐使用:
-- `x-api-key: admin-<64hex>`
-- `Content-Type: application/json`
-- 幂等接口额外传:`Idempotency-Key`
-
-说明:管理员 JWT 也可访问 admin 路由,但服务间调用建议使用 Admin API Key。
-
-### 1) 一步完成创建并兑换
-`POST /api/v1/admin/redeem-codes/create-and-redeem`
-
-用途:原子完成“创建兑换码 + 兑换到指定用户”。
-
-请求头:
-- `x-api-key`
-- `Idempotency-Key`
-
-请求体示例:
-```json
-{
- "code": "s2p_cm1234567890",
- "type": "balance",
- "value": 100.0,
- "user_id": 123,
- "notes": "sub2apipay order: cm1234567890"
-}
-```
-
-幂等语义:
-- 同 `code` 且 `used_by` 一致:`200`
-- 同 `code` 但 `used_by` 不一致:`409`
-- 缺少 `Idempotency-Key`:`400`(`IDEMPOTENCY_KEY_REQUIRED`)
-
-curl 示例:
-```bash
-curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: pay-cm1234567890-success" \
- -H "Content-Type: application/json" \
- -d '{
- "code":"s2p_cm1234567890",
- "type":"balance",
- "value":100.00,
- "user_id":123,
- "notes":"sub2apipay order: cm1234567890"
- }'
-```
-
-### 2) 查询用户(可选前置校验)
-`GET /api/v1/admin/users/:id`
-
-```bash
-curl -s "${BASE}/api/v1/admin/users/123" \
- -H "x-api-key: ${KEY}"
-```
-
-### 3) 余额调整(已有接口)
-`POST /api/v1/admin/users/:id/balance`
-
-用途:人工补偿 / 扣减,支持 `set` / `add` / `subtract`。
-
-请求体示例(扣减):
-```json
-{
- "balance": 100.0,
- "operation": "subtract",
- "notes": "manual correction"
-}
-```
-
-```bash
-curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: balance-subtract-cm1234567890" \
- -H "Content-Type: application/json" \
- -d '{
- "balance":100.00,
- "operation":"subtract",
- "notes":"manual correction"
- }'
-```
-
-### 4) 购买页 URL Query 透传(iframe / 新窗口一致)
-当 Sub2API 打开 `purchase_subscription_url` 时,会统一追加:
-- `user_id`
-- `token`
-- `theme`(`light` / `dark`)
-- `ui_mode`(固定 `embedded`)
-
-示例:
-```text
-https://pay.example.com/pay?user_id=123&token=&theme=light&ui_mode=embedded
-```
-
-### 5) 失败处理建议
-- 支付成功与充值成功分状态落库
-- 回调验签成功后立即标记“支付成功”
-- 支付成功但充值失败的订单允许后续重试
-- 重试保持相同 `code`,并使用新的 `Idempotency-Key`
-
-### 6) `doc_url` 配置建议
-- 查看链接:`https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md`
-- 下载链接:`https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md`
-
----
-
-## English
-
-### Purpose
-This document describes the minimal Sub2API Admin API surface for external payment integrations (for example, `sub2apipay`), including:
-- Recharge after payment success
-- User lookup
-- Manual balance correction
-- Purchase page query parameter forwarding
-
-### Base URL
-- Production: `https://`
-- Beta: `http://:8084`
-
-### Authentication
-Recommended headers:
-- `x-api-key: admin-<64hex>`
-- `Content-Type: application/json`
-- `Idempotency-Key` for idempotent endpoints
-
-Note: Admin JWT can also access admin routes, but Admin API Key is recommended for server-to-server integration.
-
-### 1) Create and Redeem in one step
-`POST /api/v1/admin/redeem-codes/create-and-redeem`
-
-Use case: atomically create a redeem code and redeem it to a target user.
-
-Headers:
-- `x-api-key`
-- `Idempotency-Key`
-
-Request body:
-```json
-{
- "code": "s2p_cm1234567890",
- "type": "balance",
- "value": 100.0,
- "user_id": 123,
- "notes": "sub2apipay order: cm1234567890"
-}
-```
-
-Idempotency behavior:
-- Same `code` and same `used_by`: `200`
-- Same `code` but different `used_by`: `409`
-- Missing `Idempotency-Key`: `400` (`IDEMPOTENCY_KEY_REQUIRED`)
-
-curl example:
-```bash
-curl -X POST "${BASE}/api/v1/admin/redeem-codes/create-and-redeem" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: pay-cm1234567890-success" \
- -H "Content-Type: application/json" \
- -d '{
- "code":"s2p_cm1234567890",
- "type":"balance",
- "value":100.00,
- "user_id":123,
- "notes":"sub2apipay order: cm1234567890"
- }'
-```
-
-### 2) Query User (optional pre-check)
-`GET /api/v1/admin/users/:id`
-
-```bash
-curl -s "${BASE}/api/v1/admin/users/123" \
- -H "x-api-key: ${KEY}"
-```
-
-### 3) Balance Adjustment (existing API)
-`POST /api/v1/admin/users/:id/balance`
-
-Use case: manual correction with `set` / `add` / `subtract`.
-
-Request body example (`subtract`):
-```json
-{
- "balance": 100.0,
- "operation": "subtract",
- "notes": "manual correction"
-}
-```
-
-```bash
-curl -X POST "${BASE}/api/v1/admin/users/123/balance" \
- -H "x-api-key: ${KEY}" \
- -H "Idempotency-Key: balance-subtract-cm1234567890" \
- -H "Content-Type: application/json" \
- -d '{
- "balance":100.00,
- "operation":"subtract",
- "notes":"manual correction"
- }'
-```
-
-### 4) Purchase URL query forwarding (iframe and new tab)
-When Sub2API opens `purchase_subscription_url`, it appends:
-- `user_id`
-- `token`
-- `theme` (`light` / `dark`)
-- `ui_mode` (fixed: `embedded`)
-
-Example:
-```text
-https://pay.example.com/pay?user_id=123&token=&theme=light&ui_mode=embedded
-```
-
-### 5) Failure handling recommendations
-- Persist payment success and recharge success as separate states
-- Mark payment as successful immediately after verified callback
-- Allow retry for orders with payment success but recharge failure
-- Keep the same `code` for retry, and use a new `Idempotency-Key`
-
-### 6) Recommended `doc_url`
-- View URL: `https://github.com/Wei-Shaw/sub2api/blob/main/ADMIN_PAYMENT_INTEGRATION_API.md`
-- Download URL: `https://raw.githubusercontent.com/Wei-Shaw/sub2api/main/ADMIN_PAYMENT_INTEGRATION_API.md`
diff --git a/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts b/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts
new file mode 100644
index 000000000..5baccfe8c
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts
@@ -0,0 +1,184 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import {
+ deleteBulkEditTemplate,
+ getBulkEditTemplates,
+ getBulkEditTemplateVersions,
+ rollbackBulkEditTemplate,
+ upsertBulkEditTemplate
+} from '../admin/bulkEditTemplates'
+import { apiClient } from '../client'
+
+vi.mock('../client', () => ({
+ apiClient: {
+ get: vi.fn(),
+ post: vi.fn(),
+ delete: vi.fn()
+ }
+}))
+
+describe('admin settings bulk-edit templates api', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ })
+
+ it('requests template list with expected query params', async () => {
+ (apiClient.get as any).mockResolvedValue({
+ data: {
+ items: [
+ {
+ id: 'tpl-1',
+ name: 'Template',
+ scope_platform: 'openai',
+ scope_type: 'oauth',
+ share_scope: 'team',
+ group_ids: [],
+ state: {},
+ created_by: 1,
+ updated_by: 1,
+ created_at: 1,
+ updated_at: 2
+ }
+ ]
+ }
+ })
+
+ const items = await getBulkEditTemplates({
+ scope_platform: 'openai',
+ scope_type: 'oauth',
+ scope_group_ids: [3, 9]
+ })
+
+ expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates', {
+ params: {
+ scope_platform: 'openai',
+ scope_type: 'oauth',
+ scope_group_ids: '3,9'
+ }
+ })
+ expect(items).toHaveLength(1)
+ expect(items[0].id).toBe('tpl-1')
+ })
+
+ it('returns empty list when response items is invalid', async () => {
+ (apiClient.get as any).mockResolvedValue({ data: { items: null } })
+ const items = await getBulkEditTemplates({})
+ expect(items).toEqual([])
+ })
+
+ it('posts upsert payload and returns saved template', async () => {
+ (apiClient.post as any).mockResolvedValue({
+ data: {
+ id: 'tpl-2',
+ name: 'Shared',
+ scope_platform: 'openai',
+ scope_type: 'apikey',
+ share_scope: 'groups',
+ group_ids: [1],
+ state: { enableProxy: true },
+ created_by: 2,
+ updated_by: 2,
+ created_at: 10,
+ updated_at: 11
+ }
+ })
+
+ const saved = await upsertBulkEditTemplate({
+ id: 'tpl-2',
+ name: 'Shared',
+ scope_platform: 'openai',
+ scope_type: 'apikey',
+ share_scope: 'groups',
+ group_ids: [1],
+ state: { enableProxy: true }
+ })
+
+ expect(apiClient.post).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates', {
+ id: 'tpl-2',
+ name: 'Shared',
+ scope_platform: 'openai',
+ scope_type: 'apikey',
+ share_scope: 'groups',
+ group_ids: [1],
+ state: { enableProxy: true }
+ })
+ expect(saved.id).toBe('tpl-2')
+ })
+
+ it('requests template versions with scope group params', async () => {
+ (apiClient.get as any).mockResolvedValue({
+ data: {
+ items: [
+ {
+ version_id: 'ver-1',
+ share_scope: 'team',
+ group_ids: [],
+ state: { enableOpenAIWSMode: true },
+ updated_by: 11,
+ updated_at: 100
+ }
+ ]
+ }
+ })
+
+ const items = await getBulkEditTemplateVersions('tpl-2', { scope_group_ids: [5, 8] })
+
+ expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-2/versions', {
+ params: {
+ scope_group_ids: '5,8'
+ }
+ })
+ expect(items).toHaveLength(1)
+ expect(items[0].version_id).toBe('ver-1')
+ })
+
+ it('returns empty versions list when payload is invalid', async () => {
+ (apiClient.get as any).mockResolvedValue({ data: { items: undefined } })
+
+ const items = await getBulkEditTemplateVersions('tpl-any')
+
+ expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-any/versions', {
+ params: {}
+ })
+ expect(items).toEqual([])
+ })
+
+ it('posts rollback request with optional query params', async () => {
+ (apiClient.post as any).mockResolvedValue({
+ data: {
+ id: 'tpl-3',
+ name: 'Rollbacked',
+ scope_platform: 'openai',
+ scope_type: 'oauth',
+ share_scope: 'private',
+ group_ids: [],
+ state: { enableOpenAIPassthrough: false },
+ created_by: 1,
+ updated_by: 2,
+ created_at: 10,
+ updated_at: 12
+ }
+ })
+
+ const saved = await rollbackBulkEditTemplate(
+ 'tpl-3',
+ { version_id: 'ver-2' },
+ { scope_group_ids: [2] }
+ )
+
+ expect(apiClient.post).toHaveBeenCalledWith(
+ '/admin/settings/bulk-edit-templates/tpl-3/rollback',
+ { version_id: 'ver-2' },
+ { params: { scope_group_ids: '2' } }
+ )
+ expect(saved.id).toBe('tpl-3')
+ })
+
+ it('calls delete endpoint for template removal', async () => {
+ (apiClient.delete as any).mockResolvedValue({ data: { deleted: true } })
+
+ const result = await deleteBulkEditTemplate('tpl-9')
+
+ expect(apiClient.delete).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-9')
+ expect(result).toEqual({ deleted: true })
+ })
+})
diff --git a/frontend/src/api/admin/bulkEditTemplates.ts b/frontend/src/api/admin/bulkEditTemplates.ts
new file mode 100644
index 000000000..45c5e8ac1
--- /dev/null
+++ b/frontend/src/api/admin/bulkEditTemplates.ts
@@ -0,0 +1,129 @@
+import { apiClient } from '../client'
+import type { AccountPlatform, AccountType } from '@/types'
+
+export type BulkEditTemplateShareScope = 'private' | 'team' | 'groups'
+
+export interface BulkEditTemplateRecord> {
+ id: string
+ name: string
+ scope_platform: AccountPlatform | ''
+ scope_type: AccountType | ''
+ share_scope: BulkEditTemplateShareScope
+ group_ids: number[]
+ state: TState
+ created_by: number
+ updated_by: number
+ created_at: number
+ updated_at: number
+}
+
+export interface BulkEditTemplateVersionRecord> {
+ version_id: string
+ share_scope: BulkEditTemplateShareScope
+ group_ids: number[]
+ state: TState
+ updated_by: number
+ updated_at: number
+}
+
+export interface GetBulkEditTemplatesParams {
+ scope_platform?: AccountPlatform | ''
+ scope_type?: AccountType | ''
+ scope_group_ids?: number[]
+}
+
+export interface GetBulkEditTemplateVersionsParams {
+ scope_group_ids?: number[]
+}
+
+export interface UpsertBulkEditTemplateRequest> {
+ id?: string
+ name: string
+ scope_platform: AccountPlatform | ''
+ scope_type: AccountType | ''
+ share_scope: BulkEditTemplateShareScope
+ group_ids: number[]
+ state: TState
+}
+
+export interface RollbackBulkEditTemplateRequest {
+ version_id: string
+}
+
+export async function getBulkEditTemplates>(
+ params: GetBulkEditTemplatesParams
+): Promise[]> {
+ const query: Record = {}
+ if (params.scope_platform) query.scope_platform = params.scope_platform
+ if (params.scope_type) query.scope_type = params.scope_type
+ if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) {
+ query.scope_group_ids = params.scope_group_ids.join(',')
+ }
+
+ const { data } = await apiClient.get<{ items: BulkEditTemplateRecord[] }>(
+ '/admin/settings/bulk-edit-templates',
+ { params: query }
+ )
+ return Array.isArray(data.items) ? data.items : []
+}
+
+export async function getBulkEditTemplateVersions>(
+ templateID: string,
+ params: GetBulkEditTemplateVersionsParams = {}
+): Promise[]> {
+ const query: Record = {}
+ if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) {
+ query.scope_group_ids = params.scope_group_ids.join(',')
+ }
+
+ const { data } = await apiClient.get<{ items: BulkEditTemplateVersionRecord[] }>(
+ `/admin/settings/bulk-edit-templates/${templateID}/versions`,
+ { params: query }
+ )
+ return Array.isArray(data.items) ? data.items : []
+}
+
+export async function upsertBulkEditTemplate>(
+ request: UpsertBulkEditTemplateRequest
+): Promise> {
+ const { data } = await apiClient.post>(
+ '/admin/settings/bulk-edit-templates',
+ request
+ )
+ return data
+}
+
+export async function rollbackBulkEditTemplate>(
+ templateID: string,
+ request: RollbackBulkEditTemplateRequest,
+ params: GetBulkEditTemplateVersionsParams = {}
+): Promise> {
+ const query: Record = {}
+ if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) {
+ query.scope_group_ids = params.scope_group_ids.join(',')
+ }
+
+ const { data } = await apiClient.post>(
+ `/admin/settings/bulk-edit-templates/${templateID}/rollback`,
+ request,
+ { params: query }
+ )
+ return data
+}
+
+export async function deleteBulkEditTemplate(templateID: string): Promise<{ deleted: boolean }> {
+ const { data } = await apiClient.delete<{ deleted: boolean }>(
+ `/admin/settings/bulk-edit-templates/${templateID}`
+ )
+ return data
+}
+
+const bulkEditTemplatesAPI = {
+ getBulkEditTemplates,
+ getBulkEditTemplateVersions,
+ upsertBulkEditTemplate,
+ rollbackBulkEditTemplate,
+ deleteBulkEditTemplate
+}
+
+export default bulkEditTemplatesAPI
diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts
index 54bd92a49..a5113dd1f 100644
--- a/frontend/src/api/admin/dashboard.ts
+++ b/frontend/src/api/admin/dashboard.ts
@@ -8,7 +8,6 @@ import type {
DashboardStats,
TrendDataPoint,
ModelStat,
- GroupStat,
ApiKeyUsageTrendPoint,
UserUsageTrendPoint,
UsageRequestType
@@ -102,34 +101,6 @@ export async function getModelStats(params?: ModelStatsParams): Promise {
- const { data } = await apiClient.get('/admin/dashboard/groups', { params })
- return data
-}
-
export interface ApiKeyTrendParams extends TrendParams {
limit?: number
}
@@ -232,7 +203,6 @@ export const dashboardAPI = {
getRealtimeMetrics,
getUsageTrend,
getModelStats,
- getGroupStats,
getApiKeyUsageTrend,
getUserUsageTrend,
getBatchUsersUsage,
diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts
index 5db998e57..a2c82ecbf 100644
--- a/frontend/src/api/admin/index.ts
+++ b/frontend/src/api/admin/index.ts
@@ -12,6 +12,7 @@ import redeemAPI from './redeem'
import promoAPI from './promo'
import announcementsAPI from './announcements'
import settingsAPI from './settings'
+import bulkEditTemplatesAPI from './bulkEditTemplates'
import systemAPI from './system'
import subscriptionsAPI from './subscriptions'
import usageAPI from './usage'
@@ -21,7 +22,6 @@ import userAttributesAPI from './userAttributes'
import opsAPI from './ops'
import errorPassthroughAPI from './errorPassthrough'
import dataManagementAPI from './dataManagement'
-import apiKeysAPI from './apiKeys'
/**
* Unified admin API object for convenient access
@@ -36,6 +36,7 @@ export const adminAPI = {
promo: promoAPI,
announcements: announcementsAPI,
settings: settingsAPI,
+ bulkEditTemplates: bulkEditTemplatesAPI,
system: systemAPI,
subscriptions: subscriptionsAPI,
usage: usageAPI,
@@ -44,8 +45,7 @@ export const adminAPI = {
userAttributes: userAttributesAPI,
ops: opsAPI,
errorPassthrough: errorPassthroughAPI,
- dataManagement: dataManagementAPI,
- apiKeys: apiKeysAPI
+ dataManagement: dataManagementAPI
}
export {
@@ -58,6 +58,7 @@ export {
promoAPI,
announcementsAPI,
settingsAPI,
+ bulkEditTemplatesAPI,
systemAPI,
subscriptionsAPI,
usageAPI,
@@ -66,8 +67,7 @@ export {
userAttributesAPI,
opsAPI,
errorPassthroughAPI,
- dataManagementAPI,
- apiKeysAPI
+ dataManagementAPI
}
export default adminAPI
diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts
index 52855a040..a3ed6c6ff 100644
--- a/frontend/src/api/admin/settings.ts
+++ b/frontend/src/api/admin/settings.ts
@@ -6,11 +6,6 @@
import { apiClient } from '../client'
import type { CustomMenuItem } from '@/types'
-export interface DefaultSubscriptionSetting {
- group_id: number
- validity_days: number
-}
-
/**
* System settings interface
*/
@@ -26,7 +21,6 @@ export interface SystemSettings {
// Default settings
default_balance: number
default_concurrency: number
- default_subscriptions: DefaultSubscriptionSetting[]
// OEM settings
site_name: string
site_logo: string
@@ -75,9 +69,6 @@ export interface SystemSettings {
ops_realtime_monitoring_enabled: boolean
ops_query_mode_default: 'auto' | 'raw' | 'preagg' | string
ops_metrics_interval_seconds: number
-
- // Claude Code version check
- min_claude_code_version: string
}
export interface UpdateSettingsRequest {
@@ -89,7 +80,6 @@ export interface UpdateSettingsRequest {
totp_enabled?: boolean // TOTP 双因素认证
default_balance?: number
default_concurrency?: number
- default_subscriptions?: DefaultSubscriptionSetting[]
site_name?: string
site_logo?: string
site_subtitle?: string
@@ -127,7 +117,6 @@ export interface UpdateSettingsRequest {
ops_realtime_monitoring_enabled?: boolean
ops_query_mode_default?: 'auto' | 'raw' | 'preagg' | string
ops_metrics_interval_seconds?: number
- min_claude_code_version?: string
}
/**
diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts
index d36a2a5a9..287aef967 100644
--- a/frontend/src/api/admin/users.ts
+++ b/frontend/src/api/admin/users.ts
@@ -4,7 +4,7 @@
*/
import { apiClient } from '../client'
-import type { AdminUser, UpdateUserRequest, PaginatedResponse, ApiKey } from '@/types'
+import type { AdminUser, UpdateUserRequest, PaginatedResponse } from '@/types'
/**
* List all users with pagination
@@ -145,8 +145,8 @@ export async function toggleStatus(id: number, status: 'active' | 'disabled'): P
* @param id - User ID
* @returns List of user's API keys
*/
-export async function getUserApiKeys(id: number): Promise> {
- const { data } = await apiClient.get>(`/admin/users/${id}/api-keys`)
+export async function getUserApiKeys(id: number): Promise> {
+ const { data } = await apiClient.get>(`/admin/users/${id}/api-keys`)
return data
}
diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts
index 95f9ff318..22db5a44a 100644
--- a/frontend/src/api/client.ts
+++ b/frontend/src/api/client.ts
@@ -267,7 +267,6 @@ apiClient.interceptors.response.use(
return Promise.reject({
status,
code: apiData.code,
- error: apiData.error,
message: apiData.message || apiData.detail || error.message
})
}
diff --git a/frontend/src/components/account/AccountCapacityCell.vue b/frontend/src/components/account/AccountCapacityCell.vue
index 2a4babf20..ae338aca1 100644
--- a/frontend/src/components/account/AccountCapacityCell.vue
+++ b/frontend/src/components/account/AccountCapacityCell.vue
@@ -52,25 +52,6 @@
{{ account.max_sessions }}
-
-
-
-
-
- {{ currentRPM }}
- /
- {{ account.base_rpm }}
- {{ rpmStrategyTag }}
-
-
@@ -144,15 +125,19 @@ const windowCostClass = computed(() => {
const limit = props.account.window_cost_limit || 0
const reserve = props.account.window_cost_sticky_reserve || 10
+ // >= 阈值+预留: 完全不可调度 (红色)
if (current >= limit + reserve) {
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
}
+ // >= 阈值: 仅粘性会话 (橙色)
if (current >= limit) {
return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
}
+ // >= 80% 阈值: 警告 (黄色)
if (current >= limit * 0.8) {
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
}
+ // 正常 (绿色)
return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
})
@@ -180,12 +165,15 @@ const sessionLimitClass = computed(() => {
const current = activeSessions.value
const max = props.account.max_sessions || 0
+ // >= 最大: 完全占满 (红色)
if (current >= max) {
return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
}
+ // >= 80%: 警告 (黄色)
if (current >= max * 0.8) {
return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
}
+ // 正常 (绿色)
return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
})
@@ -203,89 +191,6 @@ const sessionLimitTooltip = computed(() => {
return t('admin.accounts.capacity.sessions.normal', { idle })
})
-// 是否显示 RPM 限制
-const showRpmLimit = computed(() => {
- return (
- isAnthropicOAuthOrSetupToken.value &&
- props.account.base_rpm !== undefined &&
- props.account.base_rpm !== null &&
- props.account.base_rpm > 0
- )
-})
-
-// 当前 RPM 计数
-const currentRPM = computed(() => props.account.current_rpm ?? 0)
-
-// RPM 策略
-const rpmStrategy = computed(() => props.account.rpm_strategy || 'tiered')
-
-// RPM 策略标签
-const rpmStrategyTag = computed(() => {
- return rpmStrategy.value === 'sticky_exempt' ? '[S]' : '[T]'
-})
-
-// RPM buffer 计算(与后端一致:base <= 0 时 buffer 为 0)
-const rpmBuffer = computed(() => {
- const base = props.account.base_rpm || 0
- return props.account.rpm_sticky_buffer ?? (base > 0 ? Math.max(1, Math.floor(base / 5)) : 0)
-})
-
-// RPM 状态样式
-const rpmClass = computed(() => {
- if (!showRpmLimit.value) return ''
-
- const current = currentRPM.value
- const base = props.account.base_rpm ?? 0
- const buffer = rpmBuffer.value
-
- if (rpmStrategy.value === 'tiered') {
- if (current >= base + buffer) {
- return 'bg-red-100 text-red-700 dark:bg-red-900/30 dark:text-red-400'
- }
- if (current >= base) {
- return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
- }
- } else {
- if (current >= base) {
- return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
- }
- }
- if (current >= base * 0.8) {
- return 'bg-yellow-100 text-yellow-700 dark:bg-yellow-900/30 dark:text-yellow-400'
- }
- return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
-})
-
-// RPM 提示文字(增强版:显示策略、区域、缓冲区)
-const rpmTooltip = computed(() => {
- if (!showRpmLimit.value) return ''
-
- const current = currentRPM.value
- const base = props.account.base_rpm ?? 0
- const buffer = rpmBuffer.value
-
- if (rpmStrategy.value === 'tiered') {
- if (current >= base + buffer) {
- return t('admin.accounts.capacity.rpm.tieredBlocked', { buffer })
- }
- if (current >= base) {
- return t('admin.accounts.capacity.rpm.tieredStickyOnly', { buffer })
- }
- if (current >= base * 0.8) {
- return t('admin.accounts.capacity.rpm.tieredWarning')
- }
- return t('admin.accounts.capacity.rpm.tieredNormal')
- } else {
- if (current >= base) {
- return t('admin.accounts.capacity.rpm.stickyExemptOver')
- }
- if (current >= base * 0.8) {
- return t('admin.accounts.capacity.rpm.stickyExemptWarning')
- }
- return t('admin.accounts.capacity.rpm.stickyExemptNormal')
- }
-})
-
// 格式化费用显示
const formatCost = (value: number | null | undefined) => {
if (value === null || value === undefined) return '0'
diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue
index 1c83e6583..163bb391b 100644
--- a/frontend/src/components/account/BulkEditAccountModal.vue
+++ b/frontend/src/components/account/BulkEditAccountModal.vue
@@ -1,7 +1,7 @@
@@ -21,18 +21,279 @@
-
-
-
-
- {{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: selectedPlatforms.join(', ') }) }}
-
+
+ {{ t('admin.accounts.filters.platform') }}: {{ scopePlatformLabel }} /
+ {{ t('admin.accounts.filters.type') }}: {{ scopeTypeLabel }}
+
+
+
+
+
+
+ {{ t('admin.accounts.bulkEdit.templateTitle') }}
+
+
+ {{
+ t('admin.accounts.bulkEdit.templateScopeHint', {
+ platform: scopePlatformLabel,
+ type: scopeTypeLabel
+ })
+ }}
+
+
+ {{ t('admin.accounts.bulkEdit.templateLoading') }}
+
+
+
+ {{ scopedTemplateRecords.length }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.bulkEdit.templateShareGroupsHint') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.bulkEdit.templateVersionTitle') }}
+
+
+ {{ t('admin.accounts.bulkEdit.templateVersionHint') }}
+
+
+
+ {{ templateVersionRecords.length }}
+
+
+
+
+ {{ t('admin.accounts.bulkEdit.templateVersionLoading') }}
+
+
+
+ {{ t('admin.accounts.bulkEdit.templateVersionEmpty') }}
+
+
+
+ -
+
+
+ {{ formatTemplateVersionUpdatedAt(version.updatedAt) }}
+
+
+ {{ resolveTemplateShareScopeLabel(version.shareScope) }}
+ · {{ version.groupIDs.join(', ') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
+