From ce67cf80f08f3da9ffb5c28e2b07c2651e043bb9 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Mon, 2 Mar 2026 17:32:49 +0800
Subject: [PATCH 01/13] =?UTF-8?q?feat(openai-ws):=20=E4=BC=98=E5=8C=96=20o?=
=?UTF-8?q?penai=20websocket=20mode=20v2=20=E7=BD=91=E5=85=B3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- 对齐 test 分支的 OpenAI WebSocket Mode v2 优化实现。
- 统一协议入口与模式治理,增强 previous_response_id 续链恢复。
- 优化调度/重试/降级策略与 WS 热路径性能,并补齐可观测性与管理端配置配套。
---
README.md | 28 +
README_CN.md | 28 +
backend/.gosec.json | 5 +
backend/Dockerfile | 2 +-
backend/cmd/server/VERSION | 2 +-
backend/cmd/server/wire.go | 4 +-
backend/cmd/server/wire_gen.go | 9 +-
backend/go.mod | 3 -
backend/internal/config/config.go | 122 +-
backend/internal/config/config_test.go | 22 +-
.../internal/handler/admin/account_handler.go | 3 +
.../admin/account_handler_bulk_update_test.go | 62 +
.../handler/admin/admin_service_stub_test.go | 11 +-
.../handler/admin/data_management_handler.go | 24 +-
.../setting_handler_bulk_edit_template.go | 228 ++
...setting_handler_bulk_edit_template_test.go | 592 +++
backend/internal/handler/gateway_handler.go | 25 +-
backend/internal/handler/gateway_helper.go | 2 -
.../handler/openai_gateway_handler.go | 170 +-
.../handler/openai_gateway_handler_test.go | 28 +
.../internal/handler/ops_error_logger_test.go | 60 -
backend/internal/handler/setting_handler.go | 4 +-
.../handler/sora_client_handler_test.go | 34 +-
.../internal/handler/sora_gateway_handler.go | 27 +-
.../handler/sora_gateway_handler_test.go | 3 -
.../handler/usage_record_submit_helper.go | 61 +
.../handler/usage_record_submit_task_test.go | 67 +
backend/internal/handler/wire.go | 4 +-
backend/internal/pkg/response/response.go | 34 +-
.../internal/pkg/response/response_test.go | 76 +-
backend/internal/repository/account_repo.go | 5 +
backend/internal/repository/billing_cache.go | 31 +-
.../billing_cache_integration_test.go | 10 +-
backend/internal/repository/gateway_cache.go | 95 +
.../gateway_cache_integration_test.go | 44 +
backend/internal/repository/group_repo.go | 32 -
backend/internal/repository/http_upstream.go | 504 ++-
.../http_upstream_benchmark_test.go | 2 +-
.../internal/repository/http_upstream_test.go | 441 +++
.../repository/openai_oauth_service.go | 12 +-
backend/internal/repository/usage_log_repo.go | 315 ++
backend/internal/server/api_contract_test.go | 6 +-
backend/internal/server/routes/admin.go | 45 +-
.../server/routes/admin_routes_test.go | 34 +
backend/internal/service/account.go | 99 +-
.../account_openai_passthrough_test.go | 42 +-
backend/internal/service/account_service.go | 19 +-
.../internal/service/account_usage_service.go | 1 -
backend/internal/service/admin_service.go | 267 +-
.../service/admin_service_bulk_update_test.go | 154 +-
.../auth_service_turnstile_register_test.go | 1 -
.../internal/service/billing_cache_service.go | 11 +-
.../internal/service/claude_token_provider.go | 18 +-
.../service/claude_token_provider_test.go | 28 +
.../service/data_management_service.go | 10 +-
backend/internal/service/domain_constants.go | 10 +-
backend/internal/service/gateway_service.go | 162 +-
.../internal/service/http_upstream_profile.go | 41 +
.../service/http_upstream_profile_test.go | 39 +
.../service/openai_account_scheduler.go | 1172 +++++-
.../service/openai_account_scheduler_test.go | 3363 ++++++++++++++++-
.../service/openai_client_transport.go | 19 +-
.../service/openai_client_transport_test.go | 37 +-
.../openai_gateway_record_usage_test.go | 783 ++++
.../service/openai_gateway_service.go | 603 ++-
.../openai_gateway_service_hotpath_test.go | 123 +
.../service/openai_gateway_service_test.go | 134 +
.../internal/service/openai_oauth_service.go | 23 +-
.../service/openai_sse_zero_alloc_test.go | 276 ++
.../internal/service/openai_sticky_compat.go | 25 +-
.../service/openai_ws_client_preempt_test.go | 1076 ++++++
backend/internal/service/openai_ws_common.go | 54 +
.../service/openai_ws_fallback_test.go | 289 ++
.../internal/service/openai_ws_forwarder.go | 3193 +++++++---------
.../openai_ws_forwarder_benchmark_test.go | 141 +
..._ws_forwarder_hotpath_optimization_test.go | 59 +
...openai_ws_forwarder_ingress_policy_test.go | 154 +
...penai_ws_forwarder_ingress_session_test.go | 2483 ------------
.../openai_ws_forwarder_ingress_test.go | 787 +++-
.../service/openai_ws_forwarder_panic_test.go | 107 +
.../openai_ws_forwarder_recovery_test.go | 691 ++++
.../openai_ws_forwarder_success_test.go | 1306 -------
.../openai_ws_forwarder_turn_error_test.go | 53 +
.../service/openai_ws_hotpath_perf_test.go | 931 +++++
.../service/openai_ws_ingress_context_pool.go | 1586 ++++++++
.../openai_ws_ingress_context_pool_test.go | 2496 ++++++++++++
.../service/openai_ws_ingress_normalizer.go | 39 +
.../openai_ws_ingress_normalizer_test.go | 193 +
backend/internal/service/openai_ws_pool.go | 1706 ---------
.../service/openai_ws_pool_benchmark_test.go | 58 -
.../internal/service/openai_ws_pool_test.go | 1709 ---------
.../openai_ws_protocol_forward_test.go | 16 +
.../service/openai_ws_protocol_resolver.go | 2 +-
.../openai_ws_protocol_resolver_test.go | 44 +-
.../internal/service/openai_ws_recovery.go | 758 ++++
.../internal/service/openai_ws_state_store.go | 586 ++-
.../service/openai_ws_state_store_test.go | 361 +-
.../service/openai_ws_test_helpers_test.go | 277 ++
backend/internal/service/openai_ws_turn.go | 1430 +++++++
.../service/openai_ws_upstream_pump_test.go | 1894 ++++++++++
backend/internal/service/redeem_service.go | 11 +-
.../service/setting_bulk_edit_template.go | 770 ++++
.../setting_bulk_edit_template_test.go | 860 +++++
backend/internal/service/setting_service.go | 28 +-
.../service/sora_generation_service_test.go | 3 -
.../service/token_refresh_parallel_test.go | 439 +++
.../internal/service/token_refresh_service.go | 240 +-
.../service/token_refresh_service_test.go | 195 +-
backend/internal/service/token_refresher.go | 71 +-
.../internal/service/token_refresher_test.go | 219 +-
.../usage_billing_compensation_service.go | 256 ++
...usage_billing_compensation_service_test.go | 231 ++
.../internal/service/usage_billing_entry.go | 60 +
backend/internal/service/wire.go | 2 +-
..._gemini31_flash_image_to_model_mapping.sql | 61 +-
...4_add_billing_usage_entry_retry_fields.sql | 27 +
deploy/Caddyfile | 33 +-
deploy/Dockerfile | 2 +-
deploy/config.example.yaml | 39 +-
.../settings.bulkEditTemplates.spec.ts | 184 +
frontend/src/api/admin/bulkEditTemplates.ts | 129 +
frontend/src/api/admin/index.ts | 10 +-
.../account/BulkEditAccountModal.vue | 1046 ++++-
.../account/BulkEditAccountScopedModal.vue | 84 +
.../components/account/CreateAccountModal.vue | 8 +-
.../components/account/EditAccountModal.vue | 8 +-
.../__tests__/BulkEditAccountModal.spec.ts | 33 +-
.../account/__tests__/bulkEditPayload.spec.ts | 247 ++
.../__tests__/bulkEditScopeProfile.spec.ts | 78 +
.../bulkEditTemplateRemoteMapper.spec.ts | 79 +
.../__tests__/bulkEditTemplateState.spec.ts | 81 +
.../__tests__/bulkEditTemplateStore.spec.ts | 115 +
.../src/components/account/bulkEditPayload.ts | 199 +
.../account/bulkEditScopeProfile.ts | 69 +
.../BulkEditAnthropicApiKeyModal.vue | 32 +
.../BulkEditAnthropicOAuthModal.vue | 32 +
.../BulkEditAnthropicSetupTokenModal.vue | 32 +
.../BulkEditAntigravityOAuthModal.vue | 32 +
.../BulkEditAntigravityUpstreamModal.vue | 32 +
.../BulkEditGeminiApiKeyModal.vue | 32 +
.../BulkEditGeminiOAuthModal.vue | 32 +
.../BulkEditOpenAIApiKeyModal.vue | 32 +
.../BulkEditOpenAIOAuthModal.vue | 32 +
.../BulkEditSoraApiKeyModal.vue | 32 +
.../bulkEditScoped/BulkEditSoraOAuthModal.vue | 32 +
.../account/bulkEditTemplateRemoteMapper.ts | 55 +
.../account/bulkEditTemplateState.ts | 196 +
.../account/bulkEditTemplateStore.ts | 117 +
frontend/src/components/account/index.ts | 1 +
frontend/src/i18n/locales/en.ts | 60 +-
frontend/src/i18n/locales/zh.ts | 63 +-
.../src/utils/__tests__/openaiWsMode.spec.ts | 19 +-
frontend/src/utils/openaiWsMode.ts | 11 +-
.../__tests__/accountsBulkEditScope.spec.ts | 100 +
frontend/src/views/admin/AccountsView.vue | 359 +-
frontend/src/views/admin/SettingsView.vue | 25 +
frontend/src/views/admin/UsageView.vue | 5 +-
.../src/views/admin/accountsBulkEditScope.ts | 114 +
.../proposal.md | 206 +
.../design.md | 183 +
.../proposal.md | 96 +
.../backend-performance-hotspots/spec.md | 113 +
.../tasks.md | 28 +
.../design.md | 152 +
.../proposal.md | 111 +
.../review-rounds.md | 95 +
.../frontend-bundle-optimization/spec.md | 29 +
.../frontend-compatibility-rollout/spec.md | 48 +
.../frontend-runtime-performance/spec.md | 41 +
.../tasks.md | 47 +
.../design.md | 212 ++
.../proposal.md | 81 +
.../specs/openai-ws-v2-performance/spec.md | 85 +
.../tasks.md | 51 +
.../proposal.md | 13 +
.../specs/schedule-account/spec.md | 7 +
.../tasks.md | 4 +
.../proposal.md | 24 +
.../specs/frontend-routing/spec.md | 24 +
.../tasks.md | 12 +
.../proposal.md | 33 +
.../specs/timing-wheel/spec.md | 19 +
.../tasks.md | 21 +
.../.openspec.yaml | 2 +
.../design.md | 119 +
.../final-acceptance-report.md | 67 +
.../proposal.md | 29 +
.../review.md | 78 +
.../signoff-and-rollout.md | 104 +
.../specs/openai-oauth-performance/spec.md | 43 +
.../tasks.md | 41 +
.../2026-02-25-\303\247/.openspec.yaml" | 2 +
.../archive/2026-02-25-\303\247/design.md" | 165 +
.../archive/2026-02-25-\303\247/proposal.md" | 85 +
.../specs/openai-ws-v2-performance/spec.md" | 98 +
.../archive/2026-02-25-\303\247/tasks.md" | 50 +
.../2026-02-25-\303\247/validation.md" | 5 +
.../.openspec.yaml | 2 +
.../design.md | 174 +
.../proposal.md | 66 +
.../specs/usage-request-type/spec.md | 93 +
.../tasks.md | 47 +
.../design.md | 61 +
.../proposal.md | 69 +
.../specs/openai-ws-v2-performance/spec.md | 79 +
.../tasks.md | 40 +
.../.openspec.yaml | 2 +
.../design.md | 487 +++
.../proposal.md | 356 ++
.../review-rounds.md | 43 +
.../sora-client-mockup.html | 3088 +++++++++++++++
.../specs/sora-account-apikey/spec.md | 82 +
.../specs/sora-client-ui/spec.md | 305 ++
.../specs/sora-generation-gateway/spec.md | 129 +
.../specs/sora-generation-history/spec.md | 138 +
.../specs/sora-s3-media-storage/spec.md | 104 +
.../specs/sora-s3-settings/spec.md | 39 +
.../specs/sora-user-storage-quota/spec.md | 91 +
.../tasks.md | 150 +
.../.openspec.yaml | 2 +
.../design.md | 148 +
.../proposal.md | 92 +
.../review-rounds.md | 69 +
.../specs/build-optimization/spec.md | 65 +
.../specs/cache-optimization/spec.md | 60 +
.../specs/database-optimization/spec.md | 102 +
.../specs/hotpath-optimization/spec.md | 141 +
.../specs/logging-optimization/spec.md | 56 +
.../specs/middleware-optimization/spec.md | 47 +
.../backend-performance-optimization/tasks.md | 76 +
openspec/specs/frontend-routing/spec.md | 28 +
.../specs/openai-oauth-performance/spec.md | 47 +
.../specs/openai-ws-v2-performance/spec.md | 179 +
openspec/specs/schedule-account/spec.md | 11 +
openspec/specs/sora-account-apikey/spec.md | 82 +
openspec/specs/sora-client-ui/spec.md | 305 ++
.../specs/sora-generation-gateway/spec.md | 129 +
.../specs/sora-generation-history/spec.md | 138 +
openspec/specs/sora-s3-media-storage/spec.md | 104 +
openspec/specs/sora-s3-settings/spec.md | 39 +
.../specs/sora-user-storage-quota/spec.md | 91 +
openspec/specs/timing-wheel/spec.md | 44 +
openspec/specs/usage-request-type/spec.md | 97 +
243 files changed, 42643 insertions(+), 10426 deletions(-)
create mode 100644 backend/.gosec.json
create mode 100644 backend/internal/handler/admin/account_handler_bulk_update_test.go
create mode 100644 backend/internal/handler/admin/setting_handler_bulk_edit_template.go
create mode 100644 backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go
create mode 100644 backend/internal/handler/usage_record_submit_helper.go
create mode 100644 backend/internal/server/routes/admin_routes_test.go
create mode 100644 backend/internal/service/http_upstream_profile.go
create mode 100644 backend/internal/service/http_upstream_profile_test.go
create mode 100644 backend/internal/service/openai_gateway_record_usage_test.go
create mode 100644 backend/internal/service/openai_sse_zero_alloc_test.go
create mode 100644 backend/internal/service/openai_ws_client_preempt_test.go
create mode 100644 backend/internal/service/openai_ws_common.go
create mode 100644 backend/internal/service/openai_ws_forwarder_ingress_policy_test.go
delete mode 100644 backend/internal/service/openai_ws_forwarder_ingress_session_test.go
create mode 100644 backend/internal/service/openai_ws_forwarder_panic_test.go
create mode 100644 backend/internal/service/openai_ws_forwarder_recovery_test.go
delete mode 100644 backend/internal/service/openai_ws_forwarder_success_test.go
create mode 100644 backend/internal/service/openai_ws_forwarder_turn_error_test.go
create mode 100644 backend/internal/service/openai_ws_hotpath_perf_test.go
create mode 100644 backend/internal/service/openai_ws_ingress_context_pool.go
create mode 100644 backend/internal/service/openai_ws_ingress_context_pool_test.go
create mode 100644 backend/internal/service/openai_ws_ingress_normalizer.go
create mode 100644 backend/internal/service/openai_ws_ingress_normalizer_test.go
delete mode 100644 backend/internal/service/openai_ws_pool.go
delete mode 100644 backend/internal/service/openai_ws_pool_benchmark_test.go
delete mode 100644 backend/internal/service/openai_ws_pool_test.go
create mode 100644 backend/internal/service/openai_ws_recovery.go
create mode 100644 backend/internal/service/openai_ws_test_helpers_test.go
create mode 100644 backend/internal/service/openai_ws_turn.go
create mode 100644 backend/internal/service/openai_ws_upstream_pump_test.go
create mode 100644 backend/internal/service/setting_bulk_edit_template.go
create mode 100644 backend/internal/service/setting_bulk_edit_template_test.go
create mode 100644 backend/internal/service/token_refresh_parallel_test.go
create mode 100644 backend/internal/service/usage_billing_compensation_service.go
create mode 100644 backend/internal/service/usage_billing_compensation_service_test.go
create mode 100644 backend/internal/service/usage_billing_entry.go
create mode 100644 backend/migrations/064_add_billing_usage_entry_retry_fields.sql
create mode 100644 frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts
create mode 100644 frontend/src/api/admin/bulkEditTemplates.ts
create mode 100644 frontend/src/components/account/BulkEditAccountScopedModal.vue
create mode 100644 frontend/src/components/account/__tests__/bulkEditPayload.spec.ts
create mode 100644 frontend/src/components/account/__tests__/bulkEditScopeProfile.spec.ts
create mode 100644 frontend/src/components/account/__tests__/bulkEditTemplateRemoteMapper.spec.ts
create mode 100644 frontend/src/components/account/__tests__/bulkEditTemplateState.spec.ts
create mode 100644 frontend/src/components/account/__tests__/bulkEditTemplateStore.spec.ts
create mode 100644 frontend/src/components/account/bulkEditPayload.ts
create mode 100644 frontend/src/components/account/bulkEditScopeProfile.ts
create mode 100644 frontend/src/components/account/bulkEditScoped/BulkEditAnthropicApiKeyModal.vue
create mode 100644 frontend/src/components/account/bulkEditScoped/BulkEditAnthropicOAuthModal.vue
create mode 100644 frontend/src/components/account/bulkEditScoped/BulkEditAnthropicSetupTokenModal.vue
create mode 100644 frontend/src/components/account/bulkEditScoped/BulkEditAntigravityOAuthModal.vue
create mode 100644 frontend/src/components/account/bulkEditScoped/BulkEditAntigravityUpstreamModal.vue
create mode 100644 frontend/src/components/account/bulkEditScoped/BulkEditGeminiApiKeyModal.vue
create mode 100644 frontend/src/components/account/bulkEditScoped/BulkEditGeminiOAuthModal.vue
create mode 100644 frontend/src/components/account/bulkEditScoped/BulkEditOpenAIApiKeyModal.vue
create mode 100644 frontend/src/components/account/bulkEditScoped/BulkEditOpenAIOAuthModal.vue
create mode 100644 frontend/src/components/account/bulkEditScoped/BulkEditSoraApiKeyModal.vue
create mode 100644 frontend/src/components/account/bulkEditScoped/BulkEditSoraOAuthModal.vue
create mode 100644 frontend/src/components/account/bulkEditTemplateRemoteMapper.ts
create mode 100644 frontend/src/components/account/bulkEditTemplateState.ts
create mode 100644 frontend/src/components/account/bulkEditTemplateStore.ts
create mode 100644 frontend/src/views/__tests__/accountsBulkEditScope.spec.ts
create mode 100644 frontend/src/views/admin/accountsBulkEditScope.ts
create mode 100644 openspec/changes/2026-02-26-openai-http-path-extreme-performance/proposal.md
create mode 100644 openspec/changes/2026-02-26-optimize-backend-hotpath-performance/design.md
create mode 100644 openspec/changes/2026-02-26-optimize-backend-hotpath-performance/proposal.md
create mode 100644 openspec/changes/2026-02-26-optimize-backend-hotpath-performance/specs/backend-performance-hotspots/spec.md
create mode 100644 openspec/changes/2026-02-26-optimize-backend-hotpath-performance/tasks.md
create mode 100644 openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/design.md
create mode 100644 openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/proposal.md
create mode 100644 openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/review-rounds.md
create mode 100644 openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-bundle-optimization/spec.md
create mode 100644 openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-compatibility-rollout/spec.md
create mode 100644 openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-runtime-performance/spec.md
create mode 100644 openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/tasks.md
create mode 100644 openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/design.md
create mode 100644 openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/proposal.md
create mode 100644 openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/specs/openai-ws-v2-performance/spec.md
create mode 100644 openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/tasks.md
create mode 100644 openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/proposal.md
create mode 100644 openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/specs/schedule-account/spec.md
create mode 100644 openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/tasks.md
create mode 100644 openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/proposal.md
create mode 100644 openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/specs/frontend-routing/spec.md
create mode 100644 openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/tasks.md
create mode 100644 openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/proposal.md
create mode 100644 openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/specs/timing-wheel/spec.md
create mode 100644 openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/tasks.md
create mode 100644 openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/.openspec.yaml
create mode 100644 openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/design.md
create mode 100644 openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/final-acceptance-report.md
create mode 100644 openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/proposal.md
create mode 100644 openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/review.md
create mode 100644 openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/signoff-and-rollout.md
create mode 100644 openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/specs/openai-oauth-performance/spec.md
create mode 100644 openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/tasks.md
create mode 100644 "openspec/changes/archive/2026-02-25-\303\247/.openspec.yaml"
create mode 100644 "openspec/changes/archive/2026-02-25-\303\247/design.md"
create mode 100644 "openspec/changes/archive/2026-02-25-\303\247/proposal.md"
create mode 100644 "openspec/changes/archive/2026-02-25-\303\247/specs/openai-ws-v2-performance/spec.md"
create mode 100644 "openspec/changes/archive/2026-02-25-\303\247/tasks.md"
create mode 100644 "openspec/changes/archive/2026-02-25-\303\247/validation.md"
create mode 100644 openspec/changes/archive/2026-02-26-add-usage-request-type-enum/.openspec.yaml
create mode 100644 openspec/changes/archive/2026-02-26-add-usage-request-type-enum/design.md
create mode 100644 openspec/changes/archive/2026-02-26-add-usage-request-type-enum/proposal.md
create mode 100644 openspec/changes/archive/2026-02-26-add-usage-request-type-enum/specs/usage-request-type/spec.md
create mode 100644 openspec/changes/archive/2026-02-26-add-usage-request-type-enum/tasks.md
create mode 100644 openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/design.md
create mode 100644 openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/proposal.md
create mode 100644 openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/specs/openai-ws-v2-performance/spec.md
create mode 100644 openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/tasks.md
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/.openspec.yaml
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/design.md
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/proposal.md
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/review-rounds.md
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/sora-client-mockup.html
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-account-apikey/spec.md
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-client-ui/spec.md
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-generation-gateway/spec.md
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-generation-history/spec.md
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-s3-media-storage/spec.md
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-s3-settings/spec.md
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-user-storage-quota/spec.md
create mode 100644 openspec/changes/archive/2026-02-27-sora-client-s3-storage/tasks.md
create mode 100644 openspec/changes/backend-performance-optimization/.openspec.yaml
create mode 100644 openspec/changes/backend-performance-optimization/design.md
create mode 100644 openspec/changes/backend-performance-optimization/proposal.md
create mode 100644 openspec/changes/backend-performance-optimization/review-rounds.md
create mode 100644 openspec/changes/backend-performance-optimization/specs/build-optimization/spec.md
create mode 100644 openspec/changes/backend-performance-optimization/specs/cache-optimization/spec.md
create mode 100644 openspec/changes/backend-performance-optimization/specs/database-optimization/spec.md
create mode 100644 openspec/changes/backend-performance-optimization/specs/hotpath-optimization/spec.md
create mode 100644 openspec/changes/backend-performance-optimization/specs/logging-optimization/spec.md
create mode 100644 openspec/changes/backend-performance-optimization/specs/middleware-optimization/spec.md
create mode 100644 openspec/changes/backend-performance-optimization/tasks.md
create mode 100644 openspec/specs/frontend-routing/spec.md
create mode 100644 openspec/specs/openai-oauth-performance/spec.md
create mode 100644 openspec/specs/openai-ws-v2-performance/spec.md
create mode 100644 openspec/specs/schedule-account/spec.md
create mode 100644 openspec/specs/sora-account-apikey/spec.md
create mode 100644 openspec/specs/sora-client-ui/spec.md
create mode 100644 openspec/specs/sora-generation-gateway/spec.md
create mode 100644 openspec/specs/sora-generation-history/spec.md
create mode 100644 openspec/specs/sora-s3-media-storage/spec.md
create mode 100644 openspec/specs/sora-s3-settings/spec.md
create mode 100644 openspec/specs/sora-user-storage-quota/spec.md
create mode 100644 openspec/specs/timing-wheel/spec.md
create mode 100644 openspec/specs/usage-request-type/spec.md
diff --git a/README.md b/README.md
index 1e2f22907..c91f050eb 100644
--- a/README.md
+++ b/README.md
@@ -58,6 +58,34 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot
---
+## Codex CLI WebSocket v2 Example
+
+To enable OpenAI WebSocket Mode v2 in Codex CLI with Sub2API, add the following to `~/.codex/config.toml`:
+
+```toml
+model_provider = "aicodx2api"
+model = "gpt-5.3-codex"
+review_model = "gpt-5.3-codex"
+model_reasoning_effort = "xhigh"
+disable_response_storage = true
+network_access = "enabled"
+windows_wsl_setup_acknowledged = true
+
+[model_providers.aicodx2api]
+name = "aicodx2api"
+base_url = "https://api.sub2api.ai"
+wire_api = "responses"
+supports_websockets = true
+requires_openai_auth = true
+
+[features]
+responses_websockets_v2 = true
+```
+
+After updating the config, restart Codex CLI.
+
+---
+
## Deployment
### Method 1: Script Installation (Recommended)
diff --git a/README_CN.md b/README_CN.md
index 9da089b74..22a772b5a 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -62,6 +62,34 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅(
- 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。
- 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。
+## Codex CLI 开启 OpenAI WebSocket Mode v2 示例配置
+
+如需在 Codex CLI 中通过 Sub2API 启用 OpenAI WebSocket Mode v2,可将以下配置写入 `~/.codex/config.toml`:
+
+```toml
+model_provider = "aicodx2api"
+model = "gpt-5.3-codex"
+review_model = "gpt-5.3-codex"
+model_reasoning_effort = "xhigh"
+disable_response_storage = true
+network_access = "enabled"
+windows_wsl_setup_acknowledged = true
+
+[model_providers.aicodx2api]
+name = "aicodx2api"
+base_url = "https://api.sub2api.ai"
+wire_api = "responses"
+supports_websockets = true
+requires_openai_auth = true
+
+[features]
+responses_websockets_v2 = true
+```
+
+配置更新后,重启 Codex CLI 使其生效。
+
+---
+
## 部署方式
### 方式一:脚本安装(推荐)
diff --git a/backend/.gosec.json b/backend/.gosec.json
new file mode 100644
index 000000000..7a8ccb6a1
--- /dev/null
+++ b/backend/.gosec.json
@@ -0,0 +1,5 @@
+{
+ "global": {
+ "exclude": "G704,G101,G103,G104,G109,G115,G201,G202,G301,G302,G304,G306,G404"
+ }
+}
diff --git a/backend/Dockerfile b/backend/Dockerfile
index aeb20fdb6..6db2b1756 100644
--- a/backend/Dockerfile
+++ b/backend/Dockerfile
@@ -1,4 +1,4 @@
-FROM golang:1.25.7-alpine
+FROM registry-1.docker.io/library/golang:1.25.7-alpine
WORKDIR /app
diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION
index 32844913e..c98f2c2f4 100644
--- a/backend/cmd/server/VERSION
+++ b/backend/cmd/server/VERSION
@@ -1 +1 @@
-0.1.88
\ No newline at end of file
+0.1.85.21
diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go
index cbf89ba3b..5044f7ee0 100644
--- a/backend/cmd/server/wire.go
+++ b/backend/cmd/server/wire.go
@@ -210,9 +210,9 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
- {"OpenAIWSPool", func() error {
+ {"OpenAIWSCtxPool", func() error {
if openAIGateway != nil {
- openAIGateway.CloseOpenAIWSPool()
+ openAIGateway.CloseOpenAIWSCtxPool()
}
return nil
}},
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 90709f5bc..587374d56 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -193,8 +193,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
- adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig, settingService)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
@@ -203,7 +202,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig)
soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService)
soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig)
- handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo)
+ handlerSettingHandler := handler.ProvideSettingHandler(settingService, accountRepository, buildInfo)
totpHandler := handler.NewTotpHandler(totpService)
idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig)
idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig)
@@ -393,9 +392,9 @@ func provideCleanup(
antigravityOAuth.Stop()
return nil
}},
- {"OpenAIWSPool", func() error {
+ {"OpenAIWSCtxPool", func() error {
if openAIGateway != nil {
- openAIGateway.CloseOpenAIWSPool()
+ openAIGateway.CloseOpenAIWSCtxPool()
}
return nil
}},
diff --git a/backend/go.mod b/backend/go.mod
index a34c9fff9..d9f421a00 100644
--- a/backend/go.mod
+++ b/backend/go.mod
@@ -178,10 +178,7 @@ require (
golang.org/x/mod v0.32.0 // indirect
golang.org/x/sys v0.41.0 // indirect
golang.org/x/text v0.34.0 // indirect
- golang.org/x/tools v0.41.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect
- google.golang.org/grpc v1.75.1 // indirect
- google.golang.org/protobuf v1.36.10 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
modernc.org/libc v1.67.6 // indirect
modernc.org/mathutil v1.7.1 // indirect
diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go
index 763ed829c..f0aa5a0b1 100644
--- a/backend/internal/config/config.go
+++ b/backend/internal/config/config.go
@@ -265,13 +265,8 @@ type CSPConfig struct {
}
type ProxyFallbackConfig struct {
- // AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。
- // 仅影响以下非 AI 账号连接的辅助服务:
- // - GitHub Release 更新检查
- // - 定价数据拉取
- // 不影响 AI 账号网关连接(Claude/OpenAI/Gemini/Antigravity),
- // 这些关键路径的代理失败始终返回错误,不会回退直连。
- // 默认 false:避免因代理配置错误导致服务器真实 IP 泄露。
+ // AllowDirectOnError 当代理初始化失败时是否允许回退直连。
+ // 默认 false:避免因代理配置错误导致 IP 泄露/关联。
AllowDirectOnError bool `mapstructure:"allow_direct_on_error"`
}
@@ -371,6 +366,8 @@ type GatewayConfig struct {
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
// OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
+ // OpenAIHTTP2: OpenAI HTTP 上游协议策略(默认启用 HTTP/2,可按代理能力回退 HTTP/1.1)
+ OpenAIHTTP2 GatewayOpenAIHTTP2Config `mapstructure:"openai_http2"`
// HTTP 上游连接池配置(性能优化:支持高并发场景调优)
// MaxIdleConns: 所有主机的最大空闲连接总数
@@ -457,12 +454,27 @@ type GatewayConfig struct {
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
}
+// GatewayOpenAIHTTP2Config OpenAI HTTP 上游协议配置。
+// 默认启用 HTTP/2,多路复用提升并发效率;在部分代理不兼容时按策略回退 HTTP/1.1。
+type GatewayOpenAIHTTP2Config struct {
+ // Enabled: 是否启用 OpenAI HTTP/2 优先策略
+ Enabled bool `mapstructure:"enabled"`
+ // AllowProxyFallbackToHTTP1: 代理不兼容 HTTP/2 时是否允许回退 HTTP/1.1
+ AllowProxyFallbackToHTTP1 bool `mapstructure:"allow_proxy_fallback_to_http1"`
+ // FallbackErrorThreshold: 在窗口期内触发回退所需的连续错误次数
+ FallbackErrorThreshold int `mapstructure:"fallback_error_threshold"`
+ // FallbackWindowSeconds: 连续错误计数窗口(秒)
+ FallbackWindowSeconds int `mapstructure:"fallback_window_seconds"`
+ // FallbackTTLSeconds: 进入 HTTP/1.1 回退态后的持续时间(秒)
+ FallbackTTLSeconds int `mapstructure:"fallback_ttl_seconds"`
+}
+
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。
// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。
type GatewayOpenAIWSConfig struct {
- // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为)
+ // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 true;关闭时保持 legacy 行为)
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
- // IngressModeDefault: ingress 默认模式(off/shared/dedicated)
+ // IngressModeDefault: ingress 默认模式(off/ctx_pool)
IngressModeDefault string `mapstructure:"ingress_mode_default"`
// Enabled: 全局总开关(默认 true)
Enabled bool `mapstructure:"enabled"`
@@ -502,8 +514,9 @@ type GatewayOpenAIWSConfig struct {
// APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor))
APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"`
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
- ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
- WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
+ ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
+ ClientReadIdleTimeoutSeconds int `mapstructure:"client_read_idle_timeout_seconds"`
+ WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"`
QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"`
// EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数)
@@ -525,6 +538,11 @@ type GatewayOpenAIWSConfig struct {
// PayloadLogSampleRate: payload_schema 日志采样率(0-1)
PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"`
+ // UpstreamConnMaxAgeSeconds: 上游 WebSocket 连接最大存活时间(秒)。
+ // OpenAI 在 60 分钟后强制断开连接,此参数控制主动轮换阈值。
+ // 默认 3300(55 分钟);设为 0 则禁用超龄轮换。
+ UpstreamConnMaxAgeSeconds int `mapstructure:"upstream_conn_max_age_seconds"`
+
// 账号调度与粘连参数
LBTopK int `mapstructure:"lb_top_k"`
// StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL
@@ -541,6 +559,32 @@ type GatewayOpenAIWSConfig struct {
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
+
+ // SchedulerP2CEnabled: 启用 P2C(Power-of-Two-Choices)选择算法替代 Top-K 加权采样
+ SchedulerP2CEnabled bool `mapstructure:"scheduler_p2c_enabled"`
+
+ // Softmax 温度采样:替代线性平移的概率选择策略
+ SchedulerSoftmaxEnabled bool `mapstructure:"scheduler_softmax_enabled"`
+ SchedulerSoftmaxTemperature float64 `mapstructure:"scheduler_softmax_temperature"`
+
+ // 账号级熔断器
+ SchedulerCircuitBreakerEnabled bool `mapstructure:"scheduler_circuit_breaker_enabled"`
+ SchedulerCircuitBreakerFailThreshold int `mapstructure:"scheduler_circuit_breaker_fail_threshold"`
+ SchedulerCircuitBreakerCooldownSec int `mapstructure:"scheduler_circuit_breaker_cooldown_sec"`
+ SchedulerCircuitBreakerHalfOpenMax int `mapstructure:"scheduler_circuit_breaker_half_open_max"`
+
+ // 条件性 Sticky Session 释放:当粘连账号不健康时主动释放,回退到负载均衡
+ StickyReleaseEnabled bool `mapstructure:"sticky_release_enabled"`
+ StickyReleaseErrorThreshold float64 `mapstructure:"sticky_release_error_threshold"`
+
+ // Per-model TTFT tracking
+ SchedulerPerModelTTFTEnabled bool `mapstructure:"scheduler_per_model_ttft_enabled"`
+ SchedulerPerModelTTFTMaxModels int `mapstructure:"scheduler_per_model_ttft_max_models"`
+
+ // SchedulerTrendEnabled: 启用负载趋势预测(线性回归外推),在打分时对 loadFactor 施加趋势修正
+ SchedulerTrendEnabled bool `mapstructure:"scheduler_trend_enabled"`
+ // SchedulerTrendMaxSlope: 趋势斜率归一化上限(每秒负载百分比变化率);0 或负数使用默认值 5.0
+ SchedulerTrendMaxSlope float64 `mapstructure:"scheduler_trend_max_slope"`
}
// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。
@@ -1110,9 +1154,6 @@ func setDefaults() {
viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
- // Security - disable direct fallback on proxy error
- viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
-
// Billing
viper.SetDefault("billing.circuit_breaker.enabled", true)
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
@@ -1270,8 +1311,8 @@ func setDefaults() {
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
// OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚)
viper.SetDefault("gateway.openai_ws.enabled", true)
- viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
- viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared")
+ viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", true)
+ viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
viper.SetDefault("gateway.openai_ws.force_http", false)
@@ -1302,6 +1343,7 @@ func setDefaults() {
viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2)
viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000)
viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2)
+ viper.SetDefault("gateway.openai_ws.upstream_conn_max_age_seconds", 3300)
viper.SetDefault("gateway.openai_ws.lb_top_k", 7)
viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true)
@@ -1314,6 +1356,12 @@ func setDefaults() {
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
+ // OpenAI HTTP upstream protocol strategy
+ viper.SetDefault("gateway.openai_http2.enabled", true)
+ viper.SetDefault("gateway.openai_http2.allow_proxy_fallback_to_http1", true)
+ viper.SetDefault("gateway.openai_http2.fallback_error_threshold", 2)
+ viper.SetDefault("gateway.openai_http2.fallback_window_seconds", 60)
+ viper.SetDefault("gateway.openai_http2.fallback_ttl_seconds", 600)
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
@@ -1358,7 +1406,7 @@ func setDefaults() {
viper.SetDefault("gateway.usage_record.worker_count", 128)
viper.SetDefault("gateway.usage_record.queue_size", 16384)
viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5)
- viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample)
+ viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySync)
viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10)
viper.SetDefault("gateway.usage_record.auto_scale_enabled", true)
viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128)
@@ -1423,6 +1471,9 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
+ // Security - proxy fallback
+ viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
+
// Subscription Maintenance (bounded queue + worker pool)
viper.SetDefault("subscription_maintenance.worker_count", 2)
viper.SetDefault("subscription_maintenance.queue_size", 1024)
@@ -1901,6 +1952,15 @@ func (c *Config) Validate() error {
(c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) {
return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds")
}
+ if c.Gateway.OpenAIHTTP2.FallbackErrorThreshold < 0 {
+ return fmt.Errorf("gateway.openai_http2.fallback_error_threshold must be non-negative")
+ }
+ if c.Gateway.OpenAIHTTP2.FallbackWindowSeconds < 0 {
+ return fmt.Errorf("gateway.openai_http2.fallback_window_seconds must be non-negative")
+ }
+ if c.Gateway.OpenAIHTTP2.FallbackTTLSeconds < 0 {
+ return fmt.Errorf("gateway.openai_http2.fallback_ttl_seconds must be non-negative")
+ }
// 兼容旧键 sticky_previous_response_ttl_seconds
if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 {
c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds
@@ -1969,11 +2029,16 @@ func (c *Config) Validate() error {
if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 {
return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative")
}
+ if c.Gateway.OpenAIWS.ResponsesWebsockets && !c.Gateway.OpenAIWS.ResponsesWebsocketsV2 {
+ return fmt.Errorf("gateway.openai_ws.responses_websockets (v1) is not supported; enable gateway.openai_ws.responses_websockets_v2")
+ }
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" {
switch mode {
- case "off", "shared", "dedicated":
+ case "off", "ctx_pool":
+ case "shared", "dedicated":
+ slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool", "value", mode)
default:
- return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated")
+ return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool")
}
}
if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" {
@@ -1986,6 +2051,9 @@ func (c *Config) Validate() error {
if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 {
return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]")
}
+ if c.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds < 0 {
+ return fmt.Errorf("gateway.openai_ws.upstream_conn_max_age_seconds must be non-negative")
+ }
if c.Gateway.OpenAIWS.LBTopK <= 0 {
return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive")
}
@@ -2013,6 +2081,22 @@ func (c *Config) Validate() error {
if weightSum <= 0 {
return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero")
}
+ // Validate new scheduler/sticky-release config ranges.
+ if c.Gateway.OpenAIWS.SchedulerSoftmaxTemperature < 0 {
+ return fmt.Errorf("gateway.openai_ws.scheduler_softmax_temperature must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.StickyReleaseErrorThreshold < 0 || c.Gateway.OpenAIWS.StickyReleaseErrorThreshold > 1 {
+ return fmt.Errorf("gateway.openai_ws.sticky_release_error_threshold must be within [0,1]")
+ }
+ if c.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold < 0 {
+ return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_fail_threshold must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec < 0 {
+ return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_cooldown_sec must be non-negative")
+ }
+ if c.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax < 0 {
+ return fmt.Errorf("gateway.openai_ws.scheduler_circuit_breaker_half_open_max must be non-negative")
+ }
if c.Gateway.MaxLineSize < 0 {
return fmt.Errorf("gateway.max_line_size must be non-negative")
}
diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go
index e3b592e2c..05d880cb5 100644
--- a/backend/internal/config/config_test.go
+++ b/backend/internal/config/config_test.go
@@ -150,11 +150,11 @@ func TestLoadDefaultOpenAIWSConfig(t *testing.T) {
if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" {
t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict")
}
- if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
- t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false")
+ if !cfg.Gateway.OpenAIWS.ModeRouterV2Enabled {
+ t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = false, want true")
}
- if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" {
- t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared")
+ if cfg.Gateway.OpenAIWS.IngressModeDefault != "ctx_pool" {
+ t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "ctx_pool")
}
}
@@ -1373,7 +1373,7 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
wantErr: "gateway.openai_ws.store_disabled_conn_mode",
},
{
- name: "ingress_mode_default 必须为 off|shared|dedicated",
+ name: "ingress_mode_default 必须为 off|ctx_pool",
mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" },
wantErr: "gateway.openai_ws.ingress_mode_default",
},
@@ -1387,6 +1387,14 @@ func TestValidateConfig_OpenAIWSRules(t *testing.T) {
mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 },
wantErr: "gateway.openai_ws.retry_total_budget_ms",
},
+ {
+ name: "responses_websockets v1-only 配置不允许",
+ mutate: func(c *Config) {
+ c.Gateway.OpenAIWS.ResponsesWebsockets = true
+ c.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
+ },
+ wantErr: "gateway.openai_ws.responses_websockets",
+ },
{
name: "lb_top_k 必须为正数",
mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 },
@@ -1637,8 +1645,8 @@ func TestLoad_DefaultGatewayUsageRecordConfig(t *testing.T) {
if cfg.Gateway.UsageRecord.TaskTimeoutSeconds != 5 {
t.Fatalf("task_timeout_seconds = %d, want 5", cfg.Gateway.UsageRecord.TaskTimeoutSeconds)
}
- if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample {
- t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySample)
+ if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySync {
+ t.Fatalf("overflow_policy = %s, want %s", cfg.Gateway.UsageRecord.OverflowPolicy, UsageRecordOverflowPolicySync)
}
if cfg.Gateway.UsageRecord.OverflowSamplePercent != 10 {
t.Fatalf("overflow_sample_percent = %d, want 10", cfg.Gateway.UsageRecord.OverflowSamplePercent)
diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go
index 98ead2841..6cd2ce7a3 100644
--- a/backend/internal/handler/admin/account_handler.go
+++ b/backend/internal/handler/admin/account_handler.go
@@ -137,6 +137,7 @@ type BulkUpdateAccountsRequest struct {
RateMultiplier *float64 `json:"rate_multiplier"`
Status string `json:"status" binding:"omitempty,oneof=active inactive error"`
Schedulable *bool `json:"schedulable"`
+ AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
GroupIDs *[]int64 `json:"group_ids"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
@@ -1098,6 +1099,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
req.RateMultiplier != nil ||
req.Status != "" ||
req.Schedulable != nil ||
+ req.AutoPauseOnExpired != nil ||
req.GroupIDs != nil ||
len(req.Credentials) > 0 ||
len(req.Extra) > 0
@@ -1116,6 +1118,7 @@ func (h *AccountHandler) BulkUpdate(c *gin.Context) {
RateMultiplier: req.RateMultiplier,
Status: req.Status,
Schedulable: req.Schedulable,
+ AutoPauseOnExpired: req.AutoPauseOnExpired,
GroupIDs: req.GroupIDs,
Credentials: req.Credentials,
Extra: req.Extra,
diff --git a/backend/internal/handler/admin/account_handler_bulk_update_test.go b/backend/internal/handler/admin/account_handler_bulk_update_test.go
new file mode 100644
index 000000000..c2dfdf746
--- /dev/null
+++ b/backend/internal/handler/admin/account_handler_bulk_update_test.go
@@ -0,0 +1,62 @@
+package admin
+
+import (
+ "bytes"
+ "encoding/json"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func setupAccountBulkUpdateRouter(adminSvc *stubAdminService) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ accountHandler := NewAccountHandler(adminSvc, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ router.POST("/api/v1/admin/accounts/bulk-update", accountHandler.BulkUpdate)
+ return router
+}
+
+func TestAccountHandlerBulkUpdate_ForwardsAutoPauseOnExpired(t *testing.T) {
+ adminSvc := newStubAdminService()
+ router := setupAccountBulkUpdateRouter(adminSvc)
+
+ body, err := json.Marshal(map[string]any{
+ "account_ids": []int64{1},
+ "auto_pause_on_expired": true,
+ })
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+ require.NotNil(t, adminSvc.lastBulkUpdateInput)
+ require.NotNil(t, adminSvc.lastBulkUpdateInput.AutoPauseOnExpired)
+ require.True(t, *adminSvc.lastBulkUpdateInput.AutoPauseOnExpired)
+}
+
+func TestAccountHandlerBulkUpdate_RejectsEmptyUpdates(t *testing.T) {
+ adminSvc := newStubAdminService()
+ router := setupAccountBulkUpdateRouter(adminSvc)
+
+ body, err := json.Marshal(map[string]any{
+ "account_ids": []int64{1},
+ })
+ require.NoError(t, err)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/bulk-update", bytes.NewReader(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+
+ var resp map[string]any
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Contains(t, resp["message"], "No updates provided")
+}
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index f3b99ddbe..efd9ee13c 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -31,7 +31,8 @@ type stubAdminService struct {
platform string
groupIDs []int64
}
- mu sync.Mutex
+ lastBulkUpdateInput *service.BulkUpdateAccountsInput
+ mu sync.Mutex
}
func newStubAdminService() *stubAdminService {
@@ -236,10 +237,10 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64,
}
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
- if s.bulkUpdateAccountErr != nil {
- return nil, s.bulkUpdateAccountErr
- }
- return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil
+ s.mu.Lock()
+ s.lastBulkUpdateInput = input
+ s.mu.Unlock()
+ return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
}
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
diff --git a/backend/internal/handler/admin/data_management_handler.go b/backend/internal/handler/admin/data_management_handler.go
index 02fc766f9..69a0b5b51 100644
--- a/backend/internal/handler/admin/data_management_handler.go
+++ b/backend/internal/handler/admin/data_management_handler.go
@@ -1,7 +1,6 @@
package admin
import (
- "context"
"strconv"
"strings"
@@ -14,34 +13,13 @@ import (
)
type DataManagementHandler struct {
- dataManagementService dataManagementService
+ dataManagementService *service.DataManagementService
}
func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler {
return &DataManagementHandler{dataManagementService: dataManagementService}
}
-type dataManagementService interface {
- GetConfig(ctx context.Context) (service.DataManagementConfig, error)
- UpdateConfig(ctx context.Context, cfg service.DataManagementConfig) (service.DataManagementConfig, error)
- ValidateS3(ctx context.Context, cfg service.DataManagementS3Config) (service.DataManagementTestS3Result, error)
- CreateBackupJob(ctx context.Context, input service.DataManagementCreateBackupJobInput) (service.DataManagementBackupJob, error)
- ListSourceProfiles(ctx context.Context, sourceType string) ([]service.DataManagementSourceProfile, error)
- CreateSourceProfile(ctx context.Context, input service.DataManagementCreateSourceProfileInput) (service.DataManagementSourceProfile, error)
- UpdateSourceProfile(ctx context.Context, input service.DataManagementUpdateSourceProfileInput) (service.DataManagementSourceProfile, error)
- DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error
- SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (service.DataManagementSourceProfile, error)
- ListS3Profiles(ctx context.Context) ([]service.DataManagementS3Profile, error)
- CreateS3Profile(ctx context.Context, input service.DataManagementCreateS3ProfileInput) (service.DataManagementS3Profile, error)
- UpdateS3Profile(ctx context.Context, input service.DataManagementUpdateS3ProfileInput) (service.DataManagementS3Profile, error)
- DeleteS3Profile(ctx context.Context, profileID string) error
- SetActiveS3Profile(ctx context.Context, profileID string) (service.DataManagementS3Profile, error)
- ListBackupJobs(ctx context.Context, input service.DataManagementListBackupJobsInput) (service.DataManagementListBackupJobsResult, error)
- GetBackupJob(ctx context.Context, jobID string) (service.DataManagementBackupJob, error)
- EnsureAgentEnabled(ctx context.Context) error
- GetAgentHealth(ctx context.Context) service.DataManagementAgentHealth
-}
-
type TestS3ConnectionRequest struct {
Endpoint string `json:"endpoint"`
Region string `json:"region" binding:"required"`
diff --git a/backend/internal/handler/admin/setting_handler_bulk_edit_template.go b/backend/internal/handler/admin/setting_handler_bulk_edit_template.go
new file mode 100644
index 000000000..c63ed9af2
--- /dev/null
+++ b/backend/internal/handler/admin/setting_handler_bulk_edit_template.go
@@ -0,0 +1,228 @@
+package admin
+
+import (
+ "fmt"
+ "strconv"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+)
+
+type UpsertBulkEditTemplateRequest struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ ScopePlatform string `json:"scope_platform"`
+ ScopeType string `json:"scope_type"`
+ ShareScope string `json:"share_scope"`
+ GroupIDs []int64 `json:"group_ids"`
+ State map[string]any `json:"state"`
+}
+
+type RollbackBulkEditTemplateRequest struct {
+ VersionID string `json:"version_id"`
+}
+
+// ListBulkEditTemplates 获取批量编辑模板列表
+// GET /api/v1/admin/settings/bulk-edit-templates
+func (h *SettingHandler) ListBulkEditTemplates(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+
+ scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids"))
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ items, listErr := h.settingService.ListBulkEditTemplates(c.Request.Context(), service.BulkEditTemplateQuery{
+ ScopePlatform: c.Query("scope_platform"),
+ ScopeType: c.Query("scope_type"),
+ ScopeGroupIDs: scopeGroupIDs,
+ RequesterUserID: subject.UserID,
+ })
+ if listErr != nil {
+ response.ErrorFrom(c, listErr)
+ return
+ }
+
+ response.Success(c, gin.H{"items": items})
+}
+
+// UpsertBulkEditTemplate 创建/更新批量编辑模板
+// POST /api/v1/admin/settings/bulk-edit-templates
+func (h *SettingHandler) UpsertBulkEditTemplate(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+
+ var req UpsertBulkEditTemplateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ item, upsertErr := h.settingService.UpsertBulkEditTemplate(
+ c.Request.Context(),
+ service.BulkEditTemplateUpsertInput{
+ ID: req.ID,
+ Name: req.Name,
+ ScopePlatform: req.ScopePlatform,
+ ScopeType: req.ScopeType,
+ ShareScope: req.ShareScope,
+ GroupIDs: req.GroupIDs,
+ State: req.State,
+ RequesterUserID: subject.UserID,
+ },
+ )
+ if upsertErr != nil {
+ response.ErrorFrom(c, upsertErr)
+ return
+ }
+
+ response.Success(c, item)
+}
+
+// DeleteBulkEditTemplate 删除批量编辑模板
+// DELETE /api/v1/admin/settings/bulk-edit-templates/:template_id
+func (h *SettingHandler) DeleteBulkEditTemplate(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+
+ templateID := strings.TrimSpace(c.Param("template_id"))
+ if templateID == "" {
+ response.BadRequest(c, "template_id is required")
+ return
+ }
+
+ if err := h.settingService.DeleteBulkEditTemplate(c.Request.Context(), templateID, subject.UserID); err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ response.Success(c, gin.H{"deleted": true})
+}
+
+// ListBulkEditTemplateVersions 获取模板版本历史
+// GET /api/v1/admin/settings/bulk-edit-templates/:template_id/versions
+func (h *SettingHandler) ListBulkEditTemplateVersions(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+
+ templateID := strings.TrimSpace(c.Param("template_id"))
+ if templateID == "" {
+ response.BadRequest(c, "template_id is required")
+ return
+ }
+
+ scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids"))
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ items, listErr := h.settingService.ListBulkEditTemplateVersions(
+ c.Request.Context(),
+ service.BulkEditTemplateVersionQuery{
+ TemplateID: templateID,
+ ScopeGroupIDs: scopeGroupIDs,
+ RequesterUserID: subject.UserID,
+ },
+ )
+ if listErr != nil {
+ response.ErrorFrom(c, listErr)
+ return
+ }
+
+ response.Success(c, gin.H{"items": items})
+}
+
+// RollbackBulkEditTemplate 回滚模板到指定版本
+// POST /api/v1/admin/settings/bulk-edit-templates/:template_id/rollback
+func (h *SettingHandler) RollbackBulkEditTemplate(c *gin.Context) {
+ subject, ok := middleware.GetAuthSubjectFromContext(c)
+ if !ok || subject.UserID <= 0 {
+ response.Unauthorized(c, "Unauthorized")
+ return
+ }
+
+ templateID := strings.TrimSpace(c.Param("template_id"))
+ if templateID == "" {
+ response.BadRequest(c, "template_id is required")
+ return
+ }
+
+ scopeGroupIDs, err := parseScopeGroupIDs(c.Query("scope_group_ids"))
+ if err != nil {
+ response.BadRequest(c, err.Error())
+ return
+ }
+
+ var req RollbackBulkEditTemplateRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ item, rollbackErr := h.settingService.RollbackBulkEditTemplate(
+ c.Request.Context(),
+ service.BulkEditTemplateRollbackInput{
+ TemplateID: templateID,
+ VersionID: req.VersionID,
+ ScopeGroupIDs: scopeGroupIDs,
+ RequesterUserID: subject.UserID,
+ },
+ )
+ if rollbackErr != nil {
+ response.ErrorFrom(c, rollbackErr)
+ return
+ }
+
+ response.Success(c, item)
+}
+
+func parseScopeGroupIDs(raw string) ([]int64, error) {
+ trimmed := strings.TrimSpace(raw)
+ if trimmed == "" {
+ return nil, nil
+ }
+
+ parts := strings.Split(trimmed, ",")
+ if len(parts) == 0 {
+ return nil, nil
+ }
+
+ seen := make(map[int64]struct{}, len(parts))
+ groupIDs := make([]int64, 0, len(parts))
+ for _, part := range parts {
+ candidate := strings.TrimSpace(part)
+ if candidate == "" {
+ continue
+ }
+
+ groupID, err := strconv.ParseInt(candidate, 10, 64)
+ if err != nil || groupID <= 0 {
+ return nil, fmt.Errorf("scope_group_ids must be comma-separated positive integers")
+ }
+ if _, exists := seen[groupID]; exists {
+ continue
+ }
+ seen[groupID] = struct{}{}
+ groupIDs = append(groupIDs, groupID)
+ }
+
+ return groupIDs, nil
+}
diff --git a/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go b/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go
new file mode 100644
index 000000000..712f1bf89
--- /dev/null
+++ b/backend/internal/handler/admin/setting_handler_bulk_edit_template_test.go
@@ -0,0 +1,592 @@
+package admin
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strconv"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type settingHandlerTemplateRepoStub struct {
+ values map[string]string
+}
+
+func newSettingHandlerTemplateRepoStub() *settingHandlerTemplateRepoStub {
+ return &settingHandlerTemplateRepoStub{values: map[string]string{}}
+}
+
+func (s *settingHandlerTemplateRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ value, err := s.GetValue(ctx, key)
+ if err != nil {
+ return nil, err
+ }
+ return &service.Setting{Key: key, Value: value}, nil
+}
+
+func (s *settingHandlerTemplateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", service.ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *settingHandlerTemplateRepoStub) Set(ctx context.Context, key, value string) error {
+ s.values[key] = value
+ return nil
+}
+
+func (s *settingHandlerTemplateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *settingHandlerTemplateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ for key, value := range settings {
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *settingHandlerTemplateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *settingHandlerTemplateRepoStub) Delete(ctx context.Context, key string) error {
+ delete(s.values, key)
+ return nil
+}
+
+type failingSettingRepoStub struct{}
+
+func (s *failingSettingRepoStub) Get(ctx context.Context, key string) (*service.Setting, error) {
+ return nil, errors.New("boom")
+}
+func (s *failingSettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ return "", errors.New("boom")
+}
+func (s *failingSettingRepoStub) Set(ctx context.Context, key, value string) error {
+ return errors.New("boom")
+}
+func (s *failingSettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ return nil, errors.New("boom")
+}
+func (s *failingSettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ return errors.New("boom")
+}
+func (s *failingSettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ return nil, errors.New("boom")
+}
+func (s *failingSettingRepoStub) Delete(ctx context.Context, key string) error {
+ return errors.New("boom")
+}
+
+func setupBulkEditTemplateRouter() *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ repo := newSettingHandlerTemplateRepoStub()
+ settingService := service.NewSettingService(repo, nil)
+ handler := NewSettingHandler(settingService, nil, nil, nil, nil)
+
+ router := gin.New()
+ router.Use(func(c *gin.Context) {
+ uid := int64(1)
+ if header := c.GetHeader("X-User-ID"); header != "" {
+ if parsed, err := strconv.ParseInt(header, 10, 64); err == nil && parsed > 0 {
+ uid = parsed
+ }
+ }
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: uid})
+ c.Next()
+ })
+
+ router.GET("/api/v1/admin/settings/bulk-edit-templates", handler.ListBulkEditTemplates)
+ router.POST("/api/v1/admin/settings/bulk-edit-templates", handler.UpsertBulkEditTemplate)
+ router.DELETE("/api/v1/admin/settings/bulk-edit-templates/:template_id", handler.DeleteBulkEditTemplate)
+ router.GET("/api/v1/admin/settings/bulk-edit-templates/:template_id/versions", handler.ListBulkEditTemplateVersions)
+ router.POST("/api/v1/admin/settings/bulk-edit-templates/:template_id/rollback", handler.RollbackBulkEditTemplate)
+
+ return router
+}
+
+func decodeResponseDataMap(t *testing.T, body []byte) map[string]any {
+ t.Helper()
+ var payload response.Response
+ require.NoError(t, json.Unmarshal(body, &payload))
+ if payload.Data == nil {
+ return map[string]any{}
+ }
+ asMap, ok := payload.Data.(map[string]any)
+ require.True(t, ok)
+ return asMap
+}
+
+func TestSettingHandlerBulkEditTemplate_CRUDFlow(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ createBody := map[string]any{
+ "name": "OpenAI OAuth Baseline",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "team",
+ "state": map[string]any{
+ "enableOpenAIPassthrough": true,
+ },
+ }
+ raw, err := json.Marshal(createBody)
+ require.NoError(t, err)
+
+ createRec := httptest.NewRecorder()
+ createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw))
+ createReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(createRec, createReq)
+ require.Equal(t, http.StatusOK, createRec.Code)
+
+ createData := decodeResponseDataMap(t, createRec.Body.Bytes())
+ templateID, ok := createData["id"].(string)
+ require.True(t, ok)
+ require.NotEmpty(t, templateID)
+
+ listRec := httptest.NewRecorder()
+ listReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth",
+ nil,
+ )
+ router.ServeHTTP(listRec, listReq)
+ require.Equal(t, http.StatusOK, listRec.Code)
+
+ listData := decodeResponseDataMap(t, listRec.Body.Bytes())
+ items, ok := listData["items"].([]any)
+ require.True(t, ok)
+ require.Len(t, items, 1)
+
+ deleteRec := httptest.NewRecorder()
+ deleteReq := httptest.NewRequest(
+ http.MethodDelete,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID,
+ nil,
+ )
+ router.ServeHTTP(deleteRec, deleteReq)
+ require.Equal(t, http.StatusOK, deleteRec.Code)
+
+ listAfterDeleteRec := httptest.NewRecorder()
+ listAfterDeleteReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth",
+ nil,
+ )
+ router.ServeHTTP(listAfterDeleteRec, listAfterDeleteReq)
+ require.Equal(t, http.StatusOK, listAfterDeleteRec.Code)
+
+ listAfterDeleteData := decodeResponseDataMap(t, listAfterDeleteRec.Body.Bytes())
+ itemsAfterDelete, ok := listAfterDeleteData["items"].([]any)
+ require.True(t, ok)
+ require.Len(t, itemsAfterDelete, 0)
+}
+
+func TestSettingHandlerBulkEditTemplate_VersionsAndRollback(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ createBody := map[string]any{
+ "name": "Rollback Target",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "groups",
+ "group_ids": []int64{2},
+ "state": map[string]any{
+ "enableOpenAIWSMode": true,
+ },
+ }
+ createRaw, err := json.Marshal(createBody)
+ require.NoError(t, err)
+
+ createRec := httptest.NewRecorder()
+ createReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates",
+ bytes.NewReader(createRaw),
+ )
+ createReq.Header.Set("Content-Type", "application/json")
+ createReq.Header.Set("X-User-ID", "9")
+ router.ServeHTTP(createRec, createReq)
+ require.Equal(t, http.StatusOK, createRec.Code)
+ createData := decodeResponseDataMap(t, createRec.Body.Bytes())
+ templateID := createData["id"].(string)
+
+ updateBody := map[string]any{
+ "id": templateID,
+ "name": "Rollback Target",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "team",
+ "group_ids": []int64{},
+ "state": map[string]any{
+ "enableOpenAIWSMode": false,
+ },
+ }
+ updateRaw, err := json.Marshal(updateBody)
+ require.NoError(t, err)
+
+ updateRec := httptest.NewRecorder()
+ updateReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates",
+ bytes.NewReader(updateRaw),
+ )
+ updateReq.Header.Set("Content-Type", "application/json")
+ updateReq.Header.Set("X-User-ID", "9")
+ router.ServeHTTP(updateRec, updateReq)
+ require.Equal(t, http.StatusOK, updateRec.Code)
+
+ versionsRec := httptest.NewRecorder()
+ versionsReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/versions?scope_group_ids=2",
+ nil,
+ )
+ versionsReq.Header.Set("X-User-ID", "9")
+ router.ServeHTTP(versionsRec, versionsReq)
+ require.Equal(t, http.StatusOK, versionsRec.Code)
+ versionsData := decodeResponseDataMap(t, versionsRec.Body.Bytes())
+ versions := versionsData["items"].([]any)
+ require.Len(t, versions, 1)
+ versionID := versions[0].(map[string]any)["version_id"].(string)
+
+ rollbackBody := map[string]any{"version_id": versionID}
+ rollbackRaw, err := json.Marshal(rollbackBody)
+ require.NoError(t, err)
+
+ rollbackRec := httptest.NewRecorder()
+ rollbackReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/rollback?scope_group_ids=2",
+ bytes.NewReader(rollbackRaw),
+ )
+ rollbackReq.Header.Set("Content-Type", "application/json")
+ rollbackReq.Header.Set("X-User-ID", "9")
+ router.ServeHTTP(rollbackRec, rollbackReq)
+ require.Equal(t, http.StatusOK, rollbackRec.Code)
+ rollbackData := decodeResponseDataMap(t, rollbackRec.Body.Bytes())
+ require.Equal(t, "groups", rollbackData["share_scope"])
+ groupIDs := rollbackData["group_ids"].([]any)
+ require.Equal(t, []any{float64(2)}, groupIDs)
+ state := rollbackData["state"].(map[string]any)
+ require.Equal(t, true, state["enableOpenAIWSMode"])
+
+ versionsAfterRollbackRec := httptest.NewRecorder()
+ versionsAfterRollbackReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID+"/versions?scope_group_ids=2",
+ nil,
+ )
+ versionsAfterRollbackReq.Header.Set("X-User-ID", "9")
+ router.ServeHTTP(versionsAfterRollbackRec, versionsAfterRollbackReq)
+ require.Equal(t, http.StatusOK, versionsAfterRollbackRec.Code)
+ versionsAfterRollbackData := decodeResponseDataMap(t, versionsAfterRollbackRec.Body.Bytes())
+ versionsAfterRollback := versionsAfterRollbackData["items"].([]any)
+ require.Len(t, versionsAfterRollback, 2)
+}
+
+func TestSettingHandlerBulkEditTemplate_Validation(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ invalidCreateBody := map[string]any{
+ "name": "Groups Template",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "groups",
+ "group_ids": []int64{},
+ "state": map[string]any{},
+ }
+ raw, err := json.Marshal(invalidCreateBody)
+ require.NoError(t, err)
+
+ invalidCreateRec := httptest.NewRecorder()
+ invalidCreateReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw))
+ invalidCreateReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(invalidCreateRec, invalidCreateReq)
+ require.Equal(t, http.StatusBadRequest, invalidCreateRec.Code)
+
+ invalidListRec := httptest.NewRecorder()
+ invalidListReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_group_ids=abc",
+ nil,
+ )
+ router.ServeHTTP(invalidListRec, invalidListReq)
+ require.Equal(t, http.StatusBadRequest, invalidListRec.Code)
+}
+
+func TestSettingHandlerBulkEditTemplate_PrivateVisibilityAndDeletePermission(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ createBody := map[string]any{
+ "name": "Private Template",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "private",
+ "state": map[string]any{
+ "enableBaseUrl": true,
+ },
+ }
+ raw, err := json.Marshal(createBody)
+ require.NoError(t, err)
+
+ createRec := httptest.NewRecorder()
+ createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw))
+ createReq.Header.Set("Content-Type", "application/json")
+ createReq.Header.Set("X-User-ID", "100")
+ router.ServeHTTP(createRec, createReq)
+ require.Equal(t, http.StatusOK, createRec.Code)
+
+ createData := decodeResponseDataMap(t, createRec.Body.Bytes())
+ templateID := createData["id"].(string)
+
+ listByOtherRec := httptest.NewRecorder()
+ listByOtherReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth",
+ nil,
+ )
+ listByOtherReq.Header.Set("X-User-ID", "200")
+ router.ServeHTTP(listByOtherRec, listByOtherReq)
+ require.Equal(t, http.StatusOK, listByOtherRec.Code)
+
+ listByOtherData := decodeResponseDataMap(t, listByOtherRec.Body.Bytes())
+ items, ok := listByOtherData["items"].([]any)
+ require.True(t, ok)
+ require.Len(t, items, 0)
+
+ deleteByOtherRec := httptest.NewRecorder()
+ deleteByOtherReq := httptest.NewRequest(
+ http.MethodDelete,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID,
+ nil,
+ )
+ deleteByOtherReq.Header.Set("X-User-ID", "200")
+ router.ServeHTTP(deleteByOtherRec, deleteByOtherReq)
+ require.Equal(t, http.StatusForbidden, deleteByOtherRec.Code)
+
+ deleteByOwnerRec := httptest.NewRecorder()
+ deleteByOwnerReq := httptest.NewRequest(
+ http.MethodDelete,
+ "/api/v1/admin/settings/bulk-edit-templates/"+templateID,
+ nil,
+ )
+ deleteByOwnerReq.Header.Set("X-User-ID", "100")
+ router.ServeHTTP(deleteByOwnerRec, deleteByOwnerReq)
+ require.Equal(t, http.StatusOK, deleteByOwnerRec.Code)
+}
+
+func TestSettingHandlerBulkEditTemplate_GroupsVisibilityByScopeGroupIDs(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ createBody := map[string]any{
+ "name": "Group Shared",
+ "scope_platform": "openai",
+ "scope_type": "oauth",
+ "share_scope": "groups",
+ "group_ids": []int64{3, 8},
+ "state": map[string]any{"enableOpenAIWSMode": true},
+ }
+ raw, err := json.Marshal(createBody)
+ require.NoError(t, err)
+
+ createRec := httptest.NewRecorder()
+ createReq := httptest.NewRequest(http.MethodPost, "/api/v1/admin/settings/bulk-edit-templates", bytes.NewReader(raw))
+ createReq.Header.Set("Content-Type", "application/json")
+ createReq.Header.Set("X-User-ID", "1")
+ router.ServeHTTP(createRec, createReq)
+ require.Equal(t, http.StatusOK, createRec.Code)
+
+ invisibleRec := httptest.NewRecorder()
+ invisibleReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth&scope_group_ids=9",
+ nil,
+ )
+ invisibleReq.Header.Set("X-User-ID", "2")
+ router.ServeHTTP(invisibleRec, invisibleReq)
+ require.Equal(t, http.StatusOK, invisibleRec.Code)
+ invisibleData := decodeResponseDataMap(t, invisibleRec.Body.Bytes())
+ invisibleItems := invisibleData["items"].([]any)
+ require.Len(t, invisibleItems, 0)
+
+ visibleRec := httptest.NewRecorder()
+ visibleReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates?scope_platform=openai&scope_type=oauth&scope_group_ids=8",
+ nil,
+ )
+ visibleReq.Header.Set("X-User-ID", "2")
+ router.ServeHTTP(visibleRec, visibleReq)
+ require.Equal(t, http.StatusOK, visibleRec.Code)
+ visibleData := decodeResponseDataMap(t, visibleRec.Body.Bytes())
+ visibleItems := visibleData["items"].([]any)
+ require.Len(t, visibleItems, 1)
+}
+
+func TestSettingHandlerBulkEditTemplate_UnauthorizedAndInvalidRequests(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ repo := newSettingHandlerTemplateRepoStub()
+ settingService := service.NewSettingService(repo, nil)
+ handler := NewSettingHandler(settingService, nil, nil, nil, nil)
+
+ router := gin.New()
+ router.GET("/list", handler.ListBulkEditTemplates)
+ router.GET("/versions/:template_id", handler.ListBulkEditTemplateVersions)
+ router.POST("/rollback/:template_id", handler.RollbackBulkEditTemplate)
+ router.POST("/upsert", handler.UpsertBulkEditTemplate)
+ router.DELETE("/delete/:template_id", handler.DeleteBulkEditTemplate)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/list", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/upsert", bytes.NewBufferString("{bad-json"))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodDelete, "/delete/%20", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/versions/abc", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/rollback/abc", bytes.NewBufferString(`{"version_id":"v1"}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusUnauthorized, rec.Code)
+}
+
+func TestParseScopeGroupIDs(t *testing.T) {
+ ids, err := parseScopeGroupIDs("")
+ require.NoError(t, err)
+ require.Nil(t, ids)
+
+ ids, err = parseScopeGroupIDs("1, 2,2,3")
+ require.NoError(t, err)
+ require.Equal(t, []int64{1, 2, 3}, ids)
+
+ _, err = parseScopeGroupIDs("x,2")
+ require.Error(t, err)
+}
+
+func TestSettingHandlerBulkEditTemplate_BindErrorAndMissingTemplateID(t *testing.T) {
+ router := setupBulkEditTemplateRouter()
+
+ bindErrRec := httptest.NewRecorder()
+ bindErrReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates",
+ bytes.NewBufferString("{bad-json"),
+ )
+ bindErrReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(bindErrRec, bindErrReq)
+ require.Equal(t, http.StatusBadRequest, bindErrRec.Code)
+
+ missingIDRec := httptest.NewRecorder()
+ missingIDReq := httptest.NewRequest(
+ http.MethodDelete,
+ "/api/v1/admin/settings/bulk-edit-templates/%20",
+ nil,
+ )
+ router.ServeHTTP(missingIDRec, missingIDReq)
+ require.Equal(t, http.StatusBadRequest, missingIDRec.Code)
+
+ invalidScopeRec := httptest.NewRecorder()
+ invalidScopeReq := httptest.NewRequest(
+ http.MethodGet,
+ "/api/v1/admin/settings/bulk-edit-templates/abc/versions?scope_group_ids=bad",
+ nil,
+ )
+ router.ServeHTTP(invalidScopeRec, invalidScopeReq)
+ require.Equal(t, http.StatusBadRequest, invalidScopeRec.Code)
+
+ rollbackMissingIDRec := httptest.NewRecorder()
+ rollbackMissingIDReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates/%20/rollback",
+ bytes.NewBufferString(`{"version_id":"v1"}`),
+ )
+ rollbackMissingIDReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rollbackMissingIDRec, rollbackMissingIDReq)
+ require.Equal(t, http.StatusBadRequest, rollbackMissingIDRec.Code)
+
+ rollbackInvalidScopeRec := httptest.NewRecorder()
+ rollbackInvalidScopeReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates/abc/rollback?scope_group_ids=bad",
+ bytes.NewBufferString(`{"version_id":"v1"}`),
+ )
+ rollbackInvalidScopeReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rollbackInvalidScopeRec, rollbackInvalidScopeReq)
+ require.Equal(t, http.StatusBadRequest, rollbackInvalidScopeRec.Code)
+
+ rollbackBindErrRec := httptest.NewRecorder()
+ rollbackBindErrReq := httptest.NewRequest(
+ http.MethodPost,
+ "/api/v1/admin/settings/bulk-edit-templates/abc/rollback",
+ bytes.NewBufferString("{bad-json"),
+ )
+ rollbackBindErrReq.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rollbackBindErrRec, rollbackBindErrReq)
+ require.Equal(t, http.StatusBadRequest, rollbackBindErrRec.Code)
+}
+
+func TestSettingHandlerBulkEditTemplate_ListErrorFromService(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ settingService := service.NewSettingService(&failingSettingRepoStub{}, nil)
+ handler := NewSettingHandler(settingService, nil, nil, nil, nil)
+ router := gin.New()
+ router.Use(func(c *gin.Context) {
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: 1})
+ c.Next()
+ })
+ router.GET("/list", handler.ListBulkEditTemplates)
+ router.GET("/versions/:template_id", handler.ListBulkEditTemplateVersions)
+ router.POST("/rollback/:template_id", handler.RollbackBulkEditTemplate)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodGet, "/list", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodGet, "/versions/tpl-1", nil)
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+
+ rec = httptest.NewRecorder()
+ req = httptest.NewRequest(http.MethodPost, "/rollback/tpl-1", bytes.NewBufferString(`{"version_id":"v1"}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go
index 2bd59f322..64e8ad549 100644
--- a/backend/internal/handler/gateway_handler.go
+++ b/backend/internal/handler/gateway_handler.go
@@ -1411,23 +1411,10 @@ func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger
}
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
- if task == nil {
- return
- }
- if h.usageRecordWorkerPool != nil {
- h.usageRecordWorkerPool.Submit(task)
- return
- }
- // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- defer func() {
- if recovered := recover(); recovered != nil {
- logger.L().With(
- zap.String("component", "handler.gateway.messages"),
- zap.Any("panic", recovered),
- ).Error("gateway.usage_record_task_panic_recovered")
- }
- }()
- task(ctx)
+ submitUsageRecordTaskWithFallback(
+ "handler.gateway.messages",
+ h.usageRecordWorkerPool,
+ h.cfg,
+ task,
+ )
}
diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go
index 09e6c09ba..2715825c2 100644
--- a/backend/internal/handler/gateway_helper.go
+++ b/backend/internal/handler/gateway_helper.go
@@ -29,8 +29,6 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.
if parsedReq != nil {
c.Set(claudeCodeParsedRequestContextKey, parsedReq)
}
-
- ua := c.GetHeader("User-Agent")
// Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。
if !claudeCodeValidator.ValidateUserAgent(ua) {
ctx := service.SetClaudeCodeClient(c.Request.Context(), false)
diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go
index 4bbd17bae..9cc39bb91 100644
--- a/backend/internal/handler/openai_gateway_handler.go
+++ b/backend/internal/handler/openai_gateway_handler.go
@@ -9,6 +9,7 @@ import (
"runtime/debug"
"strconv"
"strings"
+ "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -32,6 +33,7 @@ type OpenAIGatewayHandler struct {
usageRecordWorkerPool *service.UsageRecordWorkerPool
errorPassthroughService *service.ErrorPassthroughService
concurrencyHelper *ConcurrencyHelper
+ cfg *config.Config
maxAccountSwitches int
}
@@ -60,6 +62,7 @@ func NewOpenAIGatewayHandler(
usageRecordWorkerPool: usageRecordWorkerPool,
errorPassthroughService: errorPassthroughService,
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
+ cfg: cfg,
maxAccountSwitches: maxAccountSwitches,
}
}
@@ -113,30 +116,29 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
- setOpsRequestContext(c, "", false, body)
-
- // 校验请求体 JSON 合法性
+ // 校验请求体 JSON 合法性,避免畸形 JSON 被 gjson 部分解析后继续下游处理。
if !gjson.ValidBytes(body) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
- // 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
- modelResult := gjson.GetBytes(body, "model")
+ // 使用 GetManyBytes 一次扫描提取所有字段,避免多次遍历大请求体
+ results := gjson.GetManyBytes(body, "model", "stream", "previous_response_id")
+ modelResult := results[0]
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
return
}
reqModel := modelResult.String()
- streamResult := gjson.GetBytes(body, "stream")
+ streamResult := results[1]
if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type")
return
}
reqStream := streamResult.Bool()
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
- previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
+ previousResponseID := strings.TrimSpace(results[2].String())
if previousResponseID != "" {
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
reqLog = reqLog.With(
@@ -155,6 +157,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
setOpsRequestContext(c, reqModel, reqStream, body)
+ // 缓存已提取的 meta 到 context,供 Service 层复用,避免重复解析请求体。
+ // prompt_cache_key 也在此处提前提取,确保 meta 设置后只读不写,避免并发竞态。
+ promptCacheKey := strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
+ c.Set(service.OpenAIRequestMetaKey, &service.OpenAIRequestMeta{
+ Model: reqModel,
+ Stream: reqStream,
+ PromptCacheKey: promptCacheKey,
+ })
+
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
return
@@ -269,7 +280,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0)
h.gatewayService.RecordOpenAIAccountSwitch()
failedAccountIDs[account.ID] = struct{}{}
lastFailoverErr = failoverErr
@@ -286,7 +297,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
)
continue
}
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0)
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
fields := []zap.Field{
zap.Int64("account_id", account.ID),
@@ -301,26 +312,32 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
return
}
if result != nil {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
+ var ttftMsVal float64
+ if result.FirstTokenMs != nil && *result.FirstTokenMs > 0 {
+ ttftMsVal = float64(*result.FirstTokenMs)
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs, reqModel, ttftMsVal)
} else {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil, reqModel, 0)
}
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
userAgent := c.GetHeader("User-Agent")
clientIP := ip.GetClientIP(c)
+ requestID := strings.TrimSpace(c.GetHeader("X-Request-ID"))
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
h.submitUsageRecordTask(func(ctx context.Context) {
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
- Result: result,
- APIKey: apiKey,
- User: apiKey.User,
- Account: account,
- Subscription: subscription,
- UserAgent: userAgent,
- IPAddress: clientIP,
- APIKeyService: h.apiKeyService,
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ FallbackRequestID: requestID,
+ APIKeyService: h.apiKeyService,
}); err != nil {
logger.L().With(
zap.String("component", "handler.openai_gateway.responses"),
@@ -550,9 +567,10 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
reqLog.Info("openai.websocket_ingress_started")
clientIP := ip.GetClientIP(c)
userAgent := strings.TrimSpace(c.GetHeader("User-Agent"))
+ requestID := strings.TrimSpace(c.GetHeader("X-Request-ID"))
wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
+ CompressionMode: coderws.CompressionNoContextTakeover,
})
if err != nil {
reqLog.Warn("openai.websocket_accept_failed",
@@ -674,6 +692,11 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
}
account := selection.Account
+ wsIngressMode, wsModeRouterV2Enabled := h.resolveOpenAIWSIngressMode(account)
+ reqLog = reqLog.With(
+ zap.Bool("openai_ws_mode_router_v2_enabled", wsModeRouterV2Enabled),
+ zap.String("openai_ws_ingress_mode", wsIngressMode),
+ )
accountMaxConcurrency := account.Concurrency
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
@@ -717,13 +740,19 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
zap.String("account_name", account.Name),
zap.String("schedule_layer", scheduleDecision.Layer),
zap.Int("candidate_count", scheduleDecision.CandidateCount),
+ zap.Bool("openai_ws_mode_router_v2_enabled", wsModeRouterV2Enabled),
+ zap.String("openai_ws_ingress_mode", wsIngressMode),
)
+ var turnScheduleReported atomic.Bool
hooks := &service.OpenAIWSIngressHooks{
BeforeTurn: func(turn int) error {
if turn == 1 {
return nil
}
+ if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
+ return service.NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "billing check failed", err)
+ }
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
releaseTurnSlots()
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
@@ -753,20 +782,54 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
},
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
releaseTurnSlots()
- if turnErr != nil || result == nil {
+ if turnErr != nil {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0)
+ turnScheduleReported.Store(true)
+ if partialResult, ok := service.OpenAIWSIngressTurnPartialResult(turnErr); ok && partialResult != nil {
+ h.submitUsageRecordTask(func(taskCtx context.Context) {
+ if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
+ Result: partialResult,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ FallbackRequestID: requestID,
+ APIKeyService: h.apiKeyService,
+ }); err != nil {
+ reqLog.Error("openai.websocket_record_partial_usage_failed",
+ zap.Int64("account_id", account.ID),
+ zap.String("request_id", partialResult.RequestID),
+ zap.Error(err),
+ )
+ }
+ })
+ }
+ return
+ }
+ if result == nil {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0)
+ turnScheduleReported.Store(true)
return
}
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
+ var turnTTFTMs float64
+ if result.FirstTokenMs != nil && *result.FirstTokenMs > 0 {
+ turnTTFTMs = float64(*result.FirstTokenMs)
+ }
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs, reqModel, turnTTFTMs)
+ turnScheduleReported.Store(true)
h.submitUsageRecordTask(func(taskCtx context.Context) {
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
- Result: result,
- APIKey: apiKey,
- User: apiKey.User,
- Account: account,
- Subscription: subscription,
- UserAgent: userAgent,
- IPAddress: clientIP,
- APIKeyService: h.apiKeyService,
+ Result: result,
+ APIKey: apiKey,
+ User: apiKey.User,
+ Account: account,
+ Subscription: subscription,
+ UserAgent: userAgent,
+ IPAddress: clientIP,
+ FallbackRequestID: requestID,
+ APIKeyService: h.apiKeyService,
}); err != nil {
reqLog.Error("openai.websocket_record_usage_failed",
zap.Int64("account_id", account.ID),
@@ -779,7 +842,9 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
}
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
- h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
+ if !turnScheduleReported.Load() {
+ h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil, reqModel, 0)
+ }
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
reqLog.Warn("openai.websocket_proxy_failed",
zap.Int64("account_id", account.ID),
@@ -882,25 +947,12 @@ func getContextInt64(c *gin.Context, key string) (int64, bool) {
}
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
- if task == nil {
- return
- }
- if h.usageRecordWorkerPool != nil {
- h.usageRecordWorkerPool.Submit(task)
- return
- }
- // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- defer func() {
- if recovered := recover(); recovered != nil {
- logger.L().With(
- zap.String("component", "handler.openai_gateway.responses"),
- zap.Any("panic", recovered),
- ).Error("openai.usage_record_task_panic_recovered")
- }
- }()
- task(ctx)
+ submitUsageRecordTaskWithFallback(
+ "handler.openai_gateway.responses",
+ h.usageRecordWorkerPool,
+ h.cfg,
+ task,
+ )
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
@@ -1030,6 +1082,24 @@ func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64)
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
}
+func (h *OpenAIGatewayHandler) resolveOpenAIWSIngressMode(account *service.Account) (mode string, modeRouterV2Enabled bool) {
+ if account == nil {
+ return "account_missing", false
+ }
+ if h == nil || h.cfg == nil {
+ return "config_missing", false
+ }
+ modeRouterV2Enabled = h.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
+ if !modeRouterV2Enabled {
+ return "legacy", false
+ }
+ resolvedMode := account.ResolveOpenAIResponsesWebSocketV2Mode(h.cfg.Gateway.OpenAIWS.IngressModeDefault)
+ if resolvedMode == "" {
+ resolvedMode = service.OpenAIWSIngressModeOff
+ }
+ return resolvedMode, true
+}
+
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
if r == nil {
return false
diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go
index a26b3a0c3..043ea3d60 100644
--- a/backend/internal/handler/openai_gateway_handler_test.go
+++ b/backend/internal/handler/openai_gateway_handler_test.go
@@ -431,6 +431,34 @@ func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
require.Contains(t, w.Body.String(), "previous_response_id must be a response.id")
}
+func TestOpenAIResponses_InvalidJSONBodyReturnsBadRequest(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ w := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(w)
+ c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(
+ `{"model":"gpt-5.1","stream":false,invalid}`,
+ ))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ groupID := int64(2)
+ c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{
+ ID: 201,
+ GroupID: &groupID,
+ User: &service.User{ID: 1},
+ })
+ c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
+ UserID: 1,
+ Concurrency: 1,
+ })
+
+ h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil)
+ h.Responses(c)
+
+ require.Equal(t, http.StatusBadRequest, w.Code)
+ require.Contains(t, w.Body.String(), "Failed to parse request body")
+}
+
func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) {
gin.SetMode(gin.TestMode)
diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go
index 679dd4cef..731b36ab9 100644
--- a/backend/internal/handler/ops_error_logger_test.go
+++ b/backend/internal/handler/ops_error_logger_test.go
@@ -214,63 +214,3 @@ func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) {
})
require.Equal(t, http.StatusNoContent, rec.Code)
}
-
-func TestIsKnownOpsErrorType(t *testing.T) {
- known := []string{
- "invalid_request_error",
- "authentication_error",
- "rate_limit_error",
- "billing_error",
- "subscription_error",
- "upstream_error",
- "overloaded_error",
- "api_error",
- "not_found_error",
- "forbidden_error",
- }
- for _, k := range known {
- require.True(t, isKnownOpsErrorType(k), "expected known: %s", k)
- }
-
- unknown := []string{"", "null", "", "random_error", "some_new_type", "\u003e"}
- for _, u := range unknown {
- require.False(t, isKnownOpsErrorType(u), "expected unknown: %q", u)
- }
-}
-
-func TestNormalizeOpsErrorType(t *testing.T) {
- tests := []struct {
- name string
- errType string
- code string
- want string
- }{
- // Known types pass through.
- {"known invalid_request_error", "invalid_request_error", "", "invalid_request_error"},
- {"known rate_limit_error", "rate_limit_error", "", "rate_limit_error"},
- {"known upstream_error", "upstream_error", "", "upstream_error"},
-
- // Unknown/garbage types are rejected and fall through to code-based or default.
- {"nil literal from upstream", "", "", "api_error"},
- {"null string", "null", "", "api_error"},
- {"random string", "something_weird", "", "api_error"},
-
- // Unknown type but known code still maps correctly.
- {"nil with INSUFFICIENT_BALANCE code", "", "INSUFFICIENT_BALANCE", "billing_error"},
- {"nil with USAGE_LIMIT_EXCEEDED code", "", "USAGE_LIMIT_EXCEEDED", "subscription_error"},
-
- // Empty type falls through to code-based mapping.
- {"empty type with balance code", "", "INSUFFICIENT_BALANCE", "billing_error"},
- {"empty type with subscription code", "", "SUBSCRIPTION_NOT_FOUND", "subscription_error"},
- {"empty type no code", "", "", "api_error"},
-
- // Known type overrides conflicting code-based mapping.
- {"known type overrides conflicting code", "rate_limit_error", "INSUFFICIENT_BALANCE", "rate_limit_error"},
- }
- for _, tt := range tests {
- t.Run(tt.name, func(t *testing.T) {
- got := normalizeOpsErrorType(tt.errType, tt.code)
- require.Equal(t, tt.want, got)
- })
- }
-}
diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go
index 2141a9ee5..1344747f7 100644
--- a/backend/internal/handler/setting_handler.go
+++ b/backend/internal/handler/setting_handler.go
@@ -11,13 +11,15 @@ import (
// SettingHandler 公开设置处理器(无需认证)
type SettingHandler struct {
settingService *service.SettingService
+ accountRepo service.AccountRepository
version string
}
// NewSettingHandler 创建公开设置处理器
-func NewSettingHandler(settingService *service.SettingService, version string) *SettingHandler {
+func NewSettingHandler(settingService *service.SettingService, accountRepo service.AccountRepository, version string) *SettingHandler {
return &SettingHandler{
settingService: settingService,
+ accountRepo: accountRepo,
version: version,
}
}
diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go
index 5df7fa0a5..523b016c7 100644
--- a/backend/internal/handler/sora_client_handler_test.go
+++ b/backend/internal/handler/sora_client_handler_test.go
@@ -945,9 +945,6 @@ func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, i
func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
-func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
- return nil
-}
// ==================== NewSoraClientHandler ====================
@@ -2050,13 +2047,19 @@ func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.
return nil, nil, nil
}
func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) {
- return nil, nil
+ return r.accounts, nil
}
func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) {
return nil, nil
}
-func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) {
- return nil, nil
+func (r *stubAccountRepoForHandler) ListByPlatform(_ context.Context, platform string) ([]service.Account, error) {
+ filtered := make([]service.Account, 0, len(r.accounts))
+ for _, account := range r.accounts {
+ if account.Platform == platform {
+ filtered = append(filtered, account)
+ }
+ }
+ return filtered, nil
}
func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil }
func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error {
@@ -2184,7 +2187,7 @@ func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Ac
func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService {
return service.NewGatewayService(
accountRepo, nil, nil, nil, nil, nil, nil, nil,
- nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
+ nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil,
)
}
@@ -2218,7 +2221,6 @@ func TestProcessGeneration_SelectAccountError(t *testing.T) {
}
func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
@@ -2238,7 +2240,6 @@ func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) {
}
func TestProcessGeneration_ForwardError(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
@@ -2297,7 +2298,6 @@ func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) {
}
func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
@@ -2359,7 +2359,6 @@ func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) {
}
func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
genService := service.NewSoraGenerationService(repo, nil, nil)
@@ -2390,7 +2389,6 @@ func TestProcessGeneration_FullSuccessUpstream(t *testing.T) {
}
func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
sourceServer := newFakeSourceServer()
defer sourceServer.Close()
fakeS3 := newFakeS3Server("ok")
@@ -2438,7 +2436,6 @@ func TestProcessGeneration_FullSuccessWithS3(t *testing.T) {
}
func TestProcessGeneration_MarkCompletedFails(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复")
repo := newStubSoraGenRepo()
repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"}
// 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败
@@ -2606,7 +2603,6 @@ func TestDeleteGeneration_DeleteError(t *testing.T) {
// ==================== fetchUpstreamModels 测试 ====================
func TestFetchUpstreamModels_NilGateway(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
h := &SoraClientHandler{}
_, err := h.fetchUpstreamModels(context.Background())
require.Error(t, err)
@@ -2614,7 +2610,6 @@ func TestFetchUpstreamModels_NilGateway(t *testing.T) {
}
func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{accounts: nil}
gatewayService := newMinimalGatewayService(accountRepo)
h := &SoraClientHandler{gatewayService: gatewayService}
@@ -2624,7 +2619,6 @@ func TestFetchUpstreamModels_NoAccounts(t *testing.T) {
}
func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true},
@@ -2638,7 +2632,6 @@ func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) {
}
func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
accountRepo := &stubAccountRepoForHandler{
accounts: []service.Account{
{ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true,
@@ -2653,7 +2646,6 @@ func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) {
}
func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
// GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com"
// 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败
accountRepo := &stubAccountRepoForHandler{
@@ -2669,7 +2661,6 @@ func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) {
}
func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
@@ -2689,7 +2680,6 @@ func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) {
}
func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("not json"))
@@ -2710,7 +2700,6 @@ func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) {
}
func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[]}`))
@@ -2731,7 +2720,6 @@ func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) {
}
func TestFetchUpstreamModels_Success(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 验证请求头
require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization"))
@@ -2755,7 +2743,6 @@ func TestFetchUpstreamModels_Success(t *testing.T) {
}
func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) {
- t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`))
@@ -2790,7 +2777,6 @@ func TestGetModelFamilies_CachesLocalConfig(t *testing.T) {
}
func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) {
- t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`))
diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go
index 48c1e451b..a0045aa53 100644
--- a/backend/internal/handler/sora_gateway_handler.go
+++ b/backend/internal/handler/sora_gateway_handler.go
@@ -41,6 +41,7 @@ type SoraGatewayHandler struct {
soraTLSEnabled bool
soraMediaSigningKey string
soraMediaRoot string
+ cfg *config.Config
}
// NewSoraGatewayHandler creates a new SoraGatewayHandler
@@ -83,6 +84,7 @@ func NewSoraGatewayHandler(
soraTLSEnabled: soraTLSEnabled,
soraMediaSigningKey: signKey,
soraMediaRoot: mediaRoot,
+ cfg: cfg,
}
}
@@ -451,25 +453,12 @@ func generateOpenAISessionHash(c *gin.Context, body []byte) string {
}
func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
- if task == nil {
- return
- }
- if h.usageRecordWorkerPool != nil {
- h.usageRecordWorkerPool.Submit(task)
- return
- }
- // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
- ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
- defer cancel()
- defer func() {
- if recovered := recover(); recovered != nil {
- logger.L().With(
- zap.String("component", "handler.sora_gateway.chat_completions"),
- zap.Any("panic", recovered),
- ).Error("sora.usage_record_task_panic_recovered")
- }
- }()
- task(ctx)
+ submitUsageRecordTaskWithFallback(
+ "handler.sora_gateway.chat_completions",
+ h.usageRecordWorkerPool,
+ h.cfg,
+ task,
+ )
}
func (h *SoraGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go
index 355cdb7ac..68a040847 100644
--- a/backend/internal/handler/sora_gateway_handler_test.go
+++ b/backend/internal/handler/sora_gateway_handler_test.go
@@ -320,9 +320,6 @@ func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTi
func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
return nil, nil
}
-func (s *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
- return nil, nil
-}
func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, nil
}
diff --git a/backend/internal/handler/usage_record_submit_helper.go b/backend/internal/handler/usage_record_submit_helper.go
new file mode 100644
index 000000000..3110c001a
--- /dev/null
+++ b/backend/internal/handler/usage_record_submit_helper.go
@@ -0,0 +1,61 @@
+package handler
+
+import (
+ "context"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "go.uber.org/zap"
+)
+
+func submitUsageRecordTaskWithFallback(
+ component string,
+ pool *service.UsageRecordWorkerPool,
+ cfg *config.Config,
+ task service.UsageRecordTask,
+) {
+ if task == nil {
+ return
+ }
+ if pool != nil {
+ mode := pool.Submit(task)
+ if mode != service.UsageRecordSubmitModeDropped {
+ return
+ }
+ // 队列溢出导致 submit 丢弃时,同步兜底执行,避免 usage 漏记费。
+ logger.L().With(
+ zap.String("component", component),
+ zap.String("submit_mode", mode.String()),
+ ).Warn("usage_record.task_submit_dropped_sync_fallback")
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), usageRecordSyncFallbackTimeout(cfg))
+ defer cancel()
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ logger.L().With(
+ zap.String("component", component),
+ zap.Any("panic", recovered),
+ ).Error("usage_record.task_panic_recovered")
+ }
+ }()
+ task(ctx)
+}
+
+func usageRecordSyncFallbackTimeout(cfg *config.Config) time.Duration {
+ timeout := 10 * time.Second
+ if cfg != nil && cfg.Gateway.UsageRecord.TaskTimeoutSeconds > 0 {
+ timeout = time.Duration(cfg.Gateway.UsageRecord.TaskTimeoutSeconds) * time.Second
+ }
+ // keep a sane bound on synchronous fallback to limit request-path blocking.
+ if timeout < time.Second {
+ return time.Second
+ }
+ if timeout > 10*time.Second {
+ return 10 * time.Second
+ }
+ return timeout
+}
+
diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go
index c7c48e14b..20e8e87c3 100644
--- a/backend/internal/handler/usage_record_submit_task_test.go
+++ b/backend/internal/handler/usage_record_submit_task_test.go
@@ -6,6 +6,7 @@ import (
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
@@ -54,6 +55,22 @@ func TestGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.
require.True(t, called.Load())
}
+func TestGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) {
+ pool := newUsageRecordTestPool(t)
+ pool.Stop()
+ h := &GatewayHandler{usageRecordWorkerPool: pool}
+ var called atomic.Bool
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ if _, ok := ctx.Deadline(); !ok {
+ t.Fatal("expected deadline in dropped sync fallback context")
+ }
+ called.Store(true)
+ })
+
+ require.True(t, called.Load(), "dropped task should run via sync fallback")
+}
+
func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &GatewayHandler{}
require.NotPanics(t, func() {
@@ -93,6 +110,40 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) {
}
}
+func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) {
+ pool := newUsageRecordTestPool(t)
+ pool.Stop()
+ h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool}
+ var called atomic.Bool
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ if _, ok := ctx.Deadline(); !ok {
+ t.Fatal("expected deadline in dropped sync fallback context")
+ }
+ called.Store(true)
+ })
+
+ require.True(t, called.Load(), "dropped task should run via sync fallback")
+}
+
+func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithConfigFallbackTimeout(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.UsageRecord.TaskTimeoutSeconds = 2
+ h := &OpenAIGatewayHandler{cfg: cfg}
+ var called atomic.Bool
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ deadline, ok := ctx.Deadline()
+ require.True(t, ok, "expected deadline in fallback context")
+ remaining := time.Until(deadline)
+ require.LessOrEqual(t, remaining, 2200*time.Millisecond)
+ require.GreaterOrEqual(t, remaining, 1200*time.Millisecond)
+ called.Store(true)
+ })
+
+ require.True(t, called.Load())
+}
+
func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *testing.T) {
h := &OpenAIGatewayHandler{}
var called atomic.Bool
@@ -160,6 +211,22 @@ func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPoolSyncFallback(t *test
require.True(t, called.Load())
}
+func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPoolDroppedSyncFallback(t *testing.T) {
+ pool := newUsageRecordTestPool(t)
+ pool.Stop()
+ h := &SoraGatewayHandler{usageRecordWorkerPool: pool}
+ var called atomic.Bool
+
+ h.submitUsageRecordTask(func(ctx context.Context) {
+ if _, ok := ctx.Deadline(); !ok {
+ t.Fatal("expected deadline in dropped sync fallback context")
+ }
+ called.Store(true)
+ })
+
+ require.True(t, called.Load(), "dropped task should run via sync fallback")
+}
+
func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) {
h := &SoraGatewayHandler{}
require.NotPanics(t, func() {
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index 76f5a9796..572260571 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -62,8 +62,8 @@ func ProvideSystemHandler(updateService *service.UpdateService, lockService *ser
}
// ProvideSettingHandler creates SettingHandler with version from BuildInfo
-func ProvideSettingHandler(settingService *service.SettingService, buildInfo BuildInfo) *SettingHandler {
- return NewSettingHandler(settingService, buildInfo.Version)
+func ProvideSettingHandler(settingService *service.SettingService, accountRepo service.AccountRepository, buildInfo BuildInfo) *SettingHandler {
+ return NewSettingHandler(settingService, accountRepo, buildInfo.Version)
}
// ProvideHandlers creates the Handlers struct
diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go
index 0519c2cc1..f09bee8dd 100644
--- a/backend/internal/pkg/response/response.go
+++ b/backend/internal/pkg/response/response.go
@@ -2,6 +2,8 @@
package response
import (
+ "context"
+ "errors"
"log"
"math"
"net/http"
@@ -75,17 +77,45 @@ func ErrorFrom(c *gin.Context, err error) bool {
return false
}
- statusCode, status := infraerrors.ToHTTP(err)
+ normalizedErr := normalizeHTTPError(c, err)
+ statusCode, status := infraerrors.ToHTTP(normalizedErr)
// Log internal errors with full details for debugging
if statusCode >= 500 && c.Request != nil {
- log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error()))
+ log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(normalizedErr.Error()))
}
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
return true
}
+func normalizeHTTPError(c *gin.Context, err error) error {
+ if err == nil {
+ return nil
+ }
+ if c == nil || c.Request == nil {
+ return err
+ }
+ if isClientCanceledError(c.Request.Context(), err) {
+ return infraerrors.ClientClosed("CLIENT_CLOSED", "client closed request").WithCause(err)
+ }
+ return err
+}
+
+func isClientCanceledError(reqCtx context.Context, err error) bool {
+ if reqCtx == nil {
+ return false
+ }
+ // 只有请求上下文本身被取消时,才认为是客户端断开;
+ // 避免将服务端主动 cancel 导致的 context.Canceled 误归为 499。
+ if errors.Is(err, context.Canceled) && errors.Is(reqCtx.Err(), context.Canceled) {
+ return true
+ }
+
+ // Some drivers can surface deadline errors after the request context was already canceled.
+ return errors.Is(err, context.DeadlineExceeded) && errors.Is(reqCtx.Err(), context.Canceled)
+}
+
// BadRequest 返回400错误
func BadRequest(c *gin.Context, message string) {
Error(c, http.StatusBadRequest, message)
diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go
index 0debce5fd..64918ca8b 100644
--- a/backend/internal/pkg/response/response_test.go
+++ b/backend/internal/pkg/response/response_test.go
@@ -3,8 +3,10 @@
package response
import (
+ "context"
"encoding/json"
"errors"
+ "fmt"
"net/http"
"net/http/httptest"
"testing"
@@ -107,11 +109,12 @@ func TestErrorFrom(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
- name string
- err error
- wantWritten bool
- wantHTTPCode int
- wantBody Response
+ name string
+ err error
+ cancelRequestContext bool
+ wantWritten bool
+ wantHTTPCode int
+ wantBody Response
}{
{
name: "nil_error",
@@ -184,12 +187,75 @@ func TestErrorFrom(t *testing.T) {
Message: errors2.UnknownMessage,
},
},
+ {
+ name: "context_canceled_without_request_cancel_remains_500",
+ err: context.Canceled,
+ wantWritten: true,
+ wantHTTPCode: http.StatusInternalServerError,
+ wantBody: Response{
+ Code: http.StatusInternalServerError,
+ Message: errors2.UnknownMessage,
+ },
+ },
+ {
+ name: "context_canceled_maps_to_499",
+ err: context.Canceled,
+ cancelRequestContext: true,
+ wantWritten: true,
+ wantHTTPCode: 499,
+ wantBody: Response{
+ Code: 499,
+ Message: "client closed request",
+ Reason: "CLIENT_CLOSED",
+ },
+ },
+ {
+ name: "wrapped_context_canceled_maps_to_499",
+ err: fmt.Errorf("query aborted: %w", context.Canceled),
+ cancelRequestContext: true,
+ wantWritten: true,
+ wantHTTPCode: 499,
+ wantBody: Response{
+ Code: 499,
+ Message: "client closed request",
+ Reason: "CLIENT_CLOSED",
+ },
+ },
+ {
+ name: "deadline_exceeded_without_request_cancel_remains_500",
+ err: context.DeadlineExceeded,
+ wantWritten: true,
+ wantHTTPCode: http.StatusInternalServerError,
+ wantBody: Response{
+ Code: http.StatusInternalServerError,
+ Message: errors2.UnknownMessage,
+ },
+ },
+ {
+ name: "deadline_exceeded_with_request_canceled_maps_to_499",
+ err: context.DeadlineExceeded,
+ cancelRequestContext: true,
+ wantWritten: true,
+ wantHTTPCode: 499,
+ wantBody: Response{
+ Code: 499,
+ Message: "client closed request",
+ Reason: "CLIENT_CLOSED",
+ },
+ },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
+ req := httptest.NewRequest(http.MethodGet, "/test", nil)
+ if tt.cancelRequestContext {
+ ctx, cancel := context.WithCancel(req.Context())
+ cancel()
+ req = req.WithContext(ctx)
+ }
+ c.Request = req
written := ErrorFrom(c, tt.err)
require.Equal(t, tt.wantWritten, written)
diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go
index 4aa749284..13ff57769 100644
--- a/backend/internal/repository/account_repo.go
+++ b/backend/internal/repository/account_repo.go
@@ -1170,6 +1170,11 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates
args = append(args, *updates.Schedulable)
idx++
}
+ if updates.AutoPauseOnExpired != nil {
+ setClauses = append(setClauses, "auto_pause_on_expired = $"+itoa(idx))
+ args = append(args, *updates.AutoPauseOnExpired)
+ idx++
+ }
// JSONB 需要合并而非覆盖,使用 raw SQL 保持旧行为。
if len(updates.Credentials) > 0 {
payload, err := json.Marshal(updates.Credentials)
diff --git a/backend/internal/repository/billing_cache.go b/backend/internal/repository/billing_cache.go
index e753e1b86..baaaad502 100644
--- a/backend/internal/repository/billing_cache.go
+++ b/backend/internal/repository/billing_cache.go
@@ -53,9 +53,20 @@ var (
deductBalanceScript = redis.NewScript(`
local current = redis.call('GET', KEYS[1])
if current == false then
+ return 2
+ end
+ local cur = tonumber(current)
+ local delta = tonumber(ARGV[1])
+ if cur == nil or delta == nil then
+ return -1
+ end
+ if delta < 0 then
+ return -2
+ end
+ if cur < delta then
return 0
end
- local newVal = tonumber(current) - tonumber(ARGV[1])
+ local newVal = cur - delta
redis.call('SET', KEYS[1], newVal)
redis.call('EXPIRE', KEYS[1], ARGV[2])
return 1
@@ -99,12 +110,26 @@ func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
key := billingBalanceKey(userID)
- _, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
+ result, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Int64()
if err != nil && !errors.Is(err, redis.Nil) {
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
return err
}
- return nil
+ switch result {
+ case 1:
+ return nil
+ case 2:
+ // 缓存 key 不存在(已过期),返回特定错误让调用方区分处理
+ return service.ErrBalanceCacheNotFound
+ case 0:
+ return service.ErrInsufficientBalance
+ case -1:
+ return fmt.Errorf("invalid cached balance for user %d", userID)
+ case -2:
+ return fmt.Errorf("invalid deduct amount for user %d", userID)
+ default:
+ return fmt.Errorf("unexpected deduct balance cache result for user %d: %d", userID, result)
+ }
}
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
diff --git a/backend/internal/repository/billing_cache_integration_test.go b/backend/internal/repository/billing_cache_integration_test.go
index 4b7377b12..6a5983af7 100644
--- a/backend/internal/repository/billing_cache_integration_test.go
+++ b/backend/internal/repository/billing_cache_integration_test.go
@@ -278,8 +278,8 @@ func (s *BillingCacheSuite) TestSubscriptionCache() {
}
}
-// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
-// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
+// TestDeductUserBalance_ErrorPropagation 验证修复:
+// Redis 真实错误应传播,key 不存在应返回 ErrBalanceCacheNotFound。
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
tests := []struct {
name string
@@ -287,11 +287,11 @@ func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
expectErr bool
}{
{
- name: "key_not_exists_returns_nil",
+ name: "key_not_exists_returns_ErrBalanceCacheNotFound",
fn: func(ctx context.Context, cache service.BillingCache) {
- // key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
+ // key 不存在时,Lua 脚本返回 2,应返回 ErrBalanceCacheNotFound
err := cache.DeductUserBalance(ctx, 99999, 1.0)
- require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
+ require.ErrorIs(s.T(), err, service.ErrBalanceCacheNotFound, "DeductUserBalance on non-existent key should return ErrBalanceCacheNotFound")
},
},
{
diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go
index 58291b665..0f193c7dc 100644
--- a/backend/internal/repository/gateway_cache.go
+++ b/backend/internal/repository/gateway_cache.go
@@ -2,14 +2,20 @@ package repository
import (
"context"
+ "encoding/json"
"fmt"
+ "strconv"
+ "strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/cespare/xxhash/v2"
"github.com/redis/go-redis/v9"
)
const stickySessionPrefix = "sticky_session:"
+const openAIWSSessionLastResponsePrefix = "openai_ws_session_last_response:"
+const openAIWSResponsePendingToolCallsPrefix = "openai_ws_response_pending_tool_calls:"
type gatewayCache struct {
rdb *redis.Client
@@ -25,6 +31,20 @@ func buildSessionKey(groupID int64, sessionHash string) string {
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
}
+func buildOpenAIWSSessionLastResponseKey(groupID int64, sessionHash string) string {
+ return fmt.Sprintf("%s%d:%s", openAIWSSessionLastResponsePrefix, groupID, sessionHash)
+}
+
+func buildOpenAIWSResponsePendingToolCallsKey(groupID int64, responseID string) string {
+ id := strings.TrimSpace(responseID)
+ if id == "" {
+ return ""
+ }
+ return openAIWSResponsePendingToolCallsPrefix +
+ strconv.FormatInt(groupID, 10) + ":" +
+ strconv.FormatUint(xxhash.Sum64String(id), 16)
+}
+
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Get(ctx, key).Int64()
@@ -51,3 +71,78 @@ func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64
key := buildSessionKey(groupID, sessionHash)
return c.rdb.Del(ctx, key).Err()
}
+
+func (c *gatewayCache) SetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash, responseID string, ttl time.Duration) error {
+ key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash)
+ return c.rdb.Set(ctx, key, responseID, ttl).Err()
+}
+
+func (c *gatewayCache) GetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) (string, error) {
+ key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash)
+ return c.rdb.Get(ctx, key).Result()
+}
+
+func (c *gatewayCache) DeleteOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) error {
+ key := buildOpenAIWSSessionLastResponseKey(groupID, sessionHash)
+ return c.rdb.Del(ctx, key).Err()
+}
+
+func (c *gatewayCache) SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error {
+ key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID)
+ if key == "" {
+ return nil
+ }
+ normalizedCallIDs := normalizeOpenAIWSResponsePendingToolCallIDs(callIDs)
+ if len(normalizedCallIDs) == 0 {
+ return c.rdb.Del(ctx, key).Err()
+ }
+ raw, err := json.Marshal(normalizedCallIDs)
+ if err != nil {
+ return err
+ }
+ return c.rdb.Set(ctx, key, raw, ttl).Err()
+}
+
+func (c *gatewayCache) GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error) {
+ key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID)
+ if key == "" {
+ return nil, nil
+ }
+ raw, err := c.rdb.Get(ctx, key).Bytes()
+ if err != nil {
+ return nil, err
+ }
+ var callIDs []string
+ if err := json.Unmarshal(raw, &callIDs); err != nil {
+ return nil, err
+ }
+ return normalizeOpenAIWSResponsePendingToolCallIDs(callIDs), nil
+}
+
+func (c *gatewayCache) DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error {
+ key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID)
+ if key == "" {
+ return nil
+ }
+ return c.rdb.Del(ctx, key).Err()
+}
+
+func normalizeOpenAIWSResponsePendingToolCallIDs(callIDs []string) []string {
+ if len(callIDs) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(callIDs))
+ normalized := make([]string, 0, len(callIDs))
+ for _, callID := range callIDs {
+ id := strings.TrimSpace(callID)
+ if id == "" {
+ continue
+ }
+ if _, ok := seen[id]; ok {
+ continue
+ }
+ seen[id] = struct{}{}
+ normalized = append(normalized, id)
+ }
+ return normalized
+}
diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go
index 0eebc33f6..093b45ca4 100644
--- a/backend/internal/repository/gateway_cache_integration_test.go
+++ b/backend/internal/repository/gateway_cache_integration_test.go
@@ -3,6 +3,7 @@
package repository
import (
+ "context"
"errors"
"testing"
"time"
@@ -104,6 +105,49 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
+func (s *GatewayCacheSuite) TestSetAndGetOpenAIWSResponsePendingToolCalls() {
+ type responsePendingToolCallsCache interface {
+ SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error
+ GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error)
+ }
+ cache, ok := s.cache.(responsePendingToolCallsCache)
+ require.True(s.T(), ok, "gateway cache should implement pending tool calls cache")
+
+ responseID := "resp_pending_integration_1"
+ groupID := int64(1)
+ ttl := 2 * time.Minute
+ require.NoError(s.T(), cache.SetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID, []string{"call_1", "call_2", "call_1", " "}, ttl))
+
+ callIDs, err := cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID)
+ require.NoError(s.T(), err)
+ require.ElementsMatch(s.T(), []string{"call_1", "call_2"}, callIDs)
+ _, err = cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID+1, responseID)
+ require.True(s.T(), errors.Is(err, redis.Nil), "pending tool calls should be isolated by group")
+
+ key := buildOpenAIWSResponsePendingToolCallsKey(groupID, responseID)
+ remainingTTL, ttlErr := s.rdb.TTL(s.ctx, key).Result()
+ require.NoError(s.T(), ttlErr)
+ s.AssertTTLWithin(remainingTTL, 1*time.Second, ttl)
+}
+
+func (s *GatewayCacheSuite) TestDeleteOpenAIWSResponsePendingToolCalls() {
+ type responsePendingToolCallsCache interface {
+ SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error
+ GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error)
+ DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error
+ }
+ cache, ok := s.cache.(responsePendingToolCallsCache)
+ require.True(s.T(), ok, "gateway cache should implement pending tool calls cache")
+
+ responseID := "resp_pending_integration_2"
+ groupID := int64(1)
+ require.NoError(s.T(), cache.SetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID, []string{"call_3"}, time.Minute))
+ require.NoError(s.T(), cache.DeleteOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID))
+
+ _, err := cache.GetOpenAIWSResponsePendingToolCalls(s.ctx, groupID, responseID)
+ require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
+}
+
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))
}
diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go
index 4edc85340..e9b4902ac 100644
--- a/backend/internal/repository/group_repo.go
+++ b/backend/internal/repository/group_repo.go
@@ -127,38 +127,6 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
SetMcpXMLInject(groupIn.MCPXMLInject).
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes)
- // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
- if groupIn.DailyLimitUSD != nil {
- builder = builder.SetDailyLimitUsd(*groupIn.DailyLimitUSD)
- } else {
- builder = builder.ClearDailyLimitUsd()
- }
- if groupIn.WeeklyLimitUSD != nil {
- builder = builder.SetWeeklyLimitUsd(*groupIn.WeeklyLimitUSD)
- } else {
- builder = builder.ClearWeeklyLimitUsd()
- }
- if groupIn.MonthlyLimitUSD != nil {
- builder = builder.SetMonthlyLimitUsd(*groupIn.MonthlyLimitUSD)
- } else {
- builder = builder.ClearMonthlyLimitUsd()
- }
- if groupIn.ImagePrice1K != nil {
- builder = builder.SetImagePrice1k(*groupIn.ImagePrice1K)
- } else {
- builder = builder.ClearImagePrice1k()
- }
- if groupIn.ImagePrice2K != nil {
- builder = builder.SetImagePrice2k(*groupIn.ImagePrice2K)
- } else {
- builder = builder.ClearImagePrice2k()
- }
- if groupIn.ImagePrice4K != nil {
- builder = builder.SetImagePrice4k(*groupIn.ImagePrice4K)
- } else {
- builder = builder.ClearImagePrice4k()
- }
-
// 处理 FallbackGroupID:nil 时清除,否则设置
if groupIn.FallbackGroupID != nil {
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go
index a4674c1a3..7ac424466 100644
--- a/backend/internal/repository/http_upstream.go
+++ b/backend/internal/repository/http_upstream.go
@@ -1,6 +1,7 @@
package repository
import (
+ "crypto/tls"
"errors"
"fmt"
"io"
@@ -45,6 +46,16 @@ const (
defaultMaxUpstreamClients = 5000
// defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟)
defaultClientIdleTTLSeconds = 900
+ // OpenAI HTTP/2 代理回退策略默认值
+ defaultOpenAIHTTP2FallbackErrorThreshold = 2
+ defaultOpenAIHTTP2FallbackWindow = 60 * time.Second
+ defaultOpenAIHTTP2FallbackTTL = 10 * time.Minute
+)
+
+const (
+ upstreamProtocolModeDefault = "default"
+ upstreamProtocolModeOpenAIH2 = "openai_h2"
+ upstreamProtocolModeOpenAIH1Fallback = "openai_h1_fallback"
)
var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached")
@@ -59,14 +70,30 @@ type poolSettings struct {
responseHeaderTimeout time.Duration // 等待响应头超时时间
}
+type openAIHTTP2Settings struct {
+ enabled bool
+ allowProxyFallbackToHTTP1 bool
+ fallbackErrorThreshold int
+ fallbackWindow time.Duration
+ fallbackTTL time.Duration
+}
+
// upstreamClientEntry 上游客户端缓存条目
// 记录客户端实例及其元数据,用于连接池管理和淘汰策略
type upstreamClientEntry struct {
- client *http.Client // HTTP 客户端实例
- proxyKey string // 代理标识(用于检测代理变更)
- poolKey string // 连接池配置标识(用于检测配置变更)
- lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
- inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
+ client *http.Client // HTTP 客户端实例
+ proxyKey string // 代理标识(用于检测代理变更)
+ poolKey string // 连接池配置标识(用于检测配置变更)
+ protocolMode string // 协议模式(default/openai_h2/openai_h1_fallback)
+ lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
+ inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
+}
+
+type openAIHTTP2FallbackState struct {
+ mu sync.Mutex
+ windowStart time.Time
+ errorCount int
+ fallbackUntil time.Time
}
// httpUpstreamService 通用 HTTP 上游服务
@@ -90,6 +117,8 @@ type httpUpstreamService struct {
cfg *config.Config // 全局配置
mu sync.RWMutex // 保护 clients map 的读写锁
clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定
+ // OpenAI 走 HTTP 代理时的 H2->H1 回退状态(key=标准化 proxyKey)
+ openAIHTTP2Fallbacks sync.Map
}
// NewHTTPUpstream 创建通用 HTTP 上游服务
@@ -127,9 +156,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
if err := s.validateRequestHost(req); err != nil {
return nil, err
}
+ profile := service.HTTPUpstreamProfileDefault
+ if req != nil {
+ profile = service.HTTPUpstreamProfileFromContext(req.Context())
+ }
// 获取或创建对应的客户端,并标记请求占用
- entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
+ entry, err := s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, profile)
if err != nil {
return nil, err
}
@@ -137,11 +170,13 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i
// 执行请求
resp, err := entry.client.Do(req)
if err != nil {
+ s.recordOpenAIHTTP2Failure(profile, entry.protocolMode, entry.proxyKey, err)
// 请求失败,立即减少计数
atomic.AddInt64(&entry.inFlight, -1)
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
return nil, err
}
+ s.recordOpenAIHTTP2Success(profile, entry.protocolMode, entry.proxyKey)
// 包装响应体,在关闭时自动减少计数并更新时间戳
// 这确保了流式响应(如 SSE)在完全读取前不会被淘汰
@@ -241,8 +276,8 @@ func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID i
return nil, err
}
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
- cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
- poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
+ cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID, upstreamProtocolModeDefault)
+ poolKey := s.buildPoolKey(isolation, accountConcurrency, upstreamProtocolModeDefault) + ":tls"
now := time.Now()
nowUnix := now.UnixNano()
@@ -359,7 +394,12 @@ func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Req
// acquireClient 获取或创建客户端,并标记为进行中请求
// 用于请求路径,避免在获取后被淘汰
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
- return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true)
+ return s.acquireClientWithProfile(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault)
+}
+
+// acquireClientWithProfile 获取或创建客户端,并按请求 profile 选择协议策略。
+func (s *httpUpstreamService) acquireClientWithProfile(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile) (*upstreamClientEntry, error) {
+ return s.getClientEntry(proxyURL, accountID, accountConcurrency, profile, true, true)
}
// getOrCreateClient 获取或创建客户端
@@ -377,25 +417,25 @@ func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, ac
// - proxy: 按代理地址隔离,同一代理共享客户端
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
// - account_proxy: 按账户+代理组合隔离,最细粒度
-func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
- return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
+func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) *upstreamClientEntry {
+ entry, _ := s.getClientEntry(proxyURL, accountID, accountConcurrency, service.HTTPUpstreamProfileDefault, false, false)
+ return entry
}
// getClientEntry 获取或创建客户端条目
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
-func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
+func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, profile service.HTTPUpstreamProfile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
// 获取隔离模式
isolation := s.getIsolationMode()
// 标准化代理 URL 并解析
- proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL)
- if err != nil {
- return nil, err
- }
+ proxyKey, parsedProxy := normalizeProxyURL(proxyURL)
+ // 根据请求 profile(例如 OpenAI)选择协议模式
+ protocolMode := s.resolveProtocolMode(profile, proxyKey, parsedProxy)
// 构建缓存键(根据隔离策略不同)
- cacheKey := buildCacheKey(isolation, proxyKey, accountID)
+ cacheKey := buildCacheKey(isolation, proxyKey, accountID, protocolMode)
// 构建连接池配置键(用于检测配置变更)
- poolKey := s.buildPoolKey(isolation, accountConcurrency)
+ poolKey := s.buildPoolKey(isolation, accountConcurrency, protocolMode)
now := time.Now()
nowUnix := now.UnixNano()
@@ -439,7 +479,7 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
// 缓存未命中或需要重建,创建新客户端
settings := s.resolvePoolSettings(isolation, accountConcurrency)
- transport, err := buildUpstreamTransport(settings, parsedProxy)
+ transport, err := buildUpstreamTransport(settings, parsedProxy, protocolMode)
if err != nil {
s.mu.Unlock()
return nil, fmt.Errorf("build transport: %w", err)
@@ -449,9 +489,10 @@ func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, a
client.CheckRedirect = s.redirectChecker
}
entry := &upstreamClientEntry{
- client: client,
- proxyKey: proxyKey,
- poolKey: poolKey,
+ client: client,
+ proxyKey: proxyKey,
+ poolKey: poolKey,
+ protocolMode: protocolMode,
}
atomic.StoreInt64(&entry.lastUsed, nowUnix)
if markInFlight {
@@ -644,13 +685,17 @@ func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcu
//
// 返回:
// - string: 配置键
-func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string {
+func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int, protocolMode string) string {
+ base := "default"
if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy {
if accountConcurrency > 0 {
- return fmt.Sprintf("account:%d", accountConcurrency)
+ base = fmt.Sprintf("account:%d", accountConcurrency)
}
}
- return "default"
+ if protocolMode == "" || protocolMode == upstreamProtocolModeDefault {
+ return base
+ }
+ return base + "|proto:" + protocolMode
}
// buildCacheKey 构建客户端缓存键
@@ -668,15 +713,20 @@ func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency
// - proxy 模式: "proxy:{proxyKey}"
// - account 模式: "account:{accountID}"
// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}"
-func buildCacheKey(isolation, proxyKey string, accountID int64) string {
+func buildCacheKey(isolation, proxyKey string, accountID int64, protocolMode string) string {
+ var base string
switch isolation {
case config.ConnectionPoolIsolationAccount:
- return fmt.Sprintf("account:%d", accountID)
+ base = fmt.Sprintf("account:%d", accountID)
case config.ConnectionPoolIsolationAccountProxy:
- return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
+ base = fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
default:
- return fmt.Sprintf("proxy:%s", proxyKey)
+ base = fmt.Sprintf("proxy:%s", proxyKey)
+ }
+ if protocolMode != "" && protocolMode != upstreamProtocolModeDefault {
+ base += "|proto:" + protocolMode
}
+ return base
}
// normalizeProxyURL 标准化代理 URL
@@ -720,6 +770,392 @@ func normalizeProxyURL(raw string) (string, *url.URL, error) {
return parsed.String(), parsed, nil
}
+func (s *httpUpstreamService) resolveOpenAIHTTP2Settings() openAIHTTP2Settings {
+ settings := openAIHTTP2Settings{
+ enabled: true,
+ allowProxyFallbackToHTTP1: true,
+ fallbackErrorThreshold: defaultOpenAIHTTP2FallbackErrorThreshold,
+ fallbackWindow: defaultOpenAIHTTP2FallbackWindow,
+ fallbackTTL: defaultOpenAIHTTP2FallbackTTL,
+ }
+ if s == nil || s.cfg == nil {
+ return settings
+ }
+ cfg := s.cfg.Gateway.OpenAIHTTP2
+ settings.enabled = cfg.Enabled
+ settings.allowProxyFallbackToHTTP1 = cfg.AllowProxyFallbackToHTTP1
+ if cfg.FallbackErrorThreshold > 0 {
+ settings.fallbackErrorThreshold = cfg.FallbackErrorThreshold
+ }
+ if cfg.FallbackWindowSeconds > 0 {
+ settings.fallbackWindow = time.Duration(cfg.FallbackWindowSeconds) * time.Second
+ }
+ if cfg.FallbackTTLSeconds > 0 {
+ settings.fallbackTTL = time.Duration(cfg.FallbackTTLSeconds) * time.Second
+ }
+ return settings
+}
+
+func (s *httpUpstreamService) resolveProtocolMode(profile service.HTTPUpstreamProfile, proxyKey string, parsedProxy *url.URL) string {
+ if profile != service.HTTPUpstreamProfileOpenAI {
+ return upstreamProtocolModeDefault
+ }
+ settings := s.resolveOpenAIHTTP2Settings()
+ if !settings.enabled {
+ return upstreamProtocolModeDefault
+ }
+ if parsedProxy == nil {
+ return upstreamProtocolModeOpenAIH2
+ }
+ scheme := strings.ToLower(parsedProxy.Scheme)
+ if scheme != "http" && scheme != "https" {
+ return upstreamProtocolModeOpenAIH2
+ }
+ if settings.allowProxyFallbackToHTTP1 && s.isOpenAIHTTP2FallbackActive(proxyKey) {
+ return upstreamProtocolModeOpenAIH1Fallback
+ }
+ return upstreamProtocolModeOpenAIH2
+}
+
+func (s *httpUpstreamService) isOpenAIHTTP2FallbackActive(proxyKey string) bool {
+ raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey)
+ if !ok {
+ return false
+ }
+ state, ok := raw.(*openAIHTTP2FallbackState)
+ if !ok || state == nil {
+ return false
+ }
+ return state.isFallbackActive(time.Now())
+}
+
+func (s *httpUpstreamService) getOrCreateOpenAIHTTP2FallbackState(proxyKey string) *openAIHTTP2FallbackState {
+ state := &openAIHTTP2FallbackState{}
+ actual, _ := s.openAIHTTP2Fallbacks.LoadOrStore(proxyKey, state)
+ cached, ok := actual.(*openAIHTTP2FallbackState)
+ if !ok || cached == nil {
+ return state
+ }
+ return cached
+}
+
+func isHTTPProxyKey(proxyKey string) bool {
+ return strings.HasPrefix(proxyKey, "http://") || strings.HasPrefix(proxyKey, "https://")
+}
+
+func isOpenAIHTTP2CompatibilityError(err error) bool {
+ if err == nil {
+ return false
+ }
+ msg := strings.ToLower(err.Error())
+ if msg == "" {
+ return false
+ }
+ markers := []string{
+ "http2",
+ "alpn",
+ "no application protocol",
+ "protocol error",
+ "stream error",
+ "goaway",
+ "refused_stream",
+ "frame too large",
+ }
+ for _, marker := range markers {
+ if strings.Contains(msg, marker) {
+ return true
+ }
+ }
+ return false
+}
+
+func (s *httpUpstreamService) recordOpenAIHTTP2Failure(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string, err error) {
+ if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 {
+ return
+ }
+ settings := s.resolveOpenAIHTTP2Settings()
+ if !settings.enabled || !settings.allowProxyFallbackToHTTP1 {
+ return
+ }
+ if !isHTTPProxyKey(proxyKey) || !isOpenAIHTTP2CompatibilityError(err) {
+ return
+ }
+ state := s.getOrCreateOpenAIHTTP2FallbackState(proxyKey)
+ activated, until := state.recordFailure(time.Now(), settings.fallbackErrorThreshold, settings.fallbackWindow, settings.fallbackTTL)
+ if activated {
+ slog.Warn("openai_http2_proxy_fallback_activated",
+ "proxy", proxyKey,
+ "fallback_until", until.Format(time.RFC3339))
+ }
+}
+
+func (s *httpUpstreamService) recordOpenAIHTTP2Success(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string) {
+ if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 {
+ return
+ }
+ if !isHTTPProxyKey(proxyKey) {
+ return
+ }
+ raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey)
+ if !ok {
+ return
+ }
+ state, ok := raw.(*openAIHTTP2FallbackState)
+ if !ok || state == nil {
+ return
+ }
+ state.resetErrorWindow()
+}
+
+func (s *openAIHTTP2FallbackState) isFallbackActive(now time.Time) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.fallbackUntil.IsZero() {
+ return false
+ }
+ if now.Before(s.fallbackUntil) {
+ return true
+ }
+ s.fallbackUntil = time.Time{}
+ return false
+}
+
+func (s *openAIHTTP2FallbackState) resetErrorWindow() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.windowStart = time.Time{}
+ s.errorCount = 0
+}
+
+func (s *openAIHTTP2FallbackState) recordFailure(now time.Time, threshold int, window, ttl time.Duration) (bool, time.Time) {
+ if threshold <= 0 {
+ threshold = defaultOpenAIHTTP2FallbackErrorThreshold
+ }
+ if window <= 0 {
+ window = defaultOpenAIHTTP2FallbackWindow
+ }
+ if ttl <= 0 {
+ ttl = defaultOpenAIHTTP2FallbackTTL
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if !s.fallbackUntil.IsZero() && now.Before(s.fallbackUntil) {
+ return false, s.fallbackUntil
+ }
+ if !s.fallbackUntil.IsZero() && !now.Before(s.fallbackUntil) {
+ s.fallbackUntil = time.Time{}
+ }
+
+ if s.windowStart.IsZero() || now.Sub(s.windowStart) > window {
+ s.windowStart = now
+ s.errorCount = 0
+ }
+ s.errorCount++
+ if s.errorCount < threshold {
+ return false, time.Time{}
+ }
+
+ s.fallbackUntil = now.Add(ttl)
+ s.windowStart = time.Time{}
+ s.errorCount = 0
+ return true, s.fallbackUntil
+}
+
+func (s *httpUpstreamService) resolveOpenAIHTTP2Settings() openAIHTTP2Settings {
+ settings := openAIHTTP2Settings{
+ enabled: true,
+ allowProxyFallbackToHTTP1: true,
+ fallbackErrorThreshold: defaultOpenAIHTTP2FallbackErrorThreshold,
+ fallbackWindow: defaultOpenAIHTTP2FallbackWindow,
+ fallbackTTL: defaultOpenAIHTTP2FallbackTTL,
+ }
+ if s == nil || s.cfg == nil {
+ return settings
+ }
+ cfg := s.cfg.Gateway.OpenAIHTTP2
+ settings.enabled = cfg.Enabled
+ settings.allowProxyFallbackToHTTP1 = cfg.AllowProxyFallbackToHTTP1
+ if cfg.FallbackErrorThreshold > 0 {
+ settings.fallbackErrorThreshold = cfg.FallbackErrorThreshold
+ }
+ if cfg.FallbackWindowSeconds > 0 {
+ settings.fallbackWindow = time.Duration(cfg.FallbackWindowSeconds) * time.Second
+ }
+ if cfg.FallbackTTLSeconds > 0 {
+ settings.fallbackTTL = time.Duration(cfg.FallbackTTLSeconds) * time.Second
+ }
+ return settings
+}
+
+func (s *httpUpstreamService) resolveProtocolMode(profile service.HTTPUpstreamProfile, proxyKey string, parsedProxy *url.URL) string {
+ if profile != service.HTTPUpstreamProfileOpenAI {
+ return upstreamProtocolModeDefault
+ }
+ settings := s.resolveOpenAIHTTP2Settings()
+ if !settings.enabled {
+ return upstreamProtocolModeDefault
+ }
+ if parsedProxy == nil {
+ return upstreamProtocolModeOpenAIH2
+ }
+ scheme := strings.ToLower(parsedProxy.Scheme)
+ if scheme != "http" && scheme != "https" {
+ return upstreamProtocolModeOpenAIH2
+ }
+ if settings.allowProxyFallbackToHTTP1 && s.isOpenAIHTTP2FallbackActive(proxyKey) {
+ return upstreamProtocolModeOpenAIH1Fallback
+ }
+ return upstreamProtocolModeOpenAIH2
+}
+
+func (s *httpUpstreamService) isOpenAIHTTP2FallbackActive(proxyKey string) bool {
+ raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey)
+ if !ok {
+ return false
+ }
+ state, ok := raw.(*openAIHTTP2FallbackState)
+ if !ok || state == nil {
+ return false
+ }
+ return state.isFallbackActive(time.Now())
+}
+
+func (s *httpUpstreamService) getOrCreateOpenAIHTTP2FallbackState(proxyKey string) *openAIHTTP2FallbackState {
+ state := &openAIHTTP2FallbackState{}
+ actual, _ := s.openAIHTTP2Fallbacks.LoadOrStore(proxyKey, state)
+ cached, ok := actual.(*openAIHTTP2FallbackState)
+ if !ok || cached == nil {
+ return state
+ }
+ return cached
+}
+
+func isHTTPProxyKey(proxyKey string) bool {
+ return strings.HasPrefix(proxyKey, "http://") || strings.HasPrefix(proxyKey, "https://")
+}
+
+func isOpenAIHTTP2CompatibilityError(err error) bool {
+ if err == nil {
+ return false
+ }
+ msg := strings.ToLower(err.Error())
+ if msg == "" {
+ return false
+ }
+ markers := []string{
+ "http2",
+ "alpn",
+ "no application protocol",
+ "protocol error",
+ "stream error",
+ "goaway",
+ "refused_stream",
+ "frame too large",
+ }
+ for _, marker := range markers {
+ if strings.Contains(msg, marker) {
+ return true
+ }
+ }
+ return false
+}
+
+func (s *httpUpstreamService) recordOpenAIHTTP2Failure(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string, err error) {
+ if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 {
+ return
+ }
+ settings := s.resolveOpenAIHTTP2Settings()
+ if !settings.enabled || !settings.allowProxyFallbackToHTTP1 {
+ return
+ }
+ if !isHTTPProxyKey(proxyKey) || !isOpenAIHTTP2CompatibilityError(err) {
+ return
+ }
+ state := s.getOrCreateOpenAIHTTP2FallbackState(proxyKey)
+ activated, until := state.recordFailure(time.Now(), settings.fallbackErrorThreshold, settings.fallbackWindow, settings.fallbackTTL)
+ if activated {
+ slog.Warn("openai_http2_proxy_fallback_activated",
+ "proxy", proxyKey,
+ "fallback_until", until.Format(time.RFC3339))
+ }
+}
+
+func (s *httpUpstreamService) recordOpenAIHTTP2Success(profile service.HTTPUpstreamProfile, protocolMode, proxyKey string) {
+ if profile != service.HTTPUpstreamProfileOpenAI || protocolMode != upstreamProtocolModeOpenAIH2 {
+ return
+ }
+ if !isHTTPProxyKey(proxyKey) {
+ return
+ }
+ raw, ok := s.openAIHTTP2Fallbacks.Load(proxyKey)
+ if !ok {
+ return
+ }
+ state, ok := raw.(*openAIHTTP2FallbackState)
+ if !ok || state == nil {
+ return
+ }
+ state.resetErrorWindow()
+}
+
+func (s *openAIHTTP2FallbackState) isFallbackActive(now time.Time) bool {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.fallbackUntil.IsZero() {
+ return false
+ }
+ if now.Before(s.fallbackUntil) {
+ return true
+ }
+ s.fallbackUntil = time.Time{}
+ return false
+}
+
+func (s *openAIHTTP2FallbackState) resetErrorWindow() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ s.windowStart = time.Time{}
+ s.errorCount = 0
+}
+
+func (s *openAIHTTP2FallbackState) recordFailure(now time.Time, threshold int, window, ttl time.Duration) (bool, time.Time) {
+ if threshold <= 0 {
+ threshold = defaultOpenAIHTTP2FallbackErrorThreshold
+ }
+ if window <= 0 {
+ window = defaultOpenAIHTTP2FallbackWindow
+ }
+ if ttl <= 0 {
+ ttl = defaultOpenAIHTTP2FallbackTTL
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if !s.fallbackUntil.IsZero() && now.Before(s.fallbackUntil) {
+ return false, s.fallbackUntil
+ }
+ if !s.fallbackUntil.IsZero() && !now.Before(s.fallbackUntil) {
+ s.fallbackUntil = time.Time{}
+ }
+
+ if s.windowStart.IsZero() || now.Sub(s.windowStart) > window {
+ s.windowStart = now
+ s.errorCount = 0
+ }
+ s.errorCount++
+ if s.errorCount < threshold {
+ return false, time.Time{}
+ }
+
+ s.fallbackUntil = now.Add(ttl)
+ s.windowStart = time.Time{}
+ s.errorCount = 0
+ return true, s.fallbackUntil
+}
+
// defaultPoolSettings 获取默认连接池配置
// 从全局配置中读取,无效值使用常量默认值
//
@@ -779,7 +1215,7 @@ func defaultPoolSettings(cfg *config.Config) poolSettings {
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
-func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) {
+func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL, protocolMode string) (*http.Transport, error) {
transport := &http.Transport{
MaxIdleConns: settings.maxIdleConns,
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
@@ -787,6 +1223,14 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra
IdleConnTimeout: settings.idleConnTimeout,
ResponseHeaderTimeout: settings.responseHeaderTimeout,
}
+ switch protocolMode {
+ case upstreamProtocolModeOpenAIH2:
+ transport.ForceAttemptHTTP2 = true
+ case upstreamProtocolModeOpenAIH1Fallback:
+ // 显式禁用 HTTP/2,确保代理不兼容场景回退到 HTTP/1.1。
+ transport.ForceAttemptHTTP2 = false
+ transport.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
+ }
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
return nil, err
}
diff --git a/backend/internal/repository/http_upstream_benchmark_test.go b/backend/internal/repository/http_upstream_benchmark_test.go
index 89892b3b6..a92105c1a 100644
--- a/backend/internal/repository/http_upstream_benchmark_test.go
+++ b/backend/internal/repository/http_upstream_benchmark_test.go
@@ -45,7 +45,7 @@ func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
settings := defaultPoolSettings(cfg)
for i := 0; i < b.N; i++ {
// 每次迭代都创建新客户端,包含 Transport 分配
- transport, err := buildUpstreamTransport(settings, parsedProxy)
+ transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeDefault)
if err != nil {
b.Fatalf("创建 Transport 失败: %v", err)
}
diff --git a/backend/internal/repository/http_upstream_test.go b/backend/internal/repository/http_upstream_test.go
index b3268463a..2df1dd29c 100644
--- a/backend/internal/repository/http_upstream_test.go
+++ b/backend/internal/repository/http_upstream_test.go
@@ -1,13 +1,19 @@
package repository
import (
+ "context"
+ "errors"
"io"
"net/http"
+ "net/url"
+ "strings"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
+ "github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@@ -116,6 +122,16 @@ func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
require.Equal(s.T(), "direct", string(b), "unexpected body")
}
+func (s *HTTPUpstreamSuite) TestDo_RequestErrorPath() {
+ svc := s.newService()
+ req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:1/unreachable", nil)
+ require.NoError(s.T(), err)
+
+ resp, doErr := svc.Do(req, "", 1, 1)
+ require.Nil(s.T(), resp)
+ require.Error(s.T(), doErr)
+}
+
// TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能
// 验证请求通过代理服务器转发,使用绝对 URI 格式
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
@@ -276,6 +292,431 @@ func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() {
require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收")
}
+func (s *HTTPUpstreamSuite) TestOpenAIProfile_UsesHTTP2TransportForHTTPProxy() {
+ s.cfg.Gateway = config.GatewayConfig{
+ OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 2,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ },
+ }
+ svc := s.newService()
+
+ entry, err := svc.getClientEntry("http://proxy.local:8080", 1, 1, service.HTTPUpstreamProfileOpenAI, false, false)
+ require.NoError(s.T(), err)
+ require.Equal(s.T(), upstreamProtocolModeOpenAIH2, entry.protocolMode)
+
+ transport, ok := entry.client.Transport.(*http.Transport)
+ require.True(s.T(), ok, "expected *http.Transport")
+ require.True(s.T(), transport.ForceAttemptHTTP2, "OpenAI profile should prefer HTTP/2")
+ require.Nil(s.T(), transport.TLSNextProto, "HTTP/2 mode should not force-disable TLSNextProto")
+}
+
+func (s *HTTPUpstreamSuite) TestOpenAIProfile_FallbackToHTTP11WhenProxyMarkedIncompatible() {
+ s.cfg.Gateway = config.GatewayConfig{
+ OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 2,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ },
+ }
+ svc := s.newService()
+ proxyURL := "http://proxy.local:8080"
+
+ state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyURL)
+ state.mu.Lock()
+ state.fallbackUntil = time.Now().Add(3 * time.Minute)
+ state.mu.Unlock()
+
+ entry, err := svc.getClientEntry(proxyURL, 1, 1, service.HTTPUpstreamProfileOpenAI, false, false)
+ require.NoError(s.T(), err)
+ require.Equal(s.T(), upstreamProtocolModeOpenAIH1Fallback, entry.protocolMode)
+
+ transport, ok := entry.client.Transport.(*http.Transport)
+ require.True(s.T(), ok, "expected *http.Transport")
+ require.False(s.T(), transport.ForceAttemptHTTP2, "fallback mode must disable HTTP/2 force-attempt")
+ require.NotNil(s.T(), transport.TLSNextProto, "fallback mode must disable HTTP/2 negotiation")
+}
+
+func (s *HTTPUpstreamSuite) TestOpenAIProfile_RecordHTTP2ErrorActivatesFallback() {
+ s.cfg.Gateway = config.GatewayConfig{
+ OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 2,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ },
+ }
+ svc := s.newService()
+ proxyURL := "http://proxy.local:8080"
+ h2Err := errors.New("http2: stream error")
+
+ svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, h2Err)
+ require.False(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL), "first error should not activate fallback")
+
+ svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, h2Err)
+ require.True(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL), "second error in window should activate fallback")
+}
+
+func (s *HTTPUpstreamSuite) TestOpenAIProfile_RecordNonHTTP2ErrorDoesNotActivateFallback() {
+ s.cfg.Gateway = config.GatewayConfig{
+ OpenAIHTTP2: config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 1,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ },
+ }
+ svc := s.newService()
+ proxyURL := "http://proxy.local:8080"
+
+ svc.recordOpenAIHTTP2Failure(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL, errors.New("dial tcp: i/o timeout"))
+ require.False(s.T(), svc.isOpenAIHTTP2FallbackActive(proxyURL))
+}
+
+func (s *HTTPUpstreamSuite) TestDoWithTLS_DisabledDelegatesToDo() {
+ upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = io.WriteString(w, "ok")
+ }))
+ s.T().Cleanup(upstream.Close)
+
+ svc := s.newService()
+ req, err := http.NewRequest(http.MethodGet, upstream.URL+"/tls-disabled", nil)
+ require.NoError(s.T(), err)
+
+ resp, err := svc.DoWithTLS(req, "", 1, 1, false)
+ require.NoError(s.T(), err)
+ defer func() { _ = resp.Body.Close() }()
+ body, readErr := io.ReadAll(resp.Body)
+ require.NoError(s.T(), readErr)
+ require.Equal(s.T(), "ok", string(body))
+}
+
+func (s *HTTPUpstreamSuite) TestDoWithTLS_EnabledHTTPRequestSuccess() {
+ upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, _ = io.WriteString(w, "tls-enabled")
+ }))
+ s.T().Cleanup(upstream.Close)
+
+ svc := s.newService()
+ req, err := http.NewRequest(http.MethodGet, upstream.URL+"/tls-enabled", nil)
+ require.NoError(s.T(), err)
+
+ resp, err := svc.DoWithTLS(req, "", 9, 1, true)
+ require.NoError(s.T(), err)
+ defer func() { _ = resp.Body.Close() }()
+ body, readErr := io.ReadAll(resp.Body)
+ require.NoError(s.T(), readErr)
+ require.Equal(s.T(), "tls-enabled", string(body))
+}
+
+func (s *HTTPUpstreamSuite) TestDoWithTLS_EnabledRequestError() {
+ svc := s.newService()
+ req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1:1/tls-error", nil)
+ require.NoError(s.T(), err)
+
+ resp, doErr := svc.DoWithTLS(req, "", 9, 1, true)
+ require.Nil(s.T(), resp)
+ require.Error(s.T(), doErr)
+}
+
+func (s *HTTPUpstreamSuite) TestDoWithTLS_ValidateRequestHostFailure() {
+ s.cfg.Security.URLAllowlist.Enabled = true
+ s.cfg.Security.URLAllowlist.AllowPrivateHosts = false
+ svc := s.newService()
+
+ req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1/test", nil)
+ require.NoError(s.T(), err)
+
+ resp, doErr := svc.DoWithTLS(req, "", 1, 1, true)
+ require.Nil(s.T(), resp)
+ require.Error(s.T(), doErr)
+}
+
+func (s *HTTPUpstreamSuite) TestShouldValidateResolvedIPAndValidateRequestHost() {
+ svc := s.newService()
+ require.False(s.T(), svc.shouldValidateResolvedIP())
+ require.NoError(s.T(), svc.validateRequestHost(nil))
+
+ s.cfg.Security.URLAllowlist.Enabled = true
+ s.cfg.Security.URLAllowlist.AllowPrivateHosts = false
+ require.True(s.T(), svc.shouldValidateResolvedIP())
+ require.Error(s.T(), svc.validateRequestHost(nil))
+
+ req, err := http.NewRequest(http.MethodGet, "http:///nohost", nil)
+ require.NoError(s.T(), err)
+ require.Error(s.T(), svc.validateRequestHost(req))
+}
+
+func (s *HTTPUpstreamSuite) TestRedirectCheckerStopsAfterLimit() {
+ svc := s.newService()
+ req, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
+ require.NoError(s.T(), err)
+
+ via := make([]*http.Request, 10)
+ require.Error(s.T(), svc.redirectChecker(req, via))
+}
+
+func (s *HTTPUpstreamSuite) TestRedirectCheckerValidatesRequestHost() {
+ s.cfg.Security.URLAllowlist.Enabled = true
+ s.cfg.Security.URLAllowlist.AllowPrivateHosts = false
+ svc := s.newService()
+
+ req, err := http.NewRequest(http.MethodGet, "http://127.0.0.1", nil)
+ require.NoError(s.T(), err)
+ require.Error(s.T(), svc.redirectChecker(req, nil))
+}
+
+func (s *HTTPUpstreamSuite) TestShouldReuseEntryAndEvictBranches() {
+ svc := s.newService()
+ entry := &upstreamClientEntry{
+ proxyKey: "proxy-a",
+ poolKey: "pool-a",
+ }
+ require.False(s.T(), svc.shouldReuseEntry(nil, config.ConnectionPoolIsolationAccount, "proxy-a", "pool-a"))
+ require.False(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationAccount, "proxy-b", "pool-a"))
+ require.False(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationProxy, "proxy-a", "pool-b"))
+ require.True(s.T(), svc.shouldReuseEntry(entry, config.ConnectionPoolIsolationProxy, "proxy-x", "pool-a"))
+
+ s.cfg.Gateway.MaxUpstreamClients = 2
+ svc.clients["k1"] = &upstreamClientEntry{inFlight: 1}
+ svc.clients["k2"] = &upstreamClientEntry{inFlight: 1}
+ require.False(s.T(), svc.evictOldestIdleLocked())
+ require.False(s.T(), svc.evictOverLimitLocked())
+}
+
+func (s *HTTPUpstreamSuite) TestBuildCacheKeyAndIsolationMode() {
+ svc := s.newService()
+ require.Equal(s.T(), "account:1", buildCacheKey(config.ConnectionPoolIsolationAccount, "direct", 1, ""))
+ require.Equal(s.T(), "account:2|proxy:px", buildCacheKey(config.ConnectionPoolIsolationAccountProxy, "px", 2, ""))
+ require.Equal(s.T(), "proxy:direct", buildCacheKey(config.ConnectionPoolIsolationProxy, "direct", 3, ""))
+ require.Equal(s.T(), "account:1|proto:openai_h2", buildCacheKey(config.ConnectionPoolIsolationAccount, "direct", 1, "openai_h2"))
+
+ s.cfg.Gateway.ConnectionPoolIsolation = "invalid"
+ require.Equal(s.T(), config.ConnectionPoolIsolationAccountProxy, svc.getIsolationMode())
+ s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationProxy
+ require.Equal(s.T(), config.ConnectionPoolIsolationProxy, svc.getIsolationMode())
+}
+
+func (s *HTTPUpstreamSuite) TestResolveProtocolModeAndSettingsBranches() {
+ svc := s.newService()
+ s.cfg.Gateway.OpenAIHTTP2 = config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 2,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ }
+ parsedHTTPProxy, err := url.Parse("http://proxy.local:8080")
+ require.NoError(s.T(), err)
+ parsedSOCKSProxy, err := url.Parse("socks5://proxy.local:1080")
+ require.NoError(s.T(), err)
+
+ require.Equal(s.T(), upstreamProtocolModeDefault, svc.resolveProtocolMode(service.HTTPUpstreamProfileDefault, "direct", nil))
+ require.Equal(s.T(), upstreamProtocolModeOpenAIH2, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "direct", nil))
+ require.Equal(s.T(), upstreamProtocolModeOpenAIH2, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "socks5://proxy.local:1080", parsedSOCKSProxy))
+
+ state := svc.getOrCreateOpenAIHTTP2FallbackState("http://proxy.local:8080")
+ state.mu.Lock()
+ state.fallbackUntil = time.Now().Add(10 * time.Second)
+ state.mu.Unlock()
+ require.Equal(s.T(), upstreamProtocolModeOpenAIH1Fallback, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "http://proxy.local:8080", parsedHTTPProxy))
+
+ s.cfg.Gateway.OpenAIHTTP2.Enabled = false
+ require.Equal(s.T(), upstreamProtocolModeDefault, svc.resolveProtocolMode(service.HTTPUpstreamProfileOpenAI, "http://proxy.local:8080", parsedHTTPProxy))
+}
+
+func (s *HTTPUpstreamSuite) TestGetClientEntryWithTLS_ReusesAndRebuildsOnProxyChange() {
+ s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationAccount
+ svc := s.newService()
+ profile := &tlsfingerprint.Profile{Name: "tls-profile"}
+
+ entry1, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, false, false)
+ require.NoError(s.T(), err)
+ entry2, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, false, false)
+ require.NoError(s.T(), err)
+ require.Same(s.T(), entry1, entry2)
+
+ entry3, err := svc.getClientEntryWithTLS("http://proxy-b.local:8080", 1, 1, profile, false, false)
+ require.NoError(s.T(), err)
+ require.NotSame(s.T(), entry1, entry3)
+}
+
+func (s *HTTPUpstreamSuite) TestGetClientEntryWithTLS_OverLimitReturnsError() {
+ s.cfg.Gateway.ConnectionPoolIsolation = config.ConnectionPoolIsolationAccountProxy
+ s.cfg.Gateway.MaxUpstreamClients = 1
+ svc := s.newService()
+ profile := &tlsfingerprint.Profile{Name: "tls-profile"}
+
+ entry1, err := svc.getClientEntryWithTLS("http://proxy-a.local:8080", 1, 1, profile, true, true)
+ require.NoError(s.T(), err)
+ require.NotNil(s.T(), entry1)
+
+ entry2, err := svc.getClientEntryWithTLS("http://proxy-b.local:8080", 2, 1, profile, true, true)
+ require.ErrorIs(s.T(), err, errUpstreamClientLimitReached)
+ require.Nil(s.T(), entry2)
+}
+
+func (s *HTTPUpstreamSuite) TestOpenAIFallbackStateHelpers() {
+ var state openAIHTTP2FallbackState
+ now := time.Now()
+
+ active, until := state.recordFailure(now, 1, time.Minute, time.Minute)
+ require.True(s.T(), active)
+ require.False(s.T(), until.IsZero())
+ require.True(s.T(), state.isFallbackActive(now))
+ require.False(s.T(), state.isFallbackActive(now.Add(2*time.Minute)))
+
+ state.recordFailure(now, 3, time.Minute, time.Minute)
+ state.recordFailure(now.Add(10*time.Second), 3, time.Minute, time.Minute)
+ state.resetErrorWindow()
+ require.Equal(s.T(), 0, state.errorCount)
+ require.True(s.T(), state.windowStart.IsZero())
+
+ // 在 fallback 活跃期间再次失败,不应重复激活。
+ state.fallbackUntil = now.Add(time.Minute)
+ activated, _ := state.recordFailure(now.Add(5*time.Second), 1, time.Minute, time.Minute)
+ require.False(s.T(), activated)
+}
+
+func (s *HTTPUpstreamSuite) TestRecordOpenAIHTTP2SuccessResetsWindow() {
+ svc := s.newService()
+ proxyURL := "http://proxy.local:8080"
+ state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyURL)
+ state.mu.Lock()
+ state.errorCount = 5
+ state.windowStart = time.Now()
+ state.mu.Unlock()
+
+ svc.recordOpenAIHTTP2Success(service.HTTPUpstreamProfileOpenAI, upstreamProtocolModeOpenAIH2, proxyURL)
+
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ require.Equal(s.T(), 0, state.errorCount)
+ require.True(s.T(), state.windowStart.IsZero())
+}
+
+func (s *HTTPUpstreamSuite) TestDo_OpenAIProxySuccessResetsHTTP2ErrorWindow() {
+ seen := make(chan struct{}, 1)
+ proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ select {
+ case seen <- struct{}{}:
+ default:
+ }
+ _, _ = io.WriteString(w, "proxied")
+ }))
+ s.T().Cleanup(proxySrv.Close)
+
+ s.cfg.Gateway.OpenAIHTTP2 = config.GatewayOpenAIHTTP2Config{
+ Enabled: true,
+ AllowProxyFallbackToHTTP1: true,
+ FallbackErrorThreshold: 2,
+ FallbackWindowSeconds: 60,
+ FallbackTTLSeconds: 600,
+ }
+ svc := s.newService()
+ proxyKey, _ := normalizeProxyURL(proxySrv.URL)
+ state := svc.getOrCreateOpenAIHTTP2FallbackState(proxyKey)
+ state.mu.Lock()
+ state.windowStart = time.Now()
+ state.errorCount = 3
+ state.fallbackUntil = time.Time{}
+ state.mu.Unlock()
+
+ req, err := http.NewRequest(http.MethodGet, "http://example.com/reset-window", nil)
+ require.NoError(s.T(), err)
+ req = req.WithContext(service.WithHTTPUpstreamProfile(context.Background(), service.HTTPUpstreamProfileOpenAI))
+ resp, doErr := svc.Do(req, proxySrv.URL, 1, 1)
+ require.NoError(s.T(), doErr)
+ defer func() { _ = resp.Body.Close() }()
+ _, _ = io.ReadAll(resp.Body)
+
+ select {
+ case <-seen:
+ default:
+ require.Fail(s.T(), "expected proxy to receive request")
+ }
+
+ state.mu.Lock()
+ defer state.mu.Unlock()
+ require.Equal(s.T(), 0, state.errorCount)
+ require.True(s.T(), state.windowStart.IsZero())
+}
+
+func (s *HTTPUpstreamSuite) TestOpenAIFallbackStateMapTypeSafety() {
+ svc := s.newService()
+ svc.openAIHTTP2Fallbacks.Store("x", "bad-type")
+ require.False(s.T(), svc.isOpenAIHTTP2FallbackActive("x"))
+ state := svc.getOrCreateOpenAIHTTP2FallbackState("x")
+ require.NotNil(s.T(), state)
+}
+
+func (s *HTTPUpstreamSuite) TestBuildUpstreamTransport_ModeSwitchingAndProxyErrors() {
+ settings := defaultPoolSettings(s.cfg)
+ parsedProxy, err := url.Parse("http://proxy.local:8080")
+ require.NoError(s.T(), err)
+
+ h2Transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeOpenAIH2)
+ require.NoError(s.T(), err)
+ require.True(s.T(), h2Transport.ForceAttemptHTTP2)
+
+ h1Transport, err := buildUpstreamTransport(settings, parsedProxy, upstreamProtocolModeOpenAIH1Fallback)
+ require.NoError(s.T(), err)
+ require.False(s.T(), h1Transport.ForceAttemptHTTP2)
+ require.NotNil(s.T(), h1Transport.TLSNextProto)
+
+ badProxy, err := url.Parse("ftp://proxy.local:21")
+ require.NoError(s.T(), err)
+ _, badErr := buildUpstreamTransport(settings, badProxy, upstreamProtocolModeDefault)
+ require.Error(s.T(), badErr)
+}
+
+func (s *HTTPUpstreamSuite) TestBuildUpstreamTransportWithTLSFingerprintBranches() {
+ settings := defaultPoolSettings(s.cfg)
+ profile := &tlsfingerprint.Profile{Name: "test-profile"}
+
+ transportDirect, err := buildUpstreamTransportWithTLSFingerprint(settings, nil, profile)
+ require.NoError(s.T(), err)
+ require.NotNil(s.T(), transportDirect.DialTLSContext)
+
+ httpProxy, err := url.Parse("http://proxy.local:8080")
+ require.NoError(s.T(), err)
+ transportHTTPProxy, err := buildUpstreamTransportWithTLSFingerprint(settings, httpProxy, profile)
+ require.NoError(s.T(), err)
+ require.NotNil(s.T(), transportHTTPProxy.DialTLSContext)
+
+ socksProxy, err := url.Parse("socks5://proxy.local:1080")
+ require.NoError(s.T(), err)
+ transportSOCKSProxy, err := buildUpstreamTransportWithTLSFingerprint(settings, socksProxy, profile)
+ require.NoError(s.T(), err)
+ require.NotNil(s.T(), transportSOCKSProxy.DialTLSContext)
+
+ unsupportedProxy, err := url.Parse("ftp://proxy.local:21")
+ require.NoError(s.T(), err)
+ _, unsupportedErr := buildUpstreamTransportWithTLSFingerprint(settings, unsupportedProxy, profile)
+ require.Error(s.T(), unsupportedErr)
+}
+
+func (s *HTTPUpstreamSuite) TestWrapTrackedBody_NilAndCloseOnce() {
+ require.Nil(s.T(), wrapTrackedBody(nil, nil))
+
+ closed := int32(0)
+ readCloser := io.NopCloser(strings.NewReader("x"))
+ wrapped := wrapTrackedBody(readCloser, func() {
+ atomic.AddInt32(&closed, 1)
+ })
+ require.NotNil(s.T(), wrapped)
+ _ = wrapped.Close()
+ _ = wrapped.Close()
+ require.Equal(s.T(), int32(1), atomic.LoadInt32(&closed))
+}
+
// TestHTTPUpstreamSuite 运行测试套件
func TestHTTPUpstreamSuite(t *testing.T) {
suite.Run(t, new(HTTPUpstreamSuite))
diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go
index dca0b612f..3e155971b 100644
--- a/backend/internal/repository/openai_oauth_service.go
+++ b/backend/internal/repository/openai_oauth_service.go
@@ -23,10 +23,7 @@ type openaiOAuthService struct {
}
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
- client, err := createOpenAIReqClient(proxyURL)
- if err != nil {
- return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
- }
+ client := createOpenAIReqClient(proxyURL)
if redirectURI == "" {
redirectURI = openai.DefaultRedirectURI
@@ -77,10 +74,7 @@ func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refre
}
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
- client, err := createOpenAIReqClient(proxyURL)
- if err != nil {
- return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
- }
+ client := createOpenAIReqClient(proxyURL)
formData := url.Values{}
formData.Set("grant_type", "refresh_token")
@@ -108,7 +102,7 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre
return &tokenResp, nil
}
-func createOpenAIReqClient(proxyURL string) (*req.Client, error) {
+func createOpenAIReqClient(proxyURL string) *req.Client {
return getSharedReqClient(reqClientOptions{
ProxyURL: proxyURL,
Timeout: 120 * time.Second,
diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go
index d30cc7ddb..6956e36c8 100644
--- a/backend/internal/repository/usage_log_repo.go
+++ b/backend/internal/repository/usage_log_repo.go
@@ -218,6 +218,275 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog)
return true, nil
}
+func (r *usageLogRepository) usageSQLExecutor(ctx context.Context) sqlExecutor {
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return tx.Client()
+ }
+ return r.sql
+}
+
+func (r *usageLogRepository) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ if fn == nil {
+ return nil
+ }
+ if tx := dbent.TxFromContext(ctx); tx != nil {
+ return fn(ctx)
+ }
+ tx, err := r.client.Tx(ctx)
+ if err != nil {
+ return err
+ }
+ txCtx := dbent.NewTxContext(ctx, tx)
+ if err := fn(txCtx); err != nil {
+ _ = tx.Rollback()
+ return err
+ }
+ return tx.Commit()
+}
+
+func (r *usageLogRepository) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*service.UsageBillingEntry, error) {
+ query := `
+ SELECT
+ id,
+ usage_log_id,
+ user_id,
+ api_key_id,
+ subscription_id,
+ billing_type,
+ applied,
+ delta_usd,
+ status,
+ attempt_count,
+ next_retry_at,
+ updated_at,
+ created_at,
+ last_error
+ FROM billing_usage_entries
+ WHERE usage_log_id = $1
+ `
+ rows, err := r.usageSQLExecutor(ctx).QueryContext(ctx, query, usageLogID)
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if !rows.Next() {
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
+ return nil, service.ErrUsageBillingEntryNotFound
+ }
+ entry, err := scanUsageBillingEntry(rows)
+ if err != nil {
+ return nil, err
+ }
+ if err = rows.Err(); err != nil {
+ return nil, err
+ }
+ return entry, nil
+}
+
+func (r *usageLogRepository) UpsertUsageBillingEntry(ctx context.Context, entry *service.UsageBillingEntry) (*service.UsageBillingEntry, bool, error) {
+ if entry == nil {
+ return nil, false, nil
+ }
+
+ insertQuery := `
+ INSERT INTO billing_usage_entries (
+ usage_log_id,
+ user_id,
+ api_key_id,
+ subscription_id,
+ billing_type,
+ applied,
+ delta_usd,
+ status,
+ attempt_count,
+ next_retry_at,
+ updated_at
+ ) VALUES (
+ $1, $2, $3, $4, $5, FALSE, $6, $7, 0, NOW(), NOW()
+ )
+ ON CONFLICT (usage_log_id) DO NOTHING
+ RETURNING
+ id,
+ usage_log_id,
+ user_id,
+ api_key_id,
+ subscription_id,
+ billing_type,
+ applied,
+ delta_usd,
+ status,
+ attempt_count,
+ next_retry_at,
+ updated_at,
+ created_at,
+ last_error
+ `
+
+ exec := r.usageSQLExecutor(ctx)
+ rows, err := exec.QueryContext(
+ ctx,
+ insertQuery,
+ entry.UsageLogID,
+ entry.UserID,
+ entry.APIKeyID,
+ nullInt64(entry.SubscriptionID),
+ entry.BillingType,
+ entry.DeltaUSD,
+ service.UsageBillingEntryStatusPending,
+ )
+ if err != nil {
+ return nil, false, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ if rows.Next() {
+ created, scanErr := scanUsageBillingEntry(rows)
+ if scanErr != nil {
+ return nil, false, scanErr
+ }
+ return created, true, nil
+ }
+ if err = rows.Err(); err != nil {
+ return nil, false, err
+ }
+
+ existing, err := r.GetUsageBillingEntryByUsageLogID(ctx, entry.UsageLogID)
+ if err != nil {
+ return nil, false, err
+ }
+ return existing, false, nil
+}
+
+func (r *usageLogRepository) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error {
+ query := `
+ UPDATE billing_usage_entries
+ SET
+ applied = TRUE,
+ status = $2,
+ last_error = NULL,
+ next_retry_at = NOW(),
+ updated_at = NOW()
+ WHERE id = $1
+ `
+ res, err := r.usageSQLExecutor(ctx).ExecContext(ctx, query, entryID, service.UsageBillingEntryStatusApplied)
+ if err != nil {
+ return err
+ }
+ affected, err := res.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return service.ErrUsageBillingEntryNotFound
+ }
+ return nil
+}
+
+func (r *usageLogRepository) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error {
+ query := `
+ UPDATE billing_usage_entries
+ SET
+ applied = FALSE,
+ status = $2,
+ next_retry_at = $3,
+ last_error = $4,
+ updated_at = NOW()
+ WHERE id = $1
+ `
+ res, err := r.usageSQLExecutor(ctx).ExecContext(
+ ctx,
+ query,
+ entryID,
+ service.UsageBillingEntryStatusPending,
+ nextRetryAt,
+ nullString(&lastError),
+ )
+ if err != nil {
+ return err
+ }
+ affected, err := res.RowsAffected()
+ if err != nil {
+ return err
+ }
+ if affected == 0 {
+ return service.ErrUsageBillingEntryNotFound
+ }
+ return nil
+}
+
+func (r *usageLogRepository) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]service.UsageBillingEntry, error) {
+ if limit <= 0 {
+ return nil, nil
+ }
+ staleAt := time.Now().Add(-processingStaleAfter)
+ query := `
+ WITH candidates AS (
+ SELECT id
+ FROM billing_usage_entries
+ WHERE applied = FALSE
+ AND (
+ (status = $1 AND next_retry_at <= NOW())
+ OR (status = $2 AND updated_at <= $3)
+ )
+ ORDER BY id
+ LIMIT $4
+ FOR UPDATE SKIP LOCKED
+ )
+ UPDATE billing_usage_entries b
+ SET
+ status = $2,
+ attempt_count = b.attempt_count + 1,
+ updated_at = NOW(),
+ last_error = NULL
+ FROM candidates c
+ WHERE b.id = c.id
+ RETURNING
+ b.id,
+ b.usage_log_id,
+ b.user_id,
+ b.api_key_id,
+ b.subscription_id,
+ b.billing_type,
+ b.applied,
+ b.delta_usd,
+ b.status,
+ b.attempt_count,
+ b.next_retry_at,
+ b.updated_at,
+ b.created_at,
+ b.last_error
+ `
+
+ rows, err := r.usageSQLExecutor(ctx).QueryContext(
+ ctx,
+ query,
+ service.UsageBillingEntryStatusPending,
+ service.UsageBillingEntryStatusProcessing,
+ staleAt,
+ limit,
+ )
+ if err != nil {
+ return nil, err
+ }
+ defer func() { _ = rows.Close() }()
+
+ entries := make([]service.UsageBillingEntry, 0, limit)
+ for rows.Next() {
+ item, scanErr := scanUsageBillingEntry(rows)
+ if scanErr != nil {
+ return nil, scanErr
+ }
+ entries = append(entries, *item)
+ }
+ if err := rows.Err(); err != nil {
+ return nil, err
+ }
+ return entries, nil
+}
+
func (r *usageLogRepository) GetByID(ctx context.Context, id int64) (log *service.UsageLog, err error) {
query := "SELECT " + usageLogSelectColumns + " FROM usage_logs WHERE id = $1"
rows, err := r.sql.QueryContext(ctx, query, id)
@@ -2511,6 +2780,52 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e
return log, nil
}
+func scanUsageBillingEntry(scanner interface{ Scan(...any) error }) (*service.UsageBillingEntry, error) {
+ var (
+ subscriptionID sql.NullInt64
+ nextRetryAt time.Time
+ updatedAt time.Time
+ createdAt time.Time
+ lastError sql.NullString
+ status int16
+ entry service.UsageBillingEntry
+ )
+
+ if err := scanner.Scan(
+ &entry.ID,
+ &entry.UsageLogID,
+ &entry.UserID,
+ &entry.APIKeyID,
+ &subscriptionID,
+ &entry.BillingType,
+ &entry.Applied,
+ &entry.DeltaUSD,
+ &status,
+ &entry.AttemptCount,
+ &nextRetryAt,
+ &updatedAt,
+ &createdAt,
+ &lastError,
+ ); err != nil {
+ return nil, err
+ }
+
+ if subscriptionID.Valid {
+ v := subscriptionID.Int64
+ entry.SubscriptionID = &v
+ }
+ entry.Status = service.UsageBillingEntryStatus(status)
+ entry.NextRetryAt = nextRetryAt
+ entry.UpdatedAt = updatedAt
+ entry.CreatedAt = createdAt
+ if lastError.Valid {
+ msg := lastError.String
+ entry.LastError = &msg
+ }
+
+ return &entry, nil
+}
+
func scanTrendRows(rows *sql.Rows) ([]TrendDataPoint, error) {
results := make([]TrendDataPoint, 0)
for rows.Next() {
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index a8845d9b2..c688b081e 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -626,7 +626,7 @@ func newContractDeps(t *testing.T) *contractDeps {
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil)
- adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
+ adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil)
jwtAuth := func(c *gin.Context) {
c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{
@@ -1573,10 +1573,6 @@ func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTi
return nil, errors.New("not implemented")
}
-func (r *stubUsageLogRepo) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
- return nil, errors.New("not implemented")
-}
-
func (r *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
return nil, errors.New("not implemented")
}
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index c36c36a0a..ba9956a6f 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -26,6 +26,9 @@ func RegisterAdminRoutes(
// 分组管理
registerGroupRoutes(admin, h)
+ // API Key 管理
+ registerAPIKeyRoutes(admin, h)
+
// 账号管理
registerAccountRoutes(admin, h)
@@ -227,6 +230,13 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
+func registerAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ apiKeys := admin.Group("/api-keys")
+ {
+ apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup)
+ }
+}
+
func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts := admin.Group("/accounts")
{
@@ -386,6 +396,18 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
// 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
+ // 批量编辑模板库(服务端共享)
+ adminSettings.GET("/bulk-edit-templates", h.Admin.Setting.ListBulkEditTemplates)
+ adminSettings.POST("/bulk-edit-templates", h.Admin.Setting.UpsertBulkEditTemplate)
+ adminSettings.DELETE("/bulk-edit-templates/:template_id", h.Admin.Setting.DeleteBulkEditTemplate)
+ adminSettings.GET(
+ "/bulk-edit-templates/:template_id/versions",
+ h.Admin.Setting.ListBulkEditTemplateVersions,
+ )
+ adminSettings.POST(
+ "/bulk-edit-templates/:template_id/rollback",
+ h.Admin.Setting.RollbackBulkEditTemplate,
+ )
// Sora S3 存储配置
adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings)
adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings)
@@ -398,29 +420,6 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
-func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
- dataManagement := admin.Group("/data-management")
- {
- dataManagement.GET("/agent/health", h.Admin.DataManagement.GetAgentHealth)
- dataManagement.GET("/config", h.Admin.DataManagement.GetConfig)
- dataManagement.PUT("/config", h.Admin.DataManagement.UpdateConfig)
- dataManagement.GET("/sources/:source_type/profiles", h.Admin.DataManagement.ListSourceProfiles)
- dataManagement.POST("/sources/:source_type/profiles", h.Admin.DataManagement.CreateSourceProfile)
- dataManagement.PUT("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.UpdateSourceProfile)
- dataManagement.DELETE("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.DeleteSourceProfile)
- dataManagement.POST("/sources/:source_type/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveSourceProfile)
- dataManagement.POST("/s3/test", h.Admin.DataManagement.TestS3)
- dataManagement.GET("/s3/profiles", h.Admin.DataManagement.ListS3Profiles)
- dataManagement.POST("/s3/profiles", h.Admin.DataManagement.CreateS3Profile)
- dataManagement.PUT("/s3/profiles/:profile_id", h.Admin.DataManagement.UpdateS3Profile)
- dataManagement.DELETE("/s3/profiles/:profile_id", h.Admin.DataManagement.DeleteS3Profile)
- dataManagement.POST("/s3/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveS3Profile)
- dataManagement.POST("/backups", h.Admin.DataManagement.CreateBackupJob)
- dataManagement.GET("/backups", h.Admin.DataManagement.ListBackupJobs)
- dataManagement.GET("/backups/:job_id", h.Admin.DataManagement.GetBackupJob)
- }
-}
-
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
system := admin.Group("/system")
{
diff --git a/backend/internal/server/routes/admin_routes_test.go b/backend/internal/server/routes/admin_routes_test.go
new file mode 100644
index 000000000..2f365e535
--- /dev/null
+++ b/backend/internal/server/routes/admin_routes_test.go
@@ -0,0 +1,34 @@
+package routes
+
+import (
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/handler"
+ "github.com/Wei-Shaw/sub2api/internal/server/middleware"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func TestRegisterAdminRoutes_RegistersAPIKeyGroupUpdateRoute(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ v1 := router.Group("/api/v1")
+ handlers := &handler.Handlers{
+ Admin: &handler.AdminHandlers{},
+ }
+ adminAuth := middleware.AdminAuthMiddleware(func(c *gin.Context) { c.Next() })
+
+ require.NotPanics(t, func() {
+ RegisterAdminRoutes(v1, handlers, adminAuth)
+ })
+ require.True(t, hasRoute(router, "PUT", "/api/v1/admin/api-keys/:id"))
+}
+
+func hasRoute(router *gin.Engine, method, path string) bool {
+ for _, route := range router.Routes() {
+ if route.Method == method && route.Path == path {
+ return true
+ }
+ }
+ return false
+}
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index c76c817e7..90c5026dc 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -855,16 +855,18 @@ const (
OpenAIWSIngressModeOff = "off"
OpenAIWSIngressModeShared = "shared"
OpenAIWSIngressModeDedicated = "dedicated"
+ OpenAIWSIngressModeCtxPool = "ctx_pool"
)
func normalizeOpenAIWSIngressMode(mode string) string {
switch strings.ToLower(strings.TrimSpace(mode)) {
case OpenAIWSIngressModeOff:
return OpenAIWSIngressModeOff
- case OpenAIWSIngressModeShared:
- return OpenAIWSIngressModeShared
- case OpenAIWSIngressModeDedicated:
- return OpenAIWSIngressModeDedicated
+ case OpenAIWSIngressModeCtxPool:
+ return OpenAIWSIngressModeCtxPool
+ case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
+ // Deprecated: shared/dedicated 已废弃,平滑迁移到 ctx_pool
+ return OpenAIWSIngressModeCtxPool
default:
return ""
}
@@ -874,16 +876,16 @@ func normalizeOpenAIWSIngressDefaultMode(mode string) string {
if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" {
return normalized
}
- return OpenAIWSIngressModeShared
+ return OpenAIWSIngressModeOff
}
-// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。
+// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/ctx_pool)。
//
// 优先级:
// 1. 分类型 mode 新字段(string)
// 2. 分类型 enabled 旧字段(bool)
// 3. 兼容 enabled 旧字段(bool)
-// 4. defaultMode(非法时回退 shared)
+// 4. defaultMode(非法时回退 off)
func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string {
resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode)
if a == nil || !a.IsOpenAI() {
@@ -918,7 +920,8 @@ func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) stri
return "", false
}
if enabled {
- return OpenAIWSIngressModeShared, true
+ // 兼容旧 enabled 字段:开启时至少落到 ctx_pool。
+ return OpenAIWSIngressModeCtxPool, true
}
return OpenAIWSIngressModeOff, true
}
@@ -1137,80 +1140,6 @@ func (a *Account) GetSessionIdleTimeoutMinutes() int {
return 5
}
-// GetBaseRPM 获取基础 RPM 限制
-// 返回 0 表示未启用(负数视为无效配置,按 0 处理)
-func (a *Account) GetBaseRPM() int {
- if a.Extra == nil {
- return 0
- }
- if v, ok := a.Extra["base_rpm"]; ok {
- val := parseExtraInt(v)
- if val > 0 {
- return val
- }
- }
- return 0
-}
-
-// GetRPMStrategy 获取 RPM 策略
-// "tiered" = 三区模型(默认), "sticky_exempt" = 粘性豁免
-func (a *Account) GetRPMStrategy() string {
- if a.Extra == nil {
- return "tiered"
- }
- if v, ok := a.Extra["rpm_strategy"]; ok {
- if s, ok := v.(string); ok && s == "sticky_exempt" {
- return "sticky_exempt"
- }
- }
- return "tiered"
-}
-
-// GetRPMStickyBuffer 获取 RPM 粘性缓冲数量
-// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1)
-func (a *Account) GetRPMStickyBuffer() int {
- if a.Extra == nil {
- return 0
- }
- if v, ok := a.Extra["rpm_sticky_buffer"]; ok {
- val := parseExtraInt(v)
- if val > 0 {
- return val
- }
- }
- base := a.GetBaseRPM()
- buffer := base / 5
- if buffer < 1 && base > 0 {
- buffer = 1
- }
- return buffer
-}
-
-// CheckRPMSchedulability 根据当前 RPM 计数检查调度状态
-// 复用 WindowCostSchedulability 三态:Schedulable / StickyOnly / NotSchedulable
-func (a *Account) CheckRPMSchedulability(currentRPM int) WindowCostSchedulability {
- baseRPM := a.GetBaseRPM()
- if baseRPM <= 0 {
- return WindowCostSchedulable
- }
-
- if currentRPM < baseRPM {
- return WindowCostSchedulable
- }
-
- strategy := a.GetRPMStrategy()
- if strategy == "sticky_exempt" {
- return WindowCostStickyOnly // 粘性豁免无红区
- }
-
- // tiered: 黄区 + 红区
- buffer := a.GetRPMStickyBuffer()
- if currentRPM < baseRPM+buffer {
- return WindowCostStickyOnly
- }
- return WindowCostNotSchedulable
-}
-
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)
@@ -1274,12 +1203,6 @@ func parseExtraFloat64(value any) float64 {
}
// parseExtraInt 从 extra 字段解析 int 值
-// ParseExtraInt 从 extra 字段的 any 值解析为 int。
-// 支持 int, int64, float64, json.Number, string 类型,无法解析时返回 0。
-func ParseExtraInt(value any) int {
- return parseExtraInt(value)
-}
-
func parseExtraInt(value any) int {
switch v := value.(type) {
case int:
diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go
index a85c68ec5..ea4f08990 100644
--- a/backend/internal/service/account_openai_passthrough_test.go
+++ b/backend/internal/service/account_openai_passthrough_test.go
@@ -206,30 +206,52 @@ func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
}
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
- t.Run("default fallback to shared", func(t *testing.T) {
+ t.Run("default fallback to off", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
- require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
- require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
+ require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
+ require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
})
- t.Run("oauth mode field has highest priority", func(t *testing.T) {
+ t.Run("unsupported mode field falls back to enabled flag", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
- "openai_oauth_responses_websockets_v2_enabled": false,
+ "openai_oauth_responses_websockets_v2_enabled": true,
"responses_websockets_v2_enabled": false,
},
}
- require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
+ require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
+ })
+
+ t.Run("ctx_pool mode field is recognized", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
+ },
+ }
+ require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
+ })
+
+ t.Run("legacy enabled maps to ctx_pool when default is off", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Extra: map[string]any{
+ "responses_websockets_v2_enabled": true,
+ },
+ }
+ require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
})
- t.Run("legacy enabled maps to shared", func(t *testing.T) {
+ t.Run("legacy enabled ignores unsupported default and maps to ctx_pool", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
@@ -237,7 +259,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true,
},
}
- require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
+ require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
})
t.Run("legacy disabled maps to off", func(t *testing.T) {
@@ -249,7 +271,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"responses_websockets_v2_enabled": true,
},
}
- require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared))
+ require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
t.Run("non openai always off", func(t *testing.T) {
@@ -260,7 +282,7 @@ func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
- require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
+ require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
}
diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go
index a37071842..22b2d93ac 100644
--- a/backend/internal/service/account_service.go
+++ b/backend/internal/service/account_service.go
@@ -71,15 +71,16 @@ type AccountRepository interface {
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
// Nil pointers mean "do not change".
type AccountBulkUpdate struct {
- Name *string
- ProxyID *int64
- Concurrency *int
- Priority *int
- RateMultiplier *float64
- Status *string
- Schedulable *bool
- Credentials map[string]any
- Extra map[string]any
+ Name *string
+ ProxyID *int64
+ Concurrency *int
+ Priority *int
+ RateMultiplier *float64
+ Status *string
+ Schedulable *bool
+ AutoPauseOnExpired *bool
+ Credentials map[string]any
+ Extra map[string]any
}
// CreateAccountRequest 创建账号请求
diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go
index 6dee6c133..13a138567 100644
--- a/backend/internal/service/account_usage_service.go
+++ b/backend/internal/service/account_usage_service.go
@@ -37,7 +37,6 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error)
- GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error)
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 7e6982d36..9a6f11da1 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -9,7 +9,6 @@ import (
"strings"
"time"
- dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -224,17 +223,18 @@ type UpdateAccountInput struct {
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
type BulkUpdateAccountsInput struct {
- AccountIDs []int64
- Name string
- ProxyID *int64
- Concurrency *int
- Priority *int
- RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
- Status string
- Schedulable *bool
- GroupIDs *[]int64
- Credentials map[string]any
- Extra map[string]any
+ AccountIDs []int64
+ Name string
+ ProxyID *int64
+ Concurrency *int
+ Priority *int
+ RateMultiplier *float64 // 账号计费倍率(>=0,允许 0)
+ Status string
+ Schedulable *bool
+ AutoPauseOnExpired *bool
+ GroupIDs *[]int64
+ Credentials map[string]any
+ Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
@@ -432,6 +432,14 @@ type groupExistenceBatchReader interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
+type userGroupRateBatchReader interface {
+ GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error)
+}
+
+type groupExistenceBatchReader interface {
+ ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
+}
+
// NewAdminService creates a new AdminService
func NewAdminService(
userRepo UserRepository,
@@ -1543,6 +1551,70 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
return updated, nil
}
+var openAIBulkScopedExtraKeys = map[string]struct{}{
+ "openai_passthrough": {},
+ "openai_oauth_passthrough": {},
+ "openai_oauth_responses_websockets_v2_mode": {},
+ "openai_oauth_responses_websockets_v2_enabled": {},
+ "openai_apikey_responses_websockets_v2_mode": {},
+ "openai_apikey_responses_websockets_v2_enabled": {},
+ "codex_cli_only": {},
+}
+
+func hasOpenAIBulkScopedExtraField(extra map[string]any) bool {
+ if len(extra) == 0 {
+ return false
+ }
+ for key := range extra {
+ if _, ok := openAIBulkScopedExtraKeys[key]; ok {
+ return true
+ }
+ }
+ return false
+}
+
+func validateOpenAIBulkScopedAccounts(accountsByID map[int64]*Account, accountIDs []int64) error {
+ var expectedType string
+
+ for _, accountID := range accountIDs {
+ account := accountsByID[accountID]
+ if account == nil {
+ return infraerrors.BadRequest(
+ "BULK_OPENAI_SCOPE_ACCOUNT_MISSING",
+ fmt.Sprintf("account %d not found for OpenAI scoped bulk update", accountID),
+ )
+ }
+
+ if account.Platform != PlatformOpenAI {
+ return infraerrors.BadRequest(
+ "BULK_OPENAI_SCOPE_PLATFORM_MISMATCH",
+ "OpenAI scoped bulk fields require all selected accounts to be OpenAI",
+ )
+ }
+
+ if account.Type != AccountTypeOAuth && account.Type != AccountTypeAPIKey {
+ return infraerrors.BadRequest(
+ "BULK_OPENAI_SCOPE_TYPE_UNSUPPORTED",
+ "OpenAI scoped bulk fields only support oauth or apikey account types",
+ )
+ }
+
+ if expectedType == "" {
+ expectedType = account.Type
+ continue
+ }
+
+ if account.Type != expectedType {
+ return infraerrors.BadRequest(
+ "BULK_OPENAI_SCOPE_TYPE_MISMATCH",
+ "OpenAI scoped bulk fields require all selected accounts to have the same type",
+ )
+ }
+ }
+
+ return nil
+}
+
// BulkUpdateAccounts updates multiple accounts in one request.
// It merges credentials/extra keys instead of overwriting the whole object.
func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) {
@@ -1562,32 +1634,41 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
}
needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck
+ needOpenAIScopeCheck := hasOpenAIBulkScopedExtraField(input.Extra)
+ needAccountSnapshot := needMixedChannelCheck || needOpenAIScopeCheck
- // 预加载账号平台信息(混合渠道检查需要)。
+ accountsByID := map[int64]*Account{}
platformByID := map[int64]string{}
- if needMixedChannelCheck {
+ accountByID := map[int64]*Account{}
+ groupAccountsByID := map[int64][]Account{}
+ groupNameByID := map[int64]string{}
+ if needAccountSnapshot {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
return nil, err
- }
- for _, account := range accounts {
- if account != nil {
- platformByID[account.ID] = account.Platform
+ } else {
+ for _, account := range accounts {
+ if account != nil {
+ accountByID[account.ID] = account
+ platformByID[account.ID] = account.Platform
+ }
}
}
}
- // 预检查混合渠道风险:在任何写操作之前,若发现风险立即返回错误。
+ if needOpenAIScopeCheck {
+ if err := validateOpenAIBulkScopedAccounts(accountByID, input.AccountIDs); err != nil {
+ return nil, err
+ }
+ }
+
if needMixedChannelCheck {
- for _, accountID := range input.AccountIDs {
- platform := platformByID[accountID]
- if platform == "" {
- continue
- }
- if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
- return nil, err
- }
+ loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs)
+ if err != nil {
+ return nil, err
}
+ groupAccountsByID = loadedAccounts
+ groupNameByID = loadedNames
}
if input.RateMultiplier != nil {
@@ -1622,6 +1703,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
if input.Schedulable != nil {
repoUpdates.Schedulable = input.Schedulable
}
+ if input.AutoPauseOnExpired != nil {
+ repoUpdates.AutoPauseOnExpired = input.AutoPauseOnExpired
+ }
// Run bulk update for column/jsonb fields first.
if _, err := s.accountRepo.BulkUpdate(ctx, input.AccountIDs, repoUpdates); err != nil {
@@ -1631,8 +1715,34 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
// Handle group bindings per account (requires individual operations).
for _, accountID := range input.AccountIDs {
entry := BulkUpdateAccountResult{AccountID: accountID}
+ platform := ""
if input.GroupIDs != nil {
+ // 检查混合渠道风险(除非用户已确认)
+ if !input.SkipMixedChannelCheck {
+ platform = platformByID[accountID]
+ if platform == "" {
+ account, err := s.accountRepo.GetByID(ctx, accountID)
+ if err != nil {
+ entry.Success = false
+ entry.Error = err.Error()
+ result.Failed++
+ result.FailedIDs = append(result.FailedIDs, accountID)
+ result.Results = append(result.Results, entry)
+ continue
+ }
+ platform = account.Platform
+ }
+ if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil {
+ entry.Success = false
+ entry.Error = err.Error()
+ result.Failed++
+ result.FailedIDs = append(result.FailedIDs, accountID)
+ result.Results = append(result.Results, entry)
+ continue
+ }
+ }
+
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
@@ -1641,6 +1751,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
result.Results = append(result.Results, entry)
continue
}
+ if !input.SkipMixedChannelCheck && platform != "" {
+ updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform)
+ }
}
entry.Success = true
@@ -2311,6 +2424,41 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc
return nil
}
+func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) {
+ accountsByGroup := make(map[int64][]Account)
+ groupNameByID := make(map[int64]string)
+ if len(groupIDs) == 0 {
+ return accountsByGroup, groupNameByID, nil
+ }
+
+ seen := make(map[int64]struct{}, len(groupIDs))
+ for _, groupID := range groupIDs {
+ if groupID <= 0 {
+ continue
+ }
+ if _, ok := seen[groupID]; ok {
+ continue
+ }
+ seen[groupID] = struct{}{}
+
+ accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
+ if err != nil {
+ return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err)
+ }
+ accountsByGroup[groupID] = accounts
+
+ group, err := s.groupRepo.GetByID(ctx, groupID)
+ if err != nil {
+ continue
+ }
+ if group != nil {
+ groupNameByID[groupID] = group.Name
+ }
+ }
+
+ return accountsByGroup, groupNameByID, nil
+}
+
func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
@@ -2340,6 +2488,71 @@ func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs [
return nil
}
+func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error {
+ currentPlatform := getAccountPlatform(currentAccountPlatform)
+ if currentPlatform == "" {
+ return nil
+ }
+
+ for _, groupID := range groupIDs {
+ accounts := accountsByGroup[groupID]
+ for _, account := range accounts {
+ if currentAccountID > 0 && account.ID == currentAccountID {
+ continue
+ }
+
+ otherPlatform := getAccountPlatform(account.Platform)
+ if otherPlatform == "" {
+ continue
+ }
+
+ if currentPlatform != otherPlatform {
+ groupName := fmt.Sprintf("Group %d", groupID)
+ if name := strings.TrimSpace(groupNameByID[groupID]); name != "" {
+ groupName = name
+ }
+
+ return &MixedChannelError{
+ GroupID: groupID,
+ GroupName: groupName,
+ CurrentPlatform: currentPlatform,
+ OtherPlatform: otherPlatform,
+ }
+ }
+ }
+ }
+
+ return nil
+}
+
+func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) {
+ if len(groupIDs) == 0 || accountID <= 0 || platform == "" {
+ return
+ }
+ for _, groupID := range groupIDs {
+ if groupID <= 0 {
+ continue
+ }
+ accounts := accountsByGroup[groupID]
+ found := false
+ for i := range accounts {
+ if accounts[i].ID != accountID {
+ continue
+ }
+ accounts[i].Platform = platform
+ found = true
+ break
+ }
+ if !found {
+ accounts = append(accounts, Account{
+ ID: accountID,
+ Platform: platform,
+ })
+ }
+ accountsByGroup[groupID] = accounts
+ }
+}
+
// CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform.
func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs)
diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go
index 4845d87c1..3fb14cae7 100644
--- a/backend/internal/service/admin_service_bulk_update_test.go
+++ b/backend/internal/service/admin_service_bulk_update_test.go
@@ -12,23 +12,25 @@ import (
type accountRepoStubForBulkUpdate struct {
accountRepoStub
- bulkUpdateErr error
- bulkUpdateIDs []int64
- bindGroupErrByID map[int64]error
- bindGroupsCalls []int64
- getByIDsAccounts []*Account
- getByIDsErr error
- getByIDsCalled bool
- getByIDsIDs []int64
- getByIDAccounts map[int64]*Account
- getByIDErrByID map[int64]error
- getByIDCalled []int64
- listByGroupData map[int64][]Account
- listByGroupErr map[int64]error
+ bulkUpdateErr error
+ bulkUpdateIDs []int64
+ bulkUpdatePayload AccountBulkUpdate
+ bindGroupErrByID map[int64]error
+ bindGroupsCalls []int64
+ getByIDsAccounts []*Account
+ getByIDsErr error
+ getByIDsCalled bool
+ getByIDsIDs []int64
+ getByIDAccounts map[int64]*Account
+ getByIDErrByID map[int64]error
+ getByIDCalled []int64
+ listByGroupData map[int64][]Account
+ listByGroupErr map[int64]error
}
-func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
+func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
s.bulkUpdateIDs = append([]int64{}, ids...)
+ s.bulkUpdatePayload = updates
if s.bulkUpdateErr != nil {
return 0, s.bulkUpdateErr
}
@@ -139,34 +141,134 @@ func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T)
require.Contains(t, err.Error(), "group repository not configured")
}
-// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies
-// that the global pre-check detects a conflict with existing group members and returns an
-// error before any DB write is performed.
-func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict(t *testing.T) {
+func TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
getByIDsAccounts: []*Account{
- {ID: 1, Platform: PlatformAntigravity},
+ {ID: 1, Platform: PlatformAnthropic},
+ {ID: 2, Platform: PlatformAntigravity},
},
- // Group 10 already contains an Anthropic account.
listByGroupData: map[int64][]Account{
- 10: {{ID: 99, Platform: PlatformAnthropic}},
+ 10: {},
},
}
svc := &adminServiceImpl{
accountRepo: repo,
- groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "target-group"}},
+ groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "目标分组"}},
}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
- AccountIDs: []int64{1},
+ AccountIDs: []int64{1, 2},
GroupIDs: &groupIDs,
}
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.NoError(t, err)
+ require.Equal(t, 1, result.Success)
+ require.Equal(t, 1, result.Failed)
+ require.ElementsMatch(t, []int64{1}, result.SuccessIDs)
+ require.ElementsMatch(t, []int64{2}, result.FailedIDs)
+ require.Len(t, result.Results, 2)
+ require.Contains(t, result.Results[1].Error, "mixed channel")
+ require.Equal(t, []int64{1}, repo.bindGroupsCalls)
+}
+
+func TestAdminService_BulkUpdateAccounts_ForwardsAutoPauseOnExpired(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{}
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ autoPause := true
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{101},
+ AutoPauseOnExpired: &autoPause,
+ }
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.NoError(t, err)
+ require.Equal(t, 1, result.Success)
+ require.NotNil(t, repo.bulkUpdatePayload.AutoPauseOnExpired)
+ require.True(t, *repo.bulkUpdatePayload.AutoPauseOnExpired)
+}
+
+func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsMixedTypes(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{
+ getByIDsAccounts: []*Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
+ {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey},
+ },
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{1, 2},
+ Extra: map[string]any{"openai_passthrough": true},
+ }
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.Nil(t, result)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "same type")
+ require.Empty(t, repo.bulkUpdateIDs)
+ require.True(t, repo.getByIDsCalled)
+}
+
+func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsNonOpenAIPlatform(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{
+ getByIDsAccounts: []*Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
+ {ID: 2, Platform: PlatformAnthropic, Type: AccountTypeOAuth},
+ },
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{1, 2},
+ Extra: map[string]any{"openai_oauth_responses_websockets_v2_mode": "shared"},
+ }
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.Nil(t, result)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "OpenAI")
+ require.Empty(t, repo.bulkUpdateIDs)
+}
+
+func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraAllowsSameTypeOpenAI(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{
+ getByIDsAccounts: []*Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
+ {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
+ },
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{1, 2},
+ Extra: map[string]any{"codex_cli_only": true},
+ }
+
+ result, err := svc.BulkUpdateAccounts(context.Background(), input)
+ require.NoError(t, err)
+ require.Equal(t, 2, result.Success)
+ require.ElementsMatch(t, []int64{1, 2}, repo.bulkUpdateIDs)
+}
+
+func TestAdminService_BulkUpdateAccounts_OpenAIScopedExtraRejectsMissingAccount(t *testing.T) {
+ repo := &accountRepoStubForBulkUpdate{
+ getByIDsAccounts: []*Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth},
+ },
+ }
+ svc := &adminServiceImpl{accountRepo: repo}
+
+ input := &BulkUpdateAccountsInput{
+ AccountIDs: []int64{1, 2},
+ Extra: map[string]any{"openai_passthrough": true},
+ }
+
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.Nil(t, result)
require.Error(t, err)
- require.Contains(t, err.Error(), "mixed channel")
- // No BindGroups should have been called since the check runs before any write.
- require.Empty(t, repo.bindGroupsCalls)
+ require.Contains(t, err.Error(), "not found")
+ require.Empty(t, repo.bulkUpdateIDs)
}
diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go
index 36cb1e065..7dd9edca8 100644
--- a/backend/internal/service/auth_service_turnstile_register_test.go
+++ b/backend/internal/service/auth_service_turnstile_register_test.go
@@ -52,7 +52,6 @@ func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier
turnstileService,
nil, // emailQueueService
nil, // promoService
- nil, // defaultSubAssigner
)
}
diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go
index 1a76f5f69..eea4b505c 100644
--- a/backend/internal/service/billing_cache_service.go
+++ b/backend/internal/service/billing_cache_service.go
@@ -2,6 +2,7 @@ package service
import (
"context"
+ "errors"
"fmt"
"strconv"
"sync"
@@ -184,7 +185,7 @@ func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
}
case cacheWriteDeductBalance:
if s.cache != nil {
- if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
+ if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil && !errors.Is(err, ErrBalanceCacheNotFound) {
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err)
}
}
@@ -318,7 +319,13 @@ func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int
if s.cache == nil {
return nil
}
- return s.cache.DeductUserBalance(ctx, userID, amount)
+ err := s.cache.DeductUserBalance(ctx, userID, amount)
+ if errors.Is(err, ErrBalanceCacheNotFound) {
+ // 缓存 key 不存在(已过期),无法原子扣减,不阻塞主流程。
+ // 下次 GetUserBalance 将从数据库回源重建缓存。
+ return nil
+ }
+ return err
}
// QueueDeductBalance 异步扣减余额缓存
diff --git a/backend/internal/service/claude_token_provider.go b/backend/internal/service/claude_token_provider.go
index f6cab204d..996c829cd 100644
--- a/backend/internal/service/claude_token_provider.go
+++ b/backend/internal/service/claude_token_provider.go
@@ -15,6 +15,20 @@ const (
claudeLockWaitTime = 200 * time.Millisecond
)
+func waitClaudeLockRetry(ctx context.Context, wait time.Duration) error {
+ if wait <= 0 {
+ return nil
+ }
+ timer := time.NewTimer(wait)
+ defer timer.Stop()
+ select {
+ case <-timer.C:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
// ClaudeTokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type ClaudeTokenCache = GeminiTokenCache
@@ -168,7 +182,9 @@ func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
- time.Sleep(claudeLockWaitTime)
+ if waitErr := waitClaudeLockRetry(ctx, claudeLockWaitTime); waitErr != nil {
+ return "", waitErr
+ }
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
diff --git a/backend/internal/service/claude_token_provider_test.go b/backend/internal/service/claude_token_provider_test.go
index 3e21f6f4a..09f3a31e8 100644
--- a/backend/internal/service/claude_token_provider_test.go
+++ b/backend/internal/service/claude_token_provider_test.go
@@ -800,6 +800,34 @@ func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
require.NotEmpty(t, token)
}
+func TestClaudeTokenProvider_Real_LockFailedWait_ContextCanceled(t *testing.T) {
+ cache := newClaudeTokenCacheStub()
+ cache.lockAcquired = false
+
+ expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
+ account := &Account{
+ ID: 3001,
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "access_token": "fallback-token",
+ "expires_at": expiresAt,
+ },
+ }
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+
+ provider := NewClaudeTokenProvider(nil, cache, nil)
+ start := time.Now()
+ token, err := provider.GetAccessToken(ctx, account)
+ elapsed := time.Since(start)
+
+ require.ErrorIs(t, err, context.Canceled)
+ require.Empty(t, token)
+ require.Less(t, elapsed, claudeLockWaitTime/2, "context canceled should short-circuit lock wait")
+}
+
func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
diff --git a/backend/internal/service/data_management_service.go b/backend/internal/service/data_management_service.go
index b525c0fae..83e939f45 100644
--- a/backend/internal/service/data_management_service.go
+++ b/backend/internal/service/data_management_service.go
@@ -55,7 +55,8 @@ type DataManagementAgentInfo struct {
}
type DataManagementService struct {
- socketPath string
+ socketPath string
+ dialTimeout time.Duration
}
func NewDataManagementService() *DataManagementService {
@@ -63,13 +64,16 @@ func NewDataManagementService() *DataManagementService {
}
func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService {
- _ = dialTimeout
path := strings.TrimSpace(socketPath)
if path == "" {
path = DefaultDataManagementAgentSocketPath
}
+ if dialTimeout <= 0 {
+ dialTimeout = 500 * time.Millisecond
+ }
return &DataManagementService{
- socketPath: path,
+ socketPath: path,
+ dialTimeout: dialTimeout,
}
}
diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go
index b304bc9fb..27441fbbd 100644
--- a/backend/internal/service/domain_constants.go
+++ b/backend/internal/service/domain_constants.go
@@ -127,6 +127,9 @@ const (
// Gemini 配额策略(JSON)
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
+ // Bulk edit template library(JSON)
+ SettingKeyBulkEditTemplateLibrary = "bulk_edit_template_library_v1"
+
// Model fallback settings
SettingKeyEnableModelFallback = "enable_model_fallback"
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
@@ -193,13 +196,6 @@ const (
// =========================
SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节)
-
- // =========================
- // Claude Code Version Check
- // =========================
-
- // SettingKeyMinClaudeCodeVersion 最低 Claude Code 版本号要求 (semver, 如 "2.1.0",空值=不检查)
- SettingKeyMinClaudeCodeVersion = "min_claude_code_version"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go
index 3323f8685..be15fc1b8 100644
--- a/backend/internal/service/gateway_service.go
+++ b/backend/internal/service/gateway_service.go
@@ -520,7 +520,6 @@ type GatewayService struct {
concurrencyService *ConcurrencyService
claudeTokenProvider *ClaudeTokenProvider
sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken)
- rpmCache RPMCache // RPM 计数缓存(仅 Anthropic OAuth/SetupToken)
userGroupRateCache *gocache.Cache
userGroupRateSF singleflight.Group
modelsListCache *gocache.Cache
@@ -550,7 +549,6 @@ func NewGatewayService(
deferredService *DeferredService,
claudeTokenProvider *ClaudeTokenProvider,
sessionLimitCache SessionLimitCache,
- rpmCache RPMCache,
digestStore *DigestSessionStore,
) *GatewayService {
userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg)
@@ -576,7 +574,6 @@ func NewGatewayService(
deferredService: deferredService,
claudeTokenProvider: claudeTokenProvider,
sessionLimitCache: sessionLimitCache,
- rpmCache: rpmCache,
userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute),
modelsListCache: gocache.New(modelsListTTL, time.Minute),
modelsListCacheTTL: modelsListTTL,
@@ -1157,7 +1154,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
return nil, errors.New("no available accounts")
}
ctx = s.withWindowCostPrefetch(ctx, accounts)
- ctx = s.withRPMPrefetch(ctx, accounts)
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
@@ -1233,10 +1229,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
filteredWindowCost++
continue
}
- // RPM 检查(非粘性会话路径)
- if !s.isAccountSchedulableForRPM(ctx, account, false) {
- continue
- }
routingCandidates = append(routingCandidates, account)
}
@@ -1260,9 +1252,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
- s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
-
- s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
+ s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
@@ -1416,9 +1406,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) &&
s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) &&
- s.isAccountSchedulableForWindowCost(ctx, account, true) &&
-
- s.isAccountSchedulableForRPM(ctx, account, true) { // 粘性会话窗口费用+RPM 检查
+ s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
// 会话数量限制检查
@@ -1484,10 +1472,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
continue
}
- // RPM 检查(非粘性会话路径)
- if !s.isAccountSchedulableForRPM(ctx, acc, false) {
- continue
- }
candidates = append(candidates, acc)
}
@@ -2171,88 +2155,6 @@ checkSchedulability:
return true
}
-// rpmPrefetchContextKey is the context key for prefetched RPM counts.
-type rpmPrefetchContextKeyType struct{}
-
-var rpmPrefetchContextKey = rpmPrefetchContextKeyType{}
-
-func rpmFromPrefetchContext(ctx context.Context, accountID int64) (int, bool) {
- if v, ok := ctx.Value(rpmPrefetchContextKey).(map[int64]int); ok {
- count, found := v[accountID]
- return count, found
- }
- return 0, false
-}
-
-// withRPMPrefetch 批量预取所有候选账号的 RPM 计数
-func (s *GatewayService) withRPMPrefetch(ctx context.Context, accounts []Account) context.Context {
- if s.rpmCache == nil {
- return ctx
- }
-
- var ids []int64
- for i := range accounts {
- if accounts[i].IsAnthropicOAuthOrSetupToken() && accounts[i].GetBaseRPM() > 0 {
- ids = append(ids, accounts[i].ID)
- }
- }
- if len(ids) == 0 {
- return ctx
- }
-
- counts, err := s.rpmCache.GetRPMBatch(ctx, ids)
- if err != nil {
- return ctx // 失败开放
- }
- return context.WithValue(ctx, rpmPrefetchContextKey, counts)
-}
-
-// isAccountSchedulableForRPM 检查账号是否可根据 RPM 进行调度
-// 仅适用于 Anthropic OAuth/SetupToken 账号
-func (s *GatewayService) isAccountSchedulableForRPM(ctx context.Context, account *Account, isSticky bool) bool {
- if !account.IsAnthropicOAuthOrSetupToken() {
- return true
- }
- baseRPM := account.GetBaseRPM()
- if baseRPM <= 0 {
- return true
- }
-
- // 尝试从预取缓存获取
- var currentRPM int
- if count, ok := rpmFromPrefetchContext(ctx, account.ID); ok {
- currentRPM = count
- } else if s.rpmCache != nil {
- if count, err := s.rpmCache.GetRPM(ctx, account.ID); err == nil {
- currentRPM = count
- }
- // 失败开放:GetRPM 错误时允许调度
- }
-
- schedulability := account.CheckRPMSchedulability(currentRPM)
- switch schedulability {
- case WindowCostSchedulable:
- return true
- case WindowCostStickyOnly:
- return isSticky
- case WindowCostNotSchedulable:
- return false
- }
- return true
-}
-
-// IncrementAccountRPM increments the RPM counter for the given account.
-// 已知 TOCTOU 竞态:调度时读取 RPM 计数与此处递增之间存在时间窗口,
-// 高并发下可能短暂超出 RPM 限制。这是与 WindowCost 一致的 soft-limit
-// 设计权衡——可接受的少量超额优于加锁带来的延迟和复杂度。
-func (s *GatewayService) IncrementAccountRPM(ctx context.Context, accountID int64) error {
- if s.rpmCache == nil {
- return nil
- }
- _, err := s.rpmCache.IncrementRPM(ctx, accountID)
- return err
-}
-
// checkAndRegisterSession 检查并注册会话,用于会话数量限制
// 仅适用于 Anthropic OAuth/SetupToken 账号
// sessionID: 会话标识符(使用粘性会话的 hash)
@@ -2447,7 +2349,7 @@ func sameAccountWithLoadGroup(a, b accountWithLoad) bool {
// shuffleWithinPriorityAndLastUsed 对排序后的 []*Account 切片,按 (Priority, LastUsedAt) 分组后组内随机打乱。
//
// 注意:当 preferOAuth=true 时,需要保证 OAuth 账号在同组内仍然优先,否则会把排序时的偏好打散掉。
-// 因此这里采用"组内分区 + 分区内 shuffle"的方式:
+// 因此这里采用“组内分区 + 分区内 shuffle”的方式:
// - 先把同组账号按 (OAuth / 非 OAuth) 拆成两段,保持 OAuth 段在前;
// - 再分别在各段内随机打散,避免热点。
func shuffleWithinPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
@@ -2587,7 +2489,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
}
@@ -2610,10 +2512,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
accountsLoaded = true
- // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存
- ctx = s.withWindowCostPrefetch(ctx, accounts)
- ctx = s.withRPMPrefetch(ctx, accounts)
-
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
@@ -2641,12 +2539,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
- if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
- continue
- }
- if !s.isAccountSchedulableForRPM(ctx, acc, false) {
- continue
- }
if selected == nil {
selected = acc
continue
@@ -2697,7 +2589,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
return account, nil
}
}
@@ -2718,10 +2610,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
}
- // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1)
- ctx = s.withWindowCostPrefetch(ctx, accounts)
- ctx = s.withRPMPrefetch(ctx, accounts)
-
// 3. 按优先级+最久未用选择(考虑模型支持)
var selected *Account
for i := range accounts {
@@ -2740,12 +2628,6 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
- if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
- continue
- }
- if !s.isAccountSchedulableForRPM(ctx, acc, false) {
- continue
- }
if selected == nil {
selected = acc
continue
@@ -2815,7 +2697,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
if s.debugModelRoutingEnabled() {
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
@@ -2836,10 +2718,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
accountsLoaded = true
- // 提前预取窗口费用+RPM 计数,确保 routing 段内的调度检查调用能命中缓存
- ctx = s.withWindowCostPrefetch(ctx, accounts)
- ctx = s.withRPMPrefetch(ctx, accounts)
-
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
for _, id := range routingAccountIDs {
if id > 0 {
@@ -2871,12 +2749,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
- if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
- continue
- }
- if !s.isAccountSchedulableForRPM(ctx, acc, false) {
- continue
- }
if selected == nil {
selected = acc
continue
@@ -2927,7 +2799,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if clearSticky {
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
}
- if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
+ if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) {
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
return account, nil
}
@@ -2946,10 +2818,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
}
- // 批量预取窗口费用+RPM 计数,避免逐个账号查询(N+1)
- ctx = s.withWindowCostPrefetch(ctx, accounts)
- ctx = s.withRPMPrefetch(ctx, accounts)
-
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
var selected *Account
for i := range accounts {
@@ -2972,12 +2840,6 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) {
continue
}
- if !s.isAccountSchedulableForWindowCost(ctx, acc, false) {
- continue
- }
- if !s.isAccountSchedulableForRPM(ctx, acc, false) {
- continue
- }
if selected == nil {
selected = acc
continue
@@ -5323,7 +5185,7 @@ func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
- // 只对"可能是兼容性差异导致"的 400 允许切换,避免无意义重试。
+ // 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
@@ -6512,7 +6374,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
- logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
+ return fmt.Errorf("create usage log: %w", err)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
@@ -6521,7 +6383,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return nil
}
- shouldBill := inserted || err != nil
+ shouldBill := inserted
// 根据计费类型执行扣费
if isSubscriptionBilling {
@@ -6702,7 +6564,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
- logger.LegacyPrintf("service.gateway", "Create usage log failed: %v", err)
+ return fmt.Errorf("create usage log: %w", err)
}
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
@@ -6711,7 +6573,7 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input *
return nil
}
- shouldBill := inserted || err != nil
+ shouldBill := inserted
// 根据计费类型执行扣费
if isSubscriptionBilling {
diff --git a/backend/internal/service/http_upstream_profile.go b/backend/internal/service/http_upstream_profile.go
new file mode 100644
index 000000000..fef1dc191
--- /dev/null
+++ b/backend/internal/service/http_upstream_profile.go
@@ -0,0 +1,41 @@
+package service
+
+import "context"
+
+// HTTPUpstreamProfile 标识上游 HTTP 请求的协议策略分类。
+type HTTPUpstreamProfile string
+
+const (
+ HTTPUpstreamProfileDefault HTTPUpstreamProfile = ""
+ HTTPUpstreamProfileOpenAI HTTPUpstreamProfile = "openai"
+)
+
+type httpUpstreamProfileContextKey struct{}
+
+// WithHTTPUpstreamProfile 在请求上下文中注入上游协议策略分类。
+func WithHTTPUpstreamProfile(ctx context.Context, profile HTTPUpstreamProfile) context.Context {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ if profile == HTTPUpstreamProfileDefault {
+ return ctx
+ }
+ return context.WithValue(ctx, httpUpstreamProfileContextKey{}, profile)
+}
+
+// HTTPUpstreamProfileFromContext 从请求上下文中解析上游协议策略分类。
+func HTTPUpstreamProfileFromContext(ctx context.Context) HTTPUpstreamProfile {
+ if ctx == nil {
+ return HTTPUpstreamProfileDefault
+ }
+ profile, ok := ctx.Value(httpUpstreamProfileContextKey{}).(HTTPUpstreamProfile)
+ if !ok {
+ return HTTPUpstreamProfileDefault
+ }
+ switch profile {
+ case HTTPUpstreamProfileOpenAI:
+ return profile
+ default:
+ return HTTPUpstreamProfileDefault
+ }
+}
diff --git a/backend/internal/service/http_upstream_profile_test.go b/backend/internal/service/http_upstream_profile_test.go
new file mode 100644
index 000000000..446bf93c4
--- /dev/null
+++ b/backend/internal/service/http_upstream_profile_test.go
@@ -0,0 +1,39 @@
+package service
+
+import (
+ "context"
+ "testing"
+)
+
+func TestWithHTTPUpstreamProfile_DefaultKeepsContext(t *testing.T) {
+ ctx := context.Background()
+ got := WithHTTPUpstreamProfile(ctx, HTTPUpstreamProfileDefault)
+ if got != ctx {
+ t.Fatalf("expected default profile to keep original context")
+ }
+}
+
+func TestWithHTTPUpstreamProfile_NilContextCreatesBackground(t *testing.T) {
+ ctx := WithHTTPUpstreamProfile(nil, HTTPUpstreamProfileOpenAI)
+ if ctx == nil {
+ t.Fatalf("expected non-nil context")
+ }
+ if profile := HTTPUpstreamProfileFromContext(ctx); profile != HTTPUpstreamProfileOpenAI {
+ t.Fatalf("expected profile %q, got %q", HTTPUpstreamProfileOpenAI, profile)
+ }
+}
+
+func TestHTTPUpstreamProfileFromContext_UnknownValueFallsBackDefault(t *testing.T) {
+ type badKey struct{}
+ ctx := context.WithValue(context.Background(), httpUpstreamProfileContextKey{}, HTTPUpstreamProfile("unknown"))
+ ctx = context.WithValue(ctx, badKey{}, "x")
+ if profile := HTTPUpstreamProfileFromContext(ctx); profile != HTTPUpstreamProfileDefault {
+ t.Fatalf("expected default profile, got %q", profile)
+ }
+}
+
+func TestHTTPUpstreamProfileFromContext_NilContext(t *testing.T) {
+ if profile := HTTPUpstreamProfileFromContext(nil); profile != HTTPUpstreamProfileDefault {
+ t.Fatalf("expected default profile, got %q", profile)
+ }
+}
diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go
index 99013ce55..de2000bab 100644
--- a/backend/internal/service/openai_account_scheduler.go
+++ b/backend/internal/service/openai_account_scheduler.go
@@ -43,34 +43,42 @@ type OpenAIAccountScheduleDecision struct {
}
type OpenAIAccountSchedulerMetricsSnapshot struct {
- SelectTotal int64
- StickyPreviousHitTotal int64
- StickySessionHitTotal int64
- LoadBalanceSelectTotal int64
- AccountSwitchTotal int64
- SchedulerLatencyMsTotal int64
- SchedulerLatencyMsAvg float64
- StickyHitRatio float64
- AccountSwitchRate float64
- LoadSkewAvg float64
- RuntimeStatsAccountCount int
+ SelectTotal int64
+ StickyPreviousHitTotal int64
+ StickySessionHitTotal int64
+ LoadBalanceSelectTotal int64
+ AccountSwitchTotal int64
+ SchedulerLatencyMsTotal int64
+ SchedulerLatencyMsAvg float64
+ StickyHitRatio float64
+ AccountSwitchRate float64
+ LoadSkewAvg float64
+ RuntimeStatsAccountCount int
+ CircuitBreakerOpenTotal int64
+ CircuitBreakerRecoverTotal int64
+ StickyReleaseErrorTotal int64
+ StickyReleaseCircuitOpenTotal int64
}
type OpenAIAccountScheduler interface {
Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error)
- ReportResult(accountID int64, success bool, firstTokenMs *int)
+ ReportResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64)
ReportSwitch()
SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot
}
type openAIAccountSchedulerMetrics struct {
- selectTotal atomic.Int64
- stickyPreviousHitTotal atomic.Int64
- stickySessionHitTotal atomic.Int64
- loadBalanceSelectTotal atomic.Int64
- accountSwitchTotal atomic.Int64
- latencyMsTotal atomic.Int64
- loadSkewMilliTotal atomic.Int64
+ selectTotal atomic.Int64
+ stickyPreviousHitTotal atomic.Int64
+ stickySessionHitTotal atomic.Int64
+ loadBalanceSelectTotal atomic.Int64
+ accountSwitchTotal atomic.Int64
+ latencyMsTotal atomic.Int64
+ loadSkewMilliTotal atomic.Int64
+ circuitBreakerOpenTotal atomic.Int64
+ circuitBreakerRecoverTotal atomic.Int64
+ stickyReleaseErrorTotal atomic.Int64
+ stickyReleaseCircuitOpenTotal atomic.Int64
}
func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) {
@@ -99,19 +107,439 @@ func (m *openAIAccountSchedulerMetrics) recordSwitch() {
}
type openAIAccountRuntimeStats struct {
- accounts sync.Map
- accountCount atomic.Int64
+ accounts sync.Map
+ circuitBreakers sync.Map // accountID → *accountCircuitBreaker
+ accountCount atomic.Int64
+ cleanupCounter atomic.Int64 // report call counter for periodic cleanup
+}
+
+// ---------------------------------------------------------------------------
+// Account-level Circuit Breaker (three-state: CLOSED → OPEN → HALF_OPEN)
+// ---------------------------------------------------------------------------
+
+const (
+ circuitBreakerStateClosed int32 = 0
+ circuitBreakerStateOpen int32 = 1
+ circuitBreakerStateHalfOpen int32 = 2
+
+ // Defaults (used when config values are zero/unset)
+ defaultCircuitBreakerFailThreshold = 5
+ defaultCircuitBreakerCooldownSec = 30
+ defaultCircuitBreakerHalfOpenMax = 2
+)
+
+type accountCircuitBreaker struct {
+ state atomic.Int32 // circuitBreakerState*
+ consecutiveFails atomic.Int32
+ lastFailureNano atomic.Int64 // time.Now().UnixNano()
+ halfOpenInFlight atomic.Int32 // current in-flight probes (decremented by release)
+ halfOpenAdmitted atomic.Int32 // total probes admitted this half-open cycle (never decremented by release)
+ halfOpenSuccess atomic.Int32
+}
+
+// allow returns true if the circuit breaker allows a request to pass through.
+func (cb *accountCircuitBreaker) allow(cooldown time.Duration, halfOpenMax int) bool {
+ switch cb.state.Load() {
+ case circuitBreakerStateClosed:
+ return true
+ case circuitBreakerStateOpen:
+ lastFail := time.Unix(0, cb.lastFailureNano.Load())
+ if time.Since(lastFail) <= cooldown {
+ return false
+ }
+ // Cooldown elapsed — attempt transition to HALF_OPEN.
+ // Reset counters before CAS to avoid a window where another goroutine
+ // sees HALF_OPEN but stale counter values.
+ cb.halfOpenInFlight.Store(0)
+ cb.halfOpenAdmitted.Store(0)
+ cb.halfOpenSuccess.Store(0)
+ cb.state.CompareAndSwap(circuitBreakerStateOpen, circuitBreakerStateHalfOpen)
+ // Either we transitioned or another goroutine did; fall through to
+ // HALF_OPEN gate below.
+ return cb.allowHalfOpen(halfOpenMax)
+ case circuitBreakerStateHalfOpen:
+ return cb.allowHalfOpen(halfOpenMax)
+ default:
+ return true
+ }
+}
+
+func (cb *accountCircuitBreaker) isHalfOpen() bool {
+ if cb == nil {
+ return false
+ }
+ return cb.state.Load() == circuitBreakerStateHalfOpen
+}
+
+// releaseHalfOpenPermit releases one HALF_OPEN probe permit when a candidate
+// passed filtering but was not actually selected to execute a request.
+func (cb *accountCircuitBreaker) releaseHalfOpenPermit() {
+ if cb == nil || cb.state.Load() != circuitBreakerStateHalfOpen {
+ return
+ }
+ for {
+ cur := cb.halfOpenInFlight.Load()
+ if cur <= 0 {
+ return
+ }
+ if cb.halfOpenInFlight.CompareAndSwap(cur, cur-1) {
+ return
+ }
+ }
+}
+
+func (cb *accountCircuitBreaker) allowHalfOpen(halfOpenMax int) bool {
+ for {
+ cur := cb.halfOpenInFlight.Load()
+ if int(cur) >= halfOpenMax {
+ return false
+ }
+ if cb.halfOpenInFlight.CompareAndSwap(cur, cur+1) {
+ cb.halfOpenAdmitted.Add(1)
+ return true
+ }
+ }
+}
+
+// recordSuccess is called when a request succeeds.
+func (cb *accountCircuitBreaker) recordSuccess() {
+ cb.consecutiveFails.Store(0)
+ if cb.state.Load() == circuitBreakerStateHalfOpen {
+ newSucc := cb.halfOpenSuccess.Add(1)
+ // Compare against halfOpenAdmitted (total probes ever admitted in
+ // this half-open cycle). Unlike halfOpenInFlight, this is never
+ // decremented by releaseHalfOpenPermit, so the recovery threshold
+ // remains stable regardless of candidate filtering outcomes.
+ admitted := cb.halfOpenAdmitted.Load()
+ if newSucc >= admitted && admitted > 0 {
+ if cb.state.CompareAndSwap(circuitBreakerStateHalfOpen, circuitBreakerStateClosed) {
+ cb.halfOpenInFlight.Store(0)
+ cb.halfOpenAdmitted.Store(0)
+ cb.halfOpenSuccess.Store(0)
+ }
+ }
+ }
+}
+
+// recordFailure is called when a request fails.
+func (cb *accountCircuitBreaker) recordFailure(threshold int) {
+ cb.lastFailureNano.Store(time.Now().UnixNano())
+ newFails := cb.consecutiveFails.Add(1)
+
+ switch cb.state.Load() {
+ case circuitBreakerStateClosed:
+ if int(newFails) >= threshold {
+ cb.state.CompareAndSwap(circuitBreakerStateClosed, circuitBreakerStateOpen)
+ }
+ case circuitBreakerStateHalfOpen:
+ if cb.state.CompareAndSwap(circuitBreakerStateHalfOpen, circuitBreakerStateOpen) {
+ cb.halfOpenInFlight.Store(0)
+ cb.halfOpenAdmitted.Store(0)
+ cb.halfOpenSuccess.Store(0)
+ }
+ }
+}
+
+// isOpen returns true if the circuit breaker is currently in OPEN state.
+func (cb *accountCircuitBreaker) isOpen() bool {
+ return cb.state.Load() == circuitBreakerStateOpen
+}
+
+// stateString returns a human-readable state name.
+func (cb *accountCircuitBreaker) stateString() string {
+ switch cb.state.Load() {
+ case circuitBreakerStateClosed:
+ return "CLOSED"
+ case circuitBreakerStateOpen:
+ return "OPEN"
+ case circuitBreakerStateHalfOpen:
+ return "HALF_OPEN"
+ default:
+ return "UNKNOWN"
+ }
+}
+
+// loadCircuitBreaker returns the CB for accountID if it exists, or nil.
+// Use this on hot paths (e.g. candidate filtering) to avoid allocating CB
+// objects for accounts that have never received a report.
+func (s *openAIAccountRuntimeStats) loadCircuitBreaker(accountID int64) *accountCircuitBreaker {
+ if val, ok := s.circuitBreakers.Load(accountID); ok {
+ if cb, _ := val.(*accountCircuitBreaker); cb != nil {
+ return cb
+ }
+ }
+ return nil
+}
+
+func (s *openAIAccountRuntimeStats) getCircuitBreaker(accountID int64) *accountCircuitBreaker {
+ if val, ok := s.circuitBreakers.Load(accountID); ok {
+ if cb, _ := val.(*accountCircuitBreaker); cb != nil {
+ return cb
+ }
+ }
+ cb := &accountCircuitBreaker{}
+ actual, _ := s.circuitBreakers.LoadOrStore(accountID, cb)
+ if existing, _ := actual.(*accountCircuitBreaker); existing != nil {
+ return existing
+ }
+ return cb
+}
+
+func (s *openAIAccountRuntimeStats) isCircuitOpen(accountID int64) bool {
+ val, ok := s.circuitBreakers.Load(accountID)
+ if !ok {
+ return false
+ }
+ cb, _ := val.(*accountCircuitBreaker)
+ if cb == nil {
+ return false
+ }
+ return cb.isOpen()
+}
+
+// ---------------------------------------------------------------------------
+// Dual-EWMA: fast (α=0.5) reacts quickly to degradation; slow (α=0.1)
+// stabilises over many samples. The pessimistic envelope max(fast,slow) means
+// we *sense* errors fast but only *confirm* recovery slowly.
+// ---------------------------------------------------------------------------
+
+const (
+ dualEWMAAlphaFast = 0.5
+ dualEWMAAlphaSlow = 0.1
+
+ // Per-model TTFT defaults
+ defaultPerModelTTFTMaxModels = 100
+ defaultPerModelTTFTTTL = 30 * time.Minute
+)
+
+// dualEWMA tracks a [0,1] signal (e.g. error rate) with two speeds.
+type dualEWMA struct {
+ fastBits atomic.Uint64 // α = dualEWMAAlphaFast, reacts in ~3 requests
+ slowBits atomic.Uint64 // α = dualEWMAAlphaSlow, stabilises over ~50 requests
+ sampleCount atomic.Int64 // total samples received; used for cold-start guard
+}
+
+// dualEWMAMinSamples is the minimum number of samples required before the
+// EWMA error rate is considered reliable for decision-making (e.g. sticky
+// release). This prevents a single failure on a fresh account from yielding
+// an artificially high error rate.
+const dualEWMAMinSamples = 5
+
+func (d *dualEWMA) update(sample float64) {
+ updateEWMAAtomic(&d.fastBits, sample, dualEWMAAlphaFast)
+ updateEWMAAtomic(&d.slowBits, sample, dualEWMAAlphaSlow)
+ d.sampleCount.Add(1)
+}
+
+// isWarmedUp returns true when enough samples have been collected for the
+// EWMA value to be meaningful.
+func (d *dualEWMA) isWarmedUp() bool {
+ return d.sampleCount.Load() >= dualEWMAMinSamples
+}
+
+// value returns the pessimistic envelope: max(fast, slow).
+func (d *dualEWMA) value() float64 {
+ fast := math.Float64frombits(d.fastBits.Load())
+ slow := math.Float64frombits(d.slowBits.Load())
+ if fast >= slow {
+ return fast
+ }
+ return slow
+}
+
+func (d *dualEWMA) fastValue() float64 {
+ return math.Float64frombits(d.fastBits.Load())
+}
+
+func (d *dualEWMA) slowValue() float64 {
+ return math.Float64frombits(d.slowBits.Load())
+}
+
+// dualEWMATTFT is like dualEWMA but handles the NaN-initialised first-sample
+// case required by TTFT tracking.
+type dualEWMATTFT struct {
+ fastBits atomic.Uint64 // α = dualEWMAAlphaFast
+ slowBits atomic.Uint64 // α = dualEWMAAlphaSlow
+}
+
+// initNaN stores NaN into both channels. Called once at allocation time.
+func (d *dualEWMATTFT) initNaN() {
+ nanBits := math.Float64bits(math.NaN())
+ d.fastBits.Store(nanBits)
+ d.slowBits.Store(nanBits)
+}
+
+func (d *dualEWMATTFT) update(sample float64) {
+ sampleBits := math.Float64bits(sample)
+ // fast channel
+ for {
+ oldBits := d.fastBits.Load()
+ oldValue := math.Float64frombits(oldBits)
+ if math.IsNaN(oldValue) {
+ if d.fastBits.CompareAndSwap(oldBits, sampleBits) {
+ break
+ }
+ continue
+ }
+ newValue := dualEWMAAlphaFast*sample + (1-dualEWMAAlphaFast)*oldValue
+ if d.fastBits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
+ break
+ }
+ }
+ // slow channel
+ for {
+ oldBits := d.slowBits.Load()
+ oldValue := math.Float64frombits(oldBits)
+ if math.IsNaN(oldValue) {
+ if d.slowBits.CompareAndSwap(oldBits, sampleBits) {
+ break
+ }
+ continue
+ }
+ newValue := dualEWMAAlphaSlow*sample + (1-dualEWMAAlphaSlow)*oldValue
+ if d.slowBits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
+ break
+ }
+ }
+}
+
+// value returns (pessimistic TTFT, hasTTFT). If both channels are still NaN
+// the caller gets (0, false).
+func (d *dualEWMATTFT) value() (float64, bool) {
+ fast := math.Float64frombits(d.fastBits.Load())
+ slow := math.Float64frombits(d.slowBits.Load())
+ fastOK := !math.IsNaN(fast)
+ slowOK := !math.IsNaN(slow)
+ switch {
+ case fastOK && slowOK:
+ if fast >= slow {
+ return fast, true
+ }
+ return slow, true
+ case fastOK:
+ return fast, true
+ case slowOK:
+ return slow, true
+ default:
+ return 0, false
+ }
+}
+
+func (d *dualEWMATTFT) fastValue() float64 {
+ return math.Float64frombits(d.fastBits.Load())
+}
+
+func (d *dualEWMATTFT) slowValue() float64 {
+ return math.Float64frombits(d.slowBits.Load())
+}
+
+// ---------------------------------------------------------------------------
+// Load Trend Tracker (ring-buffer linear regression)
+// ---------------------------------------------------------------------------
+
+const loadTrendRingSize = 10
+
+// loadTrendTracker maintains a fixed-size ring buffer of (timestamp, loadRate)
+// samples and computes the ordinary-least-squares slope to predict whether
+// an account's load is rising, falling, or stable.
+type loadTrendTracker struct {
+ mu sync.Mutex
+ samples [loadTrendRingSize]float64 // ring buffer of loadRate values
+ times [loadTrendRingSize]int64 // timestamps in UnixNano
+ head int // next write position
+ count int // number of valid samples (0..loadTrendRingSize)
+}
+
+// record pushes a loadRate sample with the current wall-clock timestamp.
+func (t *loadTrendTracker) record(loadRate float64) {
+ t.recordAt(loadRate, time.Now().UnixNano())
+}
+
+// recordAt pushes a loadRate sample with an explicit timestamp (for testing).
+func (t *loadTrendTracker) recordAt(loadRate float64, tsNano int64) {
+ t.mu.Lock()
+ t.samples[t.head] = loadRate
+ t.times[t.head] = tsNano
+ t.head = (t.head + 1) % loadTrendRingSize
+ if t.count < loadTrendRingSize {
+ t.count++
+ }
+ t.mu.Unlock()
+}
+
+// slope computes the simple linear regression slope of loadRate over time.
+//
+// slope = (N*Sigma(xi*yi) - Sigma(xi)*Sigma(yi)) / (N*Sigma(xi^2) - (Sigma(xi))^2)
+//
+// where xi = seconds elapsed since the oldest sample, yi = loadRate.
+// Returns 0 if fewer than 2 samples are available or if all timestamps are
+// identical (degenerate case).
+func (t *loadTrendTracker) slope() float64 {
+ t.mu.Lock()
+ n := t.count
+ if n < 2 {
+ t.mu.Unlock()
+ return 0
+ }
+
+ // Copy data under lock; computation happens outside.
+ var localSamples [loadTrendRingSize]float64
+ var localTimes [loadTrendRingSize]int64
+ copy(localSamples[:], t.samples[:])
+ copy(localTimes[:], t.times[:])
+ head := t.head
+ t.mu.Unlock()
+
+ // Identify oldest entry index.
+ oldest := head // head points to the next write pos; for a full ring it's the oldest entry.
+ if n < loadTrendRingSize {
+ oldest = 0
+ }
+ t0 := localTimes[oldest]
+
+ var sumX, sumY, sumXY, sumX2 float64
+ for i := 0; i < n; i++ {
+ idx := (oldest + i) % loadTrendRingSize
+ xi := float64(localTimes[idx]-t0) / 1e9 // relative seconds
+ yi := localSamples[idx]
+ sumX += xi
+ sumY += yi
+ sumXY += xi * yi
+ sumX2 += xi * xi
+ }
+
+ nf := float64(n)
+ denom := nf*sumX2 - sumX*sumX
+ if denom == 0 {
+ // All timestamps identical (or single sample) — no meaningful slope.
+ return 0
+ }
+ return (nf*sumXY - sumX*sumY) / denom
}
type openAIAccountRuntimeStat struct {
- errorRateEWMABits atomic.Uint64
- ttftEWMABits atomic.Uint64
+ errorRate dualEWMA
+ ttft dualEWMATTFT
+ modelTTFT sync.Map // key = model name (string), value = *dualEWMATTFT
+ modelTTFTLastUpdate sync.Map // key = model name (string), value = int64 (unix nano)
+ loadTrend loadTrendTracker
+ lastReportNano atomic.Int64 // last report timestamp for GC
}
func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats {
return &openAIAccountRuntimeStats{}
}
+// loadExisting returns the stat for accountID if it exists, or nil.
+// Unlike loadOrCreate, this never allocates a new stat.
+func (s *openAIAccountRuntimeStats) loadExisting(accountID int64) *openAIAccountRuntimeStat {
+ if value, ok := s.accounts.Load(accountID); ok {
+ stat, _ := value.(*openAIAccountRuntimeStat)
+ return stat
+ }
+ return nil
+}
+
func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat {
if value, ok := s.accounts.Load(accountID); ok {
stat, _ := value.(*openAIAccountRuntimeStat)
@@ -121,7 +549,7 @@ func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccount
}
stat := &openAIAccountRuntimeStat{}
- stat.ttftEWMABits.Store(math.Float64bits(math.NaN()))
+ stat.ttft.initNaN()
actual, loaded := s.accounts.LoadOrStore(accountID, stat)
if !loaded {
s.accountCount.Add(1)
@@ -134,6 +562,103 @@ func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccount
return stat
}
+// getOrCreateModelTTFT returns the per-model TTFT tracker, creating it if
+// it does not exist yet. Uses the LoadOrStore pattern for thread safety.
+func (stat *openAIAccountRuntimeStat) getOrCreateModelTTFT(model string) *dualEWMATTFT {
+ if val, ok := stat.modelTTFT.Load(model); ok {
+ if d, _ := val.(*dualEWMATTFT); d != nil {
+ return d
+ }
+ }
+ d := &dualEWMATTFT{}
+ d.initNaN()
+ actual, _ := stat.modelTTFT.LoadOrStore(model, d)
+ if existing, _ := actual.(*dualEWMATTFT); existing != nil {
+ return existing
+ }
+ return d
+}
+
+// reportModelTTFT updates both the per-model and global TTFT trackers.
+func (stat *openAIAccountRuntimeStat) reportModelTTFT(model string, sampleMs float64) {
+ if model == "" || sampleMs <= 0 {
+ return
+ }
+ d := stat.getOrCreateModelTTFT(model)
+ d.update(sampleMs)
+ stat.modelTTFTLastUpdate.Store(model, time.Now().UnixNano())
+ // Also update the global TTFT so that callers without a model still
+ // see a reasonable aggregate.
+ stat.ttft.update(sampleMs)
+}
+
+// modelTTFTValue returns the per-model TTFT value if a tracker exists and has
+// received at least one sample. Otherwise returns (0, false).
+func (stat *openAIAccountRuntimeStat) modelTTFTValue(model string) (float64, bool) {
+ if model == "" {
+ return 0, false
+ }
+ val, ok := stat.modelTTFT.Load(model)
+ if !ok {
+ return 0, false
+ }
+ d, _ := val.(*dualEWMATTFT)
+ if d == nil {
+ return 0, false
+ }
+ return d.value()
+}
+
+// cleanupStaleTTFT removes per-model TTFT entries that have not been updated
+// within ttl, and enforces a hard cap of maxModels entries. Oldest entries
+// are evicted first when the cap is exceeded.
+func (stat *openAIAccountRuntimeStat) cleanupStaleTTFT(ttl time.Duration, maxModels int) {
+ now := time.Now().UnixNano()
+ cutoff := now - int64(ttl)
+
+ // First pass: delete stale entries.
+ stat.modelTTFTLastUpdate.Range(func(key, value any) bool {
+ model, _ := key.(string)
+ ts, _ := value.(int64)
+ if ts < cutoff {
+ stat.modelTTFT.Delete(model)
+ stat.modelTTFTLastUpdate.Delete(model)
+ }
+ return true
+ })
+
+ if maxModels <= 0 {
+ return
+ }
+
+ // Second pass: count remaining entries and evict oldest if over limit.
+ type modelEntry struct {
+ model string
+ ts int64
+ }
+ var entries []modelEntry
+ stat.modelTTFTLastUpdate.Range(func(key, value any) bool {
+ model, _ := key.(string)
+ ts, _ := value.(int64)
+ entries = append(entries, modelEntry{model: model, ts: ts})
+ return true
+ })
+
+ if len(entries) <= maxModels {
+ return
+ }
+
+ // Sort by timestamp ascending (oldest first) and evict surplus.
+ sort.Slice(entries, func(i, j int) bool {
+ return entries[i].ts < entries[j].ts
+ })
+ evictCount := len(entries) - maxModels
+ for i := 0; i < evictCount; i++ {
+ stat.modelTTFT.Delete(entries[i].model)
+ stat.modelTTFTLastUpdate.Delete(entries[i].model)
+ }
+}
+
func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
for {
oldBits := target.Load()
@@ -145,40 +670,77 @@ func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) {
}
}
-func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) {
+func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) {
+ s.reportWithOptions(
+ accountID,
+ success,
+ firstTokenMs,
+ defaultCircuitBreakerFailThreshold,
+ true,
+ model,
+ ttftMs,
+ true,
+ defaultPerModelTTFTMaxModels,
+ )
+}
+
+func (s *openAIAccountRuntimeStats) reportWithOptions(
+ accountID int64,
+ success bool,
+ firstTokenMs *int,
+ cbThreshold int,
+ updateCircuitBreaker bool,
+ model string,
+ ttftMs float64,
+ perModelTTFTEnabled bool,
+ perModelTTFTMaxModels int,
+) {
if s == nil || accountID <= 0 {
return
}
- const alpha = 0.2
stat := s.loadOrCreate(accountID)
+ stat.lastReportNano.Store(time.Now().UnixNano())
errorSample := 1.0
if success {
errorSample = 0.0
}
- updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha)
+ stat.errorRate.update(errorSample)
- if firstTokenMs != nil && *firstTokenMs > 0 {
- ttft := float64(*firstTokenMs)
- ttftBits := math.Float64bits(ttft)
- for {
- oldBits := stat.ttftEWMABits.Load()
- oldValue := math.Float64frombits(oldBits)
- if math.IsNaN(oldValue) {
- if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) {
- break
- }
- continue
- }
- newValue := alpha*ttft + (1-alpha)*oldValue
- if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) {
- break
- }
+ // Per-model TTFT tracking: reportModelTTFT updates both per-model and
+ // global TTFT, so skip the separate global update to avoid double-counting.
+ if perModelTTFTEnabled && model != "" && ttftMs > 0 {
+ stat.reportModelTTFT(model, ttftMs)
+ } else if firstTokenMs != nil && *firstTokenMs > 0 {
+ stat.ttft.update(float64(*firstTokenMs))
+ }
+
+ // Update circuit breaker state only when feature is enabled.
+ if updateCircuitBreaker {
+ cb := s.getCircuitBreaker(accountID)
+ if success {
+ cb.recordSuccess()
+ } else {
+ cb.recordFailure(cbThreshold)
+ }
+ }
+
+ // Periodic cleanup: every 100 reports.
+ cnt := s.cleanupCounter.Add(1)
+ if cnt%100 == 0 {
+ maxModels := defaultPerModelTTFTMaxModels
+ if perModelTTFTMaxModels > 0 {
+ maxModels = perModelTTFTMaxModels
}
+ stat.cleanupStaleTTFT(defaultPerModelTTFTTTL, maxModels)
+ }
+ // GC inactive accounts and orphaned circuit breakers: every 1000 reports.
+ if cnt%1000 == 0 {
+ s.gcInactiveAccounts(time.Hour)
}
}
-func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) {
+func (s *openAIAccountRuntimeStats) snapshot(accountID int64, model ...string) (errorRate float64, ttft float64, hasTTFT bool) {
if s == nil || accountID <= 0 {
return 0, 0, false
}
@@ -190,12 +752,17 @@ func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64
if stat == nil {
return 0, 0, false
}
- errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load()))
- ttftValue := math.Float64frombits(stat.ttftEWMABits.Load())
- if math.IsNaN(ttftValue) {
- return errorRate, 0, false
+ errorRate = clamp01(stat.errorRate.value())
+
+ // Try per-model TTFT first; fallback to global.
+ if len(model) > 0 && model[0] != "" {
+ if mTTFT, mOK := stat.modelTTFTValue(model[0]); mOK {
+ return errorRate, mTTFT, true
+ }
}
- return errorRate, ttftValue, true
+
+ ttft, hasTTFT = stat.ttft.value()
+ return errorRate, ttft, hasTTFT
}
func (s *openAIAccountRuntimeStats) size() int {
@@ -205,6 +772,25 @@ func (s *openAIAccountRuntimeStats) size() int {
return int(s.accountCount.Load())
}
+// gcInactiveAccounts removes account stats and circuit breakers that have not
+// received any report for longer than maxIdle. This prevents unbounded growth
+// of the sync.Maps when accounts are created and then deleted/deactivated.
+func (s *openAIAccountRuntimeStats) gcInactiveAccounts(maxIdle time.Duration) {
+ if s == nil {
+ return
+ }
+ cutoff := time.Now().UnixNano() - int64(maxIdle)
+ s.accounts.Range(func(key, value any) bool {
+ stat, _ := value.(*openAIAccountRuntimeStat)
+ if stat == nil || stat.lastReportNano.Load() < cutoff {
+ s.accounts.Delete(key)
+ s.circuitBreakers.Delete(key)
+ s.accountCount.Add(-1)
+ }
+ return true
+ })
+}
+
type defaultOpenAIAccountScheduler struct {
service *OpenAIGatewayService
metrics openAIAccountSchedulerMetrics
@@ -331,6 +917,12 @@ func (s *defaultOpenAIAccountScheduler) selectBySessionHash(
return nil, nil
}
+ // Conditional sticky: release binding if account is unhealthy or overloaded.
+ if s.shouldReleaseStickySession(accountID) {
+ _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash)
+ return nil, nil // Fall through to load balance
+ }
+
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if acquireErr == nil && result.Acquired {
_ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL())
@@ -557,6 +1149,175 @@ func buildOpenAIWeightedSelectionOrder(
return order
}
+// selectP2COpenAICandidates selects candidates using Power-of-Two-Choices:
+// randomly pick 2 candidates, return the one with the higher score.
+// Repeat to build a full selection order for fallback.
+func selectP2COpenAICandidates(
+ candidates []openAIAccountCandidateScore,
+ req OpenAIAccountScheduleRequest,
+) []openAIAccountCandidateScore {
+ if len(candidates) <= 1 {
+ return append([]openAIAccountCandidateScore(nil), candidates...)
+ }
+
+ rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
+ pool := append([]openAIAccountCandidateScore(nil), candidates...)
+ order := make([]openAIAccountCandidateScore, 0, len(pool))
+
+ for len(pool) > 1 {
+ n := uint64(len(pool))
+ // Pick first random index.
+ idx1 := int(rng.nextUint64() % n)
+ // Pick second random index, distinct from the first.
+ idx2 := int(rng.nextUint64() % (n - 1))
+ if idx2 >= idx1 {
+ idx2++
+ }
+
+ // Compare: take the candidate with the higher score.
+ winner := idx1
+ if isOpenAIAccountCandidateBetter(pool[idx2], pool[idx1]) {
+ winner = idx2
+ }
+
+ order = append(order, pool[winner])
+ // Remove winner from pool (swap with last element for O(1) removal).
+ pool[winner] = pool[len(pool)-1]
+ pool = pool[:len(pool)-1]
+ }
+ // Append the last remaining candidate.
+ order = append(order, pool[0])
+ return order
+}
+
+// ---------------------------------------------------------------------------
+// Softmax Temperature Sampling
+// ---------------------------------------------------------------------------
+
+const defaultSoftmaxTemperature = 0.3
+
+type softmaxConfig struct {
+ enabled bool
+ temperature float64
+}
+
+// softmaxConfigRead reads softmax scheduler config with fallback defaults.
+func (s *defaultOpenAIAccountScheduler) softmaxConfigRead() softmaxConfig {
+ if s == nil || s.service == nil || s.service.cfg == nil {
+ return softmaxConfig{}
+ }
+ wsCfg := s.service.cfg.Gateway.OpenAIWS
+ temp := wsCfg.SchedulerSoftmaxTemperature
+ if temp <= 0 {
+ temp = defaultSoftmaxTemperature
+ }
+ return softmaxConfig{
+ enabled: wsCfg.SchedulerSoftmaxEnabled,
+ temperature: temp,
+ }
+}
+
+// selectSoftmaxOpenAICandidates applies softmax temperature sampling to select
+// one candidate probabilistically, then returns the full list with the selected
+// candidate first and the rest sorted by descending probability.
+//
+// Algorithm (numerically stable):
+//
+// maxScore = max(score[i])
+// weights[i] = exp((score[i] - maxScore) / temperature)
+// probability[i] = weights[i] / sum(weights)
+//
+// A higher temperature yields more uniform selection (exploration); a lower
+// temperature concentrates probability on the highest-scored candidates
+// (exploitation).
+func selectSoftmaxOpenAICandidates(
+ candidates []openAIAccountCandidateScore,
+ temperature float64,
+ rng *openAISelectionRNG,
+) []openAIAccountCandidateScore {
+ if len(candidates) == 0 {
+ return nil
+ }
+ if len(candidates) == 1 {
+ return append([]openAIAccountCandidateScore(nil), candidates...)
+ }
+ if temperature <= 0 {
+ temperature = defaultSoftmaxTemperature
+ }
+
+ // Step 1: find max score for numerical stability.
+ maxScore := candidates[0].score
+ for i := 1; i < len(candidates); i++ {
+ if candidates[i].score > maxScore {
+ maxScore = candidates[i].score
+ }
+ }
+
+ // Step 2: compute softmax weights.
+ type indexedProb struct {
+ index int
+ prob float64
+ }
+ probs := make([]indexedProb, len(candidates))
+ sumWeights := 0.0
+ for i := range candidates {
+ w := math.Exp((candidates[i].score - maxScore) / temperature)
+ // Guard against NaN/Inf from degenerate inputs.
+ if math.IsNaN(w) || math.IsInf(w, 0) {
+ w = 0
+ }
+ probs[i] = indexedProb{index: i, prob: w}
+ sumWeights += w
+ }
+
+ // Normalise to probabilities. If sumWeights is zero (all weights collapsed
+ // to zero, which can happen with extreme negative scores), fall back to
+ // uniform distribution.
+ if sumWeights > 0 {
+ for i := range probs {
+ probs[i].prob /= sumWeights
+ }
+ } else {
+ uniform := 1.0 / float64(len(probs))
+ for i := range probs {
+ probs[i].prob = uniform
+ }
+ }
+
+ // Step 3: sample ONE candidate via CDF.
+ r := rng.nextFloat64()
+ selectedIdx := probs[len(probs)-1].index // default to last if rounding issues
+ cumulative := 0.0
+ for _, ip := range probs {
+ cumulative += ip.prob
+ if cumulative >= r {
+ selectedIdx = ip.index
+ break
+ }
+ }
+
+ // Step 4: build result — selected candidate first, rest sorted by
+ // descending probability.
+ result := make([]openAIAccountCandidateScore, 0, len(candidates))
+ result = append(result, candidates[selectedIdx])
+
+ // Sort remaining by probability descending for fallback order.
+ remaining := make([]indexedProb, 0, len(probs)-1)
+ for _, ip := range probs {
+ if ip.index != selectedIdx {
+ remaining = append(remaining, ip)
+ }
+ }
+ sort.Slice(remaining, func(i, j int) bool {
+ return remaining[i].prob > remaining[j].prob
+ })
+ for _, ip := range remaining {
+ result = append(result, candidates[ip.index])
+ }
+
+ return result
+}
+
func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
ctx context.Context,
req OpenAIAccountScheduleRequest,
@@ -597,6 +1358,44 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
return nil, 0, 0, 0, errors.New("no available OpenAI accounts")
}
+ // Circuit breaker filtering: remove accounts with open CBs, but if that
+ // would empty the candidate pool, keep all accounts (graceful degradation).
+ cbEnabled, _, cbCooldown, cbHalfOpenMax := s.schedulerCircuitBreakerConfig()
+ heldHalfOpenPermits := make(map[int64]*accountCircuitBreaker)
+ releaseHalfOpenPermit := func(accountID int64) {
+ cb, ok := heldHalfOpenPermits[accountID]
+ if !ok || cb == nil {
+ return
+ }
+ cb.releaseHalfOpenPermit()
+ delete(heldHalfOpenPermits, accountID)
+ }
+ defer func() {
+ for accountID := range heldHalfOpenPermits {
+ releaseHalfOpenPermit(accountID)
+ }
+ }()
+ if cbEnabled {
+ healthy := make([]*Account, 0, len(filtered))
+ healthyLoadReq := make([]AccountWithConcurrency, 0, len(loadReq))
+ for i, account := range filtered {
+ cb := s.stats.loadCircuitBreaker(account.ID)
+ if cb == nil || cb.allow(cbCooldown, cbHalfOpenMax) {
+ healthy = append(healthy, account)
+ healthyLoadReq = append(healthyLoadReq, loadReq[i])
+ if cb.isHalfOpen() {
+ heldHalfOpenPermits[account.ID] = cb
+ }
+ }
+ }
+ if len(healthy) > 0 {
+ filtered = healthy
+ loadReq = healthyLoadReq
+ }
+ // else: all accounts are circuit-open; fall through with the
+ // original set to avoid returning "no accounts".
+ }
+
loadMap := map[int64]*AccountLoadInfo{}
if s.service.concurrencyService != nil {
if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil {
@@ -604,8 +1403,16 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
}
}
+ trendEnabled, trendMaxSlope := s.service.openAIWSSchedulerTrendConfig()
+ perModelTTFTEnabled, _ := s.schedulerPerModelTTFTConfig()
+ requestedModelForStats := ""
+ if perModelTTFTEnabled {
+ requestedModelForStats = req.RequestedModel
+ }
+
minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority
maxWaiting := 1
+ maxConcurrency := 0
loadRateSum := 0.0
loadRateSumSquares := 0.0
minTTFT, maxTTFT := 0.0, 0.0
@@ -625,7 +1432,10 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if loadInfo.WaitingCount > maxWaiting {
maxWaiting = loadInfo.WaitingCount
}
- errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID)
+ if account.Concurrency > maxConcurrency {
+ maxConcurrency = account.Concurrency
+ }
+ errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID, requestedModelForStats)
if hasTTFT && ttft > 0 {
if !hasTTFTSample {
minTTFT, maxTTFT = ttft, ttft
@@ -642,6 +1452,13 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
loadRate := float64(loadInfo.LoadRate)
loadRateSum += loadRate
loadRateSumSquares += loadRate * loadRate
+
+ // Record current load rate sample for trend tracking.
+ if trendEnabled {
+ stat := s.stats.loadOrCreate(account.ID)
+ stat.loadTrend.record(loadRate)
+ }
+
candidates = append(candidates, openAIAccountCandidateScore{
account: account,
loadInfo: loadInfo,
@@ -659,8 +1476,33 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
if maxPriority > minPriority {
priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority)
}
+ // Base load factor from percentage utilization.
loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0)
+ // Capacity-aware adjustment: accounts with more absolute headroom get a bonus.
+ if maxConcurrency > 0 && item.account.Concurrency > 0 {
+ remainingSlots := float64(item.account.Concurrency) * (1 - float64(item.loadInfo.LoadRate)/100.0)
+ capacityBonus := clamp01(remainingSlots / float64(maxConcurrency))
+ // Blend: 70% relative load + 30% capacity bonus
+ loadFactor = 0.7*loadFactor + 0.3*capacityBonus
+ }
+
+ // Trend adjustment: penalise accounts whose load is rising, reward those declining.
+ // trendAdj ranges [0, 1] where 0 = max rising slope, 1 = max falling/flat slope.
+ // loadFactor is blended: 70% base load + 30% trend influence.
+ if trendEnabled {
+ stat := s.stats.loadOrCreate(item.account.ID)
+ slope := stat.loadTrend.slope()
+ trendAdj := 1.0 - clamp01(slope/trendMaxSlope)
+ loadFactor *= (0.7 + 0.3*trendAdj)
+ }
+
queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting))
+ // Queue depth relative to account's own capacity for capacity-aware blending.
+ if item.account.Concurrency > 0 {
+ relativeQueue := clamp01(float64(item.loadInfo.WaitingCount) / float64(item.account.Concurrency))
+ // Blend: 60% cross-account normalized + 40% self-relative
+ queueFactor = 0.6*queueFactor + 0.4*(1-relativeQueue)
+ }
errorFactor := 1 - clamp01(item.errorRate)
ttftFactor := 0.5
if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT {
@@ -674,23 +1516,40 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
weights.TTFT*ttftFactor
}
- topK := s.service.openAIWSLBTopK()
- if topK > len(candidates) {
- topK = len(candidates)
- }
- if topK <= 0 {
- topK = 1
+ var selectionOrder []openAIAccountCandidateScore
+ topK := 0
+ rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req))
+ smCfg := s.softmaxConfigRead()
+ p2cEnabled := s.service.openAIWSSchedulerP2CEnabled()
+ if smCfg.enabled && len(candidates) > 3 {
+ selectionOrder = selectSoftmaxOpenAICandidates(candidates, smCfg.temperature, &rng)
+ // topK = 0 signals softmax mode in metrics / decision struct.
+ } else if p2cEnabled {
+ selectionOrder = selectP2COpenAICandidates(candidates, req)
+ // topK = 0 signals P2C mode in metrics / decision struct.
+ } else {
+ topK = s.service.openAIWSLBTopK()
+ if topK > len(candidates) {
+ topK = len(candidates)
+ }
+ if topK <= 0 {
+ topK = 1
+ }
+ rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
+ selectionOrder = buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
}
- rankedCandidates := selectTopKOpenAICandidates(candidates, topK)
- selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req)
for i := 0; i < len(selectionOrder); i++ {
candidate := selectionOrder[i]
result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency)
if acquireErr != nil {
+ releaseHalfOpenPermit(candidate.account.ID)
return nil, len(candidates), topK, loadSkew, acquireErr
}
if result != nil && result.Acquired {
+ // Keep HALF_OPEN permit for the selected account; the outcome will be
+ // settled by ReportResult(success/failure) after the request finishes.
+ delete(heldHalfOpenPermits, candidate.account.ID)
if req.SessionHash != "" {
_ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID)
}
@@ -700,10 +1559,12 @@ func (s *defaultOpenAIAccountScheduler) selectByLoadBalance(
ReleaseFunc: result.ReleaseFunc,
}, len(candidates), topK, loadSkew, nil
}
+ releaseHalfOpenPermit(candidate.account.ID)
}
cfg := s.service.schedulingConfig()
candidate := selectionOrder[0]
+ releaseHalfOpenPermit(candidate.account.ID)
return &AccountSelectionResult{
Account: candidate.account,
WaitPlan: &AccountWaitPlan{
@@ -726,11 +1587,54 @@ func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Ac
return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport
}
-func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) {
+func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) {
if s == nil || s.stats == nil {
return
}
- s.stats.report(accountID, success, firstTokenMs)
+ perModelTTFTEnabled, perModelTTFTMaxModels := s.schedulerPerModelTTFTConfig()
+ enabled, threshold, _, _ := s.schedulerCircuitBreakerConfig()
+ if !enabled {
+ // Circuit breaker disabled: only update runtime signals (error-rate/TTFT),
+ // do not mutate circuit breaker state.
+ s.stats.reportWithOptions(
+ accountID,
+ success,
+ firstTokenMs,
+ 0,
+ false,
+ model,
+ ttftMs,
+ perModelTTFTEnabled,
+ perModelTTFTMaxModels,
+ )
+ return
+ }
+
+ // Snapshot state before the update for metrics tracking.
+ cb := s.stats.getCircuitBreaker(accountID)
+ stateBefore := cb.state.Load()
+
+ s.stats.reportWithOptions(
+ accountID,
+ success,
+ firstTokenMs,
+ threshold,
+ true,
+ model,
+ ttftMs,
+ perModelTTFTEnabled,
+ perModelTTFTMaxModels,
+ )
+
+ stateAfter := cb.state.Load()
+ // CLOSED/HALF_OPEN → OPEN: circuit tripped.
+ if stateBefore != circuitBreakerStateOpen && stateAfter == circuitBreakerStateOpen {
+ s.metrics.circuitBreakerOpenTotal.Add(1)
+ }
+ // OPEN/HALF_OPEN → CLOSED: circuit recovered.
+ if stateBefore != circuitBreakerStateClosed && stateAfter == circuitBreakerStateClosed {
+ s.metrics.circuitBreakerRecoverTotal.Add(1)
+ }
}
func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
@@ -740,6 +1644,107 @@ func (s *defaultOpenAIAccountScheduler) ReportSwitch() {
s.metrics.recordSwitch()
}
+// schedulerCircuitBreakerConfig reads CB config with fallback defaults.
+func (s *defaultOpenAIAccountScheduler) schedulerCircuitBreakerConfig() (enabled bool, threshold int, cooldown time.Duration, halfOpenMax int) {
+ threshold = defaultCircuitBreakerFailThreshold
+ cooldown = time.Duration(defaultCircuitBreakerCooldownSec) * time.Second
+ halfOpenMax = defaultCircuitBreakerHalfOpenMax
+
+ if s == nil || s.service == nil || s.service.cfg == nil {
+ return false, threshold, cooldown, halfOpenMax
+ }
+ wsCfg := s.service.cfg.Gateway.OpenAIWS
+ enabled = wsCfg.SchedulerCircuitBreakerEnabled
+ if wsCfg.SchedulerCircuitBreakerFailThreshold > 0 {
+ threshold = wsCfg.SchedulerCircuitBreakerFailThreshold
+ }
+ if wsCfg.SchedulerCircuitBreakerCooldownSec > 0 {
+ cooldown = time.Duration(wsCfg.SchedulerCircuitBreakerCooldownSec) * time.Second
+ }
+ if wsCfg.SchedulerCircuitBreakerHalfOpenMax > 0 {
+ halfOpenMax = wsCfg.SchedulerCircuitBreakerHalfOpenMax
+ }
+ return enabled, threshold, cooldown, halfOpenMax
+}
+
+func (s *defaultOpenAIAccountScheduler) schedulerPerModelTTFTConfig() (enabled bool, maxModels int) {
+ maxModels = defaultPerModelTTFTMaxModels
+ if s == nil || s.service == nil || s.service.cfg == nil {
+ return false, maxModels
+ }
+ wsCfg := s.service.cfg.Gateway.OpenAIWS
+ enabled = wsCfg.SchedulerPerModelTTFTEnabled
+ if wsCfg.SchedulerPerModelTTFTMaxModels > 0 {
+ maxModels = wsCfg.SchedulerPerModelTTFTMaxModels
+ }
+ return enabled, maxModels
+}
+
+// ---------------------------------------------------------------------------
+// Conditional Sticky Session Release
+// ---------------------------------------------------------------------------
+
+const defaultStickyReleaseErrorThreshold = 0.3
+
+type stickyReleaseConfig struct {
+ enabled bool
+ errorThreshold float64
+}
+
+// stickyReleaseConfigRead reads conditional sticky release config with defaults.
+func (s *defaultOpenAIAccountScheduler) stickyReleaseConfigRead() stickyReleaseConfig {
+ if s == nil || s.service == nil || s.service.cfg == nil {
+ return stickyReleaseConfig{}
+ }
+ wsCfg := s.service.cfg.Gateway.OpenAIWS
+ threshold := wsCfg.StickyReleaseErrorThreshold
+ if threshold <= 0 {
+ threshold = defaultStickyReleaseErrorThreshold
+ }
+ return stickyReleaseConfig{
+ enabled: wsCfg.StickyReleaseEnabled,
+ errorThreshold: threshold,
+ }
+}
+
+// shouldReleaseStickySession checks whether a sticky binding should be
+// released because the account is unhealthy (circuit breaker open) or has a
+// high error rate. This runs BEFORE slot acquisition to avoid wasting
+// concurrency capacity on degraded accounts.
+func (s *defaultOpenAIAccountScheduler) shouldReleaseStickySession(accountID int64) bool {
+ if s == nil || s.stats == nil || s.service == nil {
+ return false
+ }
+
+ cfg := s.stickyReleaseConfigRead()
+ if !cfg.enabled {
+ return false
+ }
+
+ // Check 1: Circuit breaker is open -> immediate release.
+ // Only check if CB feature is actually enabled, because the default CB
+ // threshold (5) is very aggressive and may trip unexpectedly.
+ cbEnabled, _, _, _ := s.schedulerCircuitBreakerConfig()
+ if cbEnabled && s.stats.isCircuitOpen(accountID) {
+ s.metrics.stickyReleaseCircuitOpenTotal.Add(1)
+ return true
+ }
+
+ // Check 2: Error rate exceeds threshold -> immediate release.
+ // Guard against cold-start: the EWMA error rate is unreliable when
+ // fewer than dualEWMAMinSamples have been collected.
+ stat := s.stats.loadExisting(accountID)
+ if stat != nil && stat.errorRate.isWarmedUp() {
+ errorRate, _, _ := s.stats.snapshot(accountID)
+ if errorRate > cfg.errorThreshold {
+ s.metrics.stickyReleaseErrorTotal.Add(1)
+ return true
+ }
+ }
+
+ return false
+}
+
func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot {
if s == nil {
return OpenAIAccountSchedulerMetricsSnapshot{}
@@ -753,13 +1758,17 @@ func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountScheduler
loadSkewTotal := s.metrics.loadSkewMilliTotal.Load()
snapshot := OpenAIAccountSchedulerMetricsSnapshot{
- SelectTotal: selectTotal,
- StickyPreviousHitTotal: prevHit,
- StickySessionHitTotal: sessionHit,
- LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
- AccountSwitchTotal: switchTotal,
- SchedulerLatencyMsTotal: latencyTotal,
- RuntimeStatsAccountCount: s.stats.size(),
+ SelectTotal: selectTotal,
+ StickyPreviousHitTotal: prevHit,
+ StickySessionHitTotal: sessionHit,
+ LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(),
+ AccountSwitchTotal: switchTotal,
+ SchedulerLatencyMsTotal: latencyTotal,
+ RuntimeStatsAccountCount: s.stats.size(),
+ CircuitBreakerOpenTotal: s.metrics.circuitBreakerOpenTotal.Load(),
+ CircuitBreakerRecoverTotal: s.metrics.circuitBreakerRecoverTotal.Load(),
+ StickyReleaseErrorTotal: s.metrics.stickyReleaseErrorTotal.Load(),
+ StickyReleaseCircuitOpenTotal: s.metrics.stickyReleaseCircuitOpenTotal.Load(),
}
if selectTotal > 0 {
snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal)
@@ -820,12 +1829,12 @@ func (s *OpenAIGatewayService) SelectAccountWithScheduler(
})
}
-func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) {
+func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int, model string, ttftMs float64) {
scheduler := s.getOpenAIAccountScheduler()
if scheduler == nil {
return
}
- scheduler.ReportResult(accountID, success, firstTokenMs)
+ scheduler.ReportResult(accountID, success, firstTokenMs, model, ttftMs)
}
func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() {
@@ -858,6 +1867,13 @@ func (s *OpenAIGatewayService) openAIWSLBTopK() int {
return 7
}
+func (s *OpenAIGatewayService) openAIWSSchedulerP2CEnabled() bool {
+ if s != nil && s.cfg != nil {
+ return s.cfg.Gateway.OpenAIWS.SchedulerP2CEnabled
+ }
+ return false
+}
+
func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView {
if s != nil && s.cfg != nil {
return GatewayOpenAIWSSchedulerScoreWeightsView{
@@ -885,6 +1901,24 @@ type GatewayOpenAIWSSchedulerScoreWeightsView struct {
TTFT float64
}
+// defaultSchedulerTrendMaxSlope is the normalization ceiling for the trend
+// slope. A slope of 5.0 means the account's load rate is increasing at 5
+// percentage points per second — a very steep rise.
+const defaultSchedulerTrendMaxSlope = 5.0
+
+// openAIWSSchedulerTrendConfig reads trend-prediction config with defaults.
+func (s *OpenAIGatewayService) openAIWSSchedulerTrendConfig() (enabled bool, maxSlope float64) {
+ maxSlope = defaultSchedulerTrendMaxSlope
+ if s == nil || s.cfg == nil {
+ return false, maxSlope
+ }
+ enabled = s.cfg.Gateway.OpenAIWS.SchedulerTrendEnabled
+ if s.cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope > 0 {
+ maxSlope = s.cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope
+ }
+ return enabled, maxSlope
+}
+
func clamp01(value float64) float64 {
switch {
case value < 0:
diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go
index 7f6f1b66a..ce895c5ec 100644
--- a/backend/internal/service/openai_account_scheduler_test.go
+++ b/backend/internal/service/openai_account_scheduler_test.go
@@ -1,6 +1,7 @@
package service
import (
+ "container/heap"
"context"
"fmt"
"math"
@@ -447,7 +448,7 @@ func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) {
selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny)
require.NoError(t, err)
require.NotNil(t, selection)
- svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120))
+ svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120), "", 0)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
@@ -465,19 +466,120 @@ func intPtrForTest(v int) *int {
func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) {
stats := newOpenAIAccountRuntimeStats()
- stats.report(1001, true, nil)
+ stats.report(1001, true, nil, "", 0) // error: fast 0→0, slow 0→0
firstTTFT := 100
- stats.report(1001, false, &firstTTFT)
+ stats.report(1001, false, &firstTTFT, "", 0) // error: fast 0→0.5, slow 0→0.1; ttft: NaN→100 (both)
secondTTFT := 200
- stats.report(1001, false, &secondTTFT)
+ stats.report(1001, false, &secondTTFT, "", 0) // error: fast 0.5→0.75, slow 0.1→0.19; ttft: fast 100→150, slow 100→110
errorRate, ttft, hasTTFT := stats.snapshot(1001)
require.True(t, hasTTFT)
- require.InDelta(t, 0.36, errorRate, 1e-9)
- require.InDelta(t, 120.0, ttft, 1e-9)
+ // errorRate = max(fast=0.75, slow=0.19) = 0.75
+ require.InDelta(t, 0.75, errorRate, 1e-9)
+ // ttft = max(fast=150, slow=110) = 150
+ require.InDelta(t, 150.0, ttft, 1e-9)
require.Equal(t, 1, stats.size())
}
+func TestDualEWMA_UpdateAndValue(t *testing.T) {
+ var d dualEWMA
+
+ // Initial state: both channels are 0.
+ require.Equal(t, 0.0, d.fastValue())
+ require.Equal(t, 0.0, d.slowValue())
+ require.Equal(t, 0.0, d.value())
+
+ // First sample = 1.0
+ d.update(1.0)
+ // fast: 0.5*1 + 0.5*0 = 0.5
+ require.InDelta(t, 0.5, d.fastValue(), 1e-12)
+ // slow: 0.1*1 + 0.9*0 = 0.1
+ require.InDelta(t, 0.1, d.slowValue(), 1e-12)
+ // value = max(0.5, 0.1) = 0.5
+ require.InDelta(t, 0.5, d.value(), 1e-12)
+
+ // Second sample = 0.0 (recovery)
+ d.update(0.0)
+ // fast: 0.5*0 + 0.5*0.5 = 0.25
+ require.InDelta(t, 0.25, d.fastValue(), 1e-12)
+ // slow: 0.1*0 + 0.9*0.1 = 0.09
+ require.InDelta(t, 0.09, d.slowValue(), 1e-12)
+ // value = max(0.25, 0.09) = 0.25
+ require.InDelta(t, 0.25, d.value(), 1e-12)
+}
+
+func TestDualEWMA_SlowDominatesAfterRecovery(t *testing.T) {
+ var d dualEWMA
+
+ // Spike: several failures.
+ for i := 0; i < 10; i++ {
+ d.update(1.0)
+ }
+ // Now fast is close to 1, slow is also rising.
+
+ // Recovery: many successes.
+ for i := 0; i < 20; i++ {
+ d.update(0.0)
+ }
+ // Fast should have dropped close to 0, slow should still be > fast.
+ require.Greater(t, d.slowValue(), d.fastValue(),
+ "after recovery, slow channel should dominate the pessimistic envelope")
+ require.Equal(t, d.slowValue(), d.value())
+}
+
+func TestDualEWMATTFT_NaNInitAndFirstSample(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+
+ // Before any sample, value should report no data.
+ v, ok := d.value()
+ require.False(t, ok)
+ require.Equal(t, 0.0, v)
+
+ // First sample seeds both channels.
+ d.update(100.0)
+ require.InDelta(t, 100.0, d.fastValue(), 1e-12)
+ require.InDelta(t, 100.0, d.slowValue(), 1e-12)
+ v, ok = d.value()
+ require.True(t, ok)
+ require.InDelta(t, 100.0, v, 1e-12)
+
+ // Second sample.
+ d.update(200.0)
+ // fast: 0.5*200 + 0.5*100 = 150
+ require.InDelta(t, 150.0, d.fastValue(), 1e-12)
+ // slow: 0.1*200 + 0.9*100 = 110
+ require.InDelta(t, 110.0, d.slowValue(), 1e-12)
+ v, ok = d.value()
+ require.True(t, ok)
+ require.InDelta(t, 150.0, v, 1e-12)
+}
+
+func TestDualEWMATTFT_SlowDominatesWhenLatencyDrops(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+
+ // Warm up with high latency.
+ for i := 0; i < 20; i++ {
+ d.update(500.0)
+ }
+ // Now push many low-latency samples.
+ for i := 0; i < 20; i++ {
+ d.update(100.0)
+ }
+ // Fast should have adapted down quickly; slow should still be higher.
+ require.Greater(t, d.slowValue(), d.fastValue(),
+ "after latency improvement, slow channel should dominate the pessimistic TTFT")
+ v, ok := d.value()
+ require.True(t, ok)
+ require.InDelta(t, d.slowValue(), v, 1e-12)
+}
+
+func TestDualEWMAConstants(t *testing.T) {
+ require.Equal(t, 0.5, dualEWMAAlphaFast)
+ require.Equal(t, 0.1, dualEWMAAlphaSlow)
+}
+
func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) {
stats := newOpenAIAccountRuntimeStats()
@@ -496,7 +598,7 @@ func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) {
accountID := int64(i%accountCount + 1)
success := (i+worker)%3 != 0
ttft := 80 + (i+worker)%40
- stats.report(accountID, success, &ttft)
+ stats.report(accountID, success, &ttft, "", 0)
}
}()
}
@@ -751,7 +853,7 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
require.True(t, ok)
ttft := 100
- scheduler.ReportResult(1001, true, &ttft)
+ scheduler.ReportResult(1001, true, &ttft, "", 0)
scheduler.ReportSwitch()
scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{
Layer: openAIAccountScheduleLayerLoadBalance,
@@ -780,7 +882,7 @@ func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) {
func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) {
svc := &OpenAIGatewayService{}
ttft := 120
- svc.ReportOpenAIAccountScheduleResult(10, true, &ttft)
+ svc.ReportOpenAIAccountScheduleResult(10, true, &ttft, "", 0)
svc.RecordOpenAIAccountSwitch()
snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics()
require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1))
@@ -836,6 +938,3249 @@ func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *
require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2))
}
+func TestLoadFactorCapacityAwareness(t *testing.T) {
+ // Test that accounts with higher absolute capacity get better scores
+ // when percentage load is equal.
+ //
+ // Setup:
+ // Account A: Concurrency=100, LoadRate=50 (50 free slots)
+ // Account B: Concurrency=10, LoadRate=50 (5 free slots)
+ // Both at 50% load, but A should score higher due to more headroom.
+
+ ctx := context.Background()
+ groupID := int64(20)
+ accounts := []Account{
+ {
+ ID: 6001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 100,
+ Priority: 0,
+ },
+ {
+ ID: 6002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 10,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ // Use only Load weight to isolate the capacity-aware loadFactor effect.
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 6001: {AccountID: 6001, LoadRate: 50, WaitingCount: 0},
+ 6002: {AccountID: 6002, LoadRate: 50, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ // Verify account A (high capacity) is always selected first by score.
+ // Because weighted selection has randomness, we run multiple iterations
+ // and verify A is selected more often than B.
+ countA := 0
+ countB := 0
+ iterations := 100
+ for i := 0; i < iterations; i++ {
+ sessionHash := fmt.Sprintf("cap_aware_test_%d", i)
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ if selection.Account.ID == 6001 {
+ countA++
+ } else {
+ countB++
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ // Account A (100 concurrency) should be selected significantly more often
+ // than Account B (10 concurrency) because A has 50 free slots vs 5 free slots.
+ require.Greater(t, countA, countB,
+ "high-capacity account (50 free slots) should be selected more often than low-capacity (5 free slots) at equal load percentage; got A=%d B=%d", countA, countB)
+
+ // -----------------------------------------------------------------------
+ // Verify score math directly via the capacity-aware loadFactor formula.
+ // -----------------------------------------------------------------------
+ // maxConcurrency = 100 (from account A)
+ //
+ // Account A (Concurrency=100, LoadRate=50):
+ // base loadFactor = 1 - 50/100 = 0.5
+ // remainingSlots = 100 * 0.5 = 50
+ // capacityBonus = 50 / 100 = 0.5
+ // loadFactor = 0.7*0.5 + 0.3*0.5 = 0.5
+ //
+ // Account B (Concurrency=10, LoadRate=50):
+ // base loadFactor = 1 - 50/100 = 0.5
+ // remainingSlots = 10 * 0.5 = 5
+ // capacityBonus = 5 / 100 = 0.05
+ // loadFactor = 0.7*0.5 + 0.3*0.05 = 0.365
+ //
+ // With Load weight = 1.0 and all others 0.0, score = loadFactor.
+ expectedScoreA := 0.7*0.5 + 0.3*0.5 // 0.5
+ expectedScoreB := 0.7*0.5 + 0.3*(5.0/100.0) // 0.365
+ require.Greater(t, expectedScoreA, expectedScoreB, "score sanity check")
+ require.InDelta(t, 0.5, expectedScoreA, 1e-9)
+ require.InDelta(t, 0.365, expectedScoreB, 1e-9)
+}
+
+func TestQueueFactorCapacityAwareness(t *testing.T) {
+ // Test that the capacity-aware queue factor penalises accounts
+ // whose queue depth is high relative to their own concurrency.
+ //
+ // Account A: Concurrency=100, WaitingCount=10 (10% of capacity)
+ // Account B: Concurrency=10, WaitingCount=10 (100% of capacity)
+ // Both have same absolute waiting count, but B should score lower.
+
+ ctx := context.Background()
+ groupID := int64(21)
+ accounts := []Account{
+ {
+ ID: 7001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 100,
+ Priority: 0,
+ },
+ {
+ ID: 7002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 10,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ // Use only Queue weight to isolate the capacity-aware queueFactor effect.
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 7001: {AccountID: 7001, LoadRate: 30, WaitingCount: 10},
+ 7002: {AccountID: 7002, LoadRate: 30, WaitingCount: 10},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ countA := 0
+ countB := 0
+ iterations := 100
+ for i := 0; i < iterations; i++ {
+ sessionHash := fmt.Sprintf("queue_aware_test_%d", i)
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ if selection.Account.ID == 7001 {
+ countA++
+ } else {
+ countB++
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ require.Greater(t, countA, countB,
+ "account with lower relative queue depth should be selected more often; got A=%d B=%d", countA, countB)
+
+ // -----------------------------------------------------------------------
+ // Verify score math for the capacity-aware queueFactor.
+ // -----------------------------------------------------------------------
+ // maxWaiting = 10 (both accounts have WaitingCount=10)
+ //
+ // Account A (Concurrency=100, WaitingCount=10):
+ // base queueFactor = 1 - 10/10 = 0.0
+ // relativeQueue = 10/100 = 0.1
+ // queueFactor = 0.6*0.0 + 0.4*(1-0.1) = 0.36
+ //
+ // Account B (Concurrency=10, WaitingCount=10):
+ // base queueFactor = 1 - 10/10 = 0.0
+ // relativeQueue = clamp01(10/10) = 1.0
+ // queueFactor = 0.6*0.0 + 0.4*(1-1.0) = 0.0
+ expectedQueueA := 0.6*0.0 + 0.4*(1-0.1)
+ expectedQueueB := 0.6*0.0 + 0.4*(1-1.0)
+ require.Greater(t, expectedQueueA, expectedQueueB)
+ require.InDelta(t, 0.36, expectedQueueA, 1e-9)
+ require.InDelta(t, 0.0, expectedQueueB, 1e-9)
+}
+
+func TestLoadFactorCapacityAwareness_ZeroConcurrencyFallback(t *testing.T) {
+ // When Concurrency is 0, the capacity-aware blending should be skipped
+ // and loadFactor should fall back to the simple loadRate/100 formula.
+
+ ctx := context.Background()
+ groupID := int64(22)
+ accounts := []Account{
+ {
+ ID: 8001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 0, // unset / zero
+ Priority: 0,
+ },
+ {
+ ID: 8002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 0,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.0
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 8001: {AccountID: 8001, LoadRate: 30, WaitingCount: 0},
+ 8002: {AccountID: 8002, LoadRate: 70, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ // With Concurrency=0, maxConcurrency=0, so the capacity-aware path is skipped.
+ // Account 8001 (LoadRate=30) should have higher loadFactor than 8002 (LoadRate=70).
+ countLow := 0
+ iterations := 60
+ for i := 0; i < iterations; i++ {
+ sessionHash := fmt.Sprintf("zero_conc_%d", i)
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ if selection.Account.ID == 8001 {
+ countLow++
+ }
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ // 8001 (lower load) should be picked more often.
+ require.Greater(t, countLow, iterations/2,
+ "account with lower load should be selected more often when concurrency is 0; got %d/%d", countLow, iterations)
+}
+
func int64PtrForTest(v int64) *int64 {
return &v
}
+
+// ---------------------------------------------------------------------------
+// Circuit Breaker Tests
+// ---------------------------------------------------------------------------
+
+func TestAccountCircuitBreaker_ClosedToOpen(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 30 * time.Second
+ halfOpenMax := 2
+
+ // Initially CLOSED — should allow.
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+ require.Equal(t, "CLOSED", cb.stateString())
+ require.False(t, cb.isOpen())
+
+ // Record 4 failures — should still be CLOSED (threshold is 5).
+ for i := 0; i < 4; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "CLOSED", cb.stateString())
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+
+ // 5th failure trips the breaker to OPEN.
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ require.Equal(t, "OPEN", cb.stateString())
+ require.True(t, cb.isOpen())
+ require.False(t, cb.allow(cooldown, halfOpenMax))
+}
+
+func TestAccountCircuitBreaker_OpenToHalfOpen(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 50 * time.Millisecond
+ halfOpenMax := 2
+
+ // Trip the breaker.
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "OPEN", cb.stateString())
+ require.False(t, cb.allow(cooldown, halfOpenMax))
+
+ // Wait for cooldown to elapse.
+ time.Sleep(cooldown + 10*time.Millisecond)
+
+ // Next allow() should transition to HALF_OPEN and admit the request.
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+}
+
+func TestAccountCircuitBreaker_HalfOpenToClose(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 50 * time.Millisecond
+ halfOpenMax := 2
+
+ // Trip the breaker.
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "OPEN", cb.stateString())
+
+ // Wait for cooldown.
+ time.Sleep(cooldown + 10*time.Millisecond)
+
+ // Allow first probe — transitions to HALF_OPEN.
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+
+ // Allow second probe.
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+
+ // Third probe should be rejected (halfOpenMax=2).
+ require.False(t, cb.allow(cooldown, halfOpenMax))
+
+ // Both probes succeed — should close the circuit.
+ cb.recordSuccess()
+ // After first success, still HALF_OPEN (need both to succeed).
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+ cb.recordSuccess()
+ // Both probes succeeded — circuit should be CLOSED now.
+ require.Equal(t, "CLOSED", cb.stateString())
+ require.False(t, cb.isOpen())
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+}
+
+func TestAccountCircuitBreaker_ReleaseHalfOpenPermit(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 10 * time.Millisecond
+ halfOpenMax := 2
+
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "OPEN", cb.stateString())
+
+ time.Sleep(cooldown + 5*time.Millisecond)
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+ require.Equal(t, int32(1), cb.halfOpenInFlight.Load())
+
+ cb.releaseHalfOpenPermit()
+ require.Equal(t, int32(0), cb.halfOpenInFlight.Load())
+
+ // Idempotent release should not underflow.
+ cb.releaseHalfOpenPermit()
+ require.Equal(t, int32(0), cb.halfOpenInFlight.Load())
+}
+
+func TestAccountCircuitBreaker_HalfOpenToOpen(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 50 * time.Millisecond
+ halfOpenMax := 2
+
+ // Trip the breaker.
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "OPEN", cb.stateString())
+
+ // Wait for cooldown.
+ time.Sleep(cooldown + 10*time.Millisecond)
+
+ // Allow a probe — transitions to HALF_OPEN.
+ require.True(t, cb.allow(cooldown, halfOpenMax))
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+
+ // Failure in HALF_OPEN should trip back to OPEN.
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ require.Equal(t, "OPEN", cb.stateString())
+ require.True(t, cb.isOpen())
+ require.False(t, cb.allow(cooldown, halfOpenMax))
+}
+
+func TestAccountCircuitBreaker_ResetOnSuccess(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+
+ // 4 failures followed by a success should reset the counter.
+ for i := 0; i < 4; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, int32(4), cb.consecutiveFails.Load())
+
+ cb.recordSuccess()
+ require.Equal(t, int32(0), cb.consecutiveFails.Load())
+ require.Equal(t, "CLOSED", cb.stateString())
+
+ // 4 more failures — still not tripped because counter was reset.
+ for i := 0; i < 4; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.Equal(t, "CLOSED", cb.stateString())
+
+ // 5th consecutive failure trips it.
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ require.Equal(t, "OPEN", cb.stateString())
+}
+
+func TestAccountCircuitBreaker_IntegrationWithScheduler(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(30)
+ accounts := []Account{
+ {
+ ID: 9001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ {
+ ID: 9002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
+ // Enable circuit breaker with low threshold for testing.
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 9001: {AccountID: 9001, LoadRate: 10, WaitingCount: 0},
+ 9002: {AccountID: 9002, LoadRate: 10, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ scheduler := svc.getOpenAIAccountScheduler()
+
+ // Report 3 consecutive failures for account 9001 — trips the circuit breaker.
+ for i := 0; i < 3; i++ {
+ scheduler.ReportResult(9001, false, nil, "", 0)
+ }
+
+ // Now all selections should avoid account 9001 and pick 9002.
+ for i := 0; i < 20; i++ {
+ sessionHash := fmt.Sprintf("cb_integration_%d", i)
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(9002), selection.Account.ID,
+ "circuit-open account 9001 should be skipped, got %d on iteration %d", selection.Account.ID, i)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ // Verify metrics tracked the trip.
+ snapshot := scheduler.SnapshotMetrics()
+ require.GreaterOrEqual(t, snapshot.CircuitBreakerOpenTotal, int64(1))
+}
+
+func TestAccountCircuitBreaker_AllOpenFallback(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(31)
+ accounts := []Account{
+ {
+ ID: 9101,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 9101: {AccountID: 9101, LoadRate: 10, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ scheduler := svc.getOpenAIAccountScheduler()
+
+ // Trip the only account.
+ for i := 0; i < 3; i++ {
+ scheduler.ReportResult(9101, false, nil, "", 0)
+ }
+
+ // Even though the only account is circuit-open, the scheduler should
+ // still return it (graceful degradation — never return "no accounts").
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "cb_fallback_test",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, int64(9101), selection.Account.ID)
+}
+
+func TestAccountCircuitBreaker_SelectReleasesUnselectedHalfOpenPermit(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(311)
+ accounts := []Account{
+ {
+ ID: 9111,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ {
+ ID: 9112,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 1
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 1
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 1
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 9111: {AccountID: 9111, LoadRate: 10, WaitingCount: 0},
+ 9112: {AccountID: 9112, LoadRate: 10, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ scheduler, ok := svc.getOpenAIAccountScheduler().(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+
+ // Trip both accounts to OPEN so next select will transition both to HALF_OPEN.
+ scheduler.ReportResult(9111, false, nil, "", 0)
+ scheduler.ReportResult(9112, false, nil, "", 0)
+ scheduler.stats.getCircuitBreaker(9111).lastFailureNano.Store(time.Now().Add(-2 * time.Second).UnixNano())
+ scheduler.stats.getCircuitBreaker(9112).lastFailureNano.Store(time.Now().Add(-2 * time.Second).UnixNano())
+
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "cb_release_unselected",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+
+ selectedID := selection.Account.ID
+ otherID := int64(9111)
+ if selectedID == otherID {
+ otherID = 9112
+ }
+ otherCB := scheduler.stats.getCircuitBreaker(otherID)
+ require.Equal(t, int32(0), otherCB.halfOpenInFlight.Load(),
+ "unselected HALF_OPEN candidate should release probe permit")
+
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestAccountCircuitBreaker_DisabledByConfig(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(32)
+ accounts := []Account{
+ {
+ ID: 9201,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ {
+ ID: 9202,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1
+ // Circuit breaker explicitly DISABLED.
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = false
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 9201: {AccountID: 9201, LoadRate: 10, WaitingCount: 0},
+ 9202: {AccountID: 9202, LoadRate: 10, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ scheduler := svc.getOpenAIAccountScheduler()
+ internalScheduler, ok := scheduler.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+
+ // Report many failures — should NOT affect scheduling when disabled.
+ for i := 0; i < 10; i++ {
+ scheduler.ReportResult(9201, false, nil, "", 0)
+ }
+
+ // Both accounts should still be eligible.
+ selected := map[int64]int{}
+ for i := 0; i < 40; i++ {
+ sessionHash := fmt.Sprintf("cb_disabled_%d", i)
+ selection, _, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ selected[selection.Account.ID]++
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+ // When disabled, 9201 should still appear as a candidate.
+ require.Greater(t, selected[int64(9201)]+selected[int64(9202)], 0)
+ require.Len(t, selected, 2, "both accounts should be selectable when CB is disabled")
+ cb := internalScheduler.stats.getCircuitBreaker(9201)
+ require.False(t, cb.isOpen(), "circuit breaker should not transition to OPEN when feature is disabled")
+ require.Equal(t, int64(0), internalScheduler.metrics.circuitBreakerOpenTotal.Load())
+}
+
+func TestAccountCircuitBreaker_RecoveryMetrics(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ svc := &OpenAIGatewayService{}
+ schedulerAny := newDefaultOpenAIAccountScheduler(svc, stats)
+ scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler)
+ require.True(t, ok)
+
+ // Manually enable CB by setting config on the service.
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 0 // immediate cooldown for test
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 1
+ scheduler.service.cfg = cfg
+
+ // Trip the breaker: 3 consecutive failures.
+ for i := 0; i < 3; i++ {
+ scheduler.ReportResult(5001, false, nil, "", 0)
+ }
+ require.Equal(t, int64(1), scheduler.metrics.circuitBreakerOpenTotal.Load())
+
+ // Let the cooldown expire (0 seconds) and call allow to trigger HALF_OPEN.
+ cb := stats.getCircuitBreaker(5001)
+ require.Equal(t, "OPEN", cb.stateString())
+ allowed := cb.allow(0, 1)
+ require.True(t, allowed)
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+
+ // Report success — should transition HALF_OPEN → CLOSED.
+ scheduler.ReportResult(5001, true, nil, "", 0)
+ require.Equal(t, "CLOSED", cb.stateString())
+ require.Equal(t, int64(1), scheduler.metrics.circuitBreakerRecoverTotal.Load())
+}
+
+func TestAccountCircuitBreaker_StateString(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ require.Equal(t, "CLOSED", cb.stateString())
+
+ cb.state.Store(circuitBreakerStateOpen)
+ require.Equal(t, "OPEN", cb.stateString())
+
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+
+ cb.state.Store(99)
+ require.Equal(t, "UNKNOWN", cb.stateString())
+}
+
+func TestAccountCircuitBreaker_GetAndIsCircuitOpen(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+
+ // isCircuitOpen on a non-existent account should return false.
+ require.False(t, stats.isCircuitOpen(1234))
+
+ // getCircuitBreaker should create on first access.
+ cb := stats.getCircuitBreaker(1234)
+ require.NotNil(t, cb)
+ require.Equal(t, "CLOSED", cb.stateString())
+ require.False(t, stats.isCircuitOpen(1234))
+
+ // Trip it and verify isCircuitOpen returns true.
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.True(t, stats.isCircuitOpen(1234))
+
+ // Second call to getCircuitBreaker should return same instance.
+ cb2 := stats.getCircuitBreaker(1234)
+ require.True(t, cb == cb2, "should return same pointer")
+}
+
+func TestAccountCircuitBreaker_ConcurrentAllowAndRecord(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cooldown := 50 * time.Millisecond
+ halfOpenMax := 4
+
+ var wg sync.WaitGroup
+ const workers = 16
+ const iterations = 200
+
+ wg.Add(workers)
+ for w := 0; w < workers; w++ {
+ w := w
+ go func() {
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ _ = cb.allow(cooldown, halfOpenMax)
+ if (i+w)%3 == 0 {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ } else {
+ cb.recordSuccess()
+ }
+ }
+ }()
+ }
+ wg.Wait()
+
+ // Just verify it doesn't panic or deadlock, and state is valid.
+ state := cb.state.Load()
+ require.True(t, state == circuitBreakerStateClosed ||
+ state == circuitBreakerStateOpen ||
+ state == circuitBreakerStateHalfOpen,
+ "unexpected state: %d", state)
+}
+
+// ---------------------------------------------------------------------------
+// P2C (Power-of-Two-Choices) Tests
+// ---------------------------------------------------------------------------
+
+func TestSelectP2COpenAICandidates_BasicSelection(t *testing.T) {
+ // P2C should return all candidates in some order, and higher-scored
+ // candidates should tend to appear earlier in the selection order.
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 1}, score: 0.9},
+ {account: &Account{ID: 2, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 2}, score: 0.5},
+ {account: &Account{ID: 3, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 3}, score: 0.1},
+ {account: &Account{ID: 4, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 4}, score: 0.7},
+ {account: &Account{ID: 5, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 5}, score: 0.3},
+ }
+
+ req := OpenAIAccountScheduleRequest{
+ SessionHash: "p2c_basic_test",
+ }
+
+ result := selectP2COpenAICandidates(candidates, req)
+
+ // All candidates must be present exactly once.
+ require.Len(t, result, len(candidates))
+ seen := map[int64]bool{}
+ for _, c := range result {
+ require.False(t, seen[c.account.ID], "duplicate account ID %d", c.account.ID)
+ seen[c.account.ID] = true
+ }
+ for _, c := range candidates {
+ require.True(t, seen[c.account.ID], "missing account ID %d", c.account.ID)
+ }
+
+ // Statistical check: over many runs the highest-scored candidate (ID=1,
+ // score=0.9) should appear in position 0 more often than the lowest-scored
+ // candidate (ID=3, score=0.1).
+ topCount := map[int64]int{}
+ iterations := 500
+ for i := 0; i < iterations; i++ {
+ iterReq := OpenAIAccountScheduleRequest{
+ SessionHash: fmt.Sprintf("p2c_stat_%d", i),
+ }
+ order := selectP2COpenAICandidates(candidates, iterReq)
+ topCount[order[0].account.ID]++
+ }
+ require.Greater(t, topCount[int64(1)], topCount[int64(3)],
+ "highest-scored candidate should appear first more often than lowest-scored; got best=%d worst=%d",
+ topCount[int64(1)], topCount[int64(3)])
+}
+
+func TestSelectP2COpenAICandidates_SingleCandidate(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 42, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 42}, score: 1.0},
+ }
+ req := OpenAIAccountScheduleRequest{SessionHash: "single"}
+
+ result := selectP2COpenAICandidates(candidates, req)
+ require.Len(t, result, 1)
+ require.Equal(t, int64(42), result[0].account.ID)
+}
+
+func TestSelectP2COpenAICandidates_DeterministicWithSameSeed(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 10, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 10}, score: 0.8},
+ {account: &Account{ID: 20, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 20}, score: 0.6},
+ {account: &Account{ID: 30, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 30}, score: 0.4},
+ {account: &Account{ID: 40, Priority: 0}, loadInfo: &AccountLoadInfo{AccountID: 40}, score: 0.2},
+ }
+ // Use a session hash to ensure the seed is deterministic (no time entropy).
+ req := OpenAIAccountScheduleRequest{
+ SessionHash: "deterministic_p2c_seed",
+ }
+
+ first := selectP2COpenAICandidates(candidates, req)
+ for i := 0; i < 10; i++ {
+ again := selectP2COpenAICandidates(candidates, req)
+ require.Len(t, again, len(first))
+ for j := range first {
+ require.Equal(t, first[j].account.ID, again[j].account.ID,
+ "iteration %d position %d mismatch", i, j)
+ }
+ }
+}
+
+func TestP2CLoadBalanceIntegration(t *testing.T) {
+ // End-to-end test: enable P2C via config, verify it distributes across
+ // accounts and that decision.TopK == 0 (P2C mode indicator).
+ ctx := context.Background()
+ groupID := int64(50)
+ accounts := []Account{
+ {
+ ID: 5001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey,
+ Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0,
+ },
+ {
+ ID: 5002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey,
+ Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0,
+ },
+ {
+ ID: 5003, Platform: PlatformOpenAI, Type: AccountTypeAPIKey,
+ Status: StatusActive, Schedulable: true, Concurrency: 10, Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.7
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.8
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.5
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 5001: {AccountID: 5001, LoadRate: 20, WaitingCount: 0},
+ 5002: {AccountID: 5002, LoadRate: 30, WaitingCount: 0},
+ 5003: {AccountID: 5003, LoadRate: 40, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selected := map[int64]int{}
+ iterations := 100
+ for i := 0; i < iterations; i++ {
+ sessionHash := fmt.Sprintf("p2c_integration_%d", i)
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ sessionHash,
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ // P2C mode: TopK should be 0.
+ require.Equal(t, 0, decision.TopK, "P2C mode should set TopK=0")
+ selected[selection.Account.ID]++
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ // P2C with 3 candidates: the two better-scored accounts (5001, 5002)
+ // should be selected, while the weakest (5003) may rarely or never win
+ // a P2C tournament. Verify at least 2 distinct accounts are picked and
+ // the lowest-loaded account dominates.
+ require.GreaterOrEqual(t, len(selected), 2,
+ "P2C should distribute across at least 2 accounts; got %v", selected)
+ require.Greater(t, selected[int64(5001)], 0,
+ "lowest-loaded account 5001 should be selected at least once")
+ require.Greater(t, selected[int64(5001)], selected[int64(5003)],
+ "lowest-loaded account should be favored over highest-loaded; got 5001=%d 5003=%d",
+ selected[int64(5001)], selected[int64(5003)])
+}
+
+func TestP2CFallbackToTopK(t *testing.T) {
+ // When P2C is disabled (default), the Top-K path should be used.
+ // Verify topK > 0 in decision.
+ ctx := context.Background()
+ groupID := int64(51)
+ accounts := []Account{
+ {
+ ID: 5101, Platform: PlatformOpenAI, Type: AccountTypeAPIKey,
+ Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0,
+ },
+ {
+ ID: 5102, Platform: PlatformOpenAI, Type: AccountTypeAPIKey,
+ Status: StatusActive, Schedulable: true, Concurrency: 5, Priority: 0,
+ },
+ }
+
+ cfg := &config.Config{}
+ // Explicitly disable P2C (or leave at default false).
+ cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false
+ cfg.Gateway.OpenAIWS.LBTopK = 2
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.7
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.8
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.5
+
+ concurrencyCache := stubConcurrencyCache{
+ loadMap: map[int64]*AccountLoadInfo{
+ 5101: {AccountID: 5101, LoadRate: 10, WaitingCount: 0},
+ 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 0},
+ },
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: &stubGatewayCache{sessionBindings: map[string]int64{}},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(concurrencyCache),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx,
+ &groupID,
+ "",
+ "topk_fallback_test",
+ "gpt-5.1",
+ nil,
+ OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ // Top-K mode: TopK should be > 0.
+ require.Greater(t, decision.TopK, 0, "Top-K mode should set TopK > 0")
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+
+ // Also verify P2C helper returns false when disabled.
+ require.False(t, svc.openAIWSSchedulerP2CEnabled())
+}
+
+// ---------------------------------------------------------------------------
+// Conditional Sticky Session Release Tests
+// ---------------------------------------------------------------------------
+
+// buildConditionalStickyTestService creates a minimal OpenAIGatewayService and
+// scheduler with injectable runtime stats for conditional sticky tests.
+func buildConditionalStickyTestService(
+ accounts []Account,
+ stickyKey string,
+ stickyAccountID int64,
+ stickyReleaseEnabled bool,
+ cbEnabled bool,
+) (*OpenAIGatewayService, *defaultOpenAIAccountScheduler) {
+ cache := &stubGatewayCache{
+ sessionBindings: map[string]int64{
+ stickyKey: stickyAccountID,
+ },
+ }
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickyReleaseEnabled = stickyReleaseEnabled
+ // Leave StickyReleaseErrorThreshold at 0 to use the default (0.3).
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = cbEnabled
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 3
+ cfg.Gateway.Scheduling.StickySessionMaxWaiting = 5
+ cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := &defaultOpenAIAccountScheduler{
+ service: svc,
+ stats: stats,
+ }
+ // Wire the scheduler into the service so that SelectAccountWithScheduler
+ // uses it via getOpenAIAccountScheduler.
+ svc.openaiScheduler = scheduler
+ svc.openaiAccountStats = stats
+ return svc, scheduler
+}
+
+func TestConditionalSticky_ReleaseOnHighErrorRate(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(30001)
+ stickyAccount := Account{
+ ID: 5001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+ fallbackAccount := Account{
+ ID: 5002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+
+ svc, scheduler := buildConditionalStickyTestService(
+ []Account{stickyAccount, fallbackAccount},
+ fmt.Sprintf("openai:sticky_err_%d", groupID),
+ stickyAccount.ID,
+ true, // stickyReleaseEnabled
+ false, // cbEnabled (not needed for error rate test)
+ )
+
+ // Pump the error rate above the 0.3 default threshold.
+ // With alpha=0.5 (fast EWMA), after ~5 consecutive failures the rate
+ // converges well above 0.3.
+ for i := 0; i < 10; i++ {
+ scheduler.stats.report(stickyAccount.ID, false, nil, "", 0)
+ }
+ errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID)
+ require.Greater(t, errRate, 0.3, "error rate should exceed threshold before test")
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "",
+ fmt.Sprintf("sticky_err_%d", groupID),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ // The sticky account should have been released; the scheduler should
+ // have fallen through to load balance and selected one of the accounts.
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer,
+ "should fall through to load balance after sticky release")
+ require.False(t, decision.StickySessionHit, "sticky hit should be false")
+}
+
+func TestConditionalSticky_ReleaseOnCircuitOpen(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(30002)
+ stickyAccount := Account{
+ ID: 5011,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+ fallbackAccount := Account{
+ ID: 5012,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+
+ svc, scheduler := buildConditionalStickyTestService(
+ []Account{stickyAccount, fallbackAccount},
+ fmt.Sprintf("openai:sticky_cb_%d", groupID),
+ stickyAccount.ID,
+ true, // stickyReleaseEnabled
+ true, // cbEnabled
+ )
+
+ // Trip the circuit breaker by reporting consecutive failures beyond
+ // the configured threshold (3).
+ for i := 0; i < 5; i++ {
+ scheduler.ReportResult(stickyAccount.ID, false, nil, "", 0)
+ }
+ require.True(t, scheduler.stats.isCircuitOpen(stickyAccount.ID),
+ "circuit breaker should be OPEN before test")
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "",
+ fmt.Sprintf("sticky_cb_%d", groupID),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer,
+ "should fall through to load balance after sticky release due to CB open")
+ require.False(t, decision.StickySessionHit)
+}
+
+func TestConditionalSticky_KeepsHealthySticky(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(30003)
+ stickyAccount := Account{
+ ID: 5021,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+
+ svc, scheduler := buildConditionalStickyTestService(
+ []Account{stickyAccount},
+ fmt.Sprintf("openai:sticky_ok_%d", groupID),
+ stickyAccount.ID,
+ true, // stickyReleaseEnabled
+ false, // cbEnabled
+ )
+
+ // Report some successes so error rate stays at 0.
+ for i := 0; i < 5; i++ {
+ scheduler.stats.report(stickyAccount.ID, true, nil, "", 0)
+ }
+ errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID)
+ require.Less(t, errRate, 0.3, "error rate should be below threshold")
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "",
+ fmt.Sprintf("sticky_ok_%d", groupID),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, stickyAccount.ID, selection.Account.ID,
+ "healthy sticky account should be kept")
+ require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
+ require.True(t, decision.StickySessionHit)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestConditionalSticky_DisabledByConfig(t *testing.T) {
+ ctx := context.Background()
+ groupID := int64(30004)
+ stickyAccount := Account{
+ ID: 5031,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+
+ svc, scheduler := buildConditionalStickyTestService(
+ []Account{stickyAccount},
+ fmt.Sprintf("openai:sticky_off_%d", groupID),
+ stickyAccount.ID,
+ false, // stickyReleaseEnabled = OFF
+ false, // cbEnabled
+ )
+
+ // Pump error rate very high, but since sticky release is disabled,
+ // the sticky binding should still hold.
+ for i := 0; i < 10; i++ {
+ scheduler.stats.report(stickyAccount.ID, false, nil, "", 0)
+ }
+ errRate, _, _ := scheduler.stats.snapshot(stickyAccount.ID)
+ require.Greater(t, errRate, 0.3, "error rate should exceed threshold")
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "",
+ fmt.Sprintf("sticky_off_%d", groupID),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, stickyAccount.ID, selection.Account.ID,
+ "sticky should be kept when feature is disabled")
+ require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer)
+ require.True(t, decision.StickySessionHit)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestConditionalSticky_Metrics(t *testing.T) {
+ groupID := int64(30005)
+ ctx := context.Background()
+
+ // --- Error rate release metric ---
+ stickyAccount := Account{
+ ID: 5041,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+ fallbackAccount := Account{
+ ID: 5042,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+
+ svc, scheduler := buildConditionalStickyTestService(
+ []Account{stickyAccount, fallbackAccount},
+ fmt.Sprintf("openai:sticky_m1_%d", groupID),
+ stickyAccount.ID,
+ true, // stickyReleaseEnabled
+ true, // cbEnabled
+ )
+
+ // Trigger error-rate release. With CB also enabled and threshold=3,
+ // the CB will be OPEN after 3 failures via stats.report (which uses
+ // the default CB threshold of 5). Send enough to ensure both are
+ // triggered.
+ for i := 0; i < 10; i++ {
+ scheduler.stats.report(stickyAccount.ID, false, nil, "", 0)
+ }
+
+ _, _, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "",
+ fmt.Sprintf("sticky_m1_%d", groupID),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+
+ snap := scheduler.SnapshotMetrics()
+ // At least one of the two release metrics should have been incremented.
+ totalReleases := snap.StickyReleaseErrorTotal + snap.StickyReleaseCircuitOpenTotal
+ require.Greater(t, totalReleases, int64(0),
+ "at least one sticky release metric should be incremented")
+
+ // --- Circuit breaker release metric (clean setup) ---
+ groupID2 := int64(30006)
+ svc2, scheduler2 := buildConditionalStickyTestService(
+ []Account{stickyAccount, fallbackAccount},
+ fmt.Sprintf("openai:sticky_m2_%d", groupID2),
+ stickyAccount.ID,
+ true, // stickyReleaseEnabled
+ true, // cbEnabled
+ )
+
+ // Trip CB via ReportResult (which checks the configured threshold=3).
+ for i := 0; i < 5; i++ {
+ scheduler2.ReportResult(stickyAccount.ID, false, nil, "", 0)
+ }
+ require.True(t, scheduler2.stats.isCircuitOpen(stickyAccount.ID))
+
+ _, _, err = svc2.SelectAccountWithScheduler(
+ ctx, &groupID2, "",
+ fmt.Sprintf("sticky_m2_%d", groupID2),
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+
+ snap2 := scheduler2.SnapshotMetrics()
+ require.Greater(t, snap2.StickyReleaseCircuitOpenTotal, int64(0),
+ "circuit-open sticky release metric should be incremented")
+}
+
+// ---------------------------------------------------------------------------
+// Softmax Temperature Sampling Tests
+// ---------------------------------------------------------------------------
+
+// makeSoftmaxCandidates builds N candidates with the given scores.
+func makeSoftmaxCandidates(scores ...float64) []openAIAccountCandidateScore {
+ out := make([]openAIAccountCandidateScore, len(scores))
+ for i, s := range scores {
+ out[i] = openAIAccountCandidateScore{
+ account: &Account{ID: int64(i + 1), Priority: 0},
+ loadInfo: &AccountLoadInfo{AccountID: int64(i + 1)},
+ score: s,
+ }
+ }
+ return out
+}
+
+func TestSoftmax_LowTemperatureApproximatesArgmax(t *testing.T) {
+ // With a very low temperature the highest-scored candidate should win
+ // almost every time.
+ candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5)
+
+ winCount := 0
+ trials := 100
+ for i := 0; i < trials; i++ {
+ rng := newOpenAISelectionRNG(uint64(i + 1))
+ result := selectSoftmaxOpenAICandidates(candidates, 0.01, &rng)
+ require.Len(t, result, len(candidates))
+ if result[0].account.ID == 1 { // ID 1 has score 5.0 (the highest)
+ winCount++
+ }
+ }
+
+ require.Greater(t, winCount, 90,
+ "with temperature=0.01 the highest-scored candidate should win >90%% of trials; got %d/%d", winCount, trials)
+}
+
+func TestSoftmax_HighTemperatureApproximatesUniform(t *testing.T) {
+ // With a very high temperature, all candidates should get roughly equal
+ // selection frequency.
+ candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5)
+
+ counts := map[int64]int{}
+ trials := 1000
+ for i := 0; i < trials; i++ {
+ rng := newOpenAISelectionRNG(uint64(i + 1))
+ result := selectSoftmaxOpenAICandidates(candidates, 100.0, &rng)
+ require.Len(t, result, len(candidates))
+ counts[result[0].account.ID]++
+ }
+
+ expected := float64(trials) / float64(len(candidates)) // 250
+ for id, count := range counts {
+ require.InDelta(t, expected, float64(count), float64(trials)*0.10,
+ "candidate ID=%d expected ~%.0f selections, got %d", id, expected, count)
+ }
+}
+
+func TestSoftmax_DefaultTemperature(t *testing.T) {
+ // With the default temperature (0.3), higher-scored candidates should be
+ // picked more often than lower-scored ones.
+ candidates := makeSoftmaxCandidates(5.0, 3.0, 1.0, 0.5)
+
+ counts := map[int64]int{}
+ trials := 1000
+ for i := 0; i < trials; i++ {
+ rng := newOpenAISelectionRNG(uint64(i + 1))
+ result := selectSoftmaxOpenAICandidates(candidates, defaultSoftmaxTemperature, &rng)
+ counts[result[0].account.ID]++
+ }
+
+ // The candidate with the highest score (ID=1, score=5.0) should be
+ // selected more often than the candidate with the lowest (ID=4, score=0.5).
+ require.Greater(t, counts[int64(1)], counts[int64(4)],
+ "highest-scored candidate should be picked more often; best=%d worst=%d",
+ counts[int64(1)], counts[int64(4)])
+
+ // Also check that the top-scored candidate beats the second-highest.
+ require.Greater(t, counts[int64(1)], counts[int64(2)],
+ "score=5.0 should beat score=3.0; got %d vs %d",
+ counts[int64(1)], counts[int64(2)])
+}
+
+func TestSoftmax_SingleCandidate(t *testing.T) {
+ candidates := makeSoftmaxCandidates(7.5)
+
+ rng := newOpenAISelectionRNG(42)
+ result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng)
+
+ require.Len(t, result, 1)
+ require.Equal(t, int64(1), result[0].account.ID)
+ require.Equal(t, 7.5, result[0].score)
+}
+
+func TestSoftmax_TwoCandidates(t *testing.T) {
+ // Use a moderate score gap (1.0 vs 0.5) with temperature=1.0 so both
+ // candidates have meaningful selection probability.
+ candidates := makeSoftmaxCandidates(1.0, 0.5)
+
+ counts := map[int64]int{}
+ trials := 1000
+ for i := 0; i < trials; i++ {
+ rng := newOpenAISelectionRNG(uint64(i + 1))
+ result := selectSoftmaxOpenAICandidates(candidates, 1.0, &rng)
+ require.Len(t, result, 2)
+ counts[result[0].account.ID]++
+ }
+
+ // Both candidates should be selected at least once (proving no
+ // single-candidate monopoly), and the higher-scored one should dominate.
+ require.Greater(t, counts[int64(1)], 0, "high-scored candidate must be selected at least once")
+ require.Greater(t, counts[int64(2)], 0, "low-scored candidate must be selected at least once")
+ require.Greater(t, counts[int64(1)], counts[int64(2)],
+ "higher-scored candidate should be picked more often; got %d vs %d",
+ counts[int64(1)], counts[int64(2)])
+}
+
+func TestSoftmax_EqualScores(t *testing.T) {
+ // When all scores are equal, selection should be approximately uniform.
+ candidates := makeSoftmaxCandidates(3.0, 3.0, 3.0, 3.0)
+
+ counts := map[int64]int{}
+ trials := 1000
+ for i := 0; i < trials; i++ {
+ rng := newOpenAISelectionRNG(uint64(i + 1))
+ result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng)
+ counts[result[0].account.ID]++
+ }
+
+ expected := float64(trials) / float64(len(candidates)) // 250
+ for id, count := range counts {
+ require.InDelta(t, expected, float64(count), float64(trials)*0.10,
+ "equal scores should yield ~uniform distribution; ID=%d expected ~%.0f got %d",
+ id, expected, count)
+ }
+}
+
+func TestSoftmax_NumericalStability(t *testing.T) {
+ // Large score differences should not cause overflow or NaN.
+ candidates := makeSoftmaxCandidates(100.0, -100.0, 50.0, -50.0)
+
+ rng := newOpenAISelectionRNG(12345)
+ result := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng)
+
+ require.Len(t, result, len(candidates))
+ // Verify all scores are finite in the output (no NaN or Inf propagation).
+ for _, c := range result {
+ require.False(t, math.IsNaN(c.score), "score should not be NaN")
+ require.False(t, math.IsInf(c.score, 0), "score should not be Inf")
+ }
+ // All candidates must appear exactly once.
+ seen := map[int64]bool{}
+ for _, c := range result {
+ require.False(t, seen[c.account.ID], "duplicate account ID %d", c.account.ID)
+ seen[c.account.ID] = true
+ }
+ require.Len(t, seen, len(candidates))
+
+ // With such extreme differences at temperature=0.3, the highest scorer (100.0)
+ // should always win because exp((100 - 100)/0.3) = 1 while
+ // exp((-100 - 100)/0.3) ~= 0 (numerically stable via maxScore subtraction).
+ winCount := 0
+ for i := 0; i < 100; i++ {
+ rng2 := newOpenAISelectionRNG(uint64(i + 1))
+ r := selectSoftmaxOpenAICandidates(candidates, 0.3, &rng2)
+ if r[0].account.ID == 1 { // score 100.0
+ winCount++
+ }
+ }
+ require.Greater(t, winCount, 95,
+ "with extreme score gap, the highest scorer should win nearly always; got %d/100", winCount)
+}
+
+func TestSoftmax_DisabledFallsThrough(t *testing.T) {
+ // When softmax is disabled, the scheduler should fall through to P2C or Top-K.
+ ctx := context.Background()
+ groupID := int64(40001)
+ accounts := make([]Account, 5)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(6001 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+ }
+
+ cache := &stubGatewayCache{}
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ // Softmax explicitly disabled (default).
+ cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = false
+ // P2C also disabled.
+ cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "", "softmax_disabled_test",
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ // TopK should be > 0, confirming Top-K path was taken.
+ require.Greater(t, decision.TopK, 0, "should fall through to Top-K when softmax is disabled")
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestSoftmax_FewCandidatesFallsThrough(t *testing.T) {
+ // When softmax is enabled but there are <= 3 candidates, it should fall
+ // through to the next strategy (P2C or Top-K).
+ ctx := context.Background()
+ groupID := int64(40002)
+ // Only 3 accounts — softmax guard requires >3.
+ accounts := make([]Account, 3)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(7001 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 5,
+ }
+ }
+
+ cache := &stubGatewayCache{}
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ // Softmax enabled but should not activate with only 3 candidates.
+ cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0.5
+ // P2C disabled, so it should fall through to Top-K.
+ cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = false
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cache: cache,
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ selection, decision, err := svc.SelectAccountWithScheduler(
+ ctx, &groupID, "", "softmax_few_candidates_test",
+ "gpt-5.1", nil, OpenAIUpstreamTransportAny,
+ )
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer)
+ // TopK should be > 0, confirming Top-K path was taken instead of softmax.
+ require.Greater(t, decision.TopK, 0, "should fall through to Top-K when candidate count <= 3")
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+}
+
+func TestSoftmax_ConfigDefaults(t *testing.T) {
+ // When config values are zero/unset, defaults should be applied.
+
+ // Case 1: nil service — returns empty config.
+ nilScheduler := &defaultOpenAIAccountScheduler{}
+ cfg0 := nilScheduler.softmaxConfigRead()
+ require.False(t, cfg0.enabled)
+ require.Equal(t, 0.0, cfg0.temperature) // no default when service is nil
+
+ // Case 2: zero temperature (unset) — should default to 0.3.
+ svc := &OpenAIGatewayService{
+ cfg: &config.Config{},
+ }
+ svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxEnabled = true
+ svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0 // unset
+ scheduler := &defaultOpenAIAccountScheduler{
+ service: svc,
+ stats: newOpenAIAccountRuntimeStats(),
+ }
+ cfg1 := scheduler.softmaxConfigRead()
+ require.True(t, cfg1.enabled)
+ require.Equal(t, 0.3, cfg1.temperature, "default temperature should be 0.3")
+
+ // Case 3: explicit temperature — should use the configured value.
+ svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = 0.7
+ cfg2 := scheduler.softmaxConfigRead()
+ require.True(t, cfg2.enabled)
+ require.Equal(t, 0.7, cfg2.temperature, "should use explicitly configured temperature")
+
+ // Case 4: negative temperature — should fall back to default 0.3.
+ svc.cfg.Gateway.OpenAIWS.SchedulerSoftmaxTemperature = -1.0
+ cfg3 := scheduler.softmaxConfigRead()
+ require.Equal(t, 0.3, cfg3.temperature, "negative temperature should fall back to default 0.3")
+}
+
+// ---------------------------------------------------------------------------
+// Per-Model TTFT Tests
+// ---------------------------------------------------------------------------
+
+func TestPerModelTTFT_IndependentTracking(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8001)
+
+ // Report different TTFT values for two different models on the same account.
+ stats.report(accountID, true, nil, "gpt-4o", 100)
+ stats.report(accountID, true, nil, "gpt-4o", 120)
+ stats.report(accountID, true, nil, "o3-pro", 500)
+ stats.report(accountID, true, nil, "o3-pro", 600)
+
+ // Snapshot for model gpt-4o.
+ _, ttftGPT4o, hasTTFT := stats.snapshot(accountID, "gpt-4o")
+ require.True(t, hasTTFT, "gpt-4o should have TTFT data")
+
+ // Snapshot for model o3-pro.
+ _, ttftO3Pro, hasO3Pro := stats.snapshot(accountID, "o3-pro")
+ require.True(t, hasO3Pro, "o3-pro should have TTFT data")
+
+ // The two models should have different TTFT values because their
+ // sample inputs are very different (100-120 vs 500-600).
+ require.Greater(t, math.Abs(ttftGPT4o-ttftO3Pro), 50.0,
+ "different models should track independent TTFT values")
+
+ // gpt-4o TTFT should be much lower than o3-pro.
+ require.Less(t, ttftGPT4o, ttftO3Pro,
+ "gpt-4o should have lower TTFT than o3-pro")
+}
+
+func TestPerModelTTFT_FallbackToGlobal(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8002)
+
+ // Report TTFT for a specific model.
+ stats.report(accountID, true, nil, "gpt-4o", 200)
+
+ // Snapshot with an unknown model should fall back to global TTFT.
+ _, ttftUnknown, hasUnknown := stats.snapshot(accountID, "unknown-model")
+ require.True(t, hasUnknown, "should fall back to global TTFT")
+
+ // Global TTFT should have been updated by the gpt-4o report.
+ _, ttftGlobal, hasGlobal := stats.snapshot(accountID)
+ require.True(t, hasGlobal, "global TTFT should exist")
+
+ // The unknown-model fallback should equal the global.
+ require.InDelta(t, ttftGlobal, ttftUnknown, 1e-9,
+ "unknown model should return global TTFT as fallback")
+}
+
+func TestPerModelTTFT_GlobalAlsoUpdated(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8003)
+
+ // No data initially.
+ _, _, hasGlobal := stats.snapshot(accountID)
+ require.False(t, hasGlobal, "no global TTFT initially")
+
+ // Report with model — should also update global.
+ stats.report(accountID, true, nil, "gpt-4o", 300)
+
+ _, ttftGlobal, hasGlobal := stats.snapshot(accountID)
+ require.True(t, hasGlobal, "global TTFT should exist after model report")
+ require.InDelta(t, 300.0, ttftGlobal, 1e-9,
+ "global TTFT should be updated by model report")
+}
+
+func TestPerModelTTFT_TTLCleanup(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ stat.ttft.initNaN()
+
+ // Manually insert a model entry with an old timestamp.
+ d := &dualEWMATTFT{}
+ d.initNaN()
+ d.update(100)
+ stat.modelTTFT.Store("old-model", d)
+ stat.modelTTFTLastUpdate.Store("old-model", time.Now().Add(-time.Hour).UnixNano())
+
+ // Insert a recent model entry.
+ d2 := &dualEWMATTFT{}
+ d2.initNaN()
+ d2.update(200)
+ stat.modelTTFT.Store("new-model", d2)
+ stat.modelTTFTLastUpdate.Store("new-model", time.Now().UnixNano())
+
+ // Cleanup with 30-minute TTL — old-model should be removed.
+ stat.cleanupStaleTTFT(30*time.Minute, 100)
+
+ _, hasOld := stat.modelTTFTValue("old-model")
+ require.False(t, hasOld, "old-model should be cleaned up")
+
+ _, hasNew := stat.modelTTFTValue("new-model")
+ require.True(t, hasNew, "new-model should survive cleanup")
+}
+
+func TestPerModelTTFT_MaxModelLimit(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ stat.ttft.initNaN()
+
+ now := time.Now()
+ // Insert 10 models with sequential timestamps.
+ for i := 0; i < 10; i++ {
+ model := fmt.Sprintf("model-%d", i)
+ d := &dualEWMATTFT{}
+ d.initNaN()
+ d.update(float64(100 + i*10))
+ stat.modelTTFT.Store(model, d)
+ stat.modelTTFTLastUpdate.Store(model, now.Add(time.Duration(i)*time.Second).UnixNano())
+ }
+
+ // Enforce limit of 5 models — the 5 oldest should be evicted.
+ stat.cleanupStaleTTFT(time.Hour, 5)
+
+ // Count remaining models.
+ remaining := 0
+ stat.modelTTFT.Range(func(_, _ any) bool {
+ remaining++
+ return true
+ })
+ require.Equal(t, 5, remaining, "should have exactly 5 models after cleanup")
+
+ // The newest 5 (model-5 through model-9) should survive.
+ for i := 5; i < 10; i++ {
+ model := fmt.Sprintf("model-%d", i)
+ _, has := stat.modelTTFTValue(model)
+ require.True(t, has, "%s should survive", model)
+ }
+ // The oldest 5 (model-0 through model-4) should be evicted.
+ for i := 0; i < 5; i++ {
+ model := fmt.Sprintf("model-%d", i)
+ _, has := stat.modelTTFTValue(model)
+ require.False(t, has, "%s should be evicted", model)
+ }
+}
+
+func TestPerModelTTFT_SnapshotUsesModelData(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8006)
+
+ // Report two models with very different TTFT.
+ for i := 0; i < 5; i++ {
+ stats.report(accountID, true, nil, "fast-model", 50)
+ stats.report(accountID, true, nil, "slow-model", 500)
+ }
+
+ // Snapshot with specific model should return that model's TTFT.
+ _, ttftFast, hasFast := stats.snapshot(accountID, "fast-model")
+ require.True(t, hasFast)
+
+ _, ttftSlow, hasSlow := stats.snapshot(accountID, "slow-model")
+ require.True(t, hasSlow)
+
+ // Fast model should have much lower TTFT.
+ require.Less(t, ttftFast, 100.0, "fast-model TTFT should be close to 50")
+ require.Greater(t, ttftSlow, 400.0, "slow-model TTFT should be close to 500")
+
+ // Global TTFT should be a blend of both.
+ _, ttftGlobal, hasGlobal := stats.snapshot(accountID)
+ require.True(t, hasGlobal)
+ require.Greater(t, ttftGlobal, ttftFast, "global TTFT should be higher than fast-model")
+ require.Less(t, ttftGlobal, ttftSlow, "global TTFT should be lower than slow-model")
+}
+
+func TestPerModelTTFT_ConcurrentAccess(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8007)
+
+ const workers = 8
+ const iterations = 200
+ models := []string{"model-a", "model-b", "model-c", "model-d"}
+
+ var wg sync.WaitGroup
+ wg.Add(workers)
+ for w := 0; w < workers; w++ {
+ w := w
+ go func() {
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ model := models[(w+i)%len(models)]
+ ttft := float64(100 + (w*10+i)%200)
+ stats.report(accountID, true, nil, model, ttft)
+
+ // Also read concurrently.
+ stats.snapshot(accountID, model)
+ stats.snapshot(accountID)
+ }
+ }()
+ }
+ wg.Wait()
+
+ // All models should have TTFT data.
+ for _, model := range models {
+ _, ttft, has := stats.snapshot(accountID, model)
+ require.True(t, has, "%s should have TTFT", model)
+ require.Greater(t, ttft, 0.0, "%s TTFT should be positive", model)
+ }
+
+ // Global should also have data.
+ _, ttftGlobal, hasGlobal := stats.snapshot(accountID)
+ require.True(t, hasGlobal)
+ require.Greater(t, ttftGlobal, 0.0)
+}
+
+func TestPerModelTTFT_EmptyModel(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(8008)
+
+ // Report with empty model — should only update global TTFT.
+ ttft := 150
+ stats.report(accountID, true, &ttft, "", 0)
+
+ // Global should have data.
+ _, ttftGlobal, hasGlobal := stats.snapshot(accountID)
+ require.True(t, hasGlobal, "global TTFT should exist from firstTokenMs")
+ require.InDelta(t, 150.0, ttftGlobal, 1e-9)
+
+ // Snapshot with empty model returns global.
+ _, ttftEmpty, hasEmpty := stats.snapshot(accountID, "")
+ require.True(t, hasEmpty)
+ require.InDelta(t, ttftGlobal, ttftEmpty, 1e-9,
+ "empty model snapshot should return global TTFT")
+
+ // No per-model entries should exist.
+ stat := stats.loadOrCreate(accountID)
+ count := 0
+ stat.modelTTFT.Range(func(_, _ any) bool {
+ count++
+ return true
+ })
+ require.Equal(t, 0, count, "no per-model entries should exist for empty model")
+}
+
+// ---------------------------------------------------------------------------
+// Load Trend Prediction Tests
+// ---------------------------------------------------------------------------
+
+func TestLoadTrend_RisingLoad(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 10; i++ {
+ tracker.recordAt(float64((i+1)*10), base+int64(i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.Greater(t, slope, 0.0, "rising load should produce positive slope")
+ require.InDelta(t, 10.0, slope, 0.01, "slope should be ~10 per second")
+}
+
+func TestLoadTrend_FallingLoad(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 10; i++ {
+ tracker.recordAt(float64(100-i*10), base+int64(i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.Less(t, slope, 0.0, "falling load should produce negative slope")
+ require.InDelta(t, -10.0, slope, 0.01, "slope should be ~-10 per second")
+}
+
+func TestLoadTrend_StableLoad(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 10; i++ {
+ tracker.recordAt(50.0, base+int64(i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.InDelta(t, 0.0, slope, 1e-9, "constant load should produce zero slope")
+}
+
+func TestLoadTrend_RingBufferFull(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 5; i++ {
+ tracker.recordAt(100.0, base+int64(i)*int64(time.Second))
+ }
+ for i := 0; i < 10; i++ {
+ tracker.recordAt(float64((i+1)*10), base+int64(5+i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.Greater(t, slope, 0.0, "should reflect rising trend from last 10 samples")
+ require.InDelta(t, 10.0, slope, 0.01, "slope should be ~10 per second after ring wraps")
+}
+
+func TestLoadTrend_InsufficientSamples(t *testing.T) {
+ var tracker loadTrendTracker
+ slope := tracker.slope()
+ require.Equal(t, 0.0, slope, "zero samples should return slope 0")
+}
+
+func TestLoadTrend_SingleSample(t *testing.T) {
+ var tracker loadTrendTracker
+ tracker.recordAt(42.0, time.Now().UnixNano())
+ slope := tracker.slope()
+ require.Equal(t, 0.0, slope, "single sample should return slope 0")
+}
+
+func TestLoadTrend_TwoSamples(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ tracker.recordAt(10.0, base)
+ tracker.recordAt(30.0, base+int64(2*time.Second))
+ slope := tracker.slope()
+ require.InDelta(t, 10.0, slope, 0.01, "two-sample slope should be exact delta/time")
+}
+
+func TestLoadTrend_AllSameTimestamp(t *testing.T) {
+ var tracker loadTrendTracker
+ ts := time.Now().UnixNano()
+ for i := 0; i < 5; i++ {
+ tracker.recordAt(float64(i*10), ts)
+ }
+ slope := tracker.slope()
+ require.Equal(t, 0.0, slope, "all-same-timestamp should return slope 0 (degenerate)")
+}
+
+func TestLoadTrend_NegativeSlope(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 5; i++ {
+ tracker.recordAt(50.0-float64(i)*5.0, base+int64(i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.InDelta(t, -5.0, slope, 0.01)
+}
+
+func TestLoadTrend_ScoringIntegration(t *testing.T) {
+ accounts := []Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ }
+
+ loadMap := map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 50},
+ 2: {AccountID: 2, LoadRate: 50},
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 5.0
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{
+ loadMap: loadMap,
+ skipDefaultLoad: true,
+ }),
+ }
+
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+
+ base := time.Now().UnixNano() - int64(10*time.Second)
+ stat1 := stats.loadOrCreate(1)
+ for i := 0; i < 9; i++ {
+ stat1.loadTrend.recordAt(float64((i+1)*10), base+int64(i)*int64(time.Second))
+ }
+ stat2 := stats.loadOrCreate(2)
+ for i := 0; i < 9; i++ {
+ stat2.loadTrend.recordAt(50.0, base+int64(i)*int64(time.Second))
+ }
+
+ ctx := context.Background()
+ selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{
+ RequiredTransport: OpenAIUpstreamTransportAny,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+
+ slope1 := stat1.loadTrend.slope()
+ slope2 := stat2.loadTrend.slope()
+ require.Greater(t, slope1, slope2,
+ "rising-trend account should have higher slope than stable account; slope1=%f slope2=%f", slope1, slope2)
+ require.Greater(t, slope1, 0.0, "rising-trend account slope should be positive")
+ require.InDelta(t, 0.0, slope2, 1.0,
+ "stable-trend account slope should be near zero; got %f", slope2)
+
+ trendAdj1 := 1.0 - clamp01(slope1/5.0)
+ trendAdj2 := 1.0 - clamp01(slope2/5.0)
+ require.Less(t, trendAdj1, trendAdj2,
+ "rising-trend trendAdj should be less than stable trendAdj; adj1=%f adj2=%f", trendAdj1, trendAdj2)
+}
+
+func TestLoadTrend_DisabledByDefault(t *testing.T) {
+ accounts := []Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+
+ base := time.Now().UnixNano()
+ stat1 := stats.loadOrCreate(1)
+ for i := 0; i < 10; i++ {
+ stat1.loadTrend.recordAt(float64(i*10), base+int64(i)*int64(time.Second))
+ }
+ stat2 := stats.loadOrCreate(2)
+ for i := 0; i < 10; i++ {
+ stat2.loadTrend.recordAt(30.0, base+int64(i)*int64(time.Second))
+ }
+
+ ctx := context.Background()
+ selectedCounts := map[int64]int{}
+ const rounds = 100
+ for r := 0; r < rounds; r++ {
+ selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{
+ RequiredTransport: OpenAIUpstreamTransportAny,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ selectedCounts[selection.Account.ID]++
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ require.Greater(t, selectedCounts[int64(1)], 10,
+ "with trend disabled, account 1 should still get selections; got %d", selectedCounts[int64(1)])
+ require.Greater(t, selectedCounts[int64(2)], 10,
+ "with trend disabled, account 2 should still get selections; got %d", selectedCounts[int64(2)])
+}
+
+func TestLoadTrend_ConcurrentAccess(t *testing.T) {
+ var tracker loadTrendTracker
+ var wg sync.WaitGroup
+ const goroutines = 10
+ const recordsPerGoroutine = 100
+
+ for g := 0; g < goroutines; g++ {
+ wg.Add(1)
+ go func(id int) {
+ defer wg.Done()
+ for i := 0; i < recordsPerGoroutine; i++ {
+ tracker.record(float64(id*100 + i))
+ }
+ }(g)
+ }
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < goroutines*recordsPerGoroutine; i++ {
+ _ = tracker.slope()
+ }
+ }()
+
+ wg.Wait()
+
+ slope := tracker.slope()
+ require.False(t, math.IsNaN(slope), "slope should be finite after concurrent access")
+ require.False(t, math.IsInf(slope, 0), "slope should be finite after concurrent access")
+}
+
+func TestLoadTrend_TrendConfigDefaults(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ enabled, maxSlope := svc.openAIWSSchedulerTrendConfig()
+ require.False(t, enabled, "trend should be disabled by default")
+ require.Equal(t, defaultSchedulerTrendMaxSlope, maxSlope, "maxSlope should use default")
+}
+
+func TestLoadTrend_TrendConfigCustom(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 8.0
+ svc := &OpenAIGatewayService{cfg: cfg}
+ enabled, maxSlope := svc.openAIWSSchedulerTrendConfig()
+ require.True(t, enabled)
+ require.Equal(t, 8.0, maxSlope)
+}
+
+func TestLoadTrend_TrendConfigZeroMaxSlope(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 0
+ svc := &OpenAIGatewayService{cfg: cfg}
+ enabled, maxSlope := svc.openAIWSSchedulerTrendConfig()
+ require.True(t, enabled)
+ require.Equal(t, defaultSchedulerTrendMaxSlope, maxSlope)
+}
+
+func TestLoadTrend_RingBufferCountTracking(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+
+ for i := 0; i < loadTrendRingSize; i++ {
+ tracker.recordAt(float64(i), base+int64(i)*int64(time.Second))
+ }
+ require.Equal(t, loadTrendRingSize, tracker.count, "count should equal ring size after filling")
+
+ tracker.recordAt(99.0, base+int64(loadTrendRingSize)*int64(time.Second))
+ require.Equal(t, loadTrendRingSize, tracker.count, "count should remain capped at ring size")
+}
+
+func TestLoadTrend_GentleRise(t *testing.T) {
+ var tracker loadTrendTracker
+ base := time.Now().UnixNano()
+ for i := 0; i < 10; i++ {
+ tracker.recordAt(50.0+float64(i)*0.1, base+int64(i)*int64(time.Second))
+ }
+ slope := tracker.slope()
+ require.Greater(t, slope, 0.0, "gentle rise should produce positive slope")
+ require.InDelta(t, 0.1, slope, 0.01)
+}
+
+func TestLoadTrend_RecordUpdatesRuntimeStat(t *testing.T) {
+ accounts := []Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{}),
+ }
+
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+
+ ctx := context.Background()
+ for i := 0; i < 5; i++ {
+ selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{
+ RequiredTransport: OpenAIUpstreamTransportAny,
+ })
+ require.NoError(t, err)
+ if selection != nil && selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ stat := stats.loadOrCreate(1)
+ require.GreaterOrEqual(t, stat.loadTrend.count, 5,
+ "trend tracker should have received samples from scoring loop")
+}
+
+func TestLoadTrend_FallingTrendBoostsScore(t *testing.T) {
+ accounts := []Account{
+ {ID: 1, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ {ID: 2, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive, Schedulable: true, Concurrency: 10},
+ }
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
+ cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
+ cfg.Gateway.OpenAIWS.SchedulerTrendEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerTrendMaxSlope = 5.0
+
+ loadMap := map[int64]*AccountLoadInfo{
+ 1: {AccountID: 1, LoadRate: 50},
+ 2: {AccountID: 2, LoadRate: 50},
+ }
+
+ svc := &OpenAIGatewayService{
+ accountRepo: stubOpenAIAccountRepo{accounts: accounts},
+ cfg: cfg,
+ concurrencyService: NewConcurrencyService(stubConcurrencyCache{
+ loadMap: loadMap,
+ skipDefaultLoad: true,
+ }),
+ }
+
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+
+ base := time.Now().UnixNano()
+ stat1 := stats.loadOrCreate(1)
+ for i := 0; i < 10; i++ {
+ stat1.loadTrend.recordAt(80.0-float64(i)*5.0, base+int64(i)*int64(time.Second))
+ }
+ stat2 := stats.loadOrCreate(2)
+ for i := 0; i < 10; i++ {
+ stat2.loadTrend.recordAt(50.0, base+int64(i)*int64(time.Second))
+ }
+
+ ctx := context.Background()
+ selectedCounts := map[int64]int{}
+ const rounds = 50
+ for r := 0; r < rounds; r++ {
+ selection, _, _, _, err := scheduler.selectByLoadBalance(ctx, OpenAIAccountScheduleRequest{
+ RequiredTransport: OpenAIUpstreamTransportAny,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, selection)
+ require.NotNil(t, selection.Account)
+ selectedCounts[selection.Account.ID]++
+ if selection.ReleaseFunc != nil {
+ selection.ReleaseFunc()
+ }
+ }
+
+ require.Greater(t, selectedCounts[int64(1)]+selectedCounts[int64(2)], 0,
+ "both accounts should receive selections")
+}
+
+// ---------------------------------------------------------------------------
+// Circuit Breaker Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestCircuitBreaker_AllowClosed(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ require.True(t, cb.allow(30*time.Second, 2), "CLOSED state should allow requests")
+}
+
+func TestCircuitBreaker_AllowOpenWithinCooldown(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateOpen)
+ cb.lastFailureNano.Store(time.Now().UnixNano())
+ require.False(t, cb.allow(30*time.Second, 2), "OPEN within cooldown should deny")
+}
+
+func TestCircuitBreaker_AllowOpenAfterCooldown(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateOpen)
+ cb.lastFailureNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
+ require.True(t, cb.allow(30*time.Second, 2), "OPEN after cooldown should transition to HALF_OPEN and allow")
+ require.Equal(t, circuitBreakerStateHalfOpen, cb.state.Load())
+}
+
+func TestCircuitBreaker_AllowHalfOpenLimited(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ require.True(t, cb.allowHalfOpen(2))
+ require.True(t, cb.allowHalfOpen(2))
+ require.False(t, cb.allowHalfOpen(2), "should deny when in-flight reaches max")
+}
+
+func TestCircuitBreaker_AllowHalfOpenViaAllow(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ require.True(t, cb.allow(30*time.Second, 1))
+ require.False(t, cb.allow(30*time.Second, 1), "HALF_OPEN with max=1 should deny second")
+}
+
+func TestCircuitBreaker_AllowDefaultState(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(99) // unknown state
+ require.True(t, cb.allow(30*time.Second, 2), "unknown state should default to allow")
+}
+
+func TestCircuitBreaker_RecordSuccessClosed(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.consecutiveFails.Store(3)
+ cb.recordSuccess()
+ require.Equal(t, int32(0), cb.consecutiveFails.Load())
+ require.Equal(t, circuitBreakerStateClosed, cb.state.Load())
+}
+
+func TestCircuitBreaker_RecordSuccessHalfOpenToClose(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ cb.halfOpenInFlight.Store(1)
+ cb.halfOpenAdmitted.Store(1)
+ cb.halfOpenSuccess.Store(0)
+ cb.recordSuccess()
+ require.Equal(t, circuitBreakerStateClosed, cb.state.Load(), "all probes succeeded should close")
+ require.Equal(t, int32(0), cb.halfOpenInFlight.Load())
+}
+
+func TestCircuitBreaker_RecordSuccessHalfOpenPartial(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ cb.halfOpenInFlight.Store(3)
+ cb.halfOpenAdmitted.Store(3)
+ cb.halfOpenSuccess.Store(0)
+ cb.recordSuccess()
+ require.Equal(t, circuitBreakerStateHalfOpen, cb.state.Load(), "not all probes succeeded yet")
+}
+
+func TestCircuitBreaker_RecordFailureTripsOpen(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ for i := 0; i < 4; i++ {
+ cb.recordFailure(5)
+ }
+ require.Equal(t, circuitBreakerStateClosed, cb.state.Load())
+ cb.recordFailure(5)
+ require.Equal(t, circuitBreakerStateOpen, cb.state.Load(), "5th failure should trip to OPEN")
+}
+
+func TestCircuitBreaker_RecordFailureHalfOpenReverts(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ cb.halfOpenInFlight.Store(2)
+ cb.recordFailure(5)
+ require.Equal(t, circuitBreakerStateOpen, cb.state.Load(), "failure in HALF_OPEN should revert to OPEN")
+ require.Equal(t, int32(0), cb.halfOpenInFlight.Load())
+}
+
+func TestCircuitBreaker_IsHalfOpen(t *testing.T) {
+ var cb *accountCircuitBreaker
+ require.False(t, cb.isHalfOpen(), "nil should return false")
+
+ cb = &accountCircuitBreaker{}
+ require.False(t, cb.isHalfOpen())
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ require.True(t, cb.isHalfOpen())
+}
+
+func TestCircuitBreaker_ReleaseHalfOpenPermit(t *testing.T) {
+ var cb *accountCircuitBreaker
+ cb.releaseHalfOpenPermit() // should not panic
+
+ cb = &accountCircuitBreaker{}
+ cb.releaseHalfOpenPermit() // not in HALF_OPEN, should be no-op
+
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ cb.halfOpenInFlight.Store(2)
+ cb.releaseHalfOpenPermit()
+ require.Equal(t, int32(1), cb.halfOpenInFlight.Load())
+
+ cb.halfOpenInFlight.Store(0)
+ cb.releaseHalfOpenPermit() // already at 0, should be no-op
+ require.Equal(t, int32(0), cb.halfOpenInFlight.Load())
+}
+
+func TestCircuitBreaker_StateString(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ require.Equal(t, "CLOSED", cb.stateString())
+ cb.state.Store(circuitBreakerStateOpen)
+ require.Equal(t, "OPEN", cb.stateString())
+ cb.state.Store(circuitBreakerStateHalfOpen)
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+ cb.state.Store(99)
+ require.Equal(t, "UNKNOWN", cb.stateString())
+}
+
+func TestCircuitBreaker_IsOpen(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ require.False(t, cb.isOpen())
+ cb.state.Store(circuitBreakerStateOpen)
+ require.True(t, cb.isOpen())
+}
+
+func TestCircuitBreaker_FullLifecycle(t *testing.T) {
+ cb := &accountCircuitBreaker{}
+ threshold := 3
+ cooldown := 50 * time.Millisecond
+
+ // CLOSED: allow requests
+ require.True(t, cb.allow(cooldown, 2))
+ require.Equal(t, "CLOSED", cb.stateString())
+
+ // Trip to OPEN
+ for i := 0; i < threshold; i++ {
+ cb.recordFailure(threshold)
+ }
+ require.Equal(t, "OPEN", cb.stateString())
+ require.False(t, cb.allow(cooldown, 2), "should deny in OPEN within cooldown")
+
+ // Wait for cooldown
+ time.Sleep(cooldown + 10*time.Millisecond)
+
+ // Should transition to HALF_OPEN
+ require.True(t, cb.allow(cooldown, 2))
+ require.Equal(t, "HALF_OPEN", cb.stateString())
+
+ // Success should close
+ cb.recordSuccess()
+ require.Equal(t, "CLOSED", cb.stateString())
+}
+
+// ---------------------------------------------------------------------------
+// dualEWMATTFT Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestDualEWMATTFT_InitNaN(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+ require.True(t, math.IsNaN(d.fastValue()))
+ require.True(t, math.IsNaN(d.slowValue()))
+ _, hasTTFT := d.value()
+ require.False(t, hasTTFT, "NaN-initialized should return hasTTFT=false")
+}
+
+func TestDualEWMATTFT_UpdateFromNaN(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+ d.update(100.0)
+ v, ok := d.value()
+ require.True(t, ok)
+ require.InDelta(t, 100.0, v, 0.01, "first update should set sample directly")
+}
+
+func TestDualEWMATTFT_UpdateMultiple(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+ for i := 0; i < 20; i++ {
+ d.update(200.0)
+ }
+ v, ok := d.value()
+ require.True(t, ok)
+ require.InDelta(t, 200.0, v, 1.0, "after many updates of same value, should converge")
+}
+
+func TestDualEWMATTFT_ValueFastOnly(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+ // Set fast only, slow stays NaN
+ d.fastBits.Store(math.Float64bits(42.0))
+ v, ok := d.value()
+ require.True(t, ok)
+ require.Equal(t, 42.0, v)
+}
+
+func TestDualEWMATTFT_ValueSlowOnly(t *testing.T) {
+ var d dualEWMATTFT
+ d.initNaN()
+ // Set slow only, fast stays NaN
+ d.slowBits.Store(math.Float64bits(55.0))
+ v, ok := d.value()
+ require.True(t, ok)
+ require.Equal(t, 55.0, v)
+}
+
+func TestDualEWMATTFT_ValueSlowGreaterThanFast(t *testing.T) {
+ var d dualEWMATTFT
+ d.fastBits.Store(math.Float64bits(30.0))
+ d.slowBits.Store(math.Float64bits(50.0))
+ v, ok := d.value()
+ require.True(t, ok)
+ require.Equal(t, 50.0, v, "pessimistic value should return max(fast, slow)")
+}
+
+func TestDualEWMATTFT_ValueFastGreaterThanSlow(t *testing.T) {
+ var d dualEWMATTFT
+ d.fastBits.Store(math.Float64bits(80.0))
+ d.slowBits.Store(math.Float64bits(50.0))
+ v, ok := d.value()
+ require.True(t, ok)
+ require.Equal(t, 80.0, v, "pessimistic value should return max(fast, slow)")
+}
+
+// ---------------------------------------------------------------------------
+// Softmax Additional Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestSoftmax_EmptyCandidates(t *testing.T) {
+ rng := newOpenAISelectionRNG(42)
+ result := selectSoftmaxOpenAICandidates(nil, 0.3, &rng)
+ require.Nil(t, result)
+}
+
+func TestSoftmax_ZeroTemperatureFallsToDefault(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1}, score: 0.9},
+ {account: &Account{ID: 2}, score: 0.1},
+ {account: &Account{ID: 3}, score: 0.5},
+ {account: &Account{ID: 4}, score: 0.3},
+ }
+ rng := newOpenAISelectionRNG(42)
+ result := selectSoftmaxOpenAICandidates(candidates, 0, &rng)
+ require.Len(t, result, 4)
+}
+
+func TestSoftmax_NaNScoresUniform(t *testing.T) {
+ // Extreme negative scores that cause exp() to return 0 → uniform fallback
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1}, score: -1e308},
+ {account: &Account{ID: 2}, score: -1e308},
+ {account: &Account{ID: 3}, score: -1e308},
+ {account: &Account{ID: 4}, score: -1e308},
+ }
+ rng := newOpenAISelectionRNG(42)
+ result := selectSoftmaxOpenAICandidates(candidates, 0.001, &rng)
+ require.Len(t, result, 4)
+}
+
+// ---------------------------------------------------------------------------
+// Snapshot / Stats Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestSnapshot_NilStats(t *testing.T) {
+ var s *openAIAccountRuntimeStats
+ errorRate, ttft, hasTTFT := s.snapshot(1)
+ require.Equal(t, 0.0, errorRate)
+ require.Equal(t, 0.0, ttft)
+ require.False(t, hasTTFT)
+}
+
+func TestSnapshot_InvalidAccountID(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ errorRate, ttft, hasTTFT := s.snapshot(0)
+ require.Equal(t, 0.0, errorRate)
+ require.Equal(t, 0.0, ttft)
+ require.False(t, hasTTFT)
+}
+
+func TestSnapshot_UnknownAccount(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ errorRate, ttft, hasTTFT := s.snapshot(999)
+ require.Equal(t, 0.0, errorRate)
+ require.Equal(t, 0.0, ttft)
+ require.False(t, hasTTFT)
+}
+
+func TestSnapshot_WithModelFallback(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ stat := s.loadOrCreate(1)
+ // Set global TTFT
+ stat.ttft.update(100.0)
+ // Snapshot with unknown model should fallback to global
+ _, ttft, hasTTFT := s.snapshot(1, "unknown-model")
+ require.True(t, hasTTFT)
+ require.InDelta(t, 100.0, ttft, 0.01)
+}
+
+func TestSnapshot_WithModelSpecific(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ stat := s.loadOrCreate(1)
+ stat.reportModelTTFT("gpt-4", 200.0)
+ stat.ttft.update(50.0) // global is different
+ _, ttft, hasTTFT := s.snapshot(1, "gpt-4")
+ require.True(t, hasTTFT)
+ require.InDelta(t, 200.0, ttft, 0.01, "should use per-model TTFT")
+}
+
+func TestStatsSize_Nil(t *testing.T) {
+ var s *openAIAccountRuntimeStats
+ require.Equal(t, 0, s.size())
+}
+
+func TestStatsSize_Empty(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ require.Equal(t, 0, s.size())
+}
+
+func TestStatsSize_WithAccounts(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ s.loadOrCreate(1)
+ s.loadOrCreate(2)
+ require.Equal(t, 2, s.size())
+}
+
+func TestLoadOrCreate_ConcurrentSameID(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ var wg sync.WaitGroup
+ results := make([]*openAIAccountRuntimeStat, 10)
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ results[idx] = s.loadOrCreate(1)
+ }(i)
+ }
+ wg.Wait()
+ // All should return the same pointer
+ for i := 1; i < 10; i++ {
+ require.Same(t, results[0], results[i], "concurrent loadOrCreate should return same stat")
+ }
+}
+
+// ---------------------------------------------------------------------------
+// modelTTFTValue / reportModelTTFT Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestModelTTFTValue_EmptyModel(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ v, ok := stat.modelTTFTValue("")
+ require.False(t, ok)
+ require.Equal(t, 0.0, v)
+}
+
+func TestModelTTFTValue_UnknownModel(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ v, ok := stat.modelTTFTValue("nonexistent")
+ require.False(t, ok)
+ require.Equal(t, 0.0, v)
+}
+
+func TestReportModelTTFT_EmptyModel(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ stat.ttft.initNaN()
+ stat.reportModelTTFT("", 100.0)
+ // Should be no-op: global TTFT not updated
+ _, hasTTFT := stat.ttft.value()
+ require.False(t, hasTTFT)
+}
+
+func TestReportModelTTFT_ZeroSample(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ stat.ttft.initNaN()
+ stat.reportModelTTFT("gpt-4", 0)
+ _, hasTTFT := stat.ttft.value()
+ require.False(t, hasTTFT)
+}
+
+func TestReportModelTTFT_NegativeSample(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ stat.ttft.initNaN()
+ stat.reportModelTTFT("gpt-4", -10.0)
+ _, hasTTFT := stat.ttft.value()
+ require.False(t, hasTTFT)
+}
+
+func TestGetOrCreateModelTTFT_ConcurrentSameModel(t *testing.T) {
+ stat := &openAIAccountRuntimeStat{}
+ var wg sync.WaitGroup
+ results := make([]*dualEWMATTFT, 10)
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ results[idx] = stat.getOrCreateModelTTFT("gpt-4")
+ }(i)
+ }
+ wg.Wait()
+ for i := 1; i < 10; i++ {
+ require.Same(t, results[0], results[i])
+ }
+}
+
+// ---------------------------------------------------------------------------
+// schedulerCircuitBreakerConfig Coverage Tests
+// ---------------------------------------------------------------------------
+
+func TestSchedulerCircuitBreakerConfig_Defaults(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ enabled, threshold, cooldown, halfOpenMax := scheduler.schedulerCircuitBreakerConfig()
+ require.False(t, enabled)
+ require.Equal(t, defaultCircuitBreakerFailThreshold, threshold)
+ require.Equal(t, time.Duration(defaultCircuitBreakerCooldownSec)*time.Second, cooldown)
+ require.Equal(t, defaultCircuitBreakerHalfOpenMax, halfOpenMax)
+}
+
+func TestSchedulerCircuitBreakerConfig_Custom(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerFailThreshold = 10
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerCooldownSec = 60
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerHalfOpenMax = 5
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ enabled, threshold, cooldown, halfOpenMax := scheduler.schedulerCircuitBreakerConfig()
+ require.True(t, enabled)
+ require.Equal(t, 10, threshold)
+ require.Equal(t, 60*time.Second, cooldown)
+ require.Equal(t, 5, halfOpenMax)
+}
+
+func TestSchedulerPerModelTTFTConfig_Defaults(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ enabled, maxModels := scheduler.schedulerPerModelTTFTConfig()
+ require.False(t, enabled)
+ require.Equal(t, defaultPerModelTTFTMaxModels, maxModels)
+}
+
+func TestSchedulerPerModelTTFTConfig_Custom(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTMaxModels = 64
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ enabled, maxModels := scheduler.schedulerPerModelTTFTConfig()
+ require.True(t, enabled)
+ require.Equal(t, 64, maxModels)
+}
+
+func TestReportResult_PerModelTTFTDisabled_NoPerModelTrackerCreated(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = false
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ ttft := 120
+
+ scheduler.ReportResult(7001, true, &ttft, "gpt-5.1", 120)
+
+ stat := scheduler.stats.loadExisting(7001)
+ require.NotNil(t, stat)
+ count := 0
+ stat.modelTTFT.Range(func(_, _ any) bool {
+ count++
+ return true
+ })
+ require.Equal(t, 0, count, "per-model ttft should remain disabled")
+ _, globalTTFT, hasTTFT := scheduler.stats.snapshot(7001)
+ require.True(t, hasTTFT)
+ require.InDelta(t, 120.0, globalTTFT, 0.01)
+}
+
+func TestReportResult_PerModelTTFTMaxModels_UsesConfigLimit(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerPerModelTTFTMaxModels = 2
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ ttft := 100
+
+ models := []string{"gpt-5.1", "gpt-4o", "o3", "o4-mini"}
+ for i := 0; i < 200; i++ {
+ model := models[i%len(models)]
+ scheduler.ReportResult(7002, true, &ttft, model, float64(ttft+i))
+ }
+
+ stat := scheduler.stats.loadExisting(7002)
+ require.NotNil(t, stat)
+ count := 0
+ stat.modelTTFT.Range(func(_, _ any) bool {
+ count++
+ return true
+ })
+ require.LessOrEqual(t, count, 2, "model tracker count should honor scheduler_per_model_ttft_max_models")
+}
+
+// ---------------------------------------------------------------------------
+// P2C Edge Case Coverage
+// ---------------------------------------------------------------------------
+
+func TestSelectP2C_EmptyCandidates(t *testing.T) {
+ result := selectP2COpenAICandidates(nil, OpenAIAccountScheduleRequest{})
+ require.Nil(t, result)
+}
+
+func TestSelectP2C_SingleCandidate(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1}, score: 0.5},
+ }
+ result := selectP2COpenAICandidates(candidates, OpenAIAccountScheduleRequest{})
+ require.Len(t, result, 1)
+ require.Equal(t, int64(1), result[0].account.ID)
+}
+
+func TestSelectP2C_TwoCandidatesPicksBetter(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1}, score: 0.2},
+ {account: &Account{ID: 2}, score: 0.8},
+ }
+ result := selectP2COpenAICandidates(candidates, OpenAIAccountScheduleRequest{})
+ require.Len(t, result, 2)
+ require.Equal(t, int64(2), result[0].account.ID, "first should be higher-scored")
+}
+
+// ---------------------------------------------------------------------------
+// TopK Selection & Heap Coverage
+// ---------------------------------------------------------------------------
+
+func TestSelectTopK_Empty(t *testing.T) {
+ result := selectTopKOpenAICandidates(nil, 3)
+ require.Nil(t, result)
+}
+
+func TestSelectTopK_TopKZero(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}},
+ }
+ result := selectTopKOpenAICandidates(candidates, 0)
+ require.Len(t, result, 1, "topK=0 should default to 1")
+}
+
+func TestSelectTopK_TopKExceedsCandidates(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1, Priority: 1}, score: 0.3, loadInfo: &AccountLoadInfo{}},
+ {account: &Account{ID: 2, Priority: 1}, score: 0.9, loadInfo: &AccountLoadInfo{}},
+ }
+ result := selectTopKOpenAICandidates(candidates, 10)
+ require.Len(t, result, 2)
+ require.Equal(t, int64(2), result[0].account.ID, "highest score first")
+}
+
+func TestSelectTopK_ProperFiltering(t *testing.T) {
+ candidates := []openAIAccountCandidateScore{
+ {account: &Account{ID: 1, Priority: 1}, score: 0.1, loadInfo: &AccountLoadInfo{}},
+ {account: &Account{ID: 2, Priority: 1}, score: 0.9, loadInfo: &AccountLoadInfo{}},
+ {account: &Account{ID: 3, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}},
+ {account: &Account{ID: 4, Priority: 1}, score: 0.3, loadInfo: &AccountLoadInfo{}},
+ {account: &Account{ID: 5, Priority: 1}, score: 0.7, loadInfo: &AccountLoadInfo{}},
+ }
+ result := selectTopKOpenAICandidates(candidates, 3)
+ require.Len(t, result, 3)
+ require.Equal(t, int64(2), result[0].account.ID)
+ require.Equal(t, int64(5), result[1].account.ID)
+ require.Equal(t, int64(3), result[2].account.ID)
+}
+
+func TestIsOpenAIAccountCandidateBetter_AllTiebreakers(t *testing.T) {
+ // Equal scores, different priority
+ a := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}}
+ b := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 2}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}}
+ require.True(t, isOpenAIAccountCandidateBetter(a, b), "lower priority number = better")
+ require.False(t, isOpenAIAccountCandidateBetter(b, a))
+
+ // Equal scores and priority, different load rate
+ c := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 5}}
+ d := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 60, WaitingCount: 5}}
+ require.True(t, isOpenAIAccountCandidateBetter(c, d), "lower load rate = better")
+
+ // Equal everything except waiting count
+ e := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2}}
+ f := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 8}}
+ require.True(t, isOpenAIAccountCandidateBetter(e, f), "lower waiting count = better")
+
+ // Equal everything except ID
+ g := openAIAccountCandidateScore{account: &Account{ID: 1, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}}
+ h := openAIAccountCandidateScore{account: &Account{ID: 2, Priority: 1}, score: 0.5, loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 5}}
+ require.True(t, isOpenAIAccountCandidateBetter(g, h), "lower ID = better")
+}
+
+// ---------------------------------------------------------------------------
+// shouldReleaseStickySession Coverage
+// ---------------------------------------------------------------------------
+
+func TestShouldReleaseStickySession_NilScheduler(t *testing.T) {
+ var s *defaultOpenAIAccountScheduler
+ require.False(t, s.shouldReleaseStickySession(1))
+}
+
+func TestShouldReleaseStickySession_Disabled(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickyReleaseEnabled = false
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ require.False(t, scheduler.shouldReleaseStickySession(1))
+}
+
+func TestShouldReleaseStickySession_CircuitOpen(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true
+ cfg.Gateway.OpenAIWS.SchedulerCircuitBreakerEnabled = true
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ // Trip the circuit breaker
+ cb := stats.getCircuitBreaker(1)
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ require.True(t, scheduler.shouldReleaseStickySession(1), "should release when circuit is open")
+}
+
+func TestShouldReleaseStickySession_HighErrorRate(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ // Report many failures to push error rate above threshold
+ for i := 0; i < 20; i++ {
+ stats.report(1, false, nil, "", 0)
+ }
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ require.True(t, scheduler.shouldReleaseStickySession(1), "should release when error rate is high")
+}
+
+func TestShouldReleaseStickySession_Healthy(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickyReleaseEnabled = true
+ svc := &OpenAIGatewayService{cfg: cfg}
+ stats := newOpenAIAccountRuntimeStats()
+ // Report successes
+ for i := 0; i < 20; i++ {
+ stats.report(1, true, nil, "", 0)
+ }
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ require.False(t, scheduler.shouldReleaseStickySession(1), "should not release when healthy")
+}
+
+// ---------------------------------------------------------------------------
+// stickyReleaseConfigRead Coverage
+// ---------------------------------------------------------------------------
+
+func TestStickyReleaseConfigRead_NilConfig(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ cfg := scheduler.stickyReleaseConfigRead()
+ require.False(t, cfg.enabled)
+ require.Equal(t, 0.0, cfg.errorThreshold, "nil config returns zero-value struct")
+}
+
+func TestStickyReleaseConfigRead_Defaults(t *testing.T) {
+ c := &config.Config{}
+ // StickyReleaseErrorThreshold defaults to 0 → code uses defaultStickyReleaseErrorThreshold
+ svc := &OpenAIGatewayService{cfg: c}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ cfg := scheduler.stickyReleaseConfigRead()
+ require.False(t, cfg.enabled)
+ require.Equal(t, defaultStickyReleaseErrorThreshold, cfg.errorThreshold)
+}
+
+func TestStickyReleaseConfigRead_Custom(t *testing.T) {
+ c := &config.Config{}
+ c.Gateway.OpenAIWS.StickyReleaseEnabled = true
+ c.Gateway.OpenAIWS.StickyReleaseErrorThreshold = 0.5
+ svc := &OpenAIGatewayService{cfg: c}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ cfg := scheduler.stickyReleaseConfigRead()
+ require.True(t, cfg.enabled)
+ require.Equal(t, 0.5, cfg.errorThreshold)
+}
+
+// ---------------------------------------------------------------------------
+// RNG Coverage
+// ---------------------------------------------------------------------------
+
+func TestNewOpenAISelectionRNG_ZeroSeed(t *testing.T) {
+ rng := newOpenAISelectionRNG(0)
+ require.NotEqual(t, uint64(0), rng.state, "zero seed should be replaced with default")
+ v := rng.nextFloat64()
+ require.True(t, v >= 0 && v < 1.0)
+}
+
+// ---------------------------------------------------------------------------
+// isCircuitOpen Coverage
+// ---------------------------------------------------------------------------
+
+func TestIsCircuitOpen_UnknownAccount(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ require.False(t, stats.isCircuitOpen(999), "unknown account should not be circuit-open")
+}
+
+func TestIsCircuitOpen_OpenAccount(t *testing.T) {
+ stats := newOpenAIAccountRuntimeStats()
+ cb := stats.getCircuitBreaker(1)
+ for i := 0; i < defaultCircuitBreakerFailThreshold; i++ {
+ cb.recordFailure(defaultCircuitBreakerFailThreshold)
+ }
+ require.True(t, stats.isCircuitOpen(1))
+}
+
+// ---------------------------------------------------------------------------
+// openAIWSSchedulerP2CEnabled / openAIWSSchedulerWeights Coverage
+// ---------------------------------------------------------------------------
+
+func TestP2CEnabled_NilConfig(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ require.False(t, svc.openAIWSSchedulerP2CEnabled())
+}
+
+func TestP2CEnabled_Enabled(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerP2CEnabled = true
+ svc := &OpenAIGatewayService{cfg: cfg}
+ require.True(t, svc.openAIWSSchedulerP2CEnabled())
+}
+
+func TestSchedulerWeights_NilConfig(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ w := svc.openAIWSSchedulerWeights()
+ require.Equal(t, 1.0, w.Priority)
+ require.Equal(t, 1.0, w.Load)
+ require.Equal(t, 0.7, w.Queue)
+ require.Equal(t, 0.8, w.ErrorRate)
+ require.Equal(t, 0.5, w.TTFT)
+}
+
+func TestSchedulerWeights_Custom(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 2.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 3.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1.5
+ cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.8
+ svc := &OpenAIGatewayService{cfg: cfg}
+ w := svc.openAIWSSchedulerWeights()
+ require.Equal(t, 2.0, w.Priority)
+ require.Equal(t, 3.0, w.Load)
+}
+
+// ---------------------------------------------------------------------------
+// Snapshot edge cases
+// ---------------------------------------------------------------------------
+
+func TestSnapshot_WithEmptyModel(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ stat := s.loadOrCreate(1)
+ stat.ttft.update(100.0)
+ _, ttft, hasTTFT := s.snapshot(1, "")
+ require.True(t, hasTTFT)
+ require.InDelta(t, 100.0, ttft, 0.01, "empty model string should fall through to global")
+}
+
+func TestSnapshot_NoModelArg(t *testing.T) {
+ s := newOpenAIAccountRuntimeStats()
+ stat := s.loadOrCreate(1)
+ stat.ttft.update(100.0)
+ _, ttft, hasTTFT := s.snapshot(1)
+ require.True(t, hasTTFT)
+ require.InDelta(t, 100.0, ttft, 0.01)
+}
+
+// ---------------------------------------------------------------------------
+// deriveOpenAISelectionSeed coverage
+// ---------------------------------------------------------------------------
+
+func TestDeriveOpenAISelectionSeed_WithSessionHash(t *testing.T) {
+ seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{SessionHash: "abc123"})
+ require.NotEqual(t, uint64(0), seed)
+}
+
+func TestDeriveOpenAISelectionSeed_WithPreviousResponseID(t *testing.T) {
+ seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{PreviousResponseID: "resp_123"})
+ require.NotEqual(t, uint64(0), seed)
+}
+
+func TestDeriveOpenAISelectionSeed_WithGroupID(t *testing.T) {
+ gid := int64(42)
+ seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{GroupID: &gid})
+ require.NotEqual(t, uint64(0), seed)
+}
+
+func TestDeriveOpenAISelectionSeed_Empty(t *testing.T) {
+ seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{})
+ require.NotEqual(t, uint64(0), seed, "empty request should use time entropy")
+}
+
+func TestDeriveOpenAISelectionSeed_WithModel(t *testing.T) {
+ seed := deriveOpenAISelectionSeed(OpenAIAccountScheduleRequest{RequestedModel: "gpt-4"})
+ require.NotEqual(t, uint64(0), seed)
+}
+
+// ---------------------------------------------------------------------------
+// SnapshotMetrics coverage
+// ---------------------------------------------------------------------------
+
+func TestSnapshotMetrics_NilScheduler(t *testing.T) {
+ var s *defaultOpenAIAccountScheduler
+ metrics := s.SnapshotMetrics()
+ require.Equal(t, int64(0), metrics.SelectTotal)
+}
+
+func TestSnapshotMetrics_Normal(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ stats := newOpenAIAccountRuntimeStats()
+ scheduler := newDefaultOpenAIAccountScheduler(svc, stats).(*defaultOpenAIAccountScheduler)
+ metrics := scheduler.SnapshotMetrics()
+ require.Equal(t, int64(0), metrics.SelectTotal)
+}
+
+// ---------------------------------------------------------------------------
+// Heap Pop coverage
+// ---------------------------------------------------------------------------
+
+func TestCandidateHeap_Pop(t *testing.T) {
+ h := &openAIAccountCandidateHeap{}
+ heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 1}, score: 0.5, loadInfo: &AccountLoadInfo{}})
+ heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 2}, score: 0.9, loadInfo: &AccountLoadInfo{}})
+ heap.Push(h, openAIAccountCandidateScore{account: &Account{ID: 3}, score: 0.3, loadInfo: &AccountLoadInfo{}})
+ require.Equal(t, 3, h.Len())
+
+ // Pop returns the minimum (worst candidate in min-heap)
+ popped := heap.Pop(h).(openAIAccountCandidateScore)
+ require.Equal(t, int64(3), popped.account.ID, "should pop the lowest-scored")
+ require.Equal(t, 2, h.Len())
+}
+
+// ---------------------------------------------------------------------------
+// openAIWSSessionStickyTTL coverage
+// ---------------------------------------------------------------------------
+
+func TestOpenAIWSSessionStickyTTL_DefaultConfig(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ ttl := svc.openAIWSSessionStickyTTL()
+ require.Equal(t, openaiStickySessionTTL, ttl, "nil config should return default TTL")
+}
+
+func TestOpenAIWSSessionStickyTTL_Custom(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800
+ svc := &OpenAIGatewayService{cfg: cfg}
+ ttl := svc.openAIWSSessionStickyTTL()
+ require.Equal(t, 1800*time.Second, ttl)
+}
diff --git a/backend/internal/service/openai_client_transport.go b/backend/internal/service/openai_client_transport.go
index c9cf32462..5ed5ff69a 100644
--- a/backend/internal/service/openai_client_transport.go
+++ b/backend/internal/service/openai_client_transport.go
@@ -64,8 +64,25 @@ func resolveOpenAIWSDecisionByClientTransport(
decision OpenAIWSProtocolDecision,
clientTransport OpenAIClientTransport,
) OpenAIWSProtocolDecision {
- if clientTransport == OpenAIClientTransportHTTP {
+ // WSv2 upstream is only allowed for explicit WebSocket ingress.
+ // Unknown/missing transport is treated as HTTP to avoid accidental WS pool usage.
+ if clientTransport != OpenAIClientTransportWS {
return openAIWSHTTPDecision("client_protocol_http")
}
return decision
}
+
+func shouldWarnOpenAIWSUnknownTransportFallback(
+ decision OpenAIWSProtocolDecision,
+ clientTransport OpenAIClientTransport,
+) bool {
+ if clientTransport != OpenAIClientTransportUnknown {
+ return false
+ }
+ switch decision.Transport {
+ case OpenAIUpstreamTransportResponsesWebsocketV2, OpenAIUpstreamTransportResponsesWebsocket:
+ return true
+ default:
+ return false
+ }
+}
diff --git a/backend/internal/service/openai_client_transport_test.go b/backend/internal/service/openai_client_transport_test.go
index ef90e6145..479534c79 100644
--- a/backend/internal/service/openai_client_transport_test.go
+++ b/backend/internal/service/openai_client_transport_test.go
@@ -103,5 +103,40 @@ func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) {
require.Equal(t, base, wsDecision)
unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown)
- require.Equal(t, base, unknownDecision)
+ require.Equal(t, OpenAIUpstreamTransportHTTPSSE, unknownDecision.Transport)
+ require.Equal(t, "client_protocol_http", unknownDecision.Reason)
+}
+
+func TestShouldWarnOpenAIWSUnknownTransportFallback(t *testing.T) {
+ require.True(t, shouldWarnOpenAIWSUnknownTransportFallback(
+ OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
+ Reason: "ws_v2_enabled",
+ },
+ OpenAIClientTransportUnknown,
+ ))
+
+ require.True(t, shouldWarnOpenAIWSUnknownTransportFallback(
+ OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportResponsesWebsocket,
+ Reason: "ws_v1_enabled",
+ },
+ OpenAIClientTransportUnknown,
+ ))
+
+ require.False(t, shouldWarnOpenAIWSUnknownTransportFallback(
+ OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportHTTPSSE,
+ Reason: "http_only",
+ },
+ OpenAIClientTransportUnknown,
+ ))
+
+ require.False(t, shouldWarnOpenAIWSUnknownTransportFallback(
+ OpenAIWSProtocolDecision{
+ Transport: OpenAIUpstreamTransportResponsesWebsocketV2,
+ Reason: "ws_v2_enabled",
+ },
+ OpenAIClientTransportHTTP,
+ ))
}
diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go
new file mode 100644
index 000000000..b64c1441b
--- /dev/null
+++ b/backend/internal/service/openai_gateway_record_usage_test.go
@@ -0,0 +1,783 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type openAIRecordUsageLogRepoStub struct {
+ UsageLogRepository
+
+ inserted bool
+ err error
+ calls int
+ lastLog *UsageLog
+ nextID int64
+
+ billingEntry *UsageBillingEntry
+ billingEntryErr error
+ upsertCalls int
+ getCalls int
+ markAppliedCalls int
+ markRetryCalls int
+ lastRetryAt time.Time
+ lastRetryErrMessage string
+ txCalls int
+}
+
+func (s *openAIRecordUsageLogRepoStub) Create(ctx context.Context, log *UsageLog) (bool, error) {
+ s.calls++
+ if log != nil {
+ if log.ID == 0 {
+ if s.nextID == 0 {
+ s.nextID = 1000
+ }
+ log.ID = s.nextID
+ s.nextID++
+ }
+ }
+ s.lastLog = log
+ return s.inserted, s.err
+}
+
+func (s *openAIRecordUsageLogRepoStub) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error) {
+ s.getCalls++
+ if s.billingEntryErr != nil {
+ return nil, s.billingEntryErr
+ }
+ if s.billingEntry == nil {
+ return nil, ErrUsageBillingEntryNotFound
+ }
+ if s.billingEntry.UsageLogID != usageLogID {
+ return nil, ErrUsageBillingEntryNotFound
+ }
+ return s.billingEntry, nil
+}
+
+func (s *openAIRecordUsageLogRepoStub) UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error) {
+ s.upsertCalls++
+ if s.billingEntryErr != nil {
+ return nil, false, s.billingEntryErr
+ }
+ if s.billingEntry != nil {
+ return s.billingEntry, false, nil
+ }
+ if entry == nil {
+ return nil, false, nil
+ }
+ copyEntry := *entry
+ copyEntry.ID = 9100 + int64(s.upsertCalls)
+ copyEntry.Status = UsageBillingEntryStatusPending
+ s.billingEntry = ©Entry
+ return s.billingEntry, true, nil
+}
+
+func (s *openAIRecordUsageLogRepoStub) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error {
+ s.markAppliedCalls++
+ if s.billingEntry != nil && s.billingEntry.ID == entryID {
+ s.billingEntry.Applied = true
+ s.billingEntry.Status = UsageBillingEntryStatusApplied
+ }
+ return nil
+}
+
+func (s *openAIRecordUsageLogRepoStub) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error {
+ s.markRetryCalls++
+ s.lastRetryAt = nextRetryAt
+ s.lastRetryErrMessage = lastError
+ if s.billingEntry != nil && s.billingEntry.ID == entryID {
+ s.billingEntry.Applied = false
+ s.billingEntry.Status = UsageBillingEntryStatusPending
+ msg := lastError
+ s.billingEntry.LastError = &msg
+ s.billingEntry.NextRetryAt = nextRetryAt
+ }
+ return nil
+}
+
+func (s *openAIRecordUsageLogRepoStub) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error) {
+ return nil, nil
+}
+
+func (s *openAIRecordUsageLogRepoStub) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ s.txCalls++
+ if fn == nil {
+ return nil
+ }
+ return fn(ctx)
+}
+
+type openAIRecordUsageUserRepoStub struct {
+ UserRepository
+
+ deductCalls int
+ deductErr error
+}
+
+func (s *openAIRecordUsageUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
+ s.deductCalls++
+ return s.deductErr
+}
+
+type openAIRecordUsageSubRepoStub struct {
+ UserSubscriptionRepository
+
+ incrementCalls int
+ incrementErr error
+}
+
+func (s *openAIRecordUsageSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
+ s.incrementCalls++
+ return s.incrementErr
+}
+
+type openAIRecordUsageBillingCacheStub struct {
+ BillingCache
+
+ deductCalls int
+ deductErr error
+}
+
+func (s *openAIRecordUsageBillingCacheStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
+ s.deductCalls++
+ return s.deductErr
+}
+
+func (s *openAIRecordUsageBillingCacheStub) GetUserBalance(context.Context, int64) (float64, error) {
+ return 0, errors.New("not implemented")
+}
+
+func (s *openAIRecordUsageBillingCacheStub) SetUserBalance(context.Context, int64, float64) error {
+ return errors.New("not implemented")
+}
+
+func (s *openAIRecordUsageBillingCacheStub) InvalidateUserBalance(context.Context, int64) error {
+ return nil
+}
+
+func (s *openAIRecordUsageBillingCacheStub) GetSubscriptionCache(context.Context, int64, int64) (*SubscriptionCacheData, error) {
+ return nil, errors.New("not implemented")
+}
+
+func (s *openAIRecordUsageBillingCacheStub) SetSubscriptionCache(context.Context, int64, int64, *SubscriptionCacheData) error {
+ return errors.New("not implemented")
+}
+
+func (s *openAIRecordUsageBillingCacheStub) UpdateSubscriptionUsage(context.Context, int64, int64, float64) error {
+ return errors.New("not implemented")
+}
+
+func (s *openAIRecordUsageBillingCacheStub) InvalidateSubscriptionCache(context.Context, int64, int64) error {
+ return nil
+}
+
+type openAIRecordUsageAPIKeyQuotaStub struct {
+ calls int
+ err error
+}
+
+func (s *openAIRecordUsageAPIKeyQuotaStub) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
+ s.calls++
+ return s.err
+}
+
+func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo UserRepository, subRepo UserSubscriptionRepository) *OpenAIGatewayService {
+ cfg := &config.Config{
+ Default: config.DefaultConfig{
+ RateMultiplier: 1,
+ },
+ }
+ return &OpenAIGatewayService{
+ usageLogRepo: usageRepo,
+ userRepo: userRepo,
+ userSubRepo: subRepo,
+ cfg: cfg,
+ billingService: NewBillingService(cfg, nil),
+ billingCacheService: &BillingCacheService{},
+ deferredService: &DeferredService{},
+ }
+}
+
+func TestOpenAIGatewayServiceRecordUsage_NoBillingWhenCreateUsageLogFails(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ err: errors.New("write usage log failed"),
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_test_create_fail",
+ Usage: OpenAIUsage{
+ InputTokens: 12,
+ OutputTokens: 8,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1001,
+ },
+ User: &User{
+ ID: 2001,
+ },
+ Account: &Account{
+ ID: 3001,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "create usage log")
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_BillingWhenUsageLogInserted(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_test_inserted",
+ Usage: OpenAIUsage{
+ InputTokens: 20,
+ OutputTokens: 10,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1002,
+ },
+ User: &User{
+ ID: 2002,
+ },
+ Account: &Account{
+ ID: 3002,
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_PricingFailureReturnsError(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+ svc.billingService = &BillingService{
+ cfg: &config.Config{},
+ fallbackPrices: map[string]*ModelPricing{},
+ }
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_pricing_fail",
+ Usage: OpenAIUsage{
+ InputTokens: 1,
+ OutputTokens: 1,
+ },
+ Model: "model_pricing_not_found_for_test",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1102,
+ },
+ User: &User{
+ ID: 2102,
+ },
+ Account: &Account{
+ ID: 3102,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "calculate cost")
+ require.Equal(t, 0, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_DeductBalanceFailureReturnsError(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{
+ deductErr: errors.New("db deduct failed"),
+ }
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_deduct_fail",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 5,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1003,
+ },
+ User: &User{
+ ID: 2003,
+ },
+ Account: &Account{
+ ID: 3003,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "deduct balance")
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_DeductBalanceCacheFailureReturnsError(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+ cache := &openAIRecordUsageBillingCacheStub{
+ deductErr: ErrInsufficientBalance,
+ }
+ svc.billingCacheService = &BillingCacheService{cache: cache}
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_cache_deduct_fail",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 5,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1004,
+ },
+ User: &User{
+ ID: 2004,
+ },
+ Account: &Account{
+ ID: 3004,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "deduct balance cache")
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 1, cache.deductCalls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_SubscriptionIncrementFailureReturnsError(t *testing.T) {
+ groupID := int64(12)
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{
+ incrementErr: errors.New("subscription update failed"),
+ }
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+ svc.billingCacheService = &BillingCacheService{cache: &openAIRecordUsageBillingCacheStub{}}
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_sub_increment_fail",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 5,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1005,
+ GroupID: &groupID,
+ Group: &Group{
+ ID: groupID,
+ SubscriptionType: SubscriptionTypeSubscription,
+ },
+ },
+ User: &User{
+ ID: 2005,
+ },
+ Account: &Account{
+ ID: 3005,
+ },
+ Subscription: &UserSubscription{
+ ID: 4005,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "increment subscription usage")
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 1, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_DuplicateUsageLogSkipsBilling(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: false,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_duplicate",
+ Usage: OpenAIUsage{
+ InputTokens: 8,
+ OutputTokens: 4,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1006,
+ },
+ User: &User{
+ ID: 2006,
+ },
+ Account: &Account{
+ ID: 3006,
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_SimpleModeSkipsBilling(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+ svc.cfg.RunMode = config.RunModeSimple
+ quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_simple_mode",
+ Usage: OpenAIUsage{
+ InputTokens: 5,
+ OutputTokens: 2,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1007,
+ Quota: 100,
+ },
+ User: &User{
+ ID: 2007,
+ },
+ Account: &Account{
+ ID: 3007,
+ },
+ APIKeyService: quotaSvc,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 0, subRepo.incrementCalls)
+ require.Equal(t, 0, quotaSvc.calls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_SubscriptionBillingSuccess(t *testing.T) {
+ groupID := int64(13)
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_sub_success",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 8,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1008,
+ GroupID: &groupID,
+ Group: &Group{
+ ID: groupID,
+ SubscriptionType: SubscriptionTypeSubscription,
+ },
+ },
+ User: &User{
+ ID: 2008,
+ },
+ Account: &Account{
+ ID: 3008,
+ },
+ Subscription: &UserSubscription{
+ ID: 4008,
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 0, userRepo.deductCalls)
+ require.Equal(t, 1, subRepo.incrementCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_UpdatesAPIKeyQuotaWhenConfigured(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+ quotaSvc := &openAIRecordUsageAPIKeyQuotaStub{}
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_quota_update",
+ Usage: OpenAIUsage{
+ InputTokens: 1,
+ OutputTokens: 1,
+ CacheReadInputTokens: 3,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1009,
+ Quota: 100,
+ },
+ User: &User{
+ ID: 2009,
+ },
+ Account: &Account{
+ ID: 3009,
+ },
+ APIKeyService: quotaSvc,
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.NotNil(t, usageRepo.lastLog)
+ require.Equal(t, 0, usageRepo.lastLog.InputTokens, "input_tokens 小于 cache_read_tokens 时应被钳制为 0")
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.Equal(t, 1, quotaSvc.calls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_DuplicateWithPendingBillingEntryStillBills(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: false,
+ billingEntry: &UsageBillingEntry{
+ ID: 9201,
+ UsageLogID: 1000,
+ Applied: false,
+ BillingType: BillingTypeBalance,
+ DeltaUSD: 1.25,
+ },
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_duplicate_pending_entry",
+ Usage: OpenAIUsage{
+ InputTokens: 8,
+ OutputTokens: 2,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{ID: 1011},
+ User: &User{ID: 2011},
+ Account: &Account{
+ ID: 3011,
+ },
+ })
+
+ require.NoError(t, err)
+ require.Equal(t, 1, usageRepo.calls)
+ require.Equal(t, 1, usageRepo.getCalls)
+ require.Equal(t, 1, usageRepo.markAppliedCalls)
+ require.Equal(t, 1, userRepo.deductCalls)
+}
+
+func TestOpenAIGatewayServiceRecordUsage_BillingFailureMarksRetry(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{
+ deductErr: errors.New("deduct failed"),
+ }
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{
+ Result: &OpenAIForwardResult{
+ RequestID: "resp_mark_retry",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 6,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1012,
+ },
+ User: &User{
+ ID: 2012,
+ },
+ Account: &Account{
+ ID: 3012,
+ },
+ })
+
+ require.Error(t, err)
+ require.ErrorContains(t, err, "deduct balance")
+ require.Equal(t, 1, usageRepo.markRetryCalls)
+ require.NotZero(t, usageRepo.lastRetryAt)
+ require.NotEmpty(t, usageRepo.lastRetryErrMessage)
+ require.Equal(t, 0, usageRepo.markAppliedCalls)
+}
+
+func TestResolveOpenAIUsageRequestID_FallbackDeterministic(t *testing.T) {
+ reasoning := "medium"
+ input := &OpenAIRecordUsageInput{
+ FallbackRequestID: "req_fallback_seed",
+ APIKey: &APIKey{ID: 11001},
+ Account: &Account{ID: 21001},
+ Result: &OpenAIForwardResult{
+ RequestID: "",
+ Model: "gpt-5.1",
+ Usage: OpenAIUsage{
+ InputTokens: 12,
+ OutputTokens: 8,
+ CacheCreationInputTokens: 2,
+ CacheReadInputTokens: 1,
+ },
+ Duration: 2300 * time.Millisecond,
+ ReasoningEffort: &reasoning,
+ Stream: true,
+ OpenAIWSMode: true,
+ },
+ }
+
+ got1 := resolveOpenAIUsageRequestID(input)
+ got2 := resolveOpenAIUsageRequestID(input)
+
+ require.NotEmpty(t, got1)
+ require.Equal(t, got1, got2, "fallback request id should be deterministic")
+ require.Contains(t, got1, "wsf_")
+}
+
+func TestResolveOpenAIUsageRequestID_FallbackChangesWhenUsageChanges(t *testing.T) {
+ base := &OpenAIRecordUsageInput{
+ FallbackRequestID: "req_fallback_seed",
+ APIKey: &APIKey{ID: 11002},
+ Account: &Account{ID: 21002},
+ Result: &OpenAIForwardResult{
+ Model: "gpt-5.1",
+ Usage: OpenAIUsage{
+ InputTokens: 10,
+ OutputTokens: 4,
+ },
+ Duration: 2 * time.Second,
+ },
+ }
+ changed := &OpenAIRecordUsageInput{
+ FallbackRequestID: base.FallbackRequestID,
+ APIKey: base.APIKey,
+ Account: base.Account,
+ Result: &OpenAIForwardResult{
+ Model: "gpt-5.1",
+ Usage: OpenAIUsage{
+ InputTokens: 11,
+ OutputTokens: 4,
+ },
+ Duration: 2 * time.Second,
+ },
+ }
+
+ baseID := resolveOpenAIUsageRequestID(base)
+ changedID := resolveOpenAIUsageRequestID(changed)
+
+ require.NotEqual(t, baseID, changedID, "fallback request id should change when usage fingerprint changes")
+}
+
+func TestOpenAIGatewayServiceRecordUsage_UsesFallbackRequestIDWhenMissing(t *testing.T) {
+ usageRepo := &openAIRecordUsageLogRepoStub{
+ inserted: true,
+ }
+ userRepo := &openAIRecordUsageUserRepoStub{}
+ subRepo := &openAIRecordUsageSubRepoStub{}
+ svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo)
+
+ input := &OpenAIRecordUsageInput{
+ FallbackRequestID: "req_from_handler",
+ Result: &OpenAIForwardResult{
+ RequestID: "",
+ Usage: OpenAIUsage{
+ InputTokens: 9,
+ OutputTokens: 3,
+ },
+ Model: "gpt-5.1",
+ Duration: time.Second,
+ },
+ APIKey: &APIKey{
+ ID: 1013,
+ },
+ User: &User{
+ ID: 2013,
+ },
+ Account: &Account{
+ ID: 3013,
+ },
+ }
+
+ expectedRequestID := resolveOpenAIUsageRequestID(input)
+ require.NotEmpty(t, expectedRequestID)
+
+ err := svc.RecordUsage(context.Background(), input)
+ require.NoError(t, err)
+ require.NotNil(t, usageRepo.lastLog)
+ require.Equal(t, expectedRequestID, usageRepo.lastLog.RequestID)
+}
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index f624d92a5..34d4edeec 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"io"
+ "log/slog"
"math/rand"
"net/http"
"sort"
@@ -42,6 +43,8 @@ const (
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
+ // OpenAIRequestMetaKey 缓存 handler 已提取的请求元数据,供 Service 层复用。
+ OpenAIRequestMetaKey = "openai_request_meta"
// OpenAI WS Mode 失败后的重连次数上限(不含首次尝试)。
// 与 Codex 客户端保持一致:失败后最多重连 5 次。
openAIWSReconnectRetryLimit = 5
@@ -49,6 +52,13 @@ const (
openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond
openAIWSRetryBackoffMaxDefault = 2 * time.Second
openAIWSRetryJitterRatioDefault = 0.2
+ openAICodexUsageUpdateConcurrency = 16
+)
+
+// SSE 热路径包级常量,避免循环内重复分配
+var (
+ sseDataDone = []byte("[DONE]")
+ sseResponseCompletedMark = []byte(`"response.completed"`)
)
// OpenAI allowed headers whitelist (for non-passthrough).
@@ -212,6 +222,11 @@ type OpenAIForwardResult struct {
OpenAIWSMode bool
Duration time.Duration
FirstTokenMs *int
+ // TerminalEventType records the terminal event that ended the WS turn.
+ TerminalEventType string
+ // PendingFunctionCallIDs 表示该 response 中未完成的 function_call call_id 集合。
+ // 仅在 WS ingress 连续对话场景用于续链自愈,不参与外部 API 返回。
+ PendingFunctionCallIDs []string
}
type OpenAIWSRetryMetricsSnapshot struct {
@@ -221,6 +236,17 @@ type OpenAIWSRetryMetricsSnapshot struct {
NonRetryableFastFallbackTotal int64 `json:"non_retryable_fast_fallback_total"`
}
+type OpenAIWSTurnAbortMetricPoint struct {
+ Reason string `json:"reason"`
+ Expected bool `json:"expected"`
+ Total int64 `json:"total"`
+}
+
+type OpenAIWSAbortMetricsSnapshot struct {
+ TurnAbortTotal []OpenAIWSTurnAbortMetricPoint `json:"turn_abort_total"`
+ TurnAbortRecoveredTotal int64 `json:"turn_abort_recovered_total"`
+}
+
type OpenAICompatibilityFallbackMetricsSnapshot struct {
SessionHashLegacyReadFallbackTotal int64 `json:"session_hash_legacy_read_fallback_total"`
SessionHashLegacyReadFallbackHit int64 `json:"session_hash_legacy_read_fallback_hit"`
@@ -243,6 +269,16 @@ type openAIWSRetryMetrics struct {
nonRetryableFastFallback atomic.Int64
}
+type openAIWSTurnAbortMetricKey struct {
+ reason string
+ expected bool
+}
+
+type openAIWSAbortMetrics struct {
+ turnAbortTotal sync.Map // key: openAIWSTurnAbortMetricKey, value: *atomic.Int64
+ turnAbortRecovered atomic.Int64
+}
+
// OpenAIGatewayService handles OpenAI API gateway operations
type OpenAIGatewayService struct {
accountRepo AccountRepository
@@ -263,17 +299,26 @@ type OpenAIGatewayService struct {
toolCorrector *CodexToolCorrector
openaiWSResolver OpenAIWSProtocolResolver
- openaiWSPoolOnce sync.Once
+ openaiWSIngressCtxOnce sync.Once
openaiWSStateStoreOnce sync.Once
openaiSchedulerOnce sync.Once
- openaiWSPool *openAIWSConnPool
+ openaiWSIngressCtxPool *openAIWSIngressContextPool
openaiWSStateStore OpenAIWSStateStore
openaiScheduler OpenAIAccountScheduler
openaiAccountStats *openAIAccountRuntimeStats
openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time
openaiWSRetryMetrics openAIWSRetryMetrics
+ openaiWSAbortMetrics openAIWSAbortMetrics
responseHeaderFilter *responseheaders.CompiledHeaderFilter
+
+ codexUsageUpdateOnce sync.Once
+ codexUsageUpdateSem chan struct{}
+
+ usageBillingCompensation *UsageBillingCompensationService
+
+ // test hook for deterministic tie-break in account selection.
+ accountTieBreakIntnFn func(n int) int
}
// NewOpenAIGatewayService creates a new OpenAIGatewayService
@@ -313,15 +358,25 @@ func NewOpenAIGatewayService(
openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
responseHeaderFilter: compileResponseHeaderFilter(cfg),
}
+ svc.usageBillingCompensation = NewUsageBillingCompensationService(usageLogRepo, userRepo, userSubRepo, billingCacheService, cfg)
+ svc.usageBillingCompensation.Start()
svc.logOpenAIWSModeBootstrap()
return svc
}
-// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。
+// CloseOpenAIWSCtxPool 关闭 OpenAI WebSocket ctx_pool 的后台 worker 与连接资源。
// 应在应用优雅关闭时调用。
-func (s *OpenAIGatewayService) CloseOpenAIWSPool() {
- if s != nil && s.openaiWSPool != nil {
- s.openaiWSPool.Close()
+func (s *OpenAIGatewayService) CloseOpenAIWSCtxPool() {
+ if s != nil && s.openaiWSIngressCtxPool != nil {
+ s.openaiWSIngressCtxPool.Close()
+ }
+ if s != nil && s.openaiWSStateStore != nil {
+ if closer, ok := s.openaiWSStateStore.(interface{ Close() }); ok {
+ closer.Close()
+ }
+ }
+ if s != nil && s.usageBillingCompensation != nil {
+ s.usageBillingCompensation.Stop()
}
}
@@ -537,6 +592,34 @@ func (s *OpenAIGatewayService) writeOpenAIWSFallbackErrorResponse(c *gin.Context
return true
}
+func (s *OpenAIGatewayService) writeOpenAIWSV1UnsupportedResponse(c *gin.Context, account *Account) error {
+ const (
+ upstreamMessage = "openai ws v1 is temporarily unsupported; use ws v2"
+ clientMessage = "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2."
+ )
+ setOpsUpstreamError(c, http.StatusBadRequest, upstreamMessage, "")
+ if account != nil {
+ appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
+ Platform: account.Platform,
+ AccountID: account.ID,
+ AccountName: account.Name,
+ UpstreamStatusCode: http.StatusBadRequest,
+ Kind: "ws_error",
+ Message: upstreamMessage,
+ })
+ }
+ if c != nil {
+ c.JSON(http.StatusBadRequest, gin.H{
+ "error": gin.H{
+ "type": "invalid_request_error",
+ "message": clientMessage,
+ },
+ })
+ c.Abort()
+ }
+ return errors.New(upstreamMessage)
+}
+
func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration {
if attempt <= 0 {
return 0
@@ -610,6 +693,16 @@ func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration {
return 0
}
+func openAIWSRetryContextError(ctx context.Context) error {
+ if ctx == nil {
+ return nil
+ }
+ if err := ctx.Err(); err != nil {
+ return wrapOpenAIWSFallback("retry_context_canceled", err)
+ }
+ return nil
+}
+
func (s *OpenAIGatewayService) recordOpenAIWSRetryAttempt(backoff time.Duration) {
if s == nil {
return
@@ -646,6 +739,70 @@ func (s *OpenAIGatewayService) SnapshotOpenAIWSRetryMetrics() OpenAIWSRetryMetri
}
}
+func (s *OpenAIGatewayService) recordOpenAIWSTurnAbort(reason openAIWSIngressTurnAbortReason, expected bool) {
+ if s == nil {
+ return
+ }
+ normalizedReason := strings.TrimSpace(string(reason))
+ if normalizedReason == "" {
+ normalizedReason = string(openAIWSIngressTurnAbortReasonUnknown)
+ }
+ key := openAIWSTurnAbortMetricKey{
+ reason: normalizedReason,
+ expected: expected,
+ }
+ counterAny, _ := s.openaiWSAbortMetrics.turnAbortTotal.LoadOrStore(key, &atomic.Int64{})
+ counter, ok := counterAny.(*atomic.Int64)
+ if !ok || counter == nil {
+ return
+ }
+ counter.Add(1)
+}
+
+func (s *OpenAIGatewayService) recordOpenAIWSTurnAbortRecovered() {
+ if s == nil {
+ return
+ }
+ s.openaiWSAbortMetrics.turnAbortRecovered.Add(1)
+}
+
+func (s *OpenAIGatewayService) SnapshotOpenAIWSAbortMetrics() OpenAIWSAbortMetricsSnapshot {
+ if s == nil {
+ return OpenAIWSAbortMetricsSnapshot{}
+ }
+ points := make([]OpenAIWSTurnAbortMetricPoint, 0, 8)
+ s.openaiWSAbortMetrics.turnAbortTotal.Range(func(key, value any) bool {
+ label, ok := key.(openAIWSTurnAbortMetricKey)
+ if !ok {
+ return true
+ }
+ counter, ok := value.(*atomic.Int64)
+ if !ok || counter == nil {
+ return true
+ }
+ total := counter.Load()
+ if total <= 0 {
+ return true
+ }
+ points = append(points, OpenAIWSTurnAbortMetricPoint{
+ Reason: label.reason,
+ Expected: label.expected,
+ Total: total,
+ })
+ return true
+ })
+ sort.Slice(points, func(i, j int) bool {
+ if points[i].Reason == points[j].Reason {
+ return !points[i].Expected && points[j].Expected
+ }
+ return points[i].Reason < points[j].Reason
+ })
+ return OpenAIWSAbortMetricsSnapshot{
+ TurnAbortTotal: points,
+ TurnAbortRecoveredTotal: s.openaiWSAbortMetrics.turnAbortRecovered.Load(),
+ }
+}
+
func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMetricsSnapshot {
legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal := openAIStickyCompatStats()
isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount := RequestMetadataFallbackStats()
@@ -1041,6 +1198,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID
// Returns nil if no available account.
func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedModel string, excludedIDs map[int64]struct{}) *Account {
var selected *Account
+ tieCount := 0
for i := range accounts {
acc := &accounts[i]
@@ -1067,17 +1225,41 @@ func (s *OpenAIGatewayService) selectBestAccount(accounts []Account, requestedMo
// Select highest priority and least recently used
if selected == nil {
selected = acc
+ tieCount = 1
continue
}
if s.isBetterAccount(acc, selected) {
selected = acc
+ tieCount = 1
+ continue
+ }
+ if !s.isBetterAccount(selected, acc) {
+ // 完全同分档时进行 reservoir tie-break,避免长期集中命中首个账号。
+ tieCount++
+ if s.accountTieBreakIntn(tieCount) == 0 {
+ selected = acc
+ }
}
}
return selected
}
+func (s *OpenAIGatewayService) accountTieBreakIntn(n int) int {
+ if n <= 1 {
+ return 0
+ }
+ if s != nil && s.accountTieBreakIntnFn != nil {
+ v := s.accountTieBreakIntnFn(n)
+ if v >= 0 && v < n {
+ return v
+ }
+ return 0
+ }
+ return rand.Intn(n)
+}
+
// isBetterAccount 判断 candidate 是否比 current 更优。
// 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先,其次是最久未使用的。
//
@@ -1440,12 +1622,21 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
originalBody := body
- reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
+ reqModel, reqStream, promptCacheKey := extractOpenAIRequestMeta(c, body)
originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
clientTransport := GetOpenAIClientTransport(c)
+ if shouldWarnOpenAIWSUnknownTransportFallback(wsDecision, clientTransport) {
+ logOpenAIWSModeInfo(
+ "client_transport_unknown_fallback_http account_id=%d account_type=%s resolved_transport=%s resolved_reason=%s",
+ account.ID,
+ account.Type,
+ normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
+ normalizeOpenAIWSLogValue(wsDecision.Reason),
+ )
+ }
// 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。
wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport)
if c != nil {
@@ -1465,15 +1656,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。
if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket {
- if c != nil {
- c.JSON(http.StatusBadRequest, gin.H{
- "error": gin.H{
- "type": "invalid_request_error",
- "message": "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.",
- },
- })
- }
- return nil, errors.New("openai ws v1 is temporarily unsupported; use ws v2")
+ return nil, s.writeOpenAIWSV1UnsupportedResponse(c, account)
}
passthroughEnabled := account.IsOpenAIPassthroughEnabled()
if passthroughEnabled {
@@ -1762,6 +1945,10 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
retryStartedAt := time.Now()
wsRetryLoop:
for attempt := 1; attempt <= maxAttempts; attempt++ {
+ if cancelErr := openAIWSRetryContextError(ctx); cancelErr != nil {
+ wsErr = cancelErr
+ break
+ }
wsAttempts = attempt
wsResult, wsErr = s.forwardOpenAIWSV2(
ctx,
@@ -2202,6 +2389,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
if err != nil {
return nil, err
}
+ req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI))
// 透传客户端请求头(安全白名单)。
allowTimeoutHeaders := s.isOpenAIPassthroughTimeoutHeadersAllowed()
@@ -2582,6 +2770,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
if err != nil {
return nil, err
}
+ req = req.WithContext(WithHTTPUpstreamProfile(req.Context(), HTTPUpstreamProfileOpenAI))
// Set authentication header
req.Header.Set("authorization", "Bearer "+token)
@@ -2934,13 +3123,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
line = s.replaceModelInSSELine(line, mappedModel, originalModel)
}
- dataBytes := []byte(data)
-
// Correct Codex tool calls if needed (apply_patch -> edit, etc.)
- if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
- dataBytes = correctedData
- data = string(correctedData)
- line = "data: " + data
+ // 仅在 toolCorrector 存在时才转换为 []byte,避免热路径无谓分配
+ if s.toolCorrector != nil {
+ dataBytes := []byte(data)
+ if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected {
+ data = string(correctedData)
+ line = "data: " + data
+ }
}
// 写入客户端(客户端断开后继续 drain 上游)
@@ -2969,7 +3159,8 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
- s.parseSSEUsageBytes(dataBytes, usage)
+ // 使用 string 版本解析 usage,避免 string→[]byte 转换
+ s.parseSSEUsageString(data, usage)
return
}
@@ -3143,24 +3334,50 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt
}
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
- s.parseSSEUsageBytes([]byte(data), usage)
+ s.parseSSEUsageString(data, usage)
+}
+
+// parseSSEUsageString 使用 gjson.Get(string 版本)解析 usage,避免 string→[]byte 转换
+func (s *OpenAIGatewayService) parseSSEUsageString(data string, usage *OpenAIUsage) {
+ if usage == nil || len(data) == 0 || data == "[DONE]" {
+ return
+ }
+ if len(data) < 80 || !strings.Contains(data, `"response.completed"`) {
+ return
+ }
+ if gjson.Get(data, "type").String() != "response.completed" {
+ return
+ }
+ usageFields := gjson.GetMany(data,
+ "response.usage.input_tokens",
+ "response.usage.output_tokens",
+ "response.usage.input_tokens_details.cached_tokens",
+ )
+ usage.InputTokens = int(usageFields[0].Int())
+ usage.OutputTokens = int(usageFields[1].Int())
+ usage.CacheReadInputTokens = int(usageFields[2].Int())
}
func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) {
- if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) {
+ if usage == nil || len(data) == 0 || bytes.Equal(data, sseDataDone) {
return
}
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
- if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
+ if len(data) < 80 || !bytes.Contains(data, sseResponseCompletedMark) {
return
}
if gjson.GetBytes(data, "type").String() != "response.completed" {
return
}
-
- usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int())
- usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int())
- usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int())
+ // 使用 GetManyBytes 一次提取 3 个 usage 字段
+ usageFields := gjson.GetManyBytes(data,
+ "response.usage.input_tokens",
+ "response.usage.output_tokens",
+ "response.usage.input_tokens_details.cached_tokens",
+ )
+ usage.InputTokens = int(usageFields[0].Int())
+ usage.OutputTokens = int(usageFields[1].Int())
+ usage.CacheReadInputTokens = int(usageFields[2].Int())
}
func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) {
@@ -3365,14 +3582,184 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
- Result *OpenAIForwardResult
- APIKey *APIKey
- User *User
- Account *Account
- Subscription *UserSubscription
- UserAgent string // 请求的 User-Agent
- IPAddress string // 请求的客户端 IP 地址
- APIKeyService APIKeyQuotaUpdater
+ Result *OpenAIForwardResult
+ APIKey *APIKey
+ User *User
+ Account *Account
+ Subscription *UserSubscription
+ UserAgent string // 请求的 User-Agent
+ IPAddress string // 请求的客户端 IP 地址
+ FallbackRequestID string // 当上游 request_id 缺失时,用于生成稳定幂等 request_id
+ APIKeyService APIKeyQuotaUpdater
+}
+
+func (s *OpenAIGatewayService) usageBillingEntryStore() UsageBillingEntryStore {
+ store, ok := s.usageLogRepo.(UsageBillingEntryStore)
+ if !ok {
+ return nil
+ }
+ return store
+}
+
+func (s *OpenAIGatewayService) usageBillingTxRunner() UsageBillingTxRunner {
+ runner, ok := s.usageLogRepo.(UsageBillingTxRunner)
+ if !ok {
+ return nil
+ }
+ return runner
+}
+
+func (s *OpenAIGatewayService) runUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ runner := s.usageBillingTxRunner()
+ if runner == nil {
+ return fn(ctx)
+ }
+ return runner.WithUsageBillingTx(ctx, fn)
+}
+
+func (s *OpenAIGatewayService) prepareUsageBillingEntry(
+ ctx context.Context,
+ usageLog *UsageLog,
+ inserted bool,
+ billingType int8,
+ deltaUSD float64,
+) (*UsageBillingEntry, bool, error) {
+ if deltaUSD <= 0 {
+ return nil, false, nil
+ }
+
+ store := s.usageBillingEntryStore()
+ if store == nil {
+ if inserted {
+ return nil, true, nil
+ }
+ return nil, false, nil
+ }
+
+ if !inserted {
+ entry, err := store.GetUsageBillingEntryByUsageLogID(ctx, usageLog.ID)
+ if err != nil {
+ if errors.Is(err, ErrUsageBillingEntryNotFound) {
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[BillingReconcile] missing billing entry for duplicate usage log, skip immediate billing: usage_log=%d request_id=%s",
+ usageLog.ID,
+ usageLog.RequestID,
+ )
+ return nil, false, nil
+ }
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[BillingReconcile] load billing entry failed for duplicate usage log, skip immediate billing: usage_log=%d request_id=%s err=%v",
+ usageLog.ID,
+ usageLog.RequestID,
+ err,
+ )
+ return nil, false, nil
+ }
+ return entry, !entry.Applied, nil
+ }
+
+ entry, _, err := store.UpsertUsageBillingEntry(ctx, &UsageBillingEntry{
+ UsageLogID: usageLog.ID,
+ UserID: usageLog.UserID,
+ APIKeyID: usageLog.APIKeyID,
+ SubscriptionID: usageLog.SubscriptionID,
+ BillingType: billingType,
+ DeltaUSD: deltaUSD,
+ Status: UsageBillingEntryStatusPending,
+ })
+ if err != nil {
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[BillingReconcile] upsert billing entry failed, fallback to inline billing: usage_log=%d request_id=%s err=%v",
+ usageLog.ID,
+ usageLog.RequestID,
+ err,
+ )
+ return nil, true, nil
+ }
+
+ return entry, !entry.Applied, nil
+}
+
+func (s *OpenAIGatewayService) markUsageBillingRetry(ctx context.Context, entry *UsageBillingEntry, cause error) {
+ if entry == nil || cause == nil {
+ return
+ }
+ store := s.usageBillingEntryStore()
+ if store == nil {
+ return
+ }
+ errMsg := strings.TrimSpace(cause.Error())
+ if len(errMsg) > 500 {
+ errMsg = errMsg[:500]
+ }
+ nextRetryAt := time.Now().Add(usageBillingRetryBackoff(entry.AttemptCount + 1))
+ if err := store.MarkUsageBillingEntryRetry(ctx, entry.ID, nextRetryAt, errMsg); err != nil {
+ logger.LegacyPrintf("service.openai_gateway", "[BillingReconcile] mark retry failed: entry=%d err=%v", entry.ID, err)
+ }
+}
+
+func resolveOpenAIUsageRequestID(input *OpenAIRecordUsageInput) string {
+ if input == nil || input.Result == nil {
+ return ""
+ }
+ if requestID := strings.TrimSpace(input.Result.RequestID); requestID != "" {
+ return requestID
+ }
+ return buildOpenAIUsageFallbackRequestID(input)
+}
+
+func buildOpenAIUsageFallbackRequestID(input *OpenAIRecordUsageInput) string {
+ if input == nil || input.Result == nil {
+ return ""
+ }
+ result := input.Result
+ usage := result.Usage
+
+ seed := strings.Builder{}
+ seed.Grow(192)
+ seed.WriteString(strings.TrimSpace(input.FallbackRequestID))
+ seed.WriteByte('|')
+ if input.APIKey != nil {
+ seed.WriteString(strconv.FormatInt(input.APIKey.ID, 10))
+ }
+ seed.WriteByte('|')
+ if input.Account != nil {
+ seed.WriteString(strconv.FormatInt(input.Account.ID, 10))
+ }
+ seed.WriteByte('|')
+ seed.WriteString(strings.TrimSpace(result.Model))
+ seed.WriteByte('|')
+ seed.WriteString(strings.TrimSpace(result.TerminalEventType))
+ seed.WriteByte('|')
+ seed.WriteString(strconv.FormatBool(result.Stream))
+ seed.WriteByte('|')
+ seed.WriteString(strconv.FormatBool(result.OpenAIWSMode))
+ seed.WriteByte('|')
+ seed.WriteString(strconv.Itoa(usage.InputTokens))
+ seed.WriteByte('|')
+ seed.WriteString(strconv.Itoa(usage.OutputTokens))
+ seed.WriteByte('|')
+ seed.WriteString(strconv.Itoa(usage.CacheCreationInputTokens))
+ seed.WriteByte('|')
+ seed.WriteString(strconv.Itoa(usage.CacheReadInputTokens))
+ seed.WriteByte('|')
+ seed.WriteString(strconv.FormatInt(result.Duration.Milliseconds(), 10))
+ seed.WriteByte('|')
+ firstTokenMs := -1
+ if result.FirstTokenMs != nil {
+ firstTokenMs = *result.FirstTokenMs
+ }
+ seed.WriteString(strconv.Itoa(firstTokenMs))
+ seed.WriteByte('|')
+ if result.ReasoningEffort != nil {
+ seed.WriteString(strings.TrimSpace(*result.ReasoningEffort))
+ }
+
+ sum := sha256.Sum256([]byte(seed.String()))
+ return "wsf_" + hex.EncodeToString(sum[:16])
}
// RecordUsage records usage and deducts balance
@@ -3406,7 +3793,25 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
cost, err := s.billingService.CalculateCost(result.Model, tokens, multiplier)
if err != nil {
- cost = &CostBreakdown{ActualCost: 0}
+ if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[PricingWarn] calculate cost failed in simple mode, fallback to zero cost: model=%s request_id=%s err=%v",
+ result.Model,
+ result.RequestID,
+ err,
+ )
+ cost = &CostBreakdown{}
+ } else {
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[PricingAlert] calculate cost failed, reject usage record: model=%s request_id=%s err=%v",
+ result.Model,
+ result.RequestID,
+ err,
+ )
+ return fmt.Errorf("calculate cost: %w", err)
+ }
}
// Determine billing type
@@ -3419,11 +3824,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
// Create usage log
durationMs := int(result.Duration.Milliseconds())
accountRateMultiplier := account.BillingRateMultiplier()
+ requestID := resolveOpenAIUsageRequestID(input)
usageLog := &UsageLog{
UserID: user.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
- RequestID: result.RequestID,
+ RequestID: requestID,
Model: result.Model,
ReasoningEffort: result.ReasoningEffort,
InputTokens: actualInputTokens,
@@ -3464,24 +3870,67 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
}
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
+ if err != nil {
+ return fmt.Errorf("create usage log: %w", err)
+ }
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
- shouldBill := inserted || err != nil
-
- // Deduct based on billing type
+ billAmount := cost.ActualCost
if isSubscriptionBilling {
- if shouldBill && cost.TotalCost > 0 {
- _ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
- s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
+ billAmount = cost.TotalCost
+ }
+ billingEntry, shouldBill, err := s.prepareUsageBillingEntry(ctx, usageLog, inserted, billingType, billAmount)
+ if err != nil {
+ return fmt.Errorf("prepare usage billing entry: %w", err)
+ }
+
+ if shouldBill {
+ cacheDeducted := false
+ if !isSubscriptionBilling && billAmount > 0 && s.billingCacheService != nil {
+ // 同步扣减缓存,避免并发场景下仅靠“先查后扣”产生透支窗口。
+ if err := s.billingCacheService.DeductBalanceCache(ctx, user.ID, billAmount); err != nil {
+ s.markUsageBillingRetry(ctx, billingEntry, err)
+ return fmt.Errorf("deduct balance cache: %w", err)
+ }
+ cacheDeducted = true
}
- } else {
- if shouldBill && cost.ActualCost > 0 {
- _ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
- s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
+
+ applyErr := s.runUsageBillingTx(ctx, func(txCtx context.Context) error {
+ if isSubscriptionBilling {
+ if err := s.userSubRepo.IncrementUsage(txCtx, subscription.ID, cost.TotalCost); err != nil {
+ return fmt.Errorf("increment subscription usage: %w", err)
+ }
+ } else if billAmount > 0 {
+ if err := s.userRepo.DeductBalance(txCtx, user.ID, billAmount); err != nil {
+ return fmt.Errorf("deduct balance: %w", err)
+ }
+ }
+ if billingEntry == nil {
+ return nil
+ }
+ store := s.usageBillingEntryStore()
+ if store == nil {
+ return nil
+ }
+ if err := store.MarkUsageBillingEntryApplied(txCtx, billingEntry.ID); err != nil {
+ return fmt.Errorf("mark usage billing entry applied: %w", err)
+ }
+ return nil
+ })
+ if applyErr != nil {
+ if !isSubscriptionBilling && cacheDeducted && s.billingCacheService != nil {
+ _ = s.billingCacheService.InvalidateUserBalance(context.Background(), user.ID)
+ }
+ s.markUsageBillingRetry(ctx, billingEntry, applyErr)
+ return applyErr
+ }
+
+ if isSubscriptionBilling && s.billingCacheService != nil && apiKey.GroupID != nil && cost.TotalCost > 0 {
+ s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
}
@@ -3668,15 +4117,48 @@ func (s *OpenAIGatewayService) updateCodexUsageSnapshot(ctx context.Context, acc
if len(updates) == 0 {
return
}
+ if !s.tryAcquireCodexUsageUpdateSlot() {
+ slog.Warn("openai_gateway.codex_usage_update_dropped",
+ "account_id", accountID,
+ "reason", "concurrency_limit_reached",
+ )
+ return
+ }
// Update account's Extra field asynchronously
go func() {
+ defer s.releaseCodexUsageUpdateSlot()
updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = s.accountRepo.UpdateExtra(updateCtx, accountID, updates)
}()
}
+func (s *OpenAIGatewayService) tryAcquireCodexUsageUpdateSlot() bool {
+ if s == nil {
+ return false
+ }
+ s.codexUsageUpdateOnce.Do(func() {
+ s.codexUsageUpdateSem = make(chan struct{}, openAICodexUsageUpdateConcurrency)
+ })
+ select {
+ case s.codexUsageUpdateSem <- struct{}{}:
+ return true
+ default:
+ return false
+ }
+}
+
+func (s *OpenAIGatewayService) releaseCodexUsageUpdateSlot() {
+ if s == nil || s.codexUsageUpdateSem == nil {
+ return
+ }
+ select {
+ case <-s.codexUsageUpdateSem:
+ default:
+ }
+}
+
func getOpenAIReasoningEffortFromReqBody(reqBody map[string]any) (value string, present bool) {
if reqBody == nil {
return "", false
@@ -3723,6 +4205,27 @@ func deriveOpenAIReasoningEffortFromModel(model string) string {
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
}
+// OpenAIRequestMeta 缓存已提取的请求元数据,避免重复解析
+type OpenAIRequestMeta struct {
+ Model string
+ Stream bool
+ PromptCacheKey string
+}
+
+// extractOpenAIRequestMeta 优先从 context 读取已缓存的 meta(只读),回退到 body 解析。
+// Handler 层已完成所有字段提取(含 prompt_cache_key),此处不再修改 meta,避免并发竞态。
+func extractOpenAIRequestMeta(c *gin.Context, body []byte) (model string, stream bool, promptCacheKey string) {
+ if c != nil {
+ if cached, ok := c.Get(OpenAIRequestMetaKey); ok {
+ if meta, ok := cached.(*OpenAIRequestMeta); ok && meta != nil {
+ return meta.Model, meta.Stream, meta.PromptCacheKey
+ }
+ }
+ }
+ // 回退到原始解析(WebSocket 等其他入口)
+ return extractOpenAIRequestMetaFromBody(body)
+}
+
func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) {
if len(body) == 0 {
return "", false, ""
diff --git a/backend/internal/service/openai_gateway_service_hotpath_test.go b/backend/internal/service/openai_gateway_service_hotpath_test.go
index f73c06c5e..a82e6e2a3 100644
--- a/backend/internal/service/openai_gateway_service_hotpath_test.go
+++ b/backend/internal/service/openai_gateway_service_hotpath_test.go
@@ -139,3 +139,126 @@ func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) {
require.True(t, ok)
require.Equal(t, got, cachedMap)
}
+
+// --- extractOpenAIRequestMeta context 缓存测试 ---
+
+func TestExtractOpenAIRequestMeta_CacheHit(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ // 预设缓存(Handler 层已提取所有字段,包括 PromptCacheKey)
+ c.Set(OpenAIRequestMetaKey, &OpenAIRequestMeta{
+ Model: "gpt-5",
+ Stream: true,
+ PromptCacheKey: "key-1",
+ })
+
+ body := []byte(`{"model":"gpt-4","stream":false,"prompt_cache_key":"key-other"}`)
+ model, stream, promptKey := extractOpenAIRequestMeta(c, body)
+
+ // 应返回缓存值而非 body 中的值
+ require.Equal(t, "gpt-5", model)
+ require.True(t, stream)
+ require.Equal(t, "key-1", promptKey)
+}
+
+func TestExtractOpenAIRequestMeta_CacheHit_PromptCacheKeyFromHandler(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ // Handler 层已提取 PromptCacheKey,meta 设置后只读不写
+ meta := &OpenAIRequestMeta{Model: "gpt-5", Stream: false, PromptCacheKey: "pk-abc"}
+ c.Set(OpenAIRequestMetaKey, meta)
+
+ body := []byte(`{"model":"gpt-4","prompt_cache_key":"pk-other"}`)
+
+ // 应返回缓存中的值(Handler 层提取),而非 body 中的值
+ _, _, promptKey1 := extractOpenAIRequestMeta(c, body)
+ require.Equal(t, "pk-abc", promptKey1)
+
+ // 多次调用结果一致
+ _, _, promptKey2 := extractOpenAIRequestMeta(c, body)
+ require.Equal(t, "pk-abc", promptKey2)
+}
+
+func TestExtractOpenAIRequestMeta_CacheHit_EmptyBody(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ c.Set(OpenAIRequestMetaKey, &OpenAIRequestMeta{
+ Model: "gpt-5",
+ Stream: true,
+ })
+
+ // body 为空时不应 panic,prompt_cache_key 应为空
+ model, stream, promptKey := extractOpenAIRequestMeta(c, nil)
+ require.Equal(t, "gpt-5", model)
+ require.True(t, stream)
+ require.Equal(t, "", promptKey)
+}
+
+func TestExtractOpenAIRequestMeta_FallbackToBody(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ // 不设缓存,应回退到 body 解析
+ body := []byte(`{"model":"gpt-4o","stream":true,"prompt_cache_key":"ses-2"}`)
+ model, stream, promptKey := extractOpenAIRequestMeta(c, body)
+
+ require.Equal(t, "gpt-4o", model)
+ require.True(t, stream)
+ require.Equal(t, "ses-2", promptKey)
+}
+
+func TestExtractOpenAIRequestMeta_NilContext(t *testing.T) {
+ body := []byte(`{"model":"gpt-4","stream":false,"prompt_cache_key":"k"}`)
+ model, stream, promptKey := extractOpenAIRequestMeta(nil, body)
+
+ require.Equal(t, "gpt-4", model)
+ require.False(t, stream)
+ require.Equal(t, "k", promptKey)
+}
+
+func TestExtractOpenAIRequestMeta_InvalidCacheType(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ // 缓存类型错误,应回退到 body 解析
+ c.Set(OpenAIRequestMetaKey, "invalid-type")
+
+ body := []byte(`{"model":"gpt-4o","stream":true}`)
+ model, stream, _ := extractOpenAIRequestMeta(c, body)
+
+ require.Equal(t, "gpt-4o", model)
+ require.True(t, stream)
+}
+
+func TestExtractOpenAIRequestMeta_NilCacheValue(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+
+ c.Set(OpenAIRequestMetaKey, (*OpenAIRequestMeta)(nil))
+
+ body := []byte(`{"model":"gpt-5","stream":false}`)
+ model, stream, _ := extractOpenAIRequestMeta(c, body)
+
+ require.Equal(t, "gpt-5", model)
+ require.False(t, stream)
+}
+
+func TestOpenAIRequestMeta_Fields(t *testing.T) {
+ meta := &OpenAIRequestMeta{
+ Model: "gpt-5",
+ Stream: true,
+ PromptCacheKey: "pk",
+ }
+ require.Equal(t, "gpt-5", meta.Model)
+ require.True(t, meta.Stream)
+ require.Equal(t, "pk", meta.PromptCacheKey)
+}
diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go
index 89443b694..67053933a 100644
--- a/backend/internal/service/openai_gateway_service_test.go
+++ b/backend/internal/service/openai_gateway_service_test.go
@@ -10,6 +10,8 @@ import (
"net/http"
"net/http/httptest"
"strings"
+ "sync"
+ "sync/atomic"
"testing"
"time"
@@ -57,6 +59,33 @@ func (r stubOpenAIAccountRepo) ListSchedulableByPlatform(ctx context.Context, pl
return result, nil
}
+type codexUsageUpdateAccountRepoStub struct {
+ stubOpenAIAccountRepo
+
+ calls atomic.Int32
+ entered chan struct{}
+ release chan struct{}
+ enterSig sync.Once
+}
+
+func (r *codexUsageUpdateAccountRepoStub) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
+ r.calls.Add(1)
+ r.enterSig.Do(func() {
+ if r.entered != nil {
+ close(r.entered)
+ }
+ })
+ if r.release == nil {
+ return nil
+ }
+ select {
+ case <-r.release:
+ return nil
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
type stubConcurrencyCache struct {
ConcurrencyCache
loadBatchErr error
@@ -351,6 +380,37 @@ func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulableWhenNoConcurre
}
}
+func TestOpenAIGatewayService_UpdateCodexUsageSnapshot_BoundedAsyncUpdates(t *testing.T) {
+ repo := &codexUsageUpdateAccountRepoStub{
+ entered: make(chan struct{}),
+ release: make(chan struct{}),
+ }
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ codexUsageUpdateSem: make(chan struct{}, 1),
+ }
+ svc.codexUsageUpdateOnce.Do(func() {})
+
+ snapshot := &OpenAICodexUsageSnapshot{
+ UpdatedAt: time.Now().Format(time.RFC3339),
+ }
+
+ svc.updateCodexUsageSnapshot(context.Background(), 1, snapshot)
+
+ select {
+ case <-repo.entered:
+ case <-time.After(time.Second):
+ t.Fatal("first codex usage snapshot update did not start")
+ }
+
+ // first async update is still holding the single slot
+ svc.updateCodexUsageSnapshot(context.Background(), 1, snapshot)
+ time.Sleep(80 * time.Millisecond)
+ require.Equal(t, int32(1), repo.calls.Load(), "slot full 时应拒绝新异步写入,避免 goroutine 无界增长")
+
+ close(repo.release)
+}
+
func TestOpenAISelectAccountForModelWithExclusions_StickyUnschedulableClearsSession(t *testing.T) {
sessionHash := "session-1"
repo := stubOpenAIAccountRepo{
@@ -828,6 +888,33 @@ func TestOpenAISelectAccountForModelWithExclusions_LeastRecentlyUsed(t *testing.
}
}
+func TestOpenAISelectAccountForModelWithExclusions_EqualPriorityTieBreakRandomized(t *testing.T) {
+ repo := stubOpenAIAccountRepo{
+ accounts: []Account{
+ {ID: 1, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1},
+ {ID: 2, Platform: PlatformOpenAI, Status: StatusActive, Schedulable: true, Priority: 1},
+ },
+ }
+ cache := &stubGatewayCache{}
+
+ calls := 0
+ svc := &OpenAIGatewayService{
+ accountRepo: repo,
+ cache: cache,
+ accountTieBreakIntnFn: func(n int) int {
+ calls++
+ require.Equal(t, 2, n)
+ return 0
+ },
+ }
+
+ acc, err := svc.SelectAccountForModelWithExclusions(context.Background(), nil, "", "gpt-4", nil)
+ require.NoError(t, err)
+ require.NotNil(t, acc)
+ require.Equal(t, int64(2), acc.ID, "tie-break should allow later equal-priority candidate to be selected")
+ require.Equal(t, 1, calls, "tie-break should be invoked exactly once for two equal candidates")
+}
+
func TestOpenAISelectAccountWithLoadAwareness_PreferNeverUsed(t *testing.T) {
groupID := int64(1)
lastUsed := time.Now().Add(-1 * time.Hour)
@@ -1185,6 +1272,53 @@ func TestOpenAIInvalidBaseURLWhenAllowlistDisabled(t *testing.T) {
}
}
+func TestOpenAIBuildUpstreamRequestSetsHTTPUpstreamProfile(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Security: config.SecurityConfig{
+ URLAllowlist: config.URLAllowlistConfig{Enabled: false},
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`))
+
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ }
+
+ req, err := svc.buildUpstreamRequest(context.Background(), c, account, []byte(`{}`), "token", false, "", false)
+ require.NoError(t, err)
+ require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(req.Context()))
+}
+
+func TestOpenAIBuildUpstreamPassthroughRequestSetsHTTPUpstreamProfile(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ cfg := &config.Config{
+ Security: config.SecurityConfig{
+ URLAllowlist: config.URLAllowlistConfig{Enabled: false},
+ },
+ }
+ svc := &OpenAIGatewayService{cfg: cfg}
+
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`))
+ c.Request.Header.Set("Content-Type", "application/json")
+
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ }
+
+ req, err := svc.buildUpstreamRequestOpenAIPassthrough(context.Background(), c, account, []byte(`{}`), "token")
+ require.NoError(t, err)
+ require.Equal(t, HTTPUpstreamProfileOpenAI, HTTPUpstreamProfileFromContext(req.Context()))
+}
+
func TestOpenAIValidateUpstreamBaseURLDisabledRequiresHTTPS(t *testing.T) {
cfg := &config.Config{
Security: config.SecurityConfig{
diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go
index 72f4bbb09..07cb54721 100644
--- a/backend/internal/service/openai_oauth_service.go
+++ b/backend/internal/service/openai_oauth_service.go
@@ -7,6 +7,7 @@ import (
"io"
"log/slog"
"net/http"
+ "net/url"
"regexp"
"sort"
"strconv"
@@ -14,7 +15,6 @@ import (
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
- "github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
)
@@ -273,13 +273,7 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi
req.Header.Set("Referer", "https://sora.chatgpt.com/")
req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)")
- client, err := httpclient.GetClient(httpclient.Options{
- ProxyURL: proxyURL,
- Timeout: 120 * time.Second,
- })
- if err != nil {
- return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_CLIENT_FAILED", "create http client failed: %v", err)
- }
+ client := newOpenAIOAuthHTTPClient(proxyURL)
resp, err := client.Do(req)
if err != nil {
return nil, infraerrors.Newf(http.StatusBadGateway, "SORA_SESSION_REQUEST_FAILED", "request failed: %v", err)
@@ -536,6 +530,19 @@ func (s *OpenAIOAuthService) resolveProxyURL(ctx context.Context, proxyID *int64
return proxy.URL(), nil
}
+func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client {
+ transport := &http.Transport{}
+ if strings.TrimSpace(proxyURL) != "" {
+ if parsed, err := url.Parse(proxyURL); err == nil && parsed.Host != "" {
+ transport.Proxy = http.ProxyURL(parsed)
+ }
+ }
+ return &http.Client{
+ Timeout: 120 * time.Second,
+ Transport: transport,
+ }
+}
+
func normalizeOpenAIOAuthPlatform(platform string) string {
switch strings.ToLower(strings.TrimSpace(platform)) {
case PlatformSora:
diff --git a/backend/internal/service/openai_sse_zero_alloc_test.go b/backend/internal/service/openai_sse_zero_alloc_test.go
new file mode 100644
index 000000000..fe853d59a
--- /dev/null
+++ b/backend/internal/service/openai_sse_zero_alloc_test.go
@@ -0,0 +1,276 @@
+package service
+
+import (
+ "bytes"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// --- 包级常量验证 ---
+
+func TestSSEPackageLevelConstants(t *testing.T) {
+ require.Equal(t, []byte("[DONE]"), sseDataDone)
+ require.Equal(t, []byte(`"response.completed"`), sseResponseCompletedMark)
+}
+
+func TestSSEDataDone_UsedInBytesEqual(t *testing.T) {
+ require.True(t, bytes.Equal([]byte("[DONE]"), sseDataDone))
+ require.False(t, bytes.Equal([]byte("[done]"), sseDataDone))
+ require.False(t, bytes.Equal([]byte(""), sseDataDone))
+}
+
+func TestSSEResponseCompletedMark_UsedInBytesContains(t *testing.T) {
+ data := []byte(`{"type":"response.completed","response":{"usage":{}}}`)
+ require.True(t, bytes.Contains(data, sseResponseCompletedMark))
+
+ unrelated := []byte(`{"type":"response.in_progress"}`)
+ require.False(t, bytes.Contains(unrelated, sseResponseCompletedMark))
+}
+
+// --- parseSSEUsageString 测试 ---
+
+func TestParseSSEUsageString_CompletedEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{}
+
+ data := `{"type":"response.completed","response":{"usage":{"input_tokens":100,"output_tokens":50,"input_tokens_details":{"cached_tokens":20}}}}`
+ svc.parseSSEUsageString(data, usage)
+
+ require.Equal(t, 100, usage.InputTokens)
+ require.Equal(t, 50, usage.OutputTokens)
+ require.Equal(t, 20, usage.CacheReadInputTokens)
+}
+
+func TestParseSSEUsageString_NonCompletedEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 99, OutputTokens: 88}
+
+ data := `{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}}}`
+ svc.parseSSEUsageString(data, usage)
+
+ // 非 completed 事件不应修改 usage
+ require.Equal(t, 99, usage.InputTokens)
+ require.Equal(t, 88, usage.OutputTokens)
+}
+
+func TestParseSSEUsageString_DoneEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 10}
+
+ svc.parseSSEUsageString("[DONE]", usage)
+ require.Equal(t, 10, usage.InputTokens) // 不应修改
+}
+
+func TestParseSSEUsageString_EmptyString(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 5}
+
+ svc.parseSSEUsageString("", usage)
+ require.Equal(t, 5, usage.InputTokens) // 不应修改
+}
+
+func TestParseSSEUsageString_NilUsage(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+
+ // 不应 panic
+ require.NotPanics(t, func() {
+ svc.parseSSEUsageString(`{"type":"response.completed"}`, nil)
+ })
+}
+
+func TestParseSSEUsageString_ShortData(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 7}
+
+ // 短于 80 字节的数据直接跳过
+ svc.parseSSEUsageString(`{"type":"response.completed"}`, usage)
+ require.Equal(t, 7, usage.InputTokens) // 不应修改
+}
+
+func TestParseSSEUsageString_ContainsCompletedButWrongType(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 42}
+
+ // 包含 "response.completed" 子串但 type 字段不匹配
+ data := `{"type":"response.in_progress","description":"not response.completed at all","padding":"aaaaaaaaaaaaaaaaaaaaaaaaaaaa"}`
+ svc.parseSSEUsageString(data, usage)
+ require.Equal(t, 42, usage.InputTokens) // 不应修改
+}
+
+func TestParseSSEUsageString_ZeroUsageValues(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 99}
+
+ data := `{"type":"response.completed","response":{"usage":{"input_tokens":0,"output_tokens":0,"input_tokens_details":{"cached_tokens":0}}}}`
+ svc.parseSSEUsageString(data, usage)
+
+ require.Equal(t, 0, usage.InputTokens)
+ require.Equal(t, 0, usage.OutputTokens)
+ require.Equal(t, 0, usage.CacheReadInputTokens)
+}
+
+func TestParseSSEUsageString_MissingUsageFields(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{}
+
+ // response.usage 存在但缺少某些子字段
+ data := `{"type":"response.completed","response":{"usage":{"input_tokens":10},"padding":"aaaaaaaaaaaaaaaaaaa"}}`
+ svc.parseSSEUsageString(data, usage)
+
+ require.Equal(t, 10, usage.InputTokens)
+ require.Equal(t, 0, usage.OutputTokens)
+ require.Equal(t, 0, usage.CacheReadInputTokens)
+}
+
+// --- parseSSEUsageBytes 与包级常量集成测试 ---
+
+func TestParseSSEUsageBytes_CompletedEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{}
+
+ data := []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":200,"output_tokens":80,"input_tokens_details":{"cached_tokens":30}}}}`)
+ svc.parseSSEUsageBytes(data, usage)
+
+ require.Equal(t, 200, usage.InputTokens)
+ require.Equal(t, 80, usage.OutputTokens)
+ require.Equal(t, 30, usage.CacheReadInputTokens)
+}
+
+func TestParseSSEUsageBytes_DoneEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 10}
+
+ svc.parseSSEUsageBytes([]byte("[DONE]"), usage)
+ require.Equal(t, 10, usage.InputTokens) // 不应修改
+}
+
+func TestParseSSEUsageBytes_EmptyData(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 5}
+
+ svc.parseSSEUsageBytes(nil, usage)
+ require.Equal(t, 5, usage.InputTokens)
+
+ svc.parseSSEUsageBytes([]byte{}, usage)
+ require.Equal(t, 5, usage.InputTokens)
+}
+
+func TestParseSSEUsageBytes_NilUsage(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+
+ require.NotPanics(t, func() {
+ svc.parseSSEUsageBytes([]byte(`{"type":"response.completed"}`), nil)
+ })
+}
+
+func TestParseSSEUsageBytes_ShortData(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 7}
+
+ svc.parseSSEUsageBytes([]byte(`{"type":"response.completed"}`), usage)
+ require.Equal(t, 7, usage.InputTokens)
+}
+
+func TestParseSSEUsageBytes_NonCompletedEvent(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 99}
+
+ data := []byte(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}},"pad":"xxxx"}`)
+ svc.parseSSEUsageBytes(data, usage)
+
+ require.Equal(t, 99, usage.InputTokens)
+}
+
+func TestParseSSEUsageBytes_GetManyBytesExtraction(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{}
+
+ // 验证 GetManyBytes 一次提取 3 个字段的正确性
+ data := []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":111,"output_tokens":222,"input_tokens_details":{"cached_tokens":333}}}}`)
+ svc.parseSSEUsageBytes(data, usage)
+
+ require.Equal(t, 111, usage.InputTokens)
+ require.Equal(t, 222, usage.OutputTokens)
+ require.Equal(t, 333, usage.CacheReadInputTokens)
+}
+
+// --- parseSSEUsage wrapper 测试 ---
+
+func TestParseSSEUsage_DelegatesToString(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{}
+
+ // 验证 parseSSEUsage 最终正确提取 usage
+ data := `{"type":"response.completed","response":{"usage":{"input_tokens":55,"output_tokens":66,"input_tokens_details":{"cached_tokens":77}}}}`
+ svc.parseSSEUsage(data, usage)
+
+ require.Equal(t, 55, usage.InputTokens)
+ require.Equal(t, 66, usage.OutputTokens)
+ require.Equal(t, 77, usage.CacheReadInputTokens)
+}
+
+func TestParseSSEUsage_DoneNotParsed(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 123}
+
+ svc.parseSSEUsage("[DONE]", usage)
+ require.Equal(t, 123, usage.InputTokens)
+}
+
+func TestParseSSEUsage_EmptyNotParsed(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ usage := &OpenAIUsage{InputTokens: 456}
+
+ svc.parseSSEUsage("", usage)
+ require.Equal(t, 456, usage.InputTokens)
+}
+
+// --- string 和 bytes 一致性测试 ---
+
+func TestParseSSEUsage_StringAndBytesConsistency(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+
+ completedData := `{"type":"response.completed","response":{"usage":{"input_tokens":300,"output_tokens":150,"input_tokens_details":{"cached_tokens":50}}}}`
+
+ usageStr := &OpenAIUsage{}
+ svc.parseSSEUsageString(completedData, usageStr)
+
+ usageBytes := &OpenAIUsage{}
+ svc.parseSSEUsageBytes([]byte(completedData), usageBytes)
+
+ require.Equal(t, usageStr.InputTokens, usageBytes.InputTokens)
+ require.Equal(t, usageStr.OutputTokens, usageBytes.OutputTokens)
+ require.Equal(t, usageStr.CacheReadInputTokens, usageBytes.CacheReadInputTokens)
+}
+
+func TestParseSSEUsage_StringAndBytesConsistency_NonCompleted(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+
+ inProgressData := `{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":3}}},"pad":"xxx"}`
+
+ usageStr := &OpenAIUsage{InputTokens: 10}
+ svc.parseSSEUsageString(inProgressData, usageStr)
+
+ usageBytes := &OpenAIUsage{InputTokens: 10}
+ svc.parseSSEUsageBytes([]byte(inProgressData), usageBytes)
+
+ // 两者都不应修改
+ require.Equal(t, 10, usageStr.InputTokens)
+ require.Equal(t, 10, usageBytes.InputTokens)
+}
+
+func TestParseSSEUsage_StringAndBytesConsistency_LargeTokenCounts(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+
+ data := `{"type":"response.completed","response":{"usage":{"input_tokens":1000000,"output_tokens":500000,"input_tokens_details":{"cached_tokens":200000}}}}`
+
+ usageStr := &OpenAIUsage{}
+ svc.parseSSEUsageString(data, usageStr)
+
+ usageBytes := &OpenAIUsage{}
+ svc.parseSSEUsageBytes([]byte(data), usageBytes)
+
+ require.Equal(t, 1000000, usageStr.InputTokens)
+ require.Equal(t, usageStr, usageBytes)
+}
diff --git a/backend/internal/service/openai_sticky_compat.go b/backend/internal/service/openai_sticky_compat.go
index e897debc2..0f576b664 100644
--- a/backend/internal/service/openai_sticky_compat.go
+++ b/backend/internal/service/openai_sticky_compat.go
@@ -35,12 +35,31 @@ func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash
return "", ""
}
- currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized))
- sum := sha256.Sum256([]byte(normalized))
- legacyHash = hex.EncodeToString(sum[:])
+ currentHash = deriveOpenAISessionHash(normalized)
+ legacyHash = deriveOpenAILegacySessionHash(normalized)
return currentHash, legacyHash
}
+// deriveOpenAISessionHash returns the fast xxhash-based session hash.
+func deriveOpenAISessionHash(sessionID string) string {
+ normalized := strings.TrimSpace(sessionID)
+ if normalized == "" {
+ return ""
+ }
+ return fmt.Sprintf("%016x", xxhash.Sum64String(normalized))
+}
+
+// deriveOpenAILegacySessionHash returns the SHA-256 legacy hash.
+// Only call this when legacy fallback or dual-write is enabled.
+func deriveOpenAILegacySessionHash(sessionID string) string {
+ normalized := strings.TrimSpace(sessionID)
+ if normalized == "" {
+ return ""
+ }
+ sum := sha256.Sum256([]byte(normalized))
+ return hex.EncodeToString(sum[:])
+}
+
func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context {
if ctx == nil {
return nil
diff --git a/backend/internal/service/openai_ws_client_preempt_test.go b/backend/internal/service/openai_ws_client_preempt_test.go
new file mode 100644
index 000000000..3df8d12ef
--- /dev/null
+++ b/backend/internal/service/openai_ws_client_preempt_test.go
@@ -0,0 +1,1076 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "strconv"
+ "testing"
+
+ coderws "github.com/coder/websocket"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// errOpenAIWSClientPreempted 哨兵错误基础测试
+// ---------------------------------------------------------------------------
+
+func TestErrOpenAIWSClientPreempted_NotNil(t *testing.T) {
+ t.Parallel()
+ require.NotNil(t, errOpenAIWSClientPreempted)
+ require.Contains(t, errOpenAIWSClientPreempted.Error(), "client preempted")
+}
+
+func TestErrOpenAIWSClientPreempted_ErrorsIs(t *testing.T) {
+ t.Parallel()
+
+ // 直接匹配
+ require.True(t, errors.Is(errOpenAIWSClientPreempted, errOpenAIWSClientPreempted))
+
+ // 包裹后仍可匹配
+ wrapped := fmt.Errorf("outer: %w", errOpenAIWSClientPreempted)
+ require.True(t, errors.Is(wrapped, errOpenAIWSClientPreempted))
+
+ // 不同错误不匹配
+ require.False(t, errors.Is(errors.New("other"), errOpenAIWSClientPreempted))
+}
+
+func TestErrOpenAIWSClientPreempted_WrapInTurnError(t *testing.T) {
+ t.Parallel()
+
+ // 用 wrapOpenAIWSIngressTurnErrorWithPartial 包裹后 errors.Is 仍能识别
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ require.Error(t, turnErr)
+ require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted))
+}
+
+func TestErrOpenAIWSClientPreempted_WrapInTurnError_WithPartialResult(t *testing.T) {
+ t.Parallel()
+
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_preempt_partial",
+ Usage: OpenAIUsage{
+ InputTokens: 100,
+ OutputTokens: 50,
+ },
+ }
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true,
+ partial,
+ )
+ require.Error(t, turnErr)
+ require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted))
+
+ // 验证 partial result 可提取
+ got, ok := OpenAIWSIngressTurnPartialResult(turnErr)
+ require.True(t, ok)
+ require.NotNil(t, got)
+ require.Equal(t, partial.RequestID, got.RequestID)
+ require.Equal(t, partial.Usage.InputTokens, got.Usage.InputTokens)
+}
+
+// ---------------------------------------------------------------------------
+// classifyOpenAIWSIngressTurnAbortReason 对 client_preempted 的识别测试
+// ---------------------------------------------------------------------------
+
+func TestClassifyAbortReason_ClientPreempted_Direct(t *testing.T) {
+ t.Parallel()
+
+ // 直接哨兵错误
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(errOpenAIWSClientPreempted)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+}
+
+func TestClassifyAbortReason_ClientPreempted_WrappedInTurnError(t *testing.T) {
+ t.Parallel()
+
+ // 包裹在 turnError 中
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+}
+
+func TestClassifyAbortReason_ClientPreempted_WrappedInTurnError_WroteDownstream(t *testing.T) {
+ t.Parallel()
+
+ // 包裹在 turnError 中,wroteDownstream=true
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true,
+ &OpenAIForwardResult{RequestID: "resp_partial"},
+ )
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+}
+
+func TestClassifyAbortReason_ClientPreempted_DoubleWrapped(t *testing.T) {
+ t.Parallel()
+
+ // 多层 fmt.Errorf 包裹
+ inner := fmt.Errorf("relay failed: %w", errOpenAIWSClientPreempted)
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(inner)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+}
+
+func TestClassifyAbortReason_ClientPreempted_NotConfusedWithOther(t *testing.T) {
+ t.Parallel()
+
+ // 确保其他错误不会被误分类为 client_preempted
+ others := []error{
+ errors.New("client preempted"), // 文本相同但不是同一哨兵
+ context.Canceled, // context 取消
+ io.EOF, // 客户端断连
+ errors.New("random error"), // 随机错误
+ }
+
+ for _, err := range others {
+ reason, _ := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.NotEqual(t, openAIWSIngressTurnAbortReasonClientPreempted, reason,
+ "error %q should not classify as client_preempted", err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// openAIWSIngressTurnAbortDispositionForReason 对 ClientPreempted 的处置测试
+// ---------------------------------------------------------------------------
+
+func TestDisposition_ClientPreempted_IsContinueTurn(t *testing.T) {
+ t.Parallel()
+
+ disposition := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+}
+
+func TestDisposition_ClientPreempted_SameAsPreviousResponse(t *testing.T) {
+ t.Parallel()
+
+ // client_preempted 与 previous_response_not_found 应有相同的处置
+ prevDisp := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonPreviousResponse)
+ preemptDisp := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted)
+ require.Equal(t, prevDisp, preemptDisp)
+}
+
+func TestDisposition_AllContinueTurnReasons(t *testing.T) {
+ t.Parallel()
+
+ // 验证所有应归为 ContinueTurn 的 reason 列表完整且正确
+ continueTurnReasons := []openAIWSIngressTurnAbortReason{
+ openAIWSIngressTurnAbortReasonPreviousResponse,
+ openAIWSIngressTurnAbortReasonToolOutput,
+ openAIWSIngressTurnAbortReasonUpstreamError,
+ openAIWSIngressTurnAbortReasonClientPreempted,
+ }
+
+ for _, reason := range continueTurnReasons {
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition,
+ "reason %q should be ContinueTurn", reason)
+ }
+}
+
+func TestDisposition_ClientPreempted_NotCloseGracefully(t *testing.T) {
+ t.Parallel()
+
+ disposition := openAIWSIngressTurnAbortDispositionForReason(openAIWSIngressTurnAbortReasonClientPreempted)
+ require.NotEqual(t, openAIWSIngressTurnAbortDispositionCloseGracefully, disposition)
+ require.NotEqual(t, openAIWSIngressTurnAbortDispositionFailRequest, disposition)
+}
+
+// ---------------------------------------------------------------------------
+// 端到端 classify → disposition 链路测试
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_ClassifyToDisposition_EndToEnd(t *testing.T) {
+ t.Parallel()
+
+ // 模拟 sendAndRelay 返回 client_preempted 错误的完整链路
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+
+ // 1. classify
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+
+ // 2. disposition
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+
+ // 3. wroteDownstream
+ require.False(t, openAIWSIngressTurnWroteDownstream(turnErr))
+}
+
+func TestClientPreempted_ClassifyToDisposition_WroteDownstream(t *testing.T) {
+ t.Parallel()
+
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_half",
+ Usage: OpenAIUsage{
+ InputTokens: 200,
+ OutputTokens: 100,
+ },
+ }
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true,
+ partial,
+ )
+
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+
+ require.True(t, openAIWSIngressTurnWroteDownstream(turnErr))
+
+ got, ok := OpenAIWSIngressTurnPartialResult(turnErr)
+ require.True(t, ok)
+ require.Equal(t, "resp_half", got.RequestID)
+}
+
+// ---------------------------------------------------------------------------
+// ContinueTurn 分支对 client_preempted 的特殊行为验证
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_ShouldNotSendErrorEvent(t *testing.T) {
+ t.Parallel()
+
+ // 核心语义:client_preempted 时客户端已发出新请求,不需要旧 turn 的 error 事件。
+ // 验证 abortReason 为 client_preempted 时不应产生 error 通知。
+ abortReason := openAIWSIngressTurnAbortReasonClientPreempted
+
+ // 模拟 ContinueTurn 分支的判断逻辑
+ shouldSendError := abortReason != openAIWSIngressTurnAbortReasonClientPreempted
+ require.False(t, shouldSendError, "client_preempted 不应发送 error 事件")
+}
+
+func TestClientPreempted_ShouldNotClearLastResponseID(t *testing.T) {
+ t.Parallel()
+
+ // 核心语义:被抢占的 turn 未完成,上一轮 response_id 仍有效供新 turn 续链。
+ // 验证 abortReason 为 client_preempted 时不应调用 clearSessionLastResponseID。
+ abortReason := openAIWSIngressTurnAbortReasonClientPreempted
+
+ shouldClearLastResponseID := abortReason != openAIWSIngressTurnAbortReasonClientPreempted
+ require.False(t, shouldClearLastResponseID,
+ "client_preempted 不应清除 lastResponseID")
+}
+
+func TestNonPreempted_ContinueTurn_ShouldSendErrorAndClearID(t *testing.T) {
+ t.Parallel()
+
+ // 对照测试:非 client_preempted 的 ContinueTurn reason 应正常发送 error 并清除 ID
+ otherReasons := []openAIWSIngressTurnAbortReason{
+ openAIWSIngressTurnAbortReasonPreviousResponse,
+ openAIWSIngressTurnAbortReasonToolOutput,
+ openAIWSIngressTurnAbortReasonUpstreamError,
+ }
+
+ for _, reason := range otherReasons {
+ shouldSendError := reason != openAIWSIngressTurnAbortReasonClientPreempted
+ shouldClearID := reason != openAIWSIngressTurnAbortReasonClientPreempted
+ require.True(t, shouldSendError,
+ "reason %q (non-preempted) should send error event", reason)
+ require.True(t, shouldClearID,
+ "reason %q (non-preempted) should clear lastResponseID", reason)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// ContinueTurn abort 路径中 client_preempted 的 error 事件格式验证
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_ErrorEventNotGenerated(t *testing.T) {
+ t.Parallel()
+
+ // 在实际的 ContinueTurn 分支中,client_preempted 分支根本不会构造 error 事件。
+ // 此测试验证如果误走错误路径(防御性),error 事件格式仍然正确。
+ abortReason := openAIWSIngressTurnAbortReasonClientPreempted
+ abortMessage := "turn failed: " + string(abortReason)
+
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` +
+ string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`)
+
+ var parsed map[string]any
+ err := json.Unmarshal(errorEvent, &parsed)
+ require.NoError(t, err, "hypothetical error event should be valid JSON")
+
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "client_preempted", errorObj["code"])
+ require.Contains(t, errorObj["message"], "client_preempted")
+}
+
+// ---------------------------------------------------------------------------
+// openAIWSIngressTurnAbortReason 常量值验证
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_ReasonStringValue(t *testing.T) {
+ t.Parallel()
+
+ require.Equal(t, openAIWSIngressTurnAbortReason("client_preempted"),
+ openAIWSIngressTurnAbortReasonClientPreempted)
+}
+
+// ---------------------------------------------------------------------------
+// classifyOpenAIWSIngressTurnAbortReason 完整 table-driven 测试(含 client_preempted)
+// ---------------------------------------------------------------------------
+
+func TestClassifyAbortReason_AllReasons_IncludeClientPreempted(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ wantReason openAIWSIngressTurnAbortReason
+ wantExpected bool
+ }{
+ {
+ name: "client_preempted_sentinel",
+ err: errOpenAIWSClientPreempted,
+ wantReason: openAIWSIngressTurnAbortReasonClientPreempted,
+ wantExpected: true,
+ },
+ {
+ name: "client_preempted_wrapped_in_turn_error",
+ err: wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonClientPreempted,
+ wantExpected: true,
+ },
+ {
+ name: "client_preempted_wrapped_in_turn_error_wrote_downstream",
+ err: wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true,
+ &OpenAIForwardResult{RequestID: "resp_x"},
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonClientPreempted,
+ wantExpected: true,
+ },
+ {
+ name: "client_preempted_double_wrapped",
+ err: fmt.Errorf("relay: %w", errOpenAIWSClientPreempted),
+ wantReason: openAIWSIngressTurnAbortReasonClientPreempted,
+ wantExpected: true,
+ },
+ {
+ name: "previous_response_not_confused_with_preempt",
+ err: wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("not found"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonPreviousResponse,
+ wantExpected: true,
+ },
+ {
+ name: "tool_output_not_confused_with_preempt",
+ err: wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("tool output not found"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonToolOutput,
+ wantExpected: true,
+ },
+ {
+ name: "context_canceled_not_preempted",
+ err: context.Canceled,
+ wantReason: openAIWSIngressTurnAbortReasonContextCanceled,
+ wantExpected: true,
+ },
+ {
+ name: "eof_not_preempted",
+ err: io.EOF,
+ wantReason: openAIWSIngressTurnAbortReasonClientClosed,
+ wantExpected: true,
+ },
+ {
+ name: "ws_normal_closure_not_preempted",
+ err: coderws.CloseError{Code: coderws.StatusNormalClosure},
+ wantReason: openAIWSIngressTurnAbortReasonClientClosed,
+ wantExpected: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(tt.err)
+ require.Equal(t, tt.wantReason, reason)
+ require.Equal(t, tt.wantExpected, expected)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// classify 优先级测试:client_preempted 在 context.Canceled 之前
+// ---------------------------------------------------------------------------
+
+func TestClassifyAbortReason_ClientPreempted_PriorityOverContextCanceled(t *testing.T) {
+ t.Parallel()
+
+ // errOpenAIWSClientPreempted 不会同时匹配 context.Canceled,
+ // 但若将来有包裹 context.Canceled 的情况,client_preempted 检测应在前。
+ reason, _ := classifyOpenAIWSIngressTurnAbortReason(errOpenAIWSClientPreempted)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason,
+ "client_preempted 检测应优先于 context.Canceled")
+}
+
+// ---------------------------------------------------------------------------
+// openAIWSIngressTurnAbortDispositionForReason table-driven 测试(含 client_preempted)
+// ---------------------------------------------------------------------------
+
+func TestDisposition_AllReasons_IncludeClientPreempted(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ reason openAIWSIngressTurnAbortReason
+ wantDisp openAIWSIngressTurnAbortDisposition
+ }{
+ {openAIWSIngressTurnAbortReasonPreviousResponse, openAIWSIngressTurnAbortDispositionContinueTurn},
+ {openAIWSIngressTurnAbortReasonToolOutput, openAIWSIngressTurnAbortDispositionContinueTurn},
+ {openAIWSIngressTurnAbortReasonUpstreamError, openAIWSIngressTurnAbortDispositionContinueTurn},
+ {openAIWSIngressTurnAbortReasonClientPreempted, openAIWSIngressTurnAbortDispositionContinueTurn},
+ {openAIWSIngressTurnAbortReasonContextCanceled, openAIWSIngressTurnAbortDispositionCloseGracefully},
+ {openAIWSIngressTurnAbortReasonClientClosed, openAIWSIngressTurnAbortDispositionCloseGracefully},
+ {openAIWSIngressTurnAbortReasonUnknown, openAIWSIngressTurnAbortDispositionFailRequest},
+ {openAIWSIngressTurnAbortReasonContextDeadline, openAIWSIngressTurnAbortDispositionFailRequest},
+ {openAIWSIngressTurnAbortReasonWriteUpstream, openAIWSIngressTurnAbortDispositionFailRequest},
+ {openAIWSIngressTurnAbortReasonReadUpstream, openAIWSIngressTurnAbortDispositionFailRequest},
+ {openAIWSIngressTurnAbortReasonWriteClient, openAIWSIngressTurnAbortDispositionFailRequest},
+ {openAIWSIngressTurnAbortReasonContinuationUnavailable, openAIWSIngressTurnAbortDispositionFailRequest},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(string(tt.reason), func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.wantDisp, openAIWSIngressTurnAbortDispositionForReason(tt.reason))
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// isOpenAIWSIngressTurnRetryable 与 client_preempted 的交互
+// ---------------------------------------------------------------------------
+
+func TestIsRetryable_ClientPreempted_NotRetryable(t *testing.T) {
+ t.Parallel()
+
+ // client_preempted 有专门的恢复路径(ContinueTurn),不走通用重试
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ require.False(t, isOpenAIWSIngressTurnRetryable(turnErr),
+ "client_preempted 不应被标记为 retryable")
+}
+
+func TestIsRetryable_ClientPreempted_WroteDownstream(t *testing.T) {
+ t.Parallel()
+
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true,
+ nil,
+ )
+ require.False(t, isOpenAIWSIngressTurnRetryable(turnErr),
+ "client_preempted wroteDownstream=true 不应被标记为 retryable")
+}
+
+// ---------------------------------------------------------------------------
+// sendAndRelay 中 clientMsgCh / clientReadErrCh 行为的单元级测试
+// ---------------------------------------------------------------------------
+
+func TestClientMsgCh_BufferedOne(t *testing.T) {
+ t.Parallel()
+
+ // 验证 clientMsgCh(buffered 1) 的语义:goroutine 在 sendAndRelay 返回到
+ // advanceToNextClientTurn 的间隙不阻塞
+ ch := make(chan []byte, 1)
+
+ // 非阻塞写入
+ select {
+ case ch <- []byte(`{"type":"response.create"}`):
+ // ok
+ default:
+ t.Fatal("buffered(1) channel should not block on first write")
+ }
+
+ // 第二次写入应阻塞
+ select {
+ case ch <- []byte(`{"type":"response.create"}`):
+ t.Fatal("buffered(1) channel should block on second write")
+ default:
+ // expected
+ }
+}
+
+func TestClientReadErrCh_BufferedOne(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan error, 1)
+
+ // 非阻塞写入
+ select {
+ case ch <- io.EOF:
+ default:
+ t.Fatal("buffered(1) channel should not block on first write")
+ }
+
+ // 第二次写入应阻塞
+ select {
+ case ch <- io.EOF:
+ t.Fatal("buffered(1) channel should block on second write")
+ default:
+ // expected
+ }
+}
+
+func TestClientMsgCh_CloseSignalsClosed(t *testing.T) {
+ t.Parallel()
+
+ ch := make(chan []byte, 1)
+ close(ch)
+
+ msg, ok := <-ch
+ require.False(t, ok, "closed channel should return ok=false")
+ require.Nil(t, msg)
+}
+
+// ---------------------------------------------------------------------------
+// 客户端抢占暂存(nextClientPreemptedPayload)行为测试
+// ---------------------------------------------------------------------------
+
+func TestPreemptedPayload_ConsumedOnce(t *testing.T) {
+ t.Parallel()
+
+ // 模拟 advanceToNextClientTurn 中预存消息的消费行为
+ var nextPreempted []byte
+ nextPreempted = []byte(`{"type":"response.create","model":"gpt-5.1"}`)
+
+ // 第一次消费
+ require.NotNil(t, nextPreempted)
+ msg := nextPreempted
+ nextPreempted = nil
+
+ require.Equal(t, `{"type":"response.create","model":"gpt-5.1"}`, string(msg))
+ require.Nil(t, nextPreempted, "消费后应置空")
+}
+
+func TestPreemptedPayload_NilFallsBackToChannel(t *testing.T) {
+ t.Parallel()
+
+ // 模拟 advanceToNextClientTurn 中无预存消息时走 channel
+ var nextPreempted []byte
+ clientMsgCh := make(chan []byte, 1)
+ clientMsgCh <- []byte(`{"type":"response.create","model":"gpt-5.1"}`)
+
+ var nextClientMessage []byte
+ if nextPreempted != nil {
+ nextClientMessage = nextPreempted
+ nextPreempted = nil
+ } else {
+ select {
+ case msg, ok := <-clientMsgCh:
+ require.True(t, ok)
+ nextClientMessage = msg
+ }
+ }
+
+ require.Equal(t, `{"type":"response.create","model":"gpt-5.1"}`, string(nextClientMessage))
+}
+
+// ---------------------------------------------------------------------------
+// sendAndRelay select 路径:pumpEventCh 关闭 → goto pumpClosed
+// ---------------------------------------------------------------------------
+
+func TestSelectLoop_PumpClosed_GoToPumpClosed(t *testing.T) {
+ t.Parallel()
+
+ // 模拟 pumpEventCh 关闭时的行为
+ pumpEventCh := make(chan openAIWSUpstreamPumpEvent)
+ close(pumpEventCh)
+
+ evt, ok := <-pumpEventCh
+ require.False(t, ok, "closed pumpEventCh should return ok=false")
+ require.Nil(t, evt.message)
+ require.Nil(t, evt.err)
+}
+
+// ---------------------------------------------------------------------------
+// sendAndRelay select 路径:clientMsgCh 收到消息 → client preempt
+// ---------------------------------------------------------------------------
+
+func TestSelectLoop_ClientPreempt_ReturnsCorrectError(t *testing.T) {
+ t.Parallel()
+
+ // 模拟 select 中收到客户端抢占消息后生成的 turnError
+ preemptPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
+
+ // 模拟 sendAndRelay 返回的错误
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil, // buildPartialResult 在没有 usage 时返回 nil
+ )
+
+ // 验证错误分类
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+
+ // 验证处置
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+
+ // 验证预存消息可供 advanceToNextClientTurn 使用
+ require.NotEmpty(t, preemptPayload)
+}
+
+func TestSelectLoop_ClientPreempt_WithPartialUsage(t *testing.T) {
+ t.Parallel()
+
+ // 模拟上游已发送部分 token 后被客户端抢占
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_interrupted",
+ Usage: OpenAIUsage{
+ InputTokens: 500,
+ OutputTokens: 200,
+ },
+ }
+
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ true, // 已写过下游
+ partial,
+ )
+
+ require.True(t, errors.Is(turnErr, errOpenAIWSClientPreempted))
+ require.True(t, openAIWSIngressTurnWroteDownstream(turnErr))
+
+ got, ok := OpenAIWSIngressTurnPartialResult(turnErr)
+ require.True(t, ok)
+ require.Equal(t, "resp_interrupted", got.RequestID)
+ require.Equal(t, 500, got.Usage.InputTokens)
+ require.Equal(t, 200, got.Usage.OutputTokens)
+}
+
+// ---------------------------------------------------------------------------
+// sendAndRelay select 路径:clientMsgCh 关闭 → nil channel
+// ---------------------------------------------------------------------------
+
+func TestSelectLoop_ClientMsgChClosed_NilChannelPreventsReselect(t *testing.T) {
+ t.Parallel()
+
+ // clientMsgCh 关闭后被设为 nil,后续 select 不应再选中它
+ var clientMsgCh chan []byte
+ clientMsgCh = make(chan []byte, 1)
+ close(clientMsgCh)
+
+ // 第一次读取:closed
+ _, ok := <-clientMsgCh
+ require.False(t, ok)
+
+ // 设为 nil
+ clientMsgCh = nil
+
+ // nil channel 上的 select 永远不会被选中(不会 panic)
+ select {
+ case <-clientMsgCh:
+ t.Fatal("nil channel should never be selected")
+ default:
+ // expected: nil channel 不参与 select
+ }
+}
+
+// ---------------------------------------------------------------------------
+// sendAndRelay select 路径:clientReadErrCh 客户端断连
+// ---------------------------------------------------------------------------
+
+func TestSelectLoop_ClientReadErr_DisconnectSetsDrain(t *testing.T) {
+ t.Parallel()
+
+ // 模拟客户端断连读取错误的分类行为
+ disconnectErrors := []error{
+ io.EOF,
+ coderws.CloseError{Code: coderws.StatusNormalClosure},
+ coderws.CloseError{Code: coderws.StatusGoingAway},
+ }
+
+ for _, readErr := range disconnectErrors {
+ require.True(t, isOpenAIWSClientDisconnectError(readErr),
+ "error %v should be classified as client disconnect", readErr)
+ }
+}
+
+func TestSelectLoop_ClientReadErr_NonDisconnect(t *testing.T) {
+ t.Parallel()
+
+ // 非断连错误不应触发 drain
+ nonDisconnectErrors := []error{
+ errors.New("tls handshake timeout"),
+ coderws.CloseError{Code: coderws.StatusPolicyViolation},
+ }
+
+ for _, readErr := range nonDisconnectErrors {
+ require.False(t, isOpenAIWSClientDisconnectError(readErr),
+ "error %v should not be classified as client disconnect", readErr)
+ }
+}
+
+func TestSelectLoop_ClientReadErr_NilChannelsAfterError(t *testing.T) {
+ t.Parallel()
+
+ // 模拟收到 clientReadErrCh 后将两个 channel 置 nil
+ clientMsgCh := make(chan []byte, 1)
+ clientReadErrCh := make(chan error, 1)
+
+ clientReadErrCh <- io.EOF
+
+ // 消费错误
+ readErr := <-clientReadErrCh
+ require.Error(t, readErr)
+
+ // 模拟置空(实际代码中 select case 后的操作)
+ var nilMsgCh chan []byte
+ var nilErrCh chan error
+ nilMsgCh = nil
+ nilErrCh = nil
+
+ // 验证 nil channel 行为
+ _ = clientMsgCh // unused in this test
+
+ select {
+ case <-nilMsgCh:
+ t.Fatal("nil channel should never be selected")
+ case <-nilErrCh:
+ t.Fatal("nil channel should never be selected")
+ default:
+ // expected
+ }
+}
+
+func TestAdvanceConsumePendingClientReadErr(t *testing.T) {
+ t.Parallel()
+
+ require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(nil))
+
+ var pendingErr error
+ require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(&pendingErr))
+
+ sourceErr := errors.New("custom read error")
+ pendingErr = sourceErr
+
+ gotErr := openAIWSAdvanceConsumePendingClientReadErr(&pendingErr)
+ require.Error(t, gotErr)
+ require.ErrorIs(t, gotErr, sourceErr)
+ require.Nil(t, pendingErr, "pending error should be consumed once")
+ require.NoError(t, openAIWSAdvanceConsumePendingClientReadErr(&pendingErr))
+}
+
+func TestAdvanceClientReadUnavailable(t *testing.T) {
+ t.Parallel()
+
+ var nilMsgCh chan []byte
+ var nilErrCh chan error
+ require.True(t, openAIWSAdvanceClientReadUnavailable(nilMsgCh, nilErrCh))
+
+ msgCh := make(chan []byte, 1)
+ require.False(t, openAIWSAdvanceClientReadUnavailable(msgCh, nilErrCh))
+
+ errCh := make(chan error, 1)
+ require.False(t, openAIWSAdvanceClientReadUnavailable(nilMsgCh, errCh))
+ require.False(t, openAIWSAdvanceClientReadUnavailable(msgCh, errCh))
+}
+
+// ---------------------------------------------------------------------------
+// advanceToNextClientTurn channel 读取路径测试
+// ---------------------------------------------------------------------------
+
+func TestAdvance_ClientMsgCh_ClosedReturnsExit(t *testing.T) {
+ t.Parallel()
+
+ // clientMsgCh 关闭意味着客户端读取 goroutine 已退出,应返回 exit=true
+ ch := make(chan []byte, 1)
+ close(ch)
+
+ _, ok := <-ch
+ require.False(t, ok, "should signal goroutine exit")
+}
+
+func TestAdvance_ClientReadErrCh_DisconnectReturnsExit(t *testing.T) {
+ t.Parallel()
+
+ // 断连错误应返回 exit=true
+ ch := make(chan error, 1)
+ ch <- io.EOF
+
+ readErr := <-ch
+ require.True(t, isOpenAIWSClientDisconnectError(readErr))
+}
+
+func TestAdvance_ClientReadErrCh_NonDisconnectReturnsError(t *testing.T) {
+ t.Parallel()
+
+ // 非断连错误应返回 error
+ ch := make(chan error, 1)
+ errCustom := errors.New("custom read error")
+ ch <- errCustom
+
+ readErr := <-ch
+ require.False(t, isOpenAIWSClientDisconnectError(readErr))
+ require.Equal(t, errCustom, readErr)
+}
+
+// ---------------------------------------------------------------------------
+// 持久客户端读取 goroutine 行为测试
+// ---------------------------------------------------------------------------
+
+func TestPersistentReader_NormalMessage(t *testing.T) {
+ t.Parallel()
+
+ // 模拟正常消息的推送和消费
+ clientMsgCh := make(chan []byte, 1)
+
+ // 模拟 goroutine 写入
+ go func() {
+ clientMsgCh <- []byte(`{"type":"response.create"}`)
+ }()
+
+ msg := <-clientMsgCh
+ require.Equal(t, `{"type":"response.create"}`, string(msg))
+}
+
+func TestPersistentReader_ErrorSendsToErrCh(t *testing.T) {
+ t.Parallel()
+
+ clientReadErrCh := make(chan error, 1)
+
+ // 模拟 goroutine 发送错误
+ go func() {
+ clientReadErrCh <- io.EOF
+ }()
+
+ readErr := <-clientReadErrCh
+ require.Equal(t, io.EOF, readErr)
+}
+
+func TestPersistentReader_ContextCancel(t *testing.T) {
+ t.Parallel()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ clientMsgCh := make(chan []byte, 1)
+
+ // 填满 buffer
+ clientMsgCh <- []byte("first")
+
+ // 模拟 goroutine 尝试写入已满的 channel
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ select {
+ case clientMsgCh <- []byte("second"):
+ // 不应到达
+ case <-ctx.Done():
+ // 正确退出
+ return
+ }
+ }()
+
+ // 取消 context
+ cancel()
+ <-done
+}
+
+func TestPersistentReader_ClosesMsgChOnExit(t *testing.T) {
+ t.Parallel()
+
+ clientMsgCh := make(chan []byte, 1)
+
+ // 模拟 goroutine 退出时关闭 channel
+ go func() {
+ defer close(clientMsgCh)
+ // 模拟读取错误后退出
+ }()
+
+ // 等待 channel 关闭
+ _, ok := <-clientMsgCh
+ require.False(t, ok, "channel should be closed when goroutine exits")
+}
+
+// ---------------------------------------------------------------------------
+// client_preempted 与其他 abort reason 的正交性验证
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_OrthogonalWithPreviousResponseNotFound(t *testing.T) {
+ t.Parallel()
+
+ preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ prevErr := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("not found"),
+ false,
+ )
+
+ // client_preempted 不会被误判为 previous_response_not_found
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(preemptErr))
+ // previous_response_not_found 不会被误判为 client_preempted
+ require.False(t, errors.Is(prevErr, errOpenAIWSClientPreempted))
+}
+
+func TestClientPreempted_OrthogonalWithToolOutputNotFound(t *testing.T) {
+ t.Parallel()
+
+ preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ toolErr := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("tool output not found"),
+ false,
+ )
+
+ require.False(t, isOpenAIWSIngressToolOutputNotFound(preemptErr))
+ require.False(t, errors.Is(toolErr, errOpenAIWSClientPreempted))
+}
+
+func TestClientPreempted_OrthogonalWithUpstreamError(t *testing.T) {
+ t.Parallel()
+
+ preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ upstreamErr := wrapOpenAIWSIngressTurnError(
+ "upstream_error_event",
+ errors.New("upstream error"),
+ false,
+ )
+
+ require.False(t, isOpenAIWSIngressUpstreamErrorEvent(preemptErr))
+ require.False(t, errors.Is(upstreamErr, errOpenAIWSClientPreempted))
+}
+
+func TestClientPreempted_OrthogonalWithContinuationUnavailable(t *testing.T) {
+ t.Parallel()
+
+ preemptErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+ require.False(t, isOpenAIWSContinuationUnavailableCloseError(preemptErr))
+}
+
+func TestClientPreempted_NotClientDisconnect(t *testing.T) {
+ t.Parallel()
+
+ require.False(t, isOpenAIWSClientDisconnectError(errOpenAIWSClientPreempted),
+ "client_preempted should not be classified as client disconnect")
+}
+
+// ---------------------------------------------------------------------------
+// recordOpenAIWSTurnAbort 指标兼容性测试
+// ---------------------------------------------------------------------------
+
+func TestClientPreempted_RecordAbortArgs(t *testing.T) {
+ t.Parallel()
+
+ // 验证 classify 返回的 (reason, expected) 值与 recordOpenAIWSTurnAbort 兼容
+ turnErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted",
+ errOpenAIWSClientPreempted,
+ false,
+ nil,
+ )
+
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ require.Equal(t, openAIWSIngressTurnAbortReasonClientPreempted, reason)
+ require.True(t, expected)
+
+ // expected=true 表示这是预期行为,不应触发告警
+ assert.True(t, expected, "client_preempted 应标记为 expected,不触发告警")
+}
+
+// ---------------------------------------------------------------------------
+// shouldFlushOpenAIWSBufferedEventsOnError 与 client_preempted 场景
+// ---------------------------------------------------------------------------
+
+func TestShouldFlushBufferedEvents_ClientPreempted(t *testing.T) {
+ t.Parallel()
+
+ // client_preempted 场景下 clientDisconnected=false(客户端仍在),
+ // 是否 flush 取决于 reqStream 和 wroteDownstream
+ tests := []struct {
+ name string
+ reqStream bool
+ wroteDownstream bool
+ wantFlush bool
+ }{
+ {"stream_wrote", true, true, true},
+ {"stream_not_wrote", true, false, false},
+ {"not_stream_wrote", false, true, false},
+ {"not_stream_not_wrote", false, false, false},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := shouldFlushOpenAIWSBufferedEventsOnError(tt.reqStream, tt.wroteDownstream, false)
+ require.Equal(t, tt.wantFlush, got)
+ })
+ }
+}
diff --git a/backend/internal/service/openai_ws_common.go b/backend/internal/service/openai_ws_common.go
new file mode 100644
index 000000000..fa6403911
--- /dev/null
+++ b/backend/internal/service/openai_ws_common.go
@@ -0,0 +1,54 @@
+package service
+
+import (
+ "errors"
+ "fmt"
+ "net/http"
+ "time"
+)
+
+var (
+ errOpenAIWSConnClosed = errors.New("openai ws connection closed")
+ errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full")
+ errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable")
+)
+
+const (
+ openAIWSConnHealthCheckTO = 2 * time.Second
+)
+
+type openAIWSDialError struct {
+ StatusCode int
+ ResponseHeaders http.Header
+ Err error
+}
+
+func (e *openAIWSDialError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.StatusCode > 0 {
+ return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err)
+ }
+ return fmt.Sprintf("openai ws dial failed: %v", e.Err)
+}
+
+func (e *openAIWSDialError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.Err
+}
+
+func cloneHeader(h http.Header) http.Header {
+ if h == nil {
+ return nil
+ }
+ cloned := make(http.Header, len(h))
+ for k, values := range h {
+ copied := make([]string, 0, len(values))
+ copied = append(copied, values...)
+ cloned[k] = copied
+ }
+ return cloned
+}
diff --git a/backend/internal/service/openai_ws_fallback_test.go b/backend/internal/service/openai_ws_fallback_test.go
index ce06f6a21..fdc8efa55 100644
--- a/backend/internal/service/openai_ws_fallback_test.go
+++ b/backend/internal/service/openai_ws_fallback_test.go
@@ -3,12 +3,15 @@ package service
import (
"context"
"errors"
+ "io"
"net/http"
+ "net/http/httptest"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -108,6 +111,154 @@ func TestClassifyOpenAIWSErrorEvent(t *testing.T) {
reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`))
require.Equal(t, "previous_response_not_found", reason)
require.True(t, recoverable)
+
+ // tool_output_not_found: 用户按 ESC 取消 function_call 后重新发送消息
+ reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"type":"invalid_request_error","message":"No tool output found for function call call_zXKPiNecBmIAoKeW9o2pNMvo.","param":"input"}}`))
+ require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason)
+ require.True(t, recoverable)
+
+ reason, recoverable = classifyOpenAIWSErrorEventFromRaw("", "invalid_request_error", "No tool output found for function call call_abc123.")
+ require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason)
+ require.True(t, recoverable)
+
+ reason, recoverable = classifyOpenAIWSErrorEventFromRaw(
+ "",
+ "invalid_request_error",
+ "No tool call found for function call output with call_id call_abc123.",
+ )
+ require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason)
+ require.True(t, recoverable)
+
+ // reasoning orphaned items should reuse tool_output_not_found recovery path.
+ reason, recoverable = classifyOpenAIWSErrorEventFromRaw(
+ "",
+ "invalid_request_error",
+ "Item 'rs_xxx' of type 'reasoning' was provided without its required following item.",
+ )
+ require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason)
+ require.True(t, recoverable)
+
+ reason, recoverable = classifyOpenAIWSErrorEventFromRaw(
+ "",
+ "invalid_request_error",
+ "Item 'rs_xxx' of type 'reasoning' was provided without its required preceding item.",
+ )
+ require.Equal(t, openAIWSIngressStageToolOutputNotFound, reason)
+ require.True(t, recoverable)
+}
+
+func TestClassifyOpenAIWSErrorEventFromRaw_AllBranches(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ codeRaw string
+ errTypeRaw string
+ msgRaw string
+ wantReason string
+ wantRecover bool
+ }{
+ {
+ name: "code_upgrade_required",
+ codeRaw: "upgrade_required",
+ wantReason: "upgrade_required",
+ wantRecover: true,
+ },
+ {
+ name: "code_ws_unsupported",
+ codeRaw: "websocket_not_supported",
+ wantReason: "ws_unsupported",
+ wantRecover: true,
+ },
+ {
+ name: "code_ws_connection_limit",
+ codeRaw: "websocket_connection_limit_reached",
+ wantReason: "ws_connection_limit_reached",
+ wantRecover: true,
+ },
+ {
+ name: "msg_upgrade_required",
+ msgRaw: "status 426 upgrade required",
+ wantReason: "upgrade_required",
+ wantRecover: true,
+ },
+ {
+ name: "err_type_upgrade",
+ errTypeRaw: "gateway_upgrade_error",
+ wantReason: "upgrade_required",
+ wantRecover: true,
+ },
+ {
+ name: "msg_ws_unsupported",
+ msgRaw: "websocket is unsupported in this region",
+ wantReason: "ws_unsupported",
+ wantRecover: true,
+ },
+ {
+ name: "msg_ws_connection_limit",
+ msgRaw: "websocket connection limit exceeded",
+ wantReason: "ws_connection_limit_reached",
+ wantRecover: true,
+ },
+ {
+ name: "msg_previous_response_not_found_variant",
+ msgRaw: "previous response is not found",
+ wantReason: "previous_response_not_found",
+ wantRecover: true,
+ },
+ {
+ name: "msg_no_tool_output",
+ msgRaw: "No tool output found for function call call_abc.",
+ wantReason: openAIWSIngressStageToolOutputNotFound,
+ wantRecover: true,
+ },
+ {
+ name: "msg_no_tool_call_for_function_call_output",
+ msgRaw: "No tool call found for function call output with call_id call_abc.",
+ wantReason: openAIWSIngressStageToolOutputNotFound,
+ wantRecover: true,
+ },
+ {
+ name: "msg_reasoning_missing_following",
+ msgRaw: "Item 'rs_xxx' of type 'reasoning' was provided without its required following item.",
+ wantReason: openAIWSIngressStageToolOutputNotFound,
+ wantRecover: true,
+ },
+ {
+ name: "msg_reasoning_missing_preceding",
+ msgRaw: "Item 'rs_xxx' of type 'reasoning' was provided without its required preceding item.",
+ wantReason: openAIWSIngressStageToolOutputNotFound,
+ wantRecover: true,
+ },
+ {
+ name: "server_error_by_type",
+ errTypeRaw: "server_error",
+ wantReason: "upstream_error_event",
+ wantRecover: true,
+ },
+ {
+ name: "server_error_by_code",
+ codeRaw: "server_error",
+ wantReason: "upstream_error_event",
+ wantRecover: true,
+ },
+ {
+ name: "unknown_event_error",
+ codeRaw: "other",
+ errTypeRaw: "other",
+ msgRaw: "other",
+ wantReason: "event_error",
+ wantRecover: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ reason, recoverable := classifyOpenAIWSErrorEventFromRaw(tt.codeRaw, tt.errTypeRaw, tt.msgRaw)
+ require.Equal(t, tt.wantReason, reason)
+ require.Equal(t, tt.wantRecover, recoverable)
+ })
+ }
}
func TestClassifyOpenAIWSReconnectReason(t *testing.T) {
@@ -197,12 +348,39 @@ func TestOpenAIWSRetryTotalBudget(t *testing.T) {
require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget())
}
+func TestOpenAIWSRetryContextError(t *testing.T) {
+ require.NoError(t, openAIWSRetryContextError(nil))
+ require.NoError(t, openAIWSRetryContextError(context.Background()))
+
+ canceledCtx, cancel := context.WithCancel(context.Background())
+ cancel()
+ err := openAIWSRetryContextError(canceledCtx)
+ require.Error(t, err)
+ require.ErrorIs(t, err, context.Canceled)
+
+ var fallbackErr *openAIWSFallbackError
+ require.ErrorAs(t, err, &fallbackErr)
+ require.Equal(t, "retry_context_canceled", fallbackErr.Reason)
+}
+
func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) {
+ require.Equal(t, "service_restart", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusServiceRestart}))
+ require.Equal(t, "try_again_later", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusTryAgainLater}))
require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation}))
require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig}))
require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io")))
}
+func TestClassifyOpenAIWSIngressReadErrorClass(t *testing.T) {
+ require.Equal(t, "unknown", classifyOpenAIWSIngressReadErrorClass(nil))
+ require.Equal(t, "context_canceled", classifyOpenAIWSIngressReadErrorClass(context.Canceled))
+ require.Equal(t, "deadline_exceeded", classifyOpenAIWSIngressReadErrorClass(context.DeadlineExceeded))
+ require.Equal(t, "service_restart", classifyOpenAIWSIngressReadErrorClass(coderws.CloseError{Code: coderws.StatusServiceRestart}))
+ require.Equal(t, "try_again_later", classifyOpenAIWSIngressReadErrorClass(coderws.CloseError{Code: coderws.StatusTryAgainLater}))
+ require.Equal(t, "upstream_closed", classifyOpenAIWSIngressReadErrorClass(io.EOF))
+ require.Equal(t, "unknown", classifyOpenAIWSIngressReadErrorClass(errors.New("tls handshake timeout")))
+}
+
func TestOpenAIWSStoreDisabledConnMode(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
@@ -239,6 +417,117 @@ func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) {
require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal)
}
+func TestWriteOpenAIWSV1UnsupportedResponse_TracksOps(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+ rec := httptest.NewRecorder()
+ c, _ := gin.CreateTestContext(rec)
+ c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
+
+ svc := &OpenAIGatewayService{}
+ account := &Account{
+ ID: 42,
+ Name: "acc-ws-v1",
+ Platform: PlatformOpenAI,
+ }
+
+ err := svc.writeOpenAIWSV1UnsupportedResponse(c, account)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "openai ws v1 is temporarily unsupported")
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "invalid_request_error")
+ require.Contains(t, rec.Body.String(), "temporarily unsupported")
+
+ rawStatus, ok := c.Get(OpsUpstreamStatusCodeKey)
+ require.True(t, ok)
+ require.Equal(t, http.StatusBadRequest, rawStatus)
+
+ rawMsg, ok := c.Get(OpsUpstreamErrorMessageKey)
+ require.True(t, ok)
+ require.Equal(t, "openai ws v1 is temporarily unsupported; use ws v2", rawMsg)
+
+ rawEvents, ok := c.Get(OpsUpstreamErrorsKey)
+ require.True(t, ok)
+ events, ok := rawEvents.([]*OpsUpstreamErrorEvent)
+ require.True(t, ok)
+ require.Len(t, events, 1)
+ require.Equal(t, account.ID, events[0].AccountID)
+ require.Equal(t, account.Platform, events[0].Platform)
+ require.Equal(t, http.StatusBadRequest, events[0].UpstreamStatusCode)
+ require.Equal(t, "ws_error", events[0].Kind)
+}
+
+func TestIsOpenAIWSStreamWriteDisconnectError(t *testing.T) {
+ require.False(t, isOpenAIWSStreamWriteDisconnectError(nil, nil))
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ require.True(t, isOpenAIWSStreamWriteDisconnectError(errors.New("writer failed"), ctx))
+
+ require.True(t, isOpenAIWSStreamWriteDisconnectError(errors.New("broken pipe"), context.Background()))
+ require.True(t, isOpenAIWSStreamWriteDisconnectError(io.EOF, context.Background()))
+
+ require.False(t, isOpenAIWSStreamWriteDisconnectError(errors.New("template execute failed"), context.Background()))
+}
+
+func TestShouldFlushOpenAIWSBufferedEventsOnError(t *testing.T) {
+ require.True(t, shouldFlushOpenAIWSBufferedEventsOnError(true, true, false))
+ require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(true, false, false))
+ require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(true, true, true))
+ require.False(t, shouldFlushOpenAIWSBufferedEventsOnError(false, true, false))
+}
+
+func TestCloneOpenAIWSJSONRawString(t *testing.T) {
+ require.Nil(t, cloneOpenAIWSJSONRawString(""))
+ require.Nil(t, cloneOpenAIWSJSONRawString(" "))
+
+ raw := `{"id":"resp_1","type":"response"}`
+ cloned := cloneOpenAIWSJSONRawString(raw)
+ require.Equal(t, raw, string(cloned))
+ require.Equal(t, len(raw), len(cloned))
+}
+
+func TestOpenAIWSAbortMetricsSnapshot(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonUpstreamError, true)
+ svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonUpstreamError, true)
+ svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonWriteUpstream, false)
+ svc.recordOpenAIWSTurnAbortRecovered()
+
+ snapshot := svc.SnapshotOpenAIWSAbortMetrics()
+ require.Equal(t, int64(1), snapshot.TurnAbortRecoveredTotal)
+
+ getTotal := func(reason string, expected bool) int64 {
+ for _, point := range snapshot.TurnAbortTotal {
+ if point.Reason == reason && point.Expected == expected {
+ return point.Total
+ }
+ }
+ return 0
+ }
+ require.Equal(t, int64(2), getTotal(string(openAIWSIngressTurnAbortReasonUpstreamError), true))
+ require.Equal(t, int64(1), getTotal(string(openAIWSIngressTurnAbortReasonWriteUpstream), false))
+}
+
+func TestOpenAIWSPerformanceMetricsSnapshot_ContainsAbortMetrics(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ svc.recordOpenAIWSTurnAbort(openAIWSIngressTurnAbortReasonClientClosed, true)
+ svc.recordOpenAIWSTurnAbortRecovered()
+
+ snapshot := svc.SnapshotOpenAIWSPerformanceMetrics()
+ require.Equal(t, int64(1), snapshot.Abort.TurnAbortRecoveredTotal)
+
+ found := false
+ for _, point := range snapshot.Abort.TurnAbortTotal {
+ if point.Reason == string(openAIWSIngressTurnAbortReasonClientClosed) && point.Expected {
+ require.Equal(t, int64(1), point.Total)
+ found = true
+ break
+ }
+ }
+ require.True(t, found)
+}
+
func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) {
svc := &OpenAIGatewayService{cfg: &config.Config{}}
diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go
index 74ba472f4..08ce979f1 100644
--- a/backend/internal/service/openai_ws_forwarder.go
+++ b/backend/internal/service/openai_ws_forwarder.go
@@ -6,12 +6,12 @@ import (
"encoding/json"
"errors"
"fmt"
- "io"
"math/rand"
- "net"
"net/http"
"net/url"
+ "runtime/debug"
"sort"
+ "strconv"
"strings"
"time"
@@ -26,8 +26,10 @@ import (
)
const (
- openAIWSBetaV1Value = "responses_websockets=2026-02-04"
- openAIWSBetaV2Value = "responses_websockets=2026-02-06"
+ openAIWSBetaV1Value = "responses_websockets=2026-02-04"
+ openAIWSBetaV2Value = "responses_websockets=2026-02-06"
+ openAIWSConnIDPrefixLegacy = "oa_ws_"
+ openAIWSConnIDPrefixCtx = "ctxws_"
openAIWSTurnStateHeader = "x-codex-turn-state"
openAIWSTurnMetadataHeader = "x-codex-turn-metadata"
@@ -55,879 +57,97 @@ const (
openAIWSStoreDisabledConnModeOff = "off"
openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found"
+ openAIWSIngressStageToolOutputNotFound = "tool_output_not_found"
openAIWSMaxPrevResponseIDDeletePasses = 8
+ openAIWSIngressReplayInputMaxBytes = 512 * 1024
+ openAIWSContinuationUnavailableReason = "upstream continuation connection is unavailable; please restart the conversation"
+ openAIWSAutoAbortedToolOutputValue = `{"error":"tool call aborted by gateway"}`
+ openAIWSClientReadIdleTimeoutDefault = 30 * time.Minute
+ openAIWSIngressClientDisconnectDrainTimeout = 5 * time.Second
+ openAIWSUpstreamPumpInfoMinAlive = 100 * time.Millisecond
)
-var openAIWSLogValueReplacer = strings.NewReplacer(
- "error", "err",
- "fallback", "fb",
- "warning", "warnx",
- "failed", "fail",
-)
-
-var openAIWSIngressPreflightPingIdle = 20 * time.Second
-
-// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。
-type openAIWSFallbackError struct {
- Reason string
- Err error
-}
-
-func (e *openAIWSFallbackError) Error() string {
- if e == nil {
- return ""
- }
- if e.Err == nil {
- return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason))
- }
- return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err)
-}
-
-func (e *openAIWSFallbackError) Unwrap() error {
- if e == nil {
- return nil
- }
- return e.Err
-}
-
-func wrapOpenAIWSFallback(reason string, err error) error {
- return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err}
-}
-
-// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。
-type OpenAIWSClientCloseError struct {
- statusCode coderws.StatusCode
- reason string
- err error
-}
-
-type openAIWSIngressTurnError struct {
- stage string
- cause error
- wroteDownstream bool
-}
-
-func (e *openAIWSIngressTurnError) Error() string {
- if e == nil {
- return ""
- }
- if e.cause == nil {
- return strings.TrimSpace(e.stage)
- }
- return e.cause.Error()
-}
-
-func (e *openAIWSIngressTurnError) Unwrap() error {
- if e == nil {
- return nil
- }
- return e.cause
-}
-
-func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error {
- if cause == nil {
- return nil
- }
- return &openAIWSIngressTurnError{
- stage: strings.TrimSpace(stage),
- cause: cause,
- wroteDownstream: wroteDownstream,
- }
-}
-
-func isOpenAIWSIngressTurnRetryable(err error) bool {
- var turnErr *openAIWSIngressTurnError
- if !errors.As(err, &turnErr) || turnErr == nil {
- return false
- }
- if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) {
- return false
- }
- if turnErr.wroteDownstream {
- return false
- }
- switch turnErr.stage {
- case "write_upstream", "read_upstream":
- return true
- default:
- return false
- }
-}
-
-func openAIWSIngressTurnRetryReason(err error) string {
- var turnErr *openAIWSIngressTurnError
- if !errors.As(err, &turnErr) || turnErr == nil {
- return "unknown"
- }
- if turnErr.stage == "" {
- return "unknown"
- }
- return turnErr.stage
-}
-
-func isOpenAIWSIngressPreviousResponseNotFound(err error) bool {
- var turnErr *openAIWSIngressTurnError
- if !errors.As(err, &turnErr) || turnErr == nil {
- return false
- }
- if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound {
- return false
- }
- return !turnErr.wroteDownstream
-}
-
-// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。
-func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error {
- return &OpenAIWSClientCloseError{
- statusCode: statusCode,
- reason: strings.TrimSpace(reason),
- err: err,
- }
-}
-
-func (e *OpenAIWSClientCloseError) Error() string {
- if e == nil {
- return ""
- }
- if e.err == nil {
- return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason))
- }
- return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err)
-}
-
-func (e *OpenAIWSClientCloseError) Unwrap() error {
- if e == nil {
- return nil
- }
- return e.err
-}
-
-func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode {
- if e == nil {
- return coderws.StatusInternalError
- }
- return e.statusCode
-}
-
-func (e *OpenAIWSClientCloseError) Reason() string {
- if e == nil {
- return ""
- }
- return strings.TrimSpace(e.reason)
-}
-
-// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
-type OpenAIWSIngressHooks struct {
- BeforeTurn func(turn int) error
- AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
-}
-
-func normalizeOpenAIWSLogValue(value string) string {
- trimmed := strings.TrimSpace(value)
- if trimmed == "" {
- return "-"
- }
- return openAIWSLogValueReplacer.Replace(trimmed)
-}
-
-func truncateOpenAIWSLogValue(value string, maxLen int) string {
- normalized := normalizeOpenAIWSLogValue(value)
- if normalized == "-" || maxLen <= 0 {
- return normalized
- }
- if len(normalized) <= maxLen {
- return normalized
- }
- return normalized[:maxLen] + "..."
-}
-
-func openAIWSHeaderValueForLog(headers http.Header, key string) string {
- if headers == nil {
- return "-"
- }
- return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen)
-}
-
-func hasOpenAIWSHeader(headers http.Header, key string) bool {
- if headers == nil {
- return false
- }
- return strings.TrimSpace(headers.Get(key)) != ""
-}
-
-type openAIWSSessionHeaderResolution struct {
- SessionID string
- ConversationID string
- SessionSource string
- ConversationSource string
-}
-
-func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution {
- resolution := openAIWSSessionHeaderResolution{
- SessionSource: "none",
- ConversationSource: "none",
- }
- if c != nil && c.Request != nil {
- if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" {
- resolution.SessionID = sessionID
- resolution.SessionSource = "header_session_id"
- }
- if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" {
- resolution.ConversationID = conversationID
- resolution.ConversationSource = "header_conversation_id"
- if resolution.SessionID == "" {
- resolution.SessionID = conversationID
- resolution.SessionSource = "header_conversation_id"
- }
- }
- }
-
- cacheKey := strings.TrimSpace(promptCacheKey)
- if cacheKey != "" {
- if resolution.SessionID == "" {
- resolution.SessionID = cacheKey
- resolution.SessionSource = "prompt_cache_key"
- }
- }
- return resolution
-}
-
-func shouldLogOpenAIWSEvent(idx int, eventType string) bool {
- if idx <= openAIWSEventLogHeadLimit {
- return true
- }
- if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 {
- return true
- }
- if eventType == "error" || isOpenAIWSTerminalEvent(eventType) {
- return true
- }
- return false
-}
-
-func shouldLogOpenAIWSBufferedEvent(idx int) bool {
- if idx <= openAIWSBufferLogHeadLimit {
- return true
- }
- if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 {
- return true
- }
- return false
-}
-
-func openAIWSEventMayContainModel(eventType string) bool {
- switch eventType {
- case "response.created",
- "response.in_progress",
- "response.completed",
- "response.done",
- "response.failed",
- "response.incomplete",
- "response.cancelled",
- "response.canceled":
- return true
- default:
- trimmed := strings.TrimSpace(eventType)
- if trimmed == eventType {
- return false
- }
- switch trimmed {
- case "response.created",
- "response.in_progress",
- "response.completed",
- "response.done",
- "response.failed",
- "response.incomplete",
- "response.cancelled",
- "response.canceled":
- return true
- default:
- return false
- }
- }
-}
-
-func openAIWSEventMayContainToolCalls(eventType string) bool {
- eventType = strings.TrimSpace(eventType)
- if eventType == "" {
- return false
- }
- if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") {
- return true
- }
- switch eventType {
- case "response.output_item.added", "response.output_item.done", "response.completed", "response.done":
- return true
- default:
- return false
- }
-}
-
-func openAIWSEventShouldParseUsage(eventType string) bool {
- return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed"
-}
-
-func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
- if len(message) == 0 {
- return "", "", gjson.Result{}
- }
- values := gjson.GetManyBytes(message, "type", "response.id", "id", "response")
- eventType = strings.TrimSpace(values[0].String())
- if id := strings.TrimSpace(values[1].String()); id != "" {
- responseID = id
- } else {
- responseID = strings.TrimSpace(values[2].String())
- }
- return eventType, responseID, values[3]
-}
-
-func openAIWSMessageLikelyContainsToolCalls(message []byte) bool {
- if len(message) == 0 {
- return false
- }
- return bytes.Contains(message, []byte(`"tool_calls"`)) ||
- bytes.Contains(message, []byte(`"tool_call"`)) ||
- bytes.Contains(message, []byte(`"function_call"`))
-}
-
-func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) {
- if usage == nil || len(message) == 0 {
- return
- }
- values := gjson.GetManyBytes(
- message,
- "response.usage.input_tokens",
- "response.usage.output_tokens",
- "response.usage.input_tokens_details.cached_tokens",
- )
- usage.InputTokens = int(values[0].Int())
- usage.OutputTokens = int(values[1].Int())
- usage.CacheReadInputTokens = int(values[2].Int())
-}
-
-func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
- if len(message) == 0 {
- return "", "", ""
- }
- values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message")
- return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String())
-}
-
-func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) {
- code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen)
- errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen)
- errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen)
- return code, errType, errMessage
-}
-
-func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
- if len(message) == 0 {
- return "-", "-", "-"
- }
- return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message))
-}
-
-func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string {
- if len(payload) == 0 {
- return "-"
- }
- type keySize struct {
- Key string
- Size int
- }
- sizes := make([]keySize, 0, len(payload))
- for key, value := range payload {
- size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth)
- sizes = append(sizes, keySize{Key: key, Size: size})
- }
- sort.Slice(sizes, func(i, j int) bool {
- if sizes[i].Size == sizes[j].Size {
- return sizes[i].Key < sizes[j].Key
- }
- return sizes[i].Size > sizes[j].Size
- })
-
- if topN <= 0 || topN > len(sizes) {
- topN = len(sizes)
- }
- parts := make([]string, 0, topN)
- for idx := 0; idx < topN; idx++ {
- item := sizes[idx]
- parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size))
- }
- return strings.Join(parts, ",")
-}
-
-func estimateOpenAIWSPayloadValueSize(value any, depth int) int {
- if depth <= 0 {
- return -1
- }
- switch v := value.(type) {
- case nil:
- return 0
- case string:
- return len(v)
- case []byte:
- return len(v)
- case bool:
- return 1
- case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
- return 8
- case float32, float64:
- return 8
- case map[string]any:
- if len(v) == 0 {
- return 2
- }
- total := 2
- count := 0
- for key, item := range v {
- count++
- if count > openAIWSPayloadSizeEstimateMaxItems {
- return -1
- }
- itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1)
- if itemSize < 0 {
- return -1
- }
- total += len(key) + itemSize + 3
- if total > openAIWSPayloadSizeEstimateMaxBytes {
- return -1
- }
- }
- return total
- case []any:
- if len(v) == 0 {
- return 2
- }
- total := 2
- limit := len(v)
- if limit > openAIWSPayloadSizeEstimateMaxItems {
- return -1
- }
- for i := 0; i < limit; i++ {
- itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1)
- if itemSize < 0 {
- return -1
- }
- total += itemSize + 1
- if total > openAIWSPayloadSizeEstimateMaxBytes {
- return -1
- }
- }
- return total
- default:
- raw, err := json.Marshal(v)
- if err != nil {
- return -1
- }
- if len(raw) > openAIWSPayloadSizeEstimateMaxBytes {
- return -1
- }
- return len(raw)
- }
-}
-
-func openAIWSPayloadString(payload map[string]any, key string) string {
- if len(payload) == 0 {
- return ""
- }
- raw, ok := payload[key]
- if !ok {
- return ""
- }
- switch v := raw.(type) {
- case nil:
- return ""
- case string:
- return strings.TrimSpace(v)
- case []byte:
- return strings.TrimSpace(string(v))
- default:
- return ""
- }
-}
-
-func openAIWSPayloadStringFromRaw(payload []byte, key string) string {
- if len(payload) == 0 || strings.TrimSpace(key) == "" {
- return ""
- }
- return strings.TrimSpace(gjson.GetBytes(payload, key).String())
-}
-
-func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool {
- if len(payload) == 0 || strings.TrimSpace(key) == "" {
- return defaultValue
- }
- value := gjson.GetBytes(payload, key)
- if !value.Exists() {
- return defaultValue
- }
- if value.Type != gjson.True && value.Type != gjson.False {
- return defaultValue
- }
- return value.Bool()
-}
-
-func openAIWSSessionHashesFromID(sessionID string) (string, string) {
- return deriveOpenAISessionHashes(sessionID)
-}
-
-func extractOpenAIWSImageURL(value any) string {
- switch v := value.(type) {
- case string:
- return strings.TrimSpace(v)
- case map[string]any:
- if raw, ok := v["url"].(string); ok {
- return strings.TrimSpace(raw)
- }
- }
- return ""
-}
-
-func summarizeOpenAIWSInput(input any) string {
- items, ok := input.([]any)
- if !ok || len(items) == 0 {
- return "-"
- }
-
- itemCount := len(items)
- textChars := 0
- imageDataURLs := 0
- imageDataURLChars := 0
- imageRemoteURLs := 0
-
- handleContentItem := func(contentItem map[string]any) {
- contentType, _ := contentItem["type"].(string)
- switch strings.TrimSpace(contentType) {
- case "input_text", "output_text", "text":
- if text, ok := contentItem["text"].(string); ok {
- textChars += len(text)
- }
- case "input_image":
- imageURL := extractOpenAIWSImageURL(contentItem["image_url"])
- if imageURL == "" {
- return
- }
- if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
- imageDataURLs++
- imageDataURLChars += len(imageURL)
- return
- }
- imageRemoteURLs++
- }
- }
-
- handleInputItem := func(inputItem map[string]any) {
- if content, ok := inputItem["content"].([]any); ok {
- for _, rawContent := range content {
- contentItem, ok := rawContent.(map[string]any)
- if !ok {
- continue
- }
- handleContentItem(contentItem)
- }
- return
- }
-
- itemType, _ := inputItem["type"].(string)
- switch strings.TrimSpace(itemType) {
- case "input_text", "output_text", "text":
- if text, ok := inputItem["text"].(string); ok {
- textChars += len(text)
- }
- case "input_image":
- imageURL := extractOpenAIWSImageURL(inputItem["image_url"])
- if imageURL == "" {
- return
- }
- if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
- imageDataURLs++
- imageDataURLChars += len(imageURL)
- return
- }
- imageRemoteURLs++
- }
- }
-
- for _, rawItem := range items {
- inputItem, ok := rawItem.(map[string]any)
- if !ok {
- continue
- }
- handleInputItem(inputItem)
- }
-
- return fmt.Sprintf(
- "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d",
- itemCount,
- textChars,
- imageDataURLs,
- imageDataURLChars,
- imageRemoteURLs,
- )
-}
-
-func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) {
- if len(payload) == 0 || strings.TrimSpace(key) == "" {
- return
- }
- if _, exists := payload[key]; !exists {
- return
- }
- delete(payload, key)
- *removed = append(*removed, key)
-}
-
-// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段,
-// 避免重试成功却改变原始请求语义。
-// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。
-func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) {
- if len(payload) == 0 {
- return "empty", nil
- }
- if attempt <= 1 {
- return "full", nil
- }
-
- removed := make([]string, 0, 2)
- if attempt >= 2 {
- dropOpenAIWSPayloadKey(payload, "include", &removed)
- }
-
- if len(removed) == 0 {
- return "full", nil
- }
- sort.Strings(removed)
- return "trim_optional_fields", removed
-}
-
-func logOpenAIWSModeInfo(format string, args ...any) {
- logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
-}
-
-func isOpenAIWSModeDebugEnabled() bool {
- return logger.L().Core().Enabled(zap.DebugLevel)
-}
-
-func logOpenAIWSModeDebug(format string, args ...any) {
- if !isOpenAIWSModeDebugEnabled() {
- return
- }
- logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
-}
-
-func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) {
- if err == nil {
- return
- }
- logger.L().Warn(
- "openai.ws_bind_response_account_failed",
- zap.Int64("group_id", groupID),
- zap.Int64("account_id", accountID),
- zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)),
- zap.Error(err),
- )
-}
-
-func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) {
- if err == nil {
- return "-", "-"
- }
- statusCode := coderws.CloseStatus(err)
- if statusCode == -1 {
- return "-", "-"
- }
- closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
- closeReason := "-"
- var closeErr coderws.CloseError
- if errors.As(err, &closeErr) {
- reasonText := strings.TrimSpace(closeErr.Reason)
- if reasonText != "" {
- closeReason = normalizeOpenAIWSLogValue(reasonText)
- }
- }
- return normalizeOpenAIWSLogValue(closeStatus), closeReason
-}
+type openAIWSIngressTurnAbortReason string
-func unwrapOpenAIWSDialBaseError(err error) error {
- if err == nil {
- return nil
- }
- var dialErr *openAIWSDialError
- if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil {
- return dialErr.Err
- }
- return err
-}
+const (
+ openAIWSIngressTurnAbortReasonUnknown openAIWSIngressTurnAbortReason = "unknown"
+
+ openAIWSIngressTurnAbortReasonClientClosed openAIWSIngressTurnAbortReason = "client_closed"
+ openAIWSIngressTurnAbortReasonContextCanceled openAIWSIngressTurnAbortReason = "ctx_canceled"
+ openAIWSIngressTurnAbortReasonContextDeadline openAIWSIngressTurnAbortReason = "ctx_deadline_exceeded"
+ openAIWSIngressTurnAbortReasonPreviousResponse openAIWSIngressTurnAbortReason = openAIWSIngressStagePreviousResponseNotFound
+ openAIWSIngressTurnAbortReasonToolOutput openAIWSIngressTurnAbortReason = openAIWSIngressStageToolOutputNotFound
+ openAIWSIngressTurnAbortReasonUpstreamError openAIWSIngressTurnAbortReason = "upstream_error_event"
+ openAIWSIngressTurnAbortReasonWriteUpstream openAIWSIngressTurnAbortReason = "write_upstream"
+ openAIWSIngressTurnAbortReasonReadUpstream openAIWSIngressTurnAbortReason = "read_upstream"
+ openAIWSIngressTurnAbortReasonWriteClient openAIWSIngressTurnAbortReason = "write_client"
+ openAIWSIngressTurnAbortReasonContinuationUnavailable openAIWSIngressTurnAbortReason = "continuation_unavailable"
+ openAIWSIngressTurnAbortReasonClientPreempted openAIWSIngressTurnAbortReason = "client_preempted"
+ openAIWSIngressTurnAbortReasonUpstreamRestart openAIWSIngressTurnAbortReason = "upstream_restart"
+)
-func openAIWSDialRespHeaderForLog(err error, key string) string {
- var dialErr *openAIWSDialError
- if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil {
- return "-"
- }
- return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen)
-}
+type openAIWSIngressTurnAbortDisposition string
-func classifyOpenAIWSDialError(err error) string {
- if err == nil {
- return "-"
- }
- baseErr := unwrapOpenAIWSDialBaseError(err)
- if baseErr == nil {
- return "-"
- }
- if errors.Is(baseErr, context.DeadlineExceeded) {
- return "ctx_deadline_exceeded"
- }
- if errors.Is(baseErr, context.Canceled) {
- return "ctx_canceled"
- }
- var netErr net.Error
- if errors.As(baseErr, &netErr) && netErr.Timeout() {
- return "net_timeout"
- }
- if status := coderws.CloseStatus(baseErr); status != -1 {
- return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status)))
- }
- message := strings.ToLower(strings.TrimSpace(baseErr.Error()))
- switch {
- case strings.Contains(message, "handshake not finished"):
- return "handshake_not_finished"
- case strings.Contains(message, "bad handshake"):
- return "bad_handshake"
- case strings.Contains(message, "connection refused"):
- return "connection_refused"
- case strings.Contains(message, "no such host"):
- return "dns_not_found"
- case strings.Contains(message, "tls"):
- return "tls_error"
- case strings.Contains(message, "i/o timeout"):
- return "io_timeout"
- case strings.Contains(message, "context deadline exceeded"):
- return "ctx_deadline_exceeded"
- default:
- return "dial_error"
- }
-}
+const (
+ openAIWSIngressTurnAbortDispositionFailRequest openAIWSIngressTurnAbortDisposition = "fail_request"
+ openAIWSIngressTurnAbortDispositionContinueTurn openAIWSIngressTurnAbortDisposition = "continue_turn"
+ openAIWSIngressTurnAbortDispositionCloseGracefully openAIWSIngressTurnAbortDisposition = "close_gracefully"
+)
-func summarizeOpenAIWSDialError(err error) (
- statusCode int,
- dialClass string,
- closeStatus string,
- closeReason string,
- respServer string,
- respVia string,
- respCFRay string,
- respRequestID string,
-) {
- dialClass = "-"
- closeStatus = "-"
- closeReason = "-"
- respServer = "-"
- respVia = "-"
- respCFRay = "-"
- respRequestID = "-"
- if err == nil {
- return
- }
- var dialErr *openAIWSDialError
- if errors.As(err, &dialErr) && dialErr != nil {
- statusCode = dialErr.StatusCode
- respServer = openAIWSDialRespHeaderForLog(err, "server")
- respVia = openAIWSDialRespHeaderForLog(err, "via")
- respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray")
- respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id")
- }
- dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err))
- closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err))
- return
+// openAIWSUpstreamPumpEvent 是上游事件读取泵传递给主 goroutine 的消息载体。
+type openAIWSUpstreamPumpEvent struct {
+ message []byte
+ err error
}
-func isOpenAIWSClientDisconnectError(err error) bool {
- if err == nil {
- return false
- }
- if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
- return true
- }
- switch coderws.CloseStatus(err) {
- case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
- return true
- }
- message := strings.ToLower(strings.TrimSpace(err.Error()))
- if message == "" {
- return false
- }
- return strings.Contains(message, "failed to read frame header: eof") ||
- strings.Contains(message, "unexpected eof") ||
- strings.Contains(message, "use of closed network connection") ||
- strings.Contains(message, "connection reset by peer") ||
- strings.Contains(message, "broken pipe")
-}
+const (
+ // openAIWSUpstreamPumpBufferSize 是上游事件读取泵的缓冲 channel 大小。
+ // 缓冲允许上游读取和客户端写入并发执行,吸收客户端写入延迟波动。
+ openAIWSUpstreamPumpBufferSize = 16
+)
-func classifyOpenAIWSReadFallbackReason(err error) string {
- if err == nil {
- return "read_event"
- }
- switch coderws.CloseStatus(err) {
- case coderws.StatusPolicyViolation:
- return "policy_violation"
- case coderws.StatusMessageTooBig:
- return "message_too_big"
- default:
- return "read_event"
- }
-}
+var openAIWSLogValueReplacer = strings.NewReplacer(
+ "error", "err",
+ "fallback", "fb",
+ "warning", "warnx",
+ "failed", "fail",
+)
-func sortedKeys(m map[string]any) []string {
- if len(m) == 0 {
- return nil
- }
- keys := make([]string, 0, len(m))
- for k := range m {
- keys = append(keys, k)
- }
- sort.Strings(keys)
- return keys
-}
+var openAIWSIngressPreflightPingIdle = 20 * time.Second
-func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool {
+func (s *OpenAIGatewayService) getOpenAIWSIngressContextPool() *openAIWSIngressContextPool {
if s == nil {
return nil
}
- s.openaiWSPoolOnce.Do(func() {
- if s.openaiWSPool == nil {
- s.openaiWSPool = newOpenAIWSConnPool(s.cfg)
+ s.openaiWSIngressCtxOnce.Do(func() {
+ if s.openaiWSIngressCtxPool == nil {
+ pool := newOpenAIWSIngressContextPool(s.cfg)
+ // Ensure the scheduler (and its runtime stats) are initialized
+ // before wiring load-aware signals into the context pool.
+ _ = s.getOpenAIAccountScheduler()
+ pool.schedulerStats = s.openaiAccountStats
+ s.openaiWSIngressCtxPool = pool
}
})
- return s.openaiWSPool
-}
-
-func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot {
- pool := s.getOpenAIWSConnPool()
- if pool == nil {
- return OpenAIWSPoolMetricsSnapshot{}
- }
- return pool.SnapshotMetrics()
+ return s.openaiWSIngressCtxPool
}
type OpenAIWSPerformanceMetricsSnapshot struct {
- Pool OpenAIWSPoolMetricsSnapshot `json:"pool"`
Retry OpenAIWSRetryMetricsSnapshot `json:"retry"`
+ Abort OpenAIWSAbortMetricsSnapshot `json:"abort"`
Transport OpenAIWSTransportMetricsSnapshot `json:"transport"`
}
func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot {
- pool := s.getOpenAIWSConnPool()
+ ingressPool := s.getOpenAIWSIngressContextPool()
snapshot := OpenAIWSPerformanceMetricsSnapshot{
Retry: s.SnapshotOpenAIWSRetryMetrics(),
+ Abort: s.SnapshotOpenAIWSAbortMetrics(),
}
- if pool == nil {
+ if ingressPool == nil {
return snapshot
}
- snapshot.Pool = pool.SnapshotMetrics()
- snapshot.Transport = pool.SnapshotTransportMetrics()
+ snapshot.Transport = ingressPool.SnapshotTransportMetrics()
return snapshot
}
@@ -967,6 +187,13 @@ func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration {
return 15 * time.Minute
}
+func (s *OpenAIGatewayService) openAIWSClientReadIdleTimeout() time.Duration {
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds > 0 {
+ return time.Duration(s.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds) * time.Second
+ }
+ return openAIWSClientReadIdleTimeoutDefault
+}
+
func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration {
if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 {
return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second
@@ -1150,6 +377,10 @@ func (s *OpenAIGatewayService) buildOpenAIWSHeaders(
if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) {
headers.Set("user-agent", codexCLIUserAgent)
}
+ if account != nil && account.Type == AccountTypeOAuth && openai.IsCodexCLIRequest(headers.Get("user-agent")) {
+ // 保持 OAuth 握手头的一致性:Codex 风格 UA 必须搭配 codex_cli_rs originator。
+ headers.Set("originator", "codex_cli_rs")
+ }
return headers, sessionResolution
}
@@ -1179,471 +410,108 @@ func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) {
if len(payload) == 0 {
return
}
- metadata := strings.TrimSpace(turnMetadata)
- if metadata == "" {
- return
- }
-
- switch existing := payload["client_metadata"].(type) {
- case map[string]any:
- existing[openAIWSTurnMetadataHeader] = metadata
- payload["client_metadata"] = existing
- case map[string]string:
- next := make(map[string]any, len(existing)+1)
- for k, v := range existing {
- next[k] = v
- }
- next[openAIWSTurnMetadataHeader] = metadata
- payload["client_metadata"] = next
- default:
- payload["client_metadata"] = map[string]any{
- openAIWSTurnMetadataHeader: metadata,
- }
- }
-}
-
-func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool {
- if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() {
- return true
- }
- if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery {
- return true
- }
- return false
-}
-
-func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool {
- if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
- return true
- }
- if len(reqBody) == 0 {
- return false
- }
- rawStore, ok := reqBody["store"]
- if !ok {
- return false
- }
- storeEnabled, ok := rawStore.(bool)
- if !ok {
- return false
- }
- return !storeEnabled
-}
-
-func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool {
- if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
- return true
- }
- if len(reqBody) == 0 {
- return false
- }
- storeValue := gjson.GetBytes(reqBody, "store")
- if !storeValue.Exists() {
- return false
- }
- if storeValue.Type != gjson.True && storeValue.Type != gjson.False {
- return false
- }
- return !storeValue.Bool()
-}
-
-func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string {
- if s == nil || s.cfg == nil {
- return openAIWSStoreDisabledConnModeStrict
- }
- mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode))
- switch mode {
- case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff:
- return mode
- case "":
- // 兼容旧配置:仅配置了布尔开关时按旧语义推导。
- if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
- return openAIWSStoreDisabledConnModeStrict
- }
- return openAIWSStoreDisabledConnModeOff
- default:
- return openAIWSStoreDisabledConnModeStrict
- }
-}
-
-func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool {
- switch mode {
- case openAIWSStoreDisabledConnModeOff:
- return false
- case openAIWSStoreDisabledConnModeAdaptive:
- reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_")
- switch reason {
- case "policy_violation", "message_too_big", "auth_failed", "write_request", "write":
- return true
- default:
- return false
- }
- default:
- return true
- }
-}
-
-func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) {
- return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes)
-}
-
-func dropPreviousResponseIDFromRawPayloadWithDeleteFn(
- payload []byte,
- deleteFn func([]byte, string) ([]byte, error),
-) ([]byte, bool, error) {
- if len(payload) == 0 {
- return payload, false, nil
- }
- if !gjson.GetBytes(payload, "previous_response_id").Exists() {
- return payload, false, nil
- }
- if deleteFn == nil {
- deleteFn = sjson.DeleteBytes
- }
-
- updated := payload
- for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses &&
- gjson.GetBytes(updated, "previous_response_id").Exists(); i++ {
- next, err := deleteFn(updated, "previous_response_id")
- if err != nil {
- return payload, false, err
- }
- updated = next
- }
- return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil
-}
-
-func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) {
- normalizedPrevID := strings.TrimSpace(previousResponseID)
- if len(payload) == 0 || normalizedPrevID == "" {
- return payload, nil
- }
- updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID)
- if err == nil {
- return updated, nil
- }
-
- var reqBody map[string]any
- if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil {
- return nil, err
- }
- reqBody["previous_response_id"] = normalizedPrevID
- rebuilt, marshalErr := json.Marshal(reqBody)
- if marshalErr != nil {
- return nil, marshalErr
- }
- return rebuilt, nil
-}
-
-func shouldInferIngressFunctionCallOutputPreviousResponseID(
- storeDisabled bool,
- turn int,
- hasFunctionCallOutput bool,
- currentPreviousResponseID string,
- expectedPreviousResponseID string,
-) bool {
- if !storeDisabled || turn <= 1 || !hasFunctionCallOutput {
- return false
- }
- if strings.TrimSpace(currentPreviousResponseID) != "" {
- return false
- }
- return strings.TrimSpace(expectedPreviousResponseID) != ""
-}
-
-func alignStoreDisabledPreviousResponseID(
- payload []byte,
- expectedPreviousResponseID string,
-) ([]byte, bool, error) {
- if len(payload) == 0 {
- return payload, false, nil
- }
- expected := strings.TrimSpace(expectedPreviousResponseID)
- if expected == "" {
- return payload, false, nil
- }
- current := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
- if current == "" || current == expected {
- return payload, false, nil
- }
-
- withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
- if dropErr != nil {
- return payload, false, dropErr
- }
- if !removed {
- return payload, false, nil
- }
- updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected)
- if setErr != nil {
- return payload, false, setErr
- }
- return updated, true, nil
-}
-
-func cloneOpenAIWSPayloadBytes(payload []byte) []byte {
- if len(payload) == 0 {
- return nil
- }
- cloned := make([]byte, len(payload))
- copy(cloned, payload)
- return cloned
-}
-
-func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage {
- if items == nil {
- return nil
- }
- cloned := make([]json.RawMessage, 0, len(items))
- for idx := range items {
- cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx])))
- }
- return cloned
-}
-
-func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) {
- trimmed := bytes.TrimSpace(raw)
- if len(trimmed) == 0 {
- return nil, errors.New("json is empty")
- }
- var decoded any
- if err := json.Unmarshal(trimmed, &decoded); err != nil {
- return nil, err
- }
- return json.Marshal(decoded)
-}
-
-func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte {
- normalized, err := normalizeOpenAIWSJSONForCompare(raw)
- if err != nil {
- return bytes.TrimSpace(raw)
- }
- return normalized
-}
-
-func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) {
- if len(payload) == 0 {
- return nil, errors.New("payload is empty")
- }
- var decoded map[string]any
- if err := json.Unmarshal(payload, &decoded); err != nil {
- return nil, err
- }
- delete(decoded, "input")
- delete(decoded, "previous_response_id")
- return json.Marshal(decoded)
-}
-
-func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) {
- if len(payload) == 0 {
- return nil, false, nil
- }
- inputValue := gjson.GetBytes(payload, "input")
- if !inputValue.Exists() {
- return nil, false, nil
- }
- if inputValue.Type == gjson.JSON {
- raw := strings.TrimSpace(inputValue.Raw)
- if strings.HasPrefix(raw, "[") {
- var items []json.RawMessage
- if err := json.Unmarshal([]byte(raw), &items); err != nil {
- return nil, true, err
- }
- return items, true, nil
- }
- return []json.RawMessage{json.RawMessage(raw)}, true, nil
- }
- if inputValue.Type == gjson.String {
- encoded, _ := json.Marshal(inputValue.String())
- return []json.RawMessage{encoded}, true, nil
- }
- return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil
-}
-
-func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) {
- previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload)
- if prevErr != nil {
- return false, prevErr
- }
- currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
- if currentErr != nil {
- return false, currentErr
- }
- if !previousExists && !currentExists {
- return true, nil
- }
- if !previousExists {
- return len(currentItems) == 0, nil
- }
- if !currentExists {
- return len(previousItems) == 0, nil
- }
- if len(currentItems) < len(previousItems) {
- return false, nil
+ metadata := strings.TrimSpace(turnMetadata)
+ if metadata == "" {
+ return
}
- for idx := range previousItems {
- previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx])
- currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx])
- if !bytes.Equal(previousNormalized, currentNormalized) {
- return false, nil
+ switch existing := payload["client_metadata"].(type) {
+ case map[string]any:
+ existing[openAIWSTurnMetadataHeader] = metadata
+ payload["client_metadata"] = existing
+ case map[string]string:
+ next := make(map[string]any, len(existing)+1)
+ for k, v := range existing {
+ next[k] = v
+ }
+ next[openAIWSTurnMetadataHeader] = metadata
+ payload["client_metadata"] = next
+ default:
+ payload["client_metadata"] = map[string]any{
+ openAIWSTurnMetadataHeader: metadata,
}
}
- return true, nil
}
-func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool {
- if len(prefix) == 0 {
+func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool {
+ if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() {
return true
}
- if len(items) < len(prefix) {
- return false
- }
- for idx := range prefix {
- previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx])
- currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx])
- if !bytes.Equal(previousNormalized, currentNormalized) {
- return false
- }
+ if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery {
+ return true
}
- return true
+ return false
}
-func buildOpenAIWSReplayInputSequence(
- previousFullInput []json.RawMessage,
- previousFullInputExists bool,
- currentPayload []byte,
- hasPreviousResponseID bool,
-) ([]json.RawMessage, bool, error) {
- currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
- if currentErr != nil {
- return nil, false, currentErr
- }
- if !hasPreviousResponseID {
- return cloneOpenAIWSRawMessages(currentItems), currentExists, nil
- }
- if !previousFullInputExists {
- return cloneOpenAIWSRawMessages(currentItems), currentExists, nil
- }
- if !currentExists || len(currentItems) == 0 {
- return cloneOpenAIWSRawMessages(previousFullInput), true, nil
- }
- if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) {
- return cloneOpenAIWSRawMessages(currentItems), true, nil
+func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool {
+ if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
+ return true
}
- merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems))
- merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...)
- merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...)
- return merged, true, nil
-}
-
-func setOpenAIWSPayloadInputSequence(
- payload []byte,
- fullInput []json.RawMessage,
- fullInputExists bool,
-) ([]byte, error) {
- if !fullInputExists {
- return payload, nil
+ if len(reqBody) == 0 {
+ return false
}
- // Preserve [] vs null semantics when input exists but is empty.
- inputForMarshal := fullInput
- if inputForMarshal == nil {
- inputForMarshal = []json.RawMessage{}
+ rawStore, ok := reqBody["store"]
+ if !ok {
+ return false
}
- inputRaw, marshalErr := json.Marshal(inputForMarshal)
- if marshalErr != nil {
- return nil, marshalErr
+ storeEnabled, ok := rawStore.(bool)
+ if !ok {
+ return false
}
- return sjson.SetRawBytes(payload, "input", inputRaw)
+ return !storeEnabled
}
-func shouldKeepIngressPreviousResponseID(
- previousPayload []byte,
- currentPayload []byte,
- lastTurnResponseID string,
- hasFunctionCallOutput bool,
-) (bool, string, error) {
- if hasFunctionCallOutput {
- return true, "has_function_call_output", nil
- }
- currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
- if currentPreviousResponseID == "" {
- return false, "missing_previous_response_id", nil
- }
- expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
- if expectedPreviousResponseID == "" {
- return false, "missing_last_turn_response_id", nil
- }
- if currentPreviousResponseID != expectedPreviousResponseID {
- return false, "previous_response_id_mismatch", nil
- }
- if len(previousPayload) == 0 {
- return false, "missing_previous_turn_payload", nil
+func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool {
+ if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) {
+ return true
}
-
- previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload)
- if previousComparableErr != nil {
- return false, "non_input_compare_error", previousComparableErr
+ if len(reqBody) == 0 {
+ return false
}
- currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
- if currentComparableErr != nil {
- return false, "non_input_compare_error", currentComparableErr
+ storeValue := gjson.GetBytes(reqBody, "store")
+ if !storeValue.Exists() {
+ return false
}
- if !bytes.Equal(previousComparable, currentComparable) {
- return false, "non_input_changed", nil
+ if storeValue.Type != gjson.True && storeValue.Type != gjson.False {
+ return false
}
- return true, "strict_incremental_ok", nil
-}
-
-type openAIWSIngressPreviousTurnStrictState struct {
- nonInputComparable []byte
+ return !storeValue.Bool()
}
-func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) {
- if len(payload) == 0 {
- return nil, nil
+func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string {
+ if s == nil || s.cfg == nil {
+ return openAIWSStoreDisabledConnModeStrict
}
- nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload)
- if nonInputErr != nil {
- return nil, nonInputErr
+ mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode))
+ switch mode {
+ case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff:
+ return mode
+ case "":
+ // 兼容旧配置:仅配置了布尔开关时按旧语义推导。
+ if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn {
+ return openAIWSStoreDisabledConnModeStrict
+ }
+ return openAIWSStoreDisabledConnModeOff
+ default:
+ return openAIWSStoreDisabledConnModeStrict
}
- return &openAIWSIngressPreviousTurnStrictState{
- nonInputComparable: nonInputComparable,
- }, nil
}
-func shouldKeepIngressPreviousResponseIDWithStrictState(
- previousState *openAIWSIngressPreviousTurnStrictState,
- currentPayload []byte,
- lastTurnResponseID string,
- hasFunctionCallOutput bool,
-) (bool, string, error) {
- if hasFunctionCallOutput {
- return true, "has_function_call_output", nil
- }
- currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
- if currentPreviousResponseID == "" {
- return false, "missing_previous_response_id", nil
- }
- expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
- if expectedPreviousResponseID == "" {
- return false, "missing_last_turn_response_id", nil
- }
- if currentPreviousResponseID != expectedPreviousResponseID {
- return false, "previous_response_id_mismatch", nil
- }
- if previousState == nil {
- return false, "missing_previous_turn_payload", nil
- }
-
- currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
- if currentComparableErr != nil {
- return false, "non_input_compare_error", currentComparableErr
- }
- if !bytes.Equal(previousState.nonInputComparable, currentComparable) {
- return false, "non_input_changed", nil
+func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool {
+ switch mode {
+ case openAIWSStoreDisabledConnModeOff:
+ return false
+ case openAIWSStoreDisabledConnModeAdaptive:
+ reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_")
+ switch reason {
+ case "policy_violation", "message_too_big", "auth_failed", "write_request", "write":
+ return true
+ default:
+ return false
+ }
+ default:
+ return true
}
- return true, "strict_incremental_ok", nil
}
func (s *OpenAIGatewayService) forwardOpenAIWSV2(
@@ -1660,7 +528,20 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
startTime time.Time,
attempt int,
lastFailureReason string,
-) (*OpenAIForwardResult, error) {
+) (result *OpenAIForwardResult, err error) {
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ logger.LegacyPrintf(
+ "service.openai_ws_forwarder",
+ "[OpenAIWS] recovered panic in forwardOpenAIWSV2: panic=%v stack=%s",
+ recovered,
+ string(debug.Stack()),
+ )
+ err = fmt.Errorf("openai ws panic recovered: %v", recovered)
+ result = nil
+ }
+ }()
+
if s == nil || account == nil {
return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil"))
}
@@ -1738,7 +619,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
stateStore := s.getOpenAIWSStateStore()
groupID := getOpenAIGroupIDFromContext(c)
- sessionHash := s.GenerateSessionHash(c, nil)
+ sessionHash := s.GenerateSessionHashWithFallback(c, nil, openAIWSIngressFallbackSessionSeedFromContext(c))
if sessionHash == "" {
var legacySessionHash string
sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey)
@@ -1751,9 +632,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
}
preferredConnID := ""
if stateStore != nil && previousResponseID != "" {
- if connID, ok := stateStore.GetResponseConn(previousResponseID); ok {
- preferredConnID = connID
- }
+ preferredConnID = openAIWSPreferredConnIDFromResponse(stateStore, previousResponseID)
}
storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account)
if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" {
@@ -1800,18 +679,35 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout())
defer acquireCancel()
- lease, err := s.getOpenAIWSConnPool().Acquire(acquireCtx, openAIWSAcquireRequest{
- Account: account,
- WSURL: wsURL,
- Headers: wsHeaders,
- PreferredConnID: preferredConnID,
- ForceNewConn: forceNewConn,
+ ingressCtxPool := s.getOpenAIWSIngressContextPool()
+ if ingressCtxPool == nil {
+ return nil, wrapOpenAIWSFallback("ctx_pool_unavailable", errors.New("openai ws ingress context pool is nil"))
+ }
+ sessionHashForCtx := strings.TrimSpace(sessionHash)
+ if sessionHashForCtx == "" {
+ sessionHashForCtx = fmt.Sprintf("httpws:%d:%d", account.ID, startTime.UnixNano())
+ }
+ if forceNewConn {
+ sessionHashForCtx = fmt.Sprintf("%s:retry:%d", sessionHashForCtx, attempt)
+ }
+ ownerID := fmt.Sprintf("httpws_%d_%d", account.ID, attempt)
+ lease, err := ingressCtxPool.Acquire(acquireCtx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: groupID,
+ SessionHash: sessionHashForCtx,
+ OwnerID: ownerID,
+ WSURL: wsURL,
+ Headers: cloneHeader(wsHeaders),
ProxyURL: func() string {
if account.ProxyID != nil && account.Proxy != nil {
return account.Proxy.URL()
}
return ""
}(),
+ Turn: 1,
+ HasPreviousResponseID: previousResponseID != "",
+ StrictAffinity: previousResponseID != "",
+ StoreDisabled: storeDisabled,
})
if err != nil {
dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err)
@@ -1971,6 +867,11 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
}
clientDisconnected := false
+ var downstreamWriteErr error
+ var requestCtx context.Context
+ if c != nil && c.Request != nil {
+ requestCtx = c.Request.Context()
+ }
flushBatchSize := s.openAIWSEventFlushBatchSize()
flushInterval := s.openAIWSEventFlushInterval()
pendingFlushEvents := 0
@@ -1988,23 +889,41 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
pendingFlushEvents = 0
lastFlushAt = time.Now()
}
+ var sseFrameBuf []byte
emitStreamMessage := func(message []byte, forceFlush bool) {
- if clientDisconnected {
+ if clientDisconnected || downstreamWriteErr != nil {
return
}
- frame := make([]byte, 0, len(message)+8)
- frame = append(frame, "data: "...)
- frame = append(frame, message...)
- frame = append(frame, '\n', '\n')
- _, wErr := c.Writer.Write(frame)
+ sseFrameBuf = sseFrameBuf[:0]
+ sseFrameBuf = append(sseFrameBuf, "data: "...)
+ sseFrameBuf = append(sseFrameBuf, message...)
+ sseFrameBuf = append(sseFrameBuf, '\n', '\n')
+ _, wErr := c.Writer.Write(sseFrameBuf)
if wErr == nil {
wroteDownstream = true
pendingFlushEvents++
flushStreamWriter(forceFlush)
return
}
- clientDisconnected = true
- logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d", account.ID)
+ if isOpenAIWSStreamWriteDisconnectError(wErr, requestCtx) {
+ clientDisconnected = true
+ logger.LegacyPrintf(
+ "service.openai_gateway",
+ "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d conn_id=%s",
+ account.ID,
+ connID,
+ )
+ return
+ }
+ downstreamWriteErr = wErr
+ setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(wErr.Error()), "")
+ logOpenAIWSModeInfo(
+ "stream_write_fail account_id=%d conn_id=%s wrote_downstream=%v cause=%s",
+ account.ID,
+ connID,
+ wroteDownstream,
+ truncateOpenAIWSLogValue(wErr.Error(), openAIWSLogValueMaxLen),
+ )
}
flushBufferedStreamEvents := func(reason string) {
if len(bufferedStreamEvents) == 0 {
@@ -2013,6 +932,9 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
flushed := len(bufferedStreamEvents)
for _, buffered := range bufferedStreamEvents {
emitStreamMessage(buffered, false)
+ if downstreamWriteErr != nil {
+ break
+ }
}
bufferedStreamEvents = bufferedStreamEvents[:0]
flushStreamWriter(true)
@@ -2170,8 +1092,16 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw)
setOpsUpstreamError(c, statusCode, errMsg, "")
if reqStream && !clientDisconnected {
- flushBufferedStreamEvents("error_event")
+ if shouldFlushOpenAIWSBufferedEventsOnError(reqStream, wroteDownstream, clientDisconnected) {
+ flushBufferedStreamEvents("error_event")
+ } else {
+ bufferedStreamEvents = bufferedStreamEvents[:0]
+ }
emitStreamMessage(message, true)
+ if downstreamWriteErr != nil {
+ lease.MarkBroken()
+ return nil, fmt.Errorf("openai ws stream write: %w", downstreamWriteErr)
+ }
}
if !reqStream {
c.JSON(statusCode, gin.H{
@@ -2207,10 +1137,14 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
} else {
flushBufferedStreamEvents(eventType)
emitStreamMessage(message, isTerminalEvent)
+ if downstreamWriteErr != nil {
+ lease.MarkBroken()
+ return nil, fmt.Errorf("openai ws stream write: %w", downstreamWriteErr)
+ }
}
} else {
if responseField.Exists() && responseField.Type == gjson.JSON {
- finalResponse = []byte(responseField.Raw)
+ finalResponse = cloneOpenAIWSJSONRawString(responseField.Raw)
}
}
@@ -2253,7 +1187,14 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
if responseID != "" && stateStore != nil {
ttl := s.openAIWSResponseStickyTTL()
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl))
- stateStore.BindResponseConn(responseID, lease.ConnID(), ttl)
+ if connID, ok := normalizeOpenAIWSPreferredConnID(lease.ConnID()); ok {
+ stateStore.BindResponseConn(responseID, connID, ttl)
+ }
+ if sessionHash != "" && shouldPersistOpenAIWSLastResponseID(lastEventType) {
+ stateStore.BindSessionLastResponseID(groupID, sessionHash, responseID, s.openAIWSSessionStickyTTL())
+ } else if sessionHash != "" {
+ stateStore.DeleteSessionLastResponseID(groupID, sessionHash)
+ }
}
if stateStore != nil && storeDisabled && sessionHash != "" {
stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL())
@@ -2282,14 +1223,15 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
)
return &OpenAIForwardResult{
- RequestID: responseID,
- Usage: *usage,
- Model: originalModel,
- ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
- Stream: reqStream,
- OpenAIWSMode: true,
- Duration: time.Since(startTime),
- FirstTokenMs: firstTokenMs,
+ RequestID: responseID,
+ Usage: *usage,
+ Model: originalModel,
+ ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel),
+ Stream: reqStream,
+ OpenAIWSMode: true,
+ Duration: time.Since(startTime),
+ FirstTokenMs: firstTokenMs,
+ TerminalEventType: strings.TrimSpace(lastEventType),
}, nil
}
@@ -2303,7 +1245,24 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
token string,
firstClientMessage []byte,
hooks *OpenAIWSIngressHooks,
-) error {
+) (err error) {
+ defer func() {
+ if recovered := recover(); recovered != nil {
+ const panicCloseReason = "internal websocket proxy panic"
+ logger.LegacyPrintf(
+ "service.openai_ws_forwarder",
+ "[OpenAIWS] recovered panic in ProxyResponsesWebSocketFromClient: panic=%v stack=%s",
+ recovered,
+ string(debug.Stack()),
+ )
+ err = NewOpenAIWSClientCloseError(
+ coderws.StatusInternalError,
+ panicCloseReason,
+ fmt.Errorf("panic recovered: %v", recovered),
+ )
+ }
+ }()
+
if s == nil {
return errors.New("service is nil")
}
@@ -2322,22 +1281,43 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account)
modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
- ingressMode := OpenAIWSIngressModeShared
- if modeRouterV2Enabled {
- ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
- if ingressMode == OpenAIWSIngressModeOff {
- return NewOpenAIWSClientCloseError(
- coderws.StatusPolicyViolation,
- "websocket mode is disabled for this account",
- nil,
- )
- }
+ if !modeRouterV2Enabled {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "websocket mode requires mode_router_v2 with ctx_pool",
+ nil,
+ )
+ }
+ ingressMode := account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault)
+ logOpenAIWSModeInfo(
+ "ingress_ws_validate account_id=%d ingress_mode=%s transport=%s",
+ account.ID,
+ normalizeOpenAIWSLogValue(string(ingressMode)),
+ normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
+ )
+ if ingressMode == OpenAIWSIngressModeOff {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "websocket mode is disabled for this account",
+ nil,
+ )
+ }
+ if ingressMode != OpenAIWSIngressModeCtxPool {
+ return NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "websocket mode only supports ctx_pool",
+ nil,
+ )
+ }
+ // Ingress ws_v2 请求天然是 Codex 会话语义,ctx_pool 是否启用仅由账号 mode 决定。
+ ctxPoolMode := ingressMode == OpenAIWSIngressModeCtxPool
+ ctxPoolSessionScope := ""
+ if ctxPoolMode {
+ ctxPoolSessionScope = openAIWSIngressSessionScopeFromContext(c)
}
if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 {
return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport)
}
- dedicatedMode := modeRouterV2Enabled && ingressMode == OpenAIWSIngressModeDedicated
-
wsURL, err := s.buildOpenAIResponsesWSURL(account)
if err != nil {
return fmt.Errorf("build ws url: %w", err)
@@ -2349,6 +1329,15 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsPath = normalizeOpenAIWSLogValue(parsedURL.Path)
}
debugEnabled := isOpenAIWSModeDebugEnabled()
+ logOpenAIWSModeInfo(
+ "ingress_ws_session_init account_id=%d ws_host=%s ws_path=%s ctx_pool=%v session_scope=%s debug=%v",
+ account.ID,
+ wsHost,
+ wsPath,
+ ctxPoolMode,
+ truncateOpenAIWSLogValue(ctxPoolSessionScope, openAIWSIDValueMaxLen),
+ debugEnabled,
+ )
type openAIWSClientPayload struct {
payloadRaw []byte
@@ -2475,32 +1464,64 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader))
stateStore := s.getOpenAIWSStateStore()
groupID := getOpenAIGroupIDFromContext(c)
- sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash)
- if turnState == "" && stateStore != nil && sessionHash != "" {
+ fallbackSessionSeed := openAIWSIngressFallbackSessionSeedFromContext(c)
+ legacySessionHash := strings.TrimSpace(s.GenerateSessionHashWithFallback(c, firstPayload.rawForHash, fallbackSessionSeed))
+ sessionHash := legacySessionHash
+ if ctxPoolMode {
+ sessionHash = openAIWSApplySessionScope(legacySessionHash, ctxPoolSessionScope)
+ }
+ resolveSessionTurnState := func() (string, bool) {
+ if stateStore == nil || sessionHash == "" {
+ return "", false
+ }
if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok {
+ return savedTurnState, true
+ }
+ if !ctxPoolMode || legacySessionHash == "" || legacySessionHash == sessionHash {
+ return "", false
+ }
+ return stateStore.GetSessionTurnState(groupID, legacySessionHash)
+ }
+ resolveSessionLastResponseID := func() (string, bool) {
+ if stateStore == nil || sessionHash == "" {
+ return "", false
+ }
+ if savedResponseID, ok := stateStore.GetSessionLastResponseID(groupID, sessionHash); ok {
+ return strings.TrimSpace(savedResponseID), true
+ }
+ if !ctxPoolMode || legacySessionHash == "" || legacySessionHash == sessionHash {
+ return "", false
+ }
+ savedResponseID, ok := stateStore.GetSessionLastResponseID(groupID, legacySessionHash)
+ return strings.TrimSpace(savedResponseID), ok
+ }
+ if turnState == "" && stateStore != nil && sessionHash != "" {
+ if savedTurnState, ok := resolveSessionTurnState(); ok {
turnState = savedTurnState
}
}
+ sessionLastResponseID := ""
+ if stateStore != nil && sessionHash != "" {
+ if savedResponseID, ok := resolveSessionLastResponseID(); ok {
+ sessionLastResponseID = savedResponseID
+ }
+ }
preferredConnID := ""
if stateStore != nil && firstPayload.previousResponseID != "" {
- if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok {
- preferredConnID = connID
- }
+ preferredConnID = openAIWSPreferredConnIDFromResponse(stateStore, firstPayload.previousResponseID)
}
storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account)
storeDisabledConnMode := s.openAIWSStoreDisabledConnMode()
- if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" {
- if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok {
- preferredConnID = connID
- }
- }
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey)
- baseAcquireReq := openAIWSAcquireRequest{
- Account: account,
+ baseAcquireReq := struct {
+ WSURL string
+ Headers http.Header
+ ProxyURL string
+ }{
WSURL: wsURL,
Headers: wsHeaders,
ProxyURL: func() string {
@@ -2509,21 +1530,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
return ""
}(),
- ForceNewConn: false,
}
- pool := s.getOpenAIWSConnPool()
- if pool == nil {
- return errors.New("openai ws conn pool is nil")
+
+ ingressCtxPool := s.getOpenAIWSIngressContextPool()
+ if ingressCtxPool == nil {
+ return errors.New("openai ws ingress context pool is nil")
}
logOpenAIWSModeInfo(
- "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s store_disabled=%v has_session_hash=%v has_previous_response_id=%v",
+ "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_session_hash=%v has_previous_response_id=%v",
account.ID,
account.Type,
normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
wsHost,
wsPath,
normalizeOpenAIWSLogValue(ingressMode),
+ ctxPoolMode,
storeDisabled,
sessionHash != "",
firstPayload.previousResponseID != "",
@@ -2531,7 +1553,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if debugEnabled {
logOpenAIWSModeDebug(
- "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v",
+ "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v ctx_pool_mode=%v",
account.ID,
account.Type,
normalizeOpenAIWSLogValue(string(wsDecision.Transport)),
@@ -2540,6 +1562,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
sessionHash != "",
firstPayload.previousResponseID != "",
storeDisabled,
+ ctxPoolMode,
)
}
if firstPayload.previousResponseID != "" {
@@ -2566,15 +1589,37 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
acquireTimeout = 30 * time.Second
}
- acquireTurnLease := func(turn int, preferred string, forcePreferredConn bool) (*openAIWSConnLease, error) {
- req := cloneOpenAIWSAcquireRequest(baseAcquireReq)
- req.PreferredConnID = strings.TrimSpace(preferred)
- req.ForcePreferredConn = forcePreferredConn
- // dedicated 模式下每次获取均新建连接,避免跨会话复用残留上下文。
- req.ForceNewConn = dedicatedMode
+ ownerID := fmt.Sprintf("cliws_%p", clientConn)
+ acquireTurnLease := func(
+ turn int,
+ preferred string,
+ forcePreferredConn bool,
+ hasPreviousResponseID bool,
+ ) (openAIWSIngressUpstreamLease, error) {
acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout)
- lease, acquireErr := pool.Acquire(acquireCtx, req)
- acquireCancel()
+ defer acquireCancel()
+
+ var (
+ lease openAIWSIngressUpstreamLease
+ acquireErr error
+ )
+ sessionHashForCtx := strings.TrimSpace(sessionHash)
+ if sessionHashForCtx == "" {
+ sessionHashForCtx = fmt.Sprintf("conn:%s", ownerID)
+ }
+ lease, acquireErr = ingressCtxPool.Acquire(acquireCtx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: groupID,
+ SessionHash: sessionHashForCtx,
+ OwnerID: ownerID,
+ WSURL: baseAcquireReq.WSURL,
+ Headers: cloneHeader(baseAcquireReq.Headers),
+ ProxyURL: baseAcquireReq.ProxyURL,
+ Turn: turn,
+ HasPreviousResponseID: hasPreviousResponseID,
+ StrictAffinity: forcePreferredConn,
+ StoreDisabled: storeDisabled,
+ })
if acquireErr != nil {
dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr)
logOpenAIWSModeInfo(
@@ -2597,14 +1642,9 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
- if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
- return nil, NewOpenAIWSClientCloseError(
- coderws.StatusPolicyViolation,
- "upstream continuation connection is unavailable; please restart the conversation",
- acquireErr,
- )
- }
- if errors.Is(acquireErr, context.DeadlineExceeded) || errors.Is(acquireErr, errOpenAIWSConnQueueFull) {
+ if errors.Is(acquireErr, context.DeadlineExceeded) ||
+ errors.Is(acquireErr, errOpenAIWSConnQueueFull) ||
+ errors.Is(acquireErr, errOpenAIWSIngressContextBusy) {
return nil, NewOpenAIWSClientCloseError(
coderws.StatusTryAgainLater,
"upstream websocket is busy, please retry later",
@@ -2627,7 +1667,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
baseAcquireReq.Headers = updatedHeaders
}
logOpenAIWSModeInfo(
- "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s",
+ "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s ctx_pool_mode=%v schedule_layer=%s stickiness_level=%s migration_used=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
@@ -2635,6 +1675,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
lease.ConnPickDuration().Milliseconds(),
lease.QueueWaitDuration().Milliseconds(),
truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen),
+ ctxPoolMode,
+ normalizeOpenAIWSLogValue(lease.ScheduleLayer()),
+ normalizeOpenAIWSLogValue(lease.StickinessLevel()),
+ lease.MigrationUsed(),
)
return lease, nil
}
@@ -2646,8 +1690,21 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
readClientMessage := func() ([]byte, error) {
- msgType, payload, readErr := clientConn.Read(ctx)
+ readCtx := ctx
+ if idleTimeout := s.openAIWSClientReadIdleTimeout(); idleTimeout > 0 {
+ var cancel context.CancelFunc
+ readCtx, cancel = context.WithTimeout(ctx, idleTimeout)
+ defer cancel()
+ }
+ msgType, payload, readErr := clientConn.Read(readCtx)
if readErr != nil {
+ if readCtx != nil && readCtx.Err() == context.DeadlineExceeded && (ctx == nil || ctx.Err() == nil) {
+ return nil, NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ "client websocket idle timeout",
+ readErr,
+ )
+ }
return nil, readErr
}
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
@@ -2660,7 +1717,32 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return payload, nil
}
- sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) {
+ // 持久客户端读取 goroutine:将客户端消息推送到 channel,
+ // 使 sendAndRelay 可以通过 select 同时监听上游事件和客户端新请求。
+ var nextClientPreemptedPayload []byte
+ var pendingClientReadErr error
+ clientMsgCh := make(chan []byte, 1)
+ clientReadErrCh := make(chan error, 1)
+ go func() {
+ defer close(clientMsgCh)
+ for {
+ msg, err := readClientMessage()
+ if err != nil {
+ select {
+ case clientReadErrCh <- err:
+ case <-ctx.Done():
+ }
+ return
+ }
+ select {
+ case clientMsgCh <- msg:
+ case <-ctx.Done():
+ return
+ }
+ }
+ }()
+
+ sendAndRelay := func(turn int, lease openAIWSIngressUpstreamLease, payload []byte, payloadBytes int, originalModel string, expectedPreviousResponseID string) (*OpenAIForwardResult, error) {
if lease == nil {
return nil, errors.New("upstream websocket lease is nil")
}
@@ -2689,9 +1771,12 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true)
turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID)
+ turnExpectedPreviousResponseID := strings.TrimSpace(expectedPreviousResponseID)
turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key")
turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account)
- turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
+ turnFunctionCallOutputCallIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)
+ turnHasFunctionCallOutput := len(turnFunctionCallOutputCallIDs) > 0
+ turnPendingFunctionCallIDSet := make(map[string]struct{}, 4)
eventCount := 0
tokenEventCount := 0
terminalEventCount := 0
@@ -2699,8 +1784,30 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
lastEventType := ""
needModelReplace := false
clientDisconnected := false
+ clientDisconnectDrainDeadline := time.Time{}
+ terminateOnErrorEvent := false
+ terminateOnErrorMessage := ""
mappedModel := ""
var mappedModelBytes []byte
+ buildPartialResult := func(terminalEventType string) *OpenAIForwardResult {
+ if usage.InputTokens <= 0 &&
+ usage.OutputTokens <= 0 &&
+ usage.CacheCreationInputTokens <= 0 &&
+ usage.CacheReadInputTokens <= 0 {
+ return nil
+ }
+ return &OpenAIForwardResult{
+ RequestID: responseID,
+ Usage: usage,
+ Model: originalModel,
+ ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
+ Stream: reqStream,
+ OpenAIWSMode: true,
+ Duration: time.Since(turnStart),
+ FirstTokenMs: firstTokenMs,
+ TerminalEventType: strings.TrimSpace(terminalEventType),
+ }
+ }
if originalModel != "" {
mappedModel = account.GetMappedModel(originalModel)
if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" {
@@ -2711,18 +1818,190 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
mappedModelBytes = []byte(mappedModel)
}
}
+ // 启动上游事件读取泵:解耦上游读取和客户端写入,允许二者并发执行。
+ // 读取 goroutine 将上游事件推送到缓冲 channel,主 goroutine 从 channel 消费并处理/转发。
+ // 缓冲 channel 允许上游在客户端写入阻塞时继续读取后续事件,降低端到端延迟。
+ pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize)
+ pumpCtx, pumpCancel := context.WithCancel(ctx)
+ defer pumpCancel()
+ pumpStartedAt := time.Now()
+ go func() {
+ defer func() {
+ close(pumpEventCh)
+ if pumpCtx.Err() == nil {
+ return
+ }
+ pumpAlive := time.Since(pumpStartedAt)
+ if pumpAlive >= openAIWSUpstreamPumpInfoMinAlive {
+ logOpenAIWSModeInfo(
+ "ingress_ws_upstream_pump_exit account_id=%d turn=%d conn_id=%s reason=context_cancelled pump_alive_ms=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ pumpAlive.Milliseconds(),
+ )
+ return
+ }
+ logOpenAIWSModeDebug(
+ "ingress_ws_upstream_pump_exit account_id=%d turn=%d conn_id=%s reason=context_cancelled pump_alive_ms=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ pumpAlive.Milliseconds(),
+ )
+ }()
+ for {
+ msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, s.openAIWSReadTimeout())
+ select {
+ case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}:
+ case <-pumpCtx.Done():
+ return
+ }
+ if readErr != nil {
+ return
+ }
+ // 检测终端/错误事件以终止读取泵。
+ evtType, _ := parseOpenAIWSEventType(msg)
+ if isOpenAIWSTerminalEvent(evtType) || evtType == "error" {
+ return
+ }
+ }
+ }()
+ var drainTimer *time.Timer
+ defer func() {
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+ }()
for {
- upstreamMessage, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout())
- if readErr != nil {
+ var evt openAIWSUpstreamPumpEvent
+ var evtOk bool
+ select {
+ case evt, evtOk = <-pumpEventCh:
+ if !evtOk {
+ goto pumpClosed
+ }
+ case preemptMsg, ok := <-clientMsgCh:
+ if !ok {
+ // 客户端读取 goroutine 退出,置空 channel 防止再次 select
+ clientMsgCh = nil
+ continue
+ }
+ // 客户端抢占:暂存新请求,取消上游转发,返回让外层切换到下一 turn
+ nextClientPreemptedPayload = preemptMsg
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_preempt account_id=%d turn=%d conn_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ )
+ pumpCancel()
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_preempted", errOpenAIWSClientPreempted,
+ wroteDownstream, buildPartialResult("client_preempted"))
+ case readErr := <-clientReadErrCh:
+ // 客户端断连:立即取消上游 pump 并释放连接。
+ // Codex CLI 在 ESC 取消后会关闭旧 WebSocket 并新建连接发送下一条消息,
+ // 继续排水只会延迟新连接获取上游 lease,因此这里直接终止。
+ if isOpenAIWSClientDisconnectError(readErr) {
+ closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_disconnected_immediate_cancel account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ closeStatus,
+ truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
+ )
+ pumpCancel()
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_disconnected_immediate",
+ fmt.Errorf("client disconnected (read): %w", readErr),
+ wroteDownstream,
+ buildPartialResult("client_disconnected"),
+ )
+ }
+ pendingClientReadErr = readErr
+ cause := "-"
+ if readErr != nil {
+ cause = truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen)
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_read_error_deferred account_id=%d turn=%d conn_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ cause,
+ )
+ clientMsgCh = nil
+ clientReadErrCh = nil
+ continue
+ }
+ // 排水超时检查:客户端已断连且排水截止时间已过,终止读取。
+ if clientDisconnected && !clientDisconnectDrainDeadline.IsZero() && time.Now().After(clientDisconnectDrainDeadline) {
+ pumpCancel()
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_disconnected_drain_timeout account_id=%d turn=%d conn_id=%s timeout_ms=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ openAIWSIngressClientDisconnectDrainTimeout.Milliseconds(),
+ )
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_disconnected_drain_timeout",
+ openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout),
+ wroteDownstream,
+ buildPartialResult("client_disconnected_drain_timeout"),
+ )
+ }
+ upstreamMessage := evt.message
+ if evt.err != nil {
+ readErr := evt.err
+ if clientDisconnected {
+ // 排水期间读取失败(上游关闭或读取泵被取消),按排水超时处理。
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_disconnected_drain_timeout account_id=%d turn=%d conn_id=%s timeout_ms=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ openAIWSIngressClientDisconnectDrainTimeout.Milliseconds(),
+ )
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_disconnected_drain_timeout",
+ openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout),
+ wroteDownstream,
+ buildPartialResult("client_disconnected_drain_timeout"),
+ )
+ }
+ readErrClass := classifyOpenAIWSIngressReadErrorClass(readErr)
+ closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
+ logOpenAIWSModeInfo(
+ "ingress_ws_upstream_read_error account_id=%d turn=%d conn_id=%s class=%s close_status=%s close_reason=%s events_received=%d wrote_downstream=%v response_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(readErrClass),
+ closeStatus,
+ closeReason,
+ eventCount,
+ wroteDownstream,
+ truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen),
+ )
lease.MarkBroken()
- return nil, wrapOpenAIWSIngressTurnError(
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
"read_upstream",
fmt.Errorf("read upstream websocket event: %w", readErr),
wroteDownstream,
+ buildPartialResult("read_upstream"),
)
}
- eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage)
+ eventType, eventResponseID := parseOpenAIWSEventType(upstreamMessage)
if responseID == "" && eventResponseID != "" {
responseID = eventResponseID
}
@@ -2737,15 +2016,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage)
fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
+ recoveryEnabled := s.openAIWSIngressPreviousResponseRecoveryEnabled()
recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound &&
+ recoveryEnabled &&
+ (turnPreviousResponseID != "" || (turnHasFunctionCallOutput && turnExpectedPreviousResponseID != "")) &&
+ !wroteDownstream
+ // tool_output_not_found: previous_response_id 指向的 response 包含未完成的 function_call
+ // (用户在 Codex CLI 按 ESC 取消后重新发送消息),需要移除 previous_response_id 后重放。
+ recoverableToolOutputNotFound := fallbackReason == openAIWSIngressStageToolOutputNotFound &&
+ recoveryEnabled &&
turnPreviousResponseID != "" &&
- !turnHasFunctionCallOutput &&
- s.openAIWSIngressPreviousResponseRecoveryEnabled() &&
!wroteDownstream
- if recoverablePrevNotFound {
+ recoverableContextMismatch := recoverablePrevNotFound || recoverableToolOutputNotFound
+ if recoverableContextMismatch {
// 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。
logOpenAIWSModeInfo(
- "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v",
+ "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_prompt_cache_key=%v has_function_call_output=%v recovery_enabled=%v wrote_downstream=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
@@ -2757,12 +2043,17 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(turnPreviousResponseIDKind),
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(ingressMode),
+ ctxPoolMode,
turnStoreDisabled,
turnPromptCacheKey != "",
+ turnHasFunctionCallOutput,
+ recoveryEnabled,
+ wroteDownstream,
)
} else {
logOpenAIWSModeInfo(
- "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v",
+ "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s ws_mode=%s ctx_pool_mode=%v store_disabled=%v has_prompt_cache_key=%v has_function_call_output=%v recovery_enabled=%v wrote_downstream=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
@@ -2774,24 +2065,39 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen),
normalizeOpenAIWSLogValue(turnPreviousResponseIDKind),
truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(ingressMode),
+ ctxPoolMode,
turnStoreDisabled,
turnPromptCacheKey != "",
+ turnHasFunctionCallOutput,
+ recoveryEnabled,
+ wroteDownstream,
)
}
- // previous_response_not_found 在 ingress 模式支持单次恢复重试:
+ // previous_response_not_found / tool_output_not_found 在 ingress 模式支持单次恢复重试:
// 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。
- if recoverablePrevNotFound {
+ if recoverableContextMismatch {
lease.MarkBroken()
errMsg := strings.TrimSpace(errMsgRaw)
if errMsg == "" {
- errMsg = "previous response not found"
+ if fallbackReason == openAIWSIngressStageToolOutputNotFound {
+ errMsg = "no tool output found for function call"
+ } else {
+ errMsg = "previous response not found"
+ }
}
- return nil, wrapOpenAIWSIngressTurnError(
- openAIWSIngressStagePreviousResponseNotFound,
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ fallbackReason,
errors.New(errMsg),
false,
+ buildPartialResult(fallbackReason),
)
}
+ terminateOnErrorEvent = true
+ terminateOnErrorMessage = strings.TrimSpace(errMsgRaw)
+ if terminateOnErrorMessage == "" {
+ terminateOnErrorMessage = "upstream websocket error"
+ }
}
isTokenEvent := isOpenAIWSTokenEvent(eventType)
if isTokenEvent {
@@ -2808,6 +2114,11 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if openAIWSEventShouldParseUsage(eventType) {
parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage)
}
+ if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) {
+ for _, callID := range openAIWSExtractPendingFunctionCallIDsFromEvent(upstreamMessage) {
+ turnPendingFunctionCallIDSet[callID] = struct{}{}
+ }
+ }
if !clientDisconnected {
if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) {
@@ -2820,27 +2131,47 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
if err := writeClientMessage(upstreamMessage); err != nil {
if isOpenAIWSClientDisconnectError(err) {
- clientDisconnected = true
+ // 客户端断连:立即取消上游 pump 并释放连接。
+ // 不再排水等待,以便新连接能尽快获取上游 lease。
closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err)
logOpenAIWSModeInfo(
- "ingress_ws_client_disconnected_drain account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s",
+ "ingress_ws_client_disconnected_immediate_cancel account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
closeStatus,
truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
)
+ pumpCancel()
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_disconnected_immediate",
+ fmt.Errorf("client disconnected (write): %w", err),
+ wroteDownstream,
+ buildPartialResult("client_disconnected"),
+ )
} else {
- return nil, wrapOpenAIWSIngressTurnError(
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
"write_client",
fmt.Errorf("write client websocket event: %w", err),
wroteDownstream,
+ buildPartialResult("write_client"),
)
}
} else {
wroteDownstream = true
}
}
+ if terminateOnErrorEvent {
+ // WS ingress 中的 error 事件应立即终止当前 turn,避免继续阻塞在下一次上游 read。
+ lease.MarkBroken()
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ "upstream_error_event",
+ errors.New(terminateOnErrorMessage),
+ wroteDownstream,
+ buildPartialResult("upstream_error_event"),
+ )
+ }
if isTerminalEvent {
// 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。
if clientDisconnected {
@@ -2852,7 +2183,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
if debugEnabled {
logOpenAIWSModeDebug(
- "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v",
+ "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v has_function_call_output=%v pending_function_call_ids=%d",
account.ID,
turn,
truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen),
@@ -2865,20 +2196,46 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen),
firstTokenMsValue,
clientDisconnected,
+ turnHasFunctionCallOutput,
+ len(turnPendingFunctionCallIDSet),
)
}
+ pendingFunctionCallIDs := make([]string, 0, len(turnPendingFunctionCallIDSet))
+ for callID := range turnPendingFunctionCallIDSet {
+ pendingFunctionCallIDs = append(pendingFunctionCallIDs, callID)
+ }
+ sort.Strings(pendingFunctionCallIDs)
return &OpenAIForwardResult{
- RequestID: responseID,
- Usage: usage,
- Model: originalModel,
- ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
- Stream: reqStream,
- OpenAIWSMode: true,
- Duration: time.Since(turnStart),
- FirstTokenMs: firstTokenMs,
+ RequestID: responseID,
+ Usage: usage,
+ Model: originalModel,
+ ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel),
+ Stream: reqStream,
+ OpenAIWSMode: true,
+ Duration: time.Since(turnStart),
+ FirstTokenMs: firstTokenMs,
+ TerminalEventType: strings.TrimSpace(eventType),
+ PendingFunctionCallIDs: pendingFunctionCallIDs,
}, nil
}
}
+ pumpClosed:
+ // 读取泵 channel 关闭但未收到终端事件:
+ // - 客户端已断连:按排水超时收尾,避免误判为 read_upstream。
+ // - 其他场景:按上游读取异常处理。
+ lease.MarkBroken()
+ if clientDisconnected {
+ return nil, openAIWSIngressPumpClosedTurnError(
+ true,
+ wroteDownstream,
+ buildPartialResult("client_disconnected_drain_timeout"),
+ )
+ }
+ return nil, openAIWSIngressPumpClosedTurnError(
+ false,
+ wroteDownstream,
+ buildPartialResult("read_upstream"),
+ )
}
currentPayload := firstPayload.payloadRaw
@@ -2890,50 +2247,39 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != ""
}
- var sessionLease *openAIWSConnLease
+ var sessionLease openAIWSIngressUpstreamLease
sessionConnID := ""
- pinnedSessionConnID := ""
- unpinSessionConn := func(connID string) {
- connID = strings.TrimSpace(connID)
- if connID == "" || pinnedSessionConnID != connID {
- return
- }
- pool.UnpinConn(account.ID, connID)
- pinnedSessionConnID = ""
- }
- pinSessionConn := func(connID string) {
- if !storeDisabled {
- return
- }
- connID = strings.TrimSpace(connID)
- if connID == "" || pinnedSessionConnID == connID {
+ unpinSessionConn := func(_ string) {}
+ pinSessionConn := func(_ string) {}
+ releaseSessionLease := func() {
+ if sessionLease == nil {
return
}
- if pinnedSessionConnID != "" {
- pool.UnpinConn(account.ID, pinnedSessionConnID)
- pinnedSessionConnID = ""
- }
- if pool.PinConn(account.ID, connID) {
- pinnedSessionConnID = connID
+ unpinSessionConn(sessionConnID)
+ sessionLease.Release()
+ if debugEnabled {
+ logOpenAIWSModeDebug(
+ "ingress_ws_upstream_released account_id=%d conn_id=%s",
+ account.ID,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ )
}
}
- releaseSessionLease := func() {
+ yieldSessionLease := func() {
if sessionLease == nil {
return
}
- if dedicatedMode {
- // dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。
- sessionLease.MarkBroken()
- }
unpinSessionConn(sessionConnID)
- sessionLease.Release()
+ sessionLease.Yield()
if debugEnabled {
logOpenAIWSModeDebug(
- "ingress_ws_upstream_released account_id=%d conn_id=%s",
+ "ingress_ws_upstream_yielded account_id=%d conn_id=%s",
account.ID,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
)
}
+ sessionLease = nil
+ sessionConnID = ""
}
defer releaseSessionLease()
@@ -2941,7 +2287,17 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
turnRetry := 0
turnPrevRecoveryTried := false
lastTurnFinishedAt := time.Time{}
- lastTurnResponseID := ""
+ lastTurnResponseID := sessionLastResponseID
+ clearSessionLastResponseID := func() {
+ lastTurnResponseID = ""
+ if stateStore == nil || sessionHash == "" {
+ return
+ }
+ stateStore.DeleteSessionLastResponseID(groupID, sessionHash)
+ if ctxPoolMode && legacySessionHash != "" && legacySessionHash != sessionHash {
+ stateStore.DeleteSessionLastResponseID(groupID, legacySessionHash)
+ }
+ }
lastTurnPayload := []byte(nil)
var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState
lastTurnReplayInput := []json.RawMessage(nil)
@@ -2953,6 +2309,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if sessionLease == nil {
return
}
+ resetStart := time.Now()
+ resetConnID := sessionConnID
if markBroken {
sessionLease.MarkBroken()
}
@@ -2960,12 +2318,192 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
sessionLease = nil
sessionConnID = ""
preferredConnID = ""
+ if elapsed := time.Since(resetStart); elapsed > 100*time.Millisecond {
+ logOpenAIWSModeInfo(
+ "ingress_ws_reset_session_lease_slow account_id=%d conn_id=%s mark_broken=%v elapsed_ms=%d",
+ account.ID,
+ truncateOpenAIWSLogValue(resetConnID, openAIWSIDValueMaxLen),
+ markBroken,
+ elapsed.Milliseconds(),
+ )
+ }
}
recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool {
- if !isOpenAIWSIngressPreviousResponseNotFound(relayErr) {
+ isPrevNotFound := isOpenAIWSIngressPreviousResponseNotFound(relayErr)
+ isToolOutputMissing := isOpenAIWSIngressToolOutputNotFound(relayErr)
+ if !isPrevNotFound && !isToolOutputMissing {
return false
}
if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() {
+ skipReason := "already_tried"
+ if !s.openAIWSIngressPreviousResponseRecoveryEnabled() {
+ skipReason = "recovery_disabled"
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skipped account_id=%d turn=%d conn_id=%s reason=%s is_prev_not_found=%v is_tool_output_missing=%v",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(skipReason),
+ isPrevNotFound,
+ isToolOutputMissing,
+ )
+ return false
+ }
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
+ // tool_output_not_found: previous_response_id 指向的 response 包含未完成的 function_call
+ // (用户在 Codex CLI 按 ESC 取消了 function_call 后重新发送消息)。
+ // 对齐/保持 previous_response_id 无法解决问题,直接跳到 drop 分支移除后重放。
+ if isToolOutputMissing {
+ logOpenAIWSModeInfo(
+ "ingress_ws_tool_output_not_found_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ )
+ turnPrevRecoveryTried = true
+ updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
+ if dropErr != nil || !removed {
+ reason := "not_removed"
+ if dropErr != nil {
+ reason = "drop_error"
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_tool_output_not_found_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(reason),
+ )
+ return false
+ }
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ currentTurnReplayInput,
+ currentTurnReplayInputExists,
+ )
+ if setInputErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_tool_output_not_found_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
+ )
+ return false
+ }
+ currentPayload = updatedWithInput
+ currentPayloadBytes = len(updatedWithInput)
+ clearSessionLastResponseID()
+ resetSessionLease(true)
+ skipBeforeTurn = true
+ return true
+ }
+ hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
+ if hasFunctionCallOutput {
+ turnPrevRecoveryTried = true
+ expectedPrev := strings.TrimSpace(lastTurnResponseID)
+ if currentPreviousResponseID == "" && expectedPrev != "" {
+ updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev)
+ if setPrevErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ } else {
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ currentTurnReplayInput,
+ currentTurnReplayInputExists,
+ )
+ if setInputErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s previous_response_id=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ } else {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=set_previous_response_id_retry previous_response_id=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ currentPayload = updatedWithInput
+ currentPayloadBytes = len(updatedWithInput)
+ resetSessionLease(true)
+ skipBeforeTurn = true
+ return true
+ }
+ }
+ }
+ alignedPayload, aligned, alignErr := alignStoreDisabledPreviousResponseID(currentPayload, expectedPrev)
+ if alignErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=align_previous_response_id_error cause=%s previous_response_id=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(alignErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ } else if aligned {
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ alignedPayload,
+ currentTurnReplayInput,
+ currentTurnReplayInputExists,
+ )
+ if setInputErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s previous_response_id=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ } else {
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=align_previous_response_id_retry previous_response_id=%s expected_previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ )
+ currentPayload = updatedWithInput
+ currentPayloadBytes = len(updatedWithInput)
+ resetSessionLease(true)
+ skipBeforeTurn = true
+ return true
+ }
+ }
+ // function_call_output 与 previous_response_id 语义绑定:
+ // function_call_output 引用了前一个 response 中的 call_id,
+ // 移除 previous_response_id 但保留 function_call_output 会导致上游报错
+ // "No tool call found for function call output with call_id ..."。
+ // 此场景在网关层不可恢复,返回 false 走 abort 路径通知客户端,
+ // 客户端收到错误后会重置并发送完整请求(不带 previous_response_id)。
+ logOpenAIWSModeInfo(
+ "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=abort_function_call_unrecoverable previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ )
return false
}
if isStrictAffinityTurn(currentPayload) {
@@ -3018,12 +2556,26 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
)
currentPayload = updatedWithInput
currentPayloadBytes = len(updatedWithInput)
+ clearSessionLastResponseID()
resetSessionLease(true)
skipBeforeTurn = true
return true
}
retryIngressTurn := func(relayErr error, turn int, connID string) bool {
if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 {
+ retrySkipReason := "not_retryable"
+ if turnRetry >= 1 {
+ retrySkipReason = "retry_exhausted"
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_retry_skipped account_id=%d turn=%d conn_id=%s reason=%s retry_count=%d err_stage=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(retrySkipReason),
+ turnRetry,
+ truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen),
+ )
return false
}
if isStrictAffinityTurn(currentPayload) {
@@ -3048,6 +2600,160 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
skipBeforeTurn = true
return true
}
+ advanceToNextClientTurn := func(turn int, connID string) (bool, error) {
+ logOpenAIWSModeInfo(
+ "ingress_ws_advance_wait_client account_id=%d turn=%d conn_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ )
+ var nextClientMessage []byte
+ if nextClientPreemptedPayload != nil {
+ nextClientMessage = nextClientPreemptedPayload
+ nextClientPreemptedPayload = nil
+ logOpenAIWSModeInfo(
+ "ingress_ws_advance_use_preempted_payload account_id=%d turn=%d conn_id=%s payload_bytes=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ len(nextClientMessage),
+ )
+ } else {
+ if pendingReadErr := openAIWSAdvanceConsumePendingClientReadErr(&pendingClientReadErr); pendingReadErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_advance_read_fail account_id=%d turn=%d conn_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(pendingReadErr.Error(), openAIWSLogValueMaxLen),
+ )
+ return false, pendingReadErr
+ }
+ if openAIWSAdvanceClientReadUnavailable(clientMsgCh, clientReadErrCh) {
+ logOpenAIWSModeInfo(
+ "ingress_ws_advance_read_unavailable account_id=%d turn=%d conn_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ )
+ return false, fmt.Errorf("read client websocket request: %w", errOpenAIWSAdvanceClientReadUnavailable)
+ }
+ select {
+ case msg, ok := <-clientMsgCh:
+ if !ok {
+ // 客户端读取 goroutine 已退出
+ return true, nil
+ }
+ nextClientMessage = msg
+ case readErr := <-clientReadErrCh:
+ if isOpenAIWSClientDisconnectError(readErr) {
+ closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s",
+ account.ID,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ closeStatus,
+ truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
+ )
+ return true, nil
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_advance_read_fail account_id=%d turn=%d conn_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen),
+ )
+ return false, fmt.Errorf("read client websocket request: %w", readErr)
+ }
+ }
+
+ nextPayload, parseErr := parseClientPayload(nextClientMessage)
+ if parseErr != nil {
+ return false, parseErr
+ }
+ nextStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(nextPayload.payloadRaw, account)
+ nextLegacySessionHash := strings.TrimSpace(s.GenerateSessionHashWithFallback(c, nextPayload.rawForHash, fallbackSessionSeed))
+ nextSessionHash := nextLegacySessionHash
+ if ctxPoolMode {
+ nextSessionHash = openAIWSApplySessionScope(nextLegacySessionHash, ctxPoolSessionScope)
+ }
+ if sessionHash == "" && nextSessionHash != "" {
+ sessionHash = nextSessionHash
+ legacySessionHash = nextLegacySessionHash
+ if stateStore != nil {
+ if turnState == "" {
+ if savedTurnState, ok := resolveSessionTurnState(); ok {
+ turnState = savedTurnState
+ }
+ }
+ if lastTurnResponseID == "" {
+ if savedResponseID, ok := resolveSessionLastResponseID(); ok {
+ lastTurnResponseID = savedResponseID
+ }
+ }
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_session_hash_backfill account_id=%d turn=%d next_turn=%d conn_id=%s session_hash=%s has_turn_state=%v has_last_response_id=%v store_disabled=%v",
+ account.ID,
+ turn,
+ turn+1,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(sessionHash, 12),
+ turnState != "",
+ strings.TrimSpace(lastTurnResponseID) != "",
+ nextStoreDisabled,
+ )
+ }
+ if nextPayload.promptCacheKey != "" {
+ // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接;
+ // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。
+ updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey)
+ baseAcquireReq.Headers = updatedHeaders
+ }
+ if nextPayload.previousResponseID != "" {
+ expectedPrev := strings.TrimSpace(lastTurnResponseID)
+ chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev
+ nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID)
+ logOpenAIWSModeInfo(
+ "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v",
+ account.ID,
+ turn,
+ turn+1,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(nextPreviousResponseIDKind),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ chainedFromLast,
+ nextPayload.promptCacheKey != "",
+ storeDisabled,
+ )
+ }
+ if stateStore != nil && nextPayload.previousResponseID != "" {
+ if stickyConnID := openAIWSPreferredConnIDFromResponse(stateStore, nextPayload.previousResponseID); stickyConnID != "" {
+ if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID {
+ logOpenAIWSModeInfo(
+ "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
+ )
+ } else {
+ preferredConnID = stickyConnID
+ }
+ }
+ }
+ currentPayload = nextPayload.payloadRaw
+ currentOriginalModel = nextPayload.originalModel
+ currentPayloadBytes = nextPayload.payloadBytes
+ storeDisabled = nextStoreDisabled
+ if !storeDisabled {
+ unpinSessionConn(sessionConnID)
+ }
+ return false, nil
+ }
for {
if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil {
if err := hooks.BeforeTurn(turn); err != nil {
@@ -3057,39 +2763,78 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
skipBeforeTurn = false
currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")
expectedPrev := strings.TrimSpace(lastTurnResponseID)
- hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists()
- // store=false + function_call_output 场景必须有续链锚点。
- // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。
- if shouldInferIngressFunctionCallOutputPreviousResponseID(
- storeDisabled,
+ if expectedPrev == "" && stateStore != nil && sessionHash != "" {
+ if savedResponseID, ok := resolveSessionLastResponseID(); ok {
+ expectedPrev = savedResponseID
+ if expectedPrev != "" {
+ lastTurnResponseID = expectedPrev
+ }
+ }
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_begin account_id=%d turn=%d conn_id=%s previous_response_id=%s expected_previous_response_id=%s store_disabled=%v has_session_lease=%v",
+ account.ID,
turn,
- hasFunctionCallOutput,
- currentPreviousResponseID,
- expectedPrev,
- ) {
- updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev)
- if setPrevErr != nil {
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ storeDisabled,
+ sessionLease != nil,
+ )
+ pendingExpectedCallIDs := []string(nil)
+ if storeDisabled && expectedPrev != "" && stateStore != nil {
+ if pendingCallIDs, ok := stateStore.GetResponsePendingToolCalls(groupID, expectedPrev); ok {
+ pendingExpectedCallIDs = openAIWSNormalizeCallIDs(pendingCallIDs)
+ }
+ }
+ normalized := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: account.ID,
+ turn: turn,
+ connID: sessionConnID,
+ currentPayload: currentPayload,
+ currentPayloadBytes: currentPayloadBytes,
+ currentPreviousResponseID: currentPreviousResponseID,
+ expectedPreviousResponse: expectedPrev,
+ pendingExpectedCallIDs: pendingExpectedCallIDs,
+ })
+ currentPayload = normalized.currentPayload
+ currentPayloadBytes = normalized.currentPayloadBytes
+ currentPreviousResponseID = normalized.currentPreviousResponseID
+ expectedPrev = normalized.expectedPreviousResponseID
+ pendingExpectedCallIDs = normalized.pendingExpectedCallIDs
+ currentFunctionCallOutputCallIDs := normalized.functionCallOutputCallIDs
+ hasFunctionCallOutput := normalized.hasFunctionCallOutputCallID
+
+ // 当客户端发送 function_call_output 但未携带 previous_response_id 时,
+ // 主动注入 Gateway 跟踪的 lastTurnResponseID。
+ // 在 store_disabled 模式下,上游需要 previous_response_id 来关联 function_call_output 与 response,
+ // 否则会返回 "No tool call found for function call output" 错误。
+ if shouldInferIngressFunctionCallOutputPreviousResponseID(storeDisabled, turn, hasFunctionCallOutput, currentPreviousResponseID, expectedPrev) {
+ injectedPayload, injectErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev)
+ if injectErr != nil {
logOpenAIWSModeInfo(
- "ingress_ws_function_call_output_prev_infer_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s",
+ "ingress_ws_inject_prev_response_id_fail account_id=%d turn=%d conn_id=%s cause=%s expected_previous_response_id=%s",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen),
+ truncateOpenAIWSLogValue(injectErr.Error(), openAIWSLogValueMaxLen),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
)
} else {
- currentPayload = updatedPayload
- currentPayloadBytes = len(updatedPayload)
- currentPreviousResponseID = expectedPrev
logOpenAIWSModeInfo(
- "ingress_ws_function_call_output_prev_infer account_id=%d turn=%d conn_id=%s action=set_previous_response_id previous_response_id=%s",
+ "ingress_ws_inject_prev_response_id account_id=%d turn=%d conn_id=%s injected_previous_response_id=%s has_function_call_output=%v",
account.ID,
turn,
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
+ hasFunctionCallOutput,
)
+ currentPayload = injectedPayload
+ currentPayloadBytes = len(injectedPayload)
+ currentPreviousResponseID = expectedPrev
}
}
+
nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence(
lastTurnReplayInput,
lastTurnReplayInputExists,
@@ -3120,6 +2865,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentPayload,
lastTurnResponseID,
hasFunctionCallOutput,
+ pendingExpectedCallIDs,
+ currentFunctionCallOutputCallIDs,
)
} else {
shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID(
@@ -3127,6 +2874,8 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
currentPayload,
lastTurnResponseID,
hasFunctionCallOutput,
+ pendingExpectedCallIDs,
+ currentFunctionCallOutputCallIDs,
)
}
if strictErr != nil {
@@ -3196,9 +2945,22 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
}
forcePreferredConn := isStrictAffinityTurn(currentPayload)
+ hasPreviousResponseIDForAcquire := currentPreviousResponseID != ""
if sessionLease == nil {
- acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn)
+ acquiredLease, acquireErr := acquireTurnLease(
+ turn,
+ preferredConnID,
+ forcePreferredConn,
+ hasPreviousResponseIDForAcquire,
+ )
if acquireErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_acquire_lease_fail account_id=%d turn=%d conn_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen),
+ )
return fmt.Errorf("acquire upstream websocket: %w", acquireErr)
}
sessionLease = acquiredLease
@@ -3224,64 +2986,14 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen),
)
- if forcePreferredConn {
- if !turnPrevRecoveryTried && currentPreviousResponseID != "" {
- updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
- if dropErr != nil || !removed {
- reason := "not_removed"
- if dropErr != nil {
- reason = "drop_error"
- }
- logOpenAIWSModeInfo(
- "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s previous_response_id=%s",
- account.ID,
- turn,
- truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
- normalizeOpenAIWSLogValue(reason),
- truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
- )
- } else {
- updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
- updatedPayload,
- currentTurnReplayInput,
- currentTurnReplayInputExists,
- )
- if setInputErr != nil {
- logOpenAIWSModeInfo(
- "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error previous_response_id=%s cause=%s",
- account.ID,
- turn,
- truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen),
- )
- } else {
- logOpenAIWSModeInfo(
- "ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s",
- account.ID,
- turn,
- truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
- )
- turnPrevRecoveryTried = true
- currentPayload = updatedWithInput
- currentPayloadBytes = len(updatedWithInput)
- resetSessionLease(true)
- skipBeforeTurn = true
- continue
- }
- }
- }
- resetSessionLease(true)
- return NewOpenAIWSClientCloseError(
- coderws.StatusPolicyViolation,
- "upstream continuation connection is unavailable; please restart the conversation",
- pingErr,
- )
- }
+ // preflight ping 失败:直接重连,不修改 payload
resetSessionLease(true)
-
- acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn)
+ acquiredLease, acquireErr := acquireTurnLease(
+ turn,
+ preferredConnID,
+ forcePreferredConn,
+ currentPreviousResponseID != "",
+ )
if acquireErr != nil {
return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr)
}
@@ -3315,7 +3027,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
)
}
- result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel)
+ result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel, expectedPrev)
if relayErr != nil {
if recoverIngressPrevResponseNotFound(relayErr, turn, connID) {
continue
@@ -3327,11 +3039,107 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
if unwrapped := errors.Unwrap(relayErr); unwrapped != nil {
finalErr = unwrapped
}
+ abortReason, abortExpected := classifyOpenAIWSIngressTurnAbortReason(relayErr)
+ s.recordOpenAIWSTurnAbort(abortReason, abortExpected)
+ logOpenAIWSIngressTurnAbort(account.ID, turn, connID, abortReason, abortExpected, finalErr)
if hooks != nil && hooks.AfterTurn != nil {
hooks.AfterTurn(turn, nil, finalErr)
}
- sessionLease.MarkBroken()
- return finalErr
+ switch openAIWSIngressTurnAbortDispositionForReason(abortReason) {
+ case openAIWSIngressTurnAbortDispositionContinueTurn:
+ if abortReason == openAIWSIngressTurnAbortReasonClientPreempted {
+ // 客户端抢占:不通知 error(客户端已发出新请求,不需要旧 turn 的错误事件),
+ // 保留上一轮 response_id(被抢占的 turn 未完成,上一轮 response_id 仍有效供新 turn 续链)。
+ preemptRecoverStart := time.Now()
+ resetSessionLease(true)
+ logOpenAIWSModeInfo(
+ "ingress_ws_client_preempt_recover account_id=%d turn=%d conn_id=%s reset_elapsed_ms=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ time.Since(preemptRecoverStart).Milliseconds(),
+ )
+ } else if abortReason == openAIWSIngressTurnAbortReasonUpstreamRestart {
+ // 上游重启(1012/1013):连接级关闭,客户端未收到任何终端事件,
+ // 始终补发 error 事件(无论 wroteDownstream 状态),避免客户端永远等待响应。
+ abortMessage := "upstream service restarting, please retry"
+ if finalErr != nil {
+ abortMessage = finalErr.Error()
+ }
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`)
+ if writeErr := writeClientMessage(errorEvent); writeErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_upstream_restart_notify_failed account_id=%d turn=%d conn_id=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(writeErr.Error(), openAIWSLogValueMaxLen),
+ )
+ } else {
+ logOpenAIWSModeInfo(
+ "ingress_ws_upstream_restart_notified account_id=%d turn=%d conn_id=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ )
+ }
+ resetSessionLease(true)
+ clearSessionLastResponseID()
+ } else {
+ // turn 级终止:当前 turn 失败,但客户端 WS 会话继续。
+ // 这样可与 Codex 客户端语义对齐:后续 turn 允许在新上游连接上继续进行。
+ //
+ // 关键修复:若未向客户端写入过任何数据 (wroteDownstream=false),
+ // 必须补发 error 事件通知客户端本轮失败,否则客户端会一直等待响应,
+ // 而服务端在 advanceToNextClientTurn 中等待客户端下一条消息 → 双向死锁。
+ if !openAIWSIngressTurnWroteDownstream(relayErr) {
+ abortMessage := "turn failed: " + string(abortReason)
+ if finalErr != nil {
+ abortMessage = finalErr.Error()
+ }
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` + string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`)
+ if writeErr := writeClientMessage(errorEvent); writeErr != nil {
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_abort_notify_failed account_id=%d turn=%d conn_id=%s reason=%s cause=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(string(abortReason)),
+ truncateOpenAIWSLogValue(writeErr.Error(), openAIWSLogValueMaxLen),
+ )
+ } else {
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_abort_notified account_id=%d turn=%d conn_id=%s reason=%s",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(string(abortReason)),
+ )
+ }
+ }
+ resetSessionLease(true)
+ clearSessionLastResponseID()
+ }
+ turnRetry = 0
+ turnPrevRecoveryTried = false
+ exit, advanceErr := advanceToNextClientTurn(turn, connID)
+ if advanceErr != nil {
+ return advanceErr
+ }
+ if exit {
+ return nil
+ }
+ s.recordOpenAIWSTurnAbortRecovered()
+ turn++
+ continue
+ case openAIWSIngressTurnAbortDispositionCloseGracefully:
+ resetSessionLease(true)
+ clearSessionLastResponseID()
+ return nil
+ case openAIWSIngressTurnAbortDispositionFailRequest:
+ sessionLease.MarkBroken()
+ return finalErr
+ }
}
turnRetry = 0
turnPrevRecoveryTried = false
@@ -3343,7 +3151,23 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
return errors.New("websocket turn result is nil")
}
responseID := strings.TrimSpace(result.RequestID)
- lastTurnResponseID = responseID
+ persistLastResponseID := responseID != "" && shouldPersistOpenAIWSLastResponseID(result.TerminalEventType)
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d persist_response_id=%v has_function_call_output=%v pending_function_calls=%d",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ result.Duration.Milliseconds(),
+ persistLastResponseID,
+ hasFunctionCallOutput,
+ len(result.PendingFunctionCallIDs),
+ )
+ if persistLastResponseID {
+ lastTurnResponseID = responseID
+ } else {
+ clearSessionLastResponseID()
+ }
lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload)
lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput)
lastTurnReplayInputExists = currentTurnReplayInputExists
@@ -3361,84 +3185,39 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
lastTurnStrictState = nextStrictState
}
+ if stateStore != nil &&
+ expectedPrev != "" &&
+ currentPreviousResponseID == expectedPrev &&
+ (hasFunctionCallOutput || len(pendingExpectedCallIDs) > 0) {
+ stateStore.DeleteResponsePendingToolCalls(groupID, expectedPrev)
+ }
+
if responseID != "" && stateStore != nil {
ttl := s.openAIWSResponseStickyTTL()
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl))
- stateStore.BindResponseConn(responseID, connID, ttl)
- }
- if stateStore != nil && storeDisabled && sessionHash != "" {
- stateStore.BindSessionConn(groupID, sessionHash, connID, s.openAIWSSessionStickyTTL())
+ if poolConnID, ok := normalizeOpenAIWSPreferredConnID(connID); ok {
+ stateStore.BindResponseConn(responseID, poolConnID, ttl)
+ }
+ if pendingFunctionCallIDs := openAIWSNormalizeCallIDs(result.PendingFunctionCallIDs); len(pendingFunctionCallIDs) > 0 {
+ stateStore.BindResponsePendingToolCalls(groupID, responseID, pendingFunctionCallIDs, ttl)
+ } else {
+ stateStore.DeleteResponsePendingToolCalls(groupID, responseID)
+ }
+ if sessionHash != "" && persistLastResponseID {
+ stateStore.BindSessionLastResponseID(groupID, sessionHash, responseID, s.openAIWSSessionStickyTTL())
+ }
}
if connID != "" {
preferredConnID = connID
}
+ yieldSessionLease()
- nextClientMessage, readErr := readClientMessage()
- if readErr != nil {
- if isOpenAIWSClientDisconnectError(readErr) {
- closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr)
- logOpenAIWSModeInfo(
- "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s",
- account.ID,
- truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
- closeStatus,
- truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen),
- )
- return nil
- }
- return fmt.Errorf("read client websocket request: %w", readErr)
- }
-
- nextPayload, parseErr := parseClientPayload(nextClientMessage)
- if parseErr != nil {
- return parseErr
- }
- if nextPayload.promptCacheKey != "" {
- // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接;
- // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。
- updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey)
- baseAcquireReq.Headers = updatedHeaders
- }
- if nextPayload.previousResponseID != "" {
- expectedPrev := strings.TrimSpace(lastTurnResponseID)
- chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev
- nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID)
- logOpenAIWSModeInfo(
- "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v",
- account.ID,
- turn,
- turn+1,
- truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
- normalizeOpenAIWSLogValue(nextPreviousResponseIDKind),
- truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen),
- chainedFromLast,
- nextPayload.promptCacheKey != "",
- storeDisabled,
- )
- }
- if stateStore != nil && nextPayload.previousResponseID != "" {
- if stickyConnID, ok := stateStore.GetResponseConn(nextPayload.previousResponseID); ok {
- if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID {
- logOpenAIWSModeInfo(
- "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s",
- account.ID,
- turn,
- truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen),
- truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen),
- )
- } else {
- preferredConnID = stickyConnID
- }
- }
+ exit, advanceErr := advanceToNextClientTurn(turn, connID)
+ if advanceErr != nil {
+ return advanceErr
}
- currentPayload = nextPayload.payloadRaw
- currentOriginalModel = nextPayload.originalModel
- currentPayloadBytes = nextPayload.payloadBytes
- storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account)
- if !storeDisabled {
- unpinSessionConn(sessionConnID)
+ if exit {
+ return nil
}
turn++
}
@@ -3452,7 +3231,7 @@ func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool {
// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。
func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
ctx context.Context,
- lease *openAIWSConnLease,
+ lease openAIWSIngressUpstreamLease,
decision OpenAIWSProtocolDecision,
payload map[string]any,
previousResponseID string,
@@ -3540,7 +3319,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr)
}
- eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message)
+ eventType, eventResponseID := parseOpenAIWSEventType(message)
if eventType == "" {
continue
}
@@ -3595,7 +3374,9 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
if prewarmResponseID != "" && stateStore != nil {
ttl := s.openAIWSResponseStickyTTL()
logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl))
- stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl)
+ if connID, ok := normalizeOpenAIWSPreferredConnID(lease.ConnID()); ok {
+ stateStore.BindResponseConn(prewarmResponseID, connID, ttl)
+ }
}
logOpenAIWSModeInfo(
"prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d",
@@ -3609,111 +3390,6 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
return nil
}
-func payloadAsJSON(payload map[string]any) string {
- return string(payloadAsJSONBytes(payload))
-}
-
-func payloadAsJSONBytes(payload map[string]any) []byte {
- if len(payload) == 0 {
- return []byte("{}")
- }
- body, err := json.Marshal(payload)
- if err != nil {
- return []byte("{}")
- }
- return body
-}
-
-func isOpenAIWSTerminalEvent(eventType string) bool {
- switch strings.TrimSpace(eventType) {
- case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
- return true
- default:
- return false
- }
-}
-
-func isOpenAIWSTokenEvent(eventType string) bool {
- eventType = strings.TrimSpace(eventType)
- if eventType == "" {
- return false
- }
- switch eventType {
- case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
- return false
- }
- if strings.Contains(eventType, ".delta") {
- return true
- }
- if strings.HasPrefix(eventType, "response.output_text") {
- return true
- }
- if strings.HasPrefix(eventType, "response.output") {
- return true
- }
- return eventType == "response.completed" || eventType == "response.done"
-}
-
-func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte {
- if len(message) == 0 {
- return message
- }
- if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel {
- return message
- }
- if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) {
- return message
- }
- modelValues := gjson.GetManyBytes(message, "model", "response.model")
- replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel
- replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel
- if !replaceModel && !replaceResponseModel {
- return message
- }
- updated := message
- if replaceModel {
- if next, err := sjson.SetBytes(updated, "model", toModel); err == nil {
- updated = next
- }
- }
- if replaceResponseModel {
- if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil {
- updated = next
- }
- }
- return updated
-}
-
-func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) {
- if usage == nil || len(body) == 0 {
- return
- }
- values := gjson.GetManyBytes(
- body,
- "usage.input_tokens",
- "usage.output_tokens",
- "usage.input_tokens_details.cached_tokens",
- )
- usage.InputTokens = int(values[0].Int())
- usage.OutputTokens = int(values[1].Int())
- usage.CacheReadInputTokens = int(values[2].Int())
-}
-
-func getOpenAIGroupIDFromContext(c *gin.Context) int64 {
- if c == nil {
- return 0
- }
- value, exists := c.Get("api_key")
- if !exists {
- return 0
- }
- apiKey, ok := value.(*APIKey)
- if !ok || apiKey == nil || apiKey.GroupID == nil {
- return 0
- }
- return *apiKey.GroupID
-}
-
// SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。
// 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。
func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
@@ -3792,164 +3468,3 @@ func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID(
}
return nil, nil
}
-
-func classifyOpenAIWSAcquireError(err error) string {
- if err == nil {
- return "acquire_conn"
- }
- var dialErr *openAIWSDialError
- if errors.As(err, &dialErr) {
- switch dialErr.StatusCode {
- case 426:
- return "upgrade_required"
- case 401, 403:
- return "auth_failed"
- case 429:
- return "upstream_rate_limited"
- }
- if dialErr.StatusCode >= 500 {
- return "upstream_5xx"
- }
- return "dial_failed"
- }
- if errors.Is(err, errOpenAIWSConnQueueFull) {
- return "conn_queue_full"
- }
- if errors.Is(err, errOpenAIWSPreferredConnUnavailable) {
- return "preferred_conn_unavailable"
- }
- if errors.Is(err, context.DeadlineExceeded) {
- return "acquire_timeout"
- }
- return "acquire_conn"
-}
-
-func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
- code := strings.ToLower(strings.TrimSpace(codeRaw))
- errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
- msg := strings.ToLower(strings.TrimSpace(msgRaw))
-
- switch code {
- case "upgrade_required":
- return "upgrade_required", true
- case "websocket_not_supported", "websocket_unsupported":
- return "ws_unsupported", true
- case "websocket_connection_limit_reached":
- return "ws_connection_limit_reached", true
- case "previous_response_not_found":
- return "previous_response_not_found", true
- }
- if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
- return "upgrade_required", true
- }
- if strings.Contains(errType, "upgrade") {
- return "upgrade_required", true
- }
- if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") {
- return "ws_unsupported", true
- }
- if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") {
- return "ws_connection_limit_reached", true
- }
- if strings.Contains(msg, "previous_response_not_found") ||
- (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) {
- return "previous_response_not_found", true
- }
- if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") {
- return "upstream_error_event", true
- }
- return "event_error", false
-}
-
-func classifyOpenAIWSErrorEvent(message []byte) (string, bool) {
- if len(message) == 0 {
- return "event_error", false
- }
- return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message))
-}
-
-func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
- code := strings.ToLower(strings.TrimSpace(codeRaw))
- errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
- switch {
- case strings.Contains(errType, "invalid_request"),
- strings.Contains(code, "invalid_request"),
- strings.Contains(code, "bad_request"),
- code == "previous_response_not_found":
- return http.StatusBadRequest
- case strings.Contains(errType, "authentication"),
- strings.Contains(code, "invalid_api_key"),
- strings.Contains(code, "unauthorized"):
- return http.StatusUnauthorized
- case strings.Contains(errType, "permission"),
- strings.Contains(code, "forbidden"):
- return http.StatusForbidden
- case strings.Contains(errType, "rate_limit"),
- strings.Contains(code, "rate_limit"),
- strings.Contains(code, "insufficient_quota"):
- return http.StatusTooManyRequests
- default:
- return http.StatusBadGateway
- }
-}
-
-func openAIWSErrorHTTPStatus(message []byte) int {
- if len(message) == 0 {
- return http.StatusBadGateway
- }
- codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message)
- return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
-}
-
-func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration {
- if s == nil || s.cfg == nil {
- return 30 * time.Second
- }
- seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds
- if seconds <= 0 {
- return 0
- }
- return time.Duration(seconds) * time.Second
-}
-
-func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool {
- if s == nil || accountID <= 0 {
- return false
- }
- cooldown := s.openAIWSFallbackCooldown()
- if cooldown <= 0 {
- return false
- }
- rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID)
- if !ok || rawUntil == nil {
- return false
- }
- until, ok := rawUntil.(time.Time)
- if !ok || until.IsZero() {
- s.openaiWSFallbackUntil.Delete(accountID)
- return false
- }
- if time.Now().Before(until) {
- return true
- }
- s.openaiWSFallbackUntil.Delete(accountID)
- return false
-}
-
-func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) {
- if s == nil || accountID <= 0 {
- return
- }
- cooldown := s.openAIWSFallbackCooldown()
- if cooldown <= 0 {
- return
- }
- s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown))
-}
-
-func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) {
- if s == nil || accountID <= 0 {
- return
- }
- s.openaiWSFallbackUntil.Delete(accountID)
-}
diff --git a/backend/internal/service/openai_ws_forwarder_benchmark_test.go b/backend/internal/service/openai_ws_forwarder_benchmark_test.go
index bd03ab5a6..b1d9ed02e 100644
--- a/backend/internal/service/openai_ws_forwarder_benchmark_test.go
+++ b/backend/internal/service/openai_ws_forwarder_benchmark_test.go
@@ -1,8 +1,10 @@
package service
import (
+ "context"
"fmt"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
@@ -125,3 +127,142 @@ func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) {
benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model")
}
}
+
+// --- Optimization benchmarks ---
+
+var benchmarkOpenAIWSConnSink openAIWSClientConn
+
+func BenchmarkTouchLease_Full(b *testing.B) {
+ ctx := &openAIWSIngressContext{}
+ ttl := 10 * time.Minute
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ctx.touchLease(time.Now(), ttl)
+ }
+}
+
+func BenchmarkMaybeTouchLease_Throttled(b *testing.B) {
+ ctx := &openAIWSIngressContext{}
+ ttl := 10 * time.Minute
+ ctx.touchLease(time.Now(), ttl) // seed the initial touch
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ ctx.maybeTouchLease(ttl)
+ }
+}
+
+func BenchmarkActiveConn_CachedPath(b *testing.B) {
+ conn := &benchmarkOpenAIWSNoopConn{}
+ ctx := &openAIWSIngressContext{ownerID: "bench_owner", upstream: conn}
+ lease := &openAIWSIngressContextLease{context: ctx, ownerID: "bench_owner"}
+ // Prime the cache
+ _, _ = lease.activeConn()
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSConnSink, _ = lease.activeConn()
+ }
+}
+
+func BenchmarkActiveConn_MutexPath(b *testing.B) {
+ conn := &benchmarkOpenAIWSNoopConn{}
+ ctx := &openAIWSIngressContext{ownerID: "bench_owner", upstream: conn}
+ lease := &openAIWSIngressContextLease{context: ctx, ownerID: "bench_owner"}
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ lease.cachedConn = nil // force mutex path each iteration
+ benchmarkOpenAIWSConnSink, _ = lease.activeConn()
+ }
+}
+
+func BenchmarkParseOpenAIWSEventType_Lightweight(b *testing.B) {
+ event := []byte(`{"type":"response.output_text.delta","delta":"hello","response":{"id":"resp_1","model":"gpt-5.1"}}`)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ et, rid := parseOpenAIWSEventType(event)
+ benchmarkOpenAIWSStringSink = et
+ benchmarkOpenAIWSStringSink = rid
+ }
+}
+
+func BenchmarkParseOpenAIWSEventEnvelope_Full(b *testing.B) {
+ event := []byte(`{"type":"response.output_text.delta","delta":"hello","response":{"id":"resp_1","model":"gpt-5.1"}}`)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ et, rid, resp := parseOpenAIWSEventEnvelope(event)
+ benchmarkOpenAIWSStringSink = et
+ benchmarkOpenAIWSStringSink = rid
+ benchmarkOpenAIWSBoolSink = resp.Exists()
+ }
+}
+
+func BenchmarkSessionTurnStateKey_Strconv(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSStringSink = openAIWSSessionTurnStateKey(int64(i%1000+1), "session_hash_bench")
+ }
+}
+
+func BenchmarkResponseAccountCacheKey_XXHash(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSStringSink = openAIWSResponseAccountCacheKey(fmt.Sprintf("resp_bench_%d", i%1000))
+ }
+}
+
+func BenchmarkIsOpenAIWSTerminalEvent(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSBoolSink = isOpenAIWSTerminalEvent("response.completed")
+ }
+}
+
+func BenchmarkIsOpenAIWSTokenEvent(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSBoolSink = isOpenAIWSTokenEvent("response.output_text.delta")
+ }
+}
+
+func BenchmarkStateStore_ShardedBindGet(b *testing.B) {
+ store := NewOpenAIWSStateStore(nil).(*defaultOpenAIWSStateStore)
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ key := fmt.Sprintf("resp_%d", i%1000)
+ store.BindResponseConn(key, "conn_bench", time.Minute)
+ benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = store.GetResponseConn(key)
+ }
+}
+
+func BenchmarkDeriveOpenAISessionHash(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSStringSink = deriveOpenAISessionHash("session-id-benchmark-value")
+ }
+}
+
+func BenchmarkDeriveOpenAILegacySessionHash(b *testing.B) {
+ b.ReportAllocs()
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ benchmarkOpenAIWSStringSink = deriveOpenAILegacySessionHash("session-id-benchmark-value")
+ }
+}
+
+type benchmarkOpenAIWSNoopConn struct{}
+
+func (c *benchmarkOpenAIWSNoopConn) WriteJSON(_ context.Context, _ any) error { return nil }
+func (c *benchmarkOpenAIWSNoopConn) ReadMessage(_ context.Context) ([]byte, error) { return nil, nil }
+func (c *benchmarkOpenAIWSNoopConn) Ping(_ context.Context) error { return nil }
+func (c *benchmarkOpenAIWSNoopConn) Close() error { return nil }
diff --git a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go
index 761676038..7b77641f3 100644
--- a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go
+++ b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go
@@ -5,6 +5,7 @@ import (
"testing"
"github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
)
func TestParseOpenAIWSEventEnvelope(t *testing.T) {
@@ -31,6 +32,15 @@ func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) {
require.Equal(t, 3, usage.CacheReadInputTokens)
}
+func TestOpenAIWSEventShouldParseUsage_TerminalEvents(t *testing.T) {
+ require.True(t, openAIWSEventShouldParseUsage("response.completed"))
+ require.True(t, openAIWSEventShouldParseUsage("response.done"))
+ require.True(t, openAIWSEventShouldParseUsage("response.failed"))
+ // After removing TrimSpace, callers must provide pre-trimmed input.
+ require.False(t, openAIWSEventShouldParseUsage(" response.done "))
+ require.False(t, openAIWSEventShouldParseUsage("response.in_progress"))
+}
+
func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) {
message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)
codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
@@ -58,6 +68,55 @@ func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) {
require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`)))
}
+func TestOpenAIWSExtractPendingFunctionCallIDsFromEvent(t *testing.T) {
+ callIDs := openAIWSExtractPendingFunctionCallIDsFromEvent([]byte(`{
+ "type":"response.output_item.added",
+ "response":{"id":"resp_1"},
+ "item":{"type":"function_call","call_id":"call_a"}
+ }`))
+ require.Equal(t, []string{"call_a"}, callIDs)
+
+ callIDs = openAIWSExtractPendingFunctionCallIDsFromEvent([]byte(`{
+ "type":"response.completed",
+ "response":{
+ "id":"resp_2",
+ "output":[
+ {"type":"function_call","call_id":"call_b"},
+ {"type":"message","content":[{"type":"output_text","text":"ok"}]},
+ {"type":"function_call","call_id":"call_c"}
+ ]
+ }
+ }`))
+ require.Equal(t, []string{"call_b", "call_c"}, callIDs)
+}
+
+func TestOpenAIWSExtractFunctionCallOutputCallIDsFromPayload(t *testing.T) {
+ callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload([]byte(`{
+ "input":[
+ {"type":"input_text","text":"hi"},
+ {"type":"function_call_output","call_id":"call_2","output":"ok"},
+ {"type":"function_call_output","call_id":"call_1","output":"ok"},
+ {"type":"function_call_output","call_id":"call_2","output":"dup"}
+ ]
+ }`))
+ require.Equal(t, []string{"call_1", "call_2"}, callIDs)
+}
+
+func TestOpenAIWSInjectFunctionCallOutputItems(t *testing.T) {
+ updatedPayload, injected, err := openAIWSInjectFunctionCallOutputItems(
+ []byte(`{"type":"response.create","input":[{"type":"input_text","text":"hello"}]}`),
+ []string{"call_1", "call_2", "call_1"},
+ openAIWSAutoAbortedToolOutputValue,
+ )
+ require.NoError(t, err)
+ require.Equal(t, 2, injected)
+ require.Equal(t, "input_text", gjson.GetBytes(updatedPayload, "input.0.type").String())
+ require.Equal(t, "function_call_output", gjson.GetBytes(updatedPayload, "input.1.type").String())
+ require.Equal(t, "call_1", gjson.GetBytes(updatedPayload, "input.1.call_id").String())
+ require.Equal(t, openAIWSAutoAbortedToolOutputValue, gjson.GetBytes(updatedPayload, "input.1.output").String())
+ require.Equal(t, "call_2", gjson.GetBytes(updatedPayload, "input.2.call_id").String())
+}
+
func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) {
noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`)
require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model")))
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go b/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go
new file mode 100644
index 000000000..35cde6d27
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_ingress_policy_test.go
@@ -0,0 +1,154 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func runIngressProxyWithFirstPayload(
+ t *testing.T,
+ svc *OpenAIGatewayService,
+ account *Account,
+ firstPayload string,
+) error {
+ t.Helper()
+
+ serverErrCh := make(chan error, 1)
+ wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
+ CompressionMode: coderws.CompressionContextTakeover,
+ })
+ if err != nil {
+ serverErrCh <- err
+ return
+ }
+ defer func() {
+ _ = conn.CloseNow()
+ }()
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := r.Clone(r.Context())
+ req.Header = req.Header.Clone()
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
+ msgType, message, readErr := conn.Read(readCtx)
+ cancel()
+ if readErr != nil {
+ serverErrCh <- readErr
+ return
+ }
+ if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
+ serverErrCh <- errors.New("unsupported websocket client message type")
+ return
+ }
+
+ serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", message, nil)
+ }))
+ defer wsServer.Close()
+
+ dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
+ clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
+ cancelDial()
+ require.NoError(t, err)
+ defer func() {
+ _ = clientConn.CloseNow()
+ }()
+
+ writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
+ err = clientConn.Write(writeCtx, coderws.MessageText, []byte(firstPayload))
+ cancelWrite()
+ require.NoError(t, err)
+
+ select {
+ case serverErr := <-serverErrCh:
+ return serverErr
+ case <-time.After(5 * time.Second):
+ t.Fatal("等待 ingress websocket 结束超时")
+ return nil
+ }
+}
+
+func buildIngressPolicyTestConfig() *config.Config {
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+ cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeCtxPool
+ return cfg
+}
+
+func buildIngressPolicyTestService(cfg *config.Config) *OpenAIGatewayService {
+ return &OpenAIGatewayService{
+ cfg: cfg,
+ httpUpstream: &httpUpstreamRecorder{},
+ cache: &stubGatewayCache{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ }
+}
+
+func buildIngressPolicyTestAccount(extra map[string]any) *Account {
+ return &Account{
+ ID: 442,
+ Name: "openai-ingress-policy",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ Extra: extra,
+ }
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ svc := buildIngressPolicyTestService(cfg)
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
+ })
+
+ serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, serverErr, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ require.Equal(t, "websocket mode is disabled for this account", closeErr.Reason())
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeRouterDisabledReturnsPolicyViolation(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = false
+ svc := buildIngressPolicyTestService(cfg)
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "responses_websockets_v2_enabled": true,
+ })
+
+ serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, serverErr, &closeErr)
+ require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
+ require.Equal(t, "websocket mode requires mode_router_v2 with ctx_pool", closeErr.Reason())
+}
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
deleted file mode 100644
index 5a3c12c39..000000000
--- a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go
+++ /dev/null
@@ -1,2483 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "io"
- "net/http"
- "net/http/httptest"
- "strings"
- "sync"
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- coderws "github.com/coder/websocket"
- "github.com/gin-gonic/gin"
- "github.com/stretchr/testify/require"
- "github.com/tidwall/gjson"
-)
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossTurns(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 114,
- Name: "openai-ingress-session-lease",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- turnWSModeCh := make(chan bool, 2)
- hooks := &OpenAIWSIngressHooks{
- AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
- if turnErr == nil && result != nil {
- turnWSModeCh <- result.OpenAIWSMode
- }
- },
- }
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
- firstTurnEvent := readMessage()
- require.Equal(t, "response.completed", gjson.GetBytes(firstTurnEvent, "type").String())
- require.Equal(t, "resp_ingress_turn_1", gjson.GetBytes(firstTurnEvent, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_ingress_turn_1"}`)
- secondTurnEvent := readMessage()
- require.Equal(t, "response.completed", gjson.GetBytes(secondTurnEvent, "type").String())
- require.Equal(t, "resp_ingress_turn_2", gjson.GetBytes(secondTurnEvent, "response.id").String())
- require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式")
- require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式")
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
-
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- metrics := svc.SnapshotOpenAIWSPoolMetrics()
- require.Equal(t, int64(1), metrics.AcquireTotal, "同一 ingress 会话多 turn 应只获取一次上游 lease")
- require.Equal(t, 1, captureDialer.DialCount(), "同一 ingress 会话应保持同一上游连接")
- require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
- cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- upstreamConn1 := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- upstreamConn2 := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{upstreamConn1, upstreamConn2},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 441,
- Name: "openai-ingress-dedicated",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
- },
- }
-
- serverErrCh := make(chan error, 2)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- runSingleTurnSession := func(expectedResponseID string) {
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
- err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
- cancelWrite()
- require.NoError(t, err)
-
- readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second)
- msgType, event, readErr := clientConn.Read(readCtx)
- cancelRead()
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- require.Equal(t, expectedResponseID, gjson.GetBytes(event, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
-
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
- }
-
- runSingleTurnSession("resp_dedicated_1")
- runSingleTurnSession("resp_dedicated_2")
-
- require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
- cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: newOpenAIWSConnPool(cfg),
- }
-
- account := &Account{
- ID: 442,
- Name: "openai-ingress-off",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeOff,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
- err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`))
- cancelWrite()
- require.NoError(t, err)
-
- select {
- case serverErr := <-serverErrCh:
- var closeErr *OpenAIWSClientCloseError
- require.ErrorAs(t, serverErr, &closeErr)
- require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
- require.Equal(t, "websocket mode is disabled for this account", closeErr.Reason())
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropToFullCreate(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 140,
- Name: "openai-ingress-prev-preflight-rewrite",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_preflight_rewrite_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_preflight_rewrite_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 1, captureDialer.DialCount(), "严格增量不成立时应在同一连接内降级为 full create")
- require.Len(t, captureConn.writes, 2)
- secondWrite := requestToJSONString(captureConn.writes[1])
- require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格增量不成立时应移除 previous_response_id,改为 full create")
- require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "严格降级为 full create 时应重放完整 input 上下文")
- require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
- require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropBeforePreflightPingFailReconnects(t *testing.T) {
- gin.SetMode(gin.TestMode)
- prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
- openAIWSIngressPreflightPingIdle = 0
- defer func() {
- openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
- }()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSPreflightFailConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 142,
- Name: "openai-ingress-prev-strict-drop-before-ping",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_ping_drop_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_ping_drop_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 严格降级后预检换连超时")
- }
-
- require.Equal(t, 2, dialer.DialCount(), "严格降级为 full create 后,预检 ping 失败应允许换连")
- require.Equal(t, 1, firstConn.WriteCount(), "首连接在预检失败后不应继续发送第二轮")
- require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping")
- secondConn.mu.Lock()
- secondWrites := append([]map[string]any(nil), secondConn.writes...)
- secondConn.mu.Unlock()
- require.Len(t, secondWrites, 1)
- secondWrite := requestToJSONString(secondWrites[0])
- require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格降级后重试应移除 previous_response_id")
- require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()))
- require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
- require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreEnabledSkipsStrictPrevResponseEval(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 143,
- Name: "openai-ingress-store-enabled-skip-strict",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_store_enabled_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true,"previous_response_id":"resp_stale_external"}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_store_enabled_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 store=true 场景 websocket 结束超时")
- }
-
- require.Equal(t, 1, captureDialer.DialCount())
- require.Len(t, captureConn.writes, 2)
- require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "store=true 场景不应触发 store-disabled strict 规则")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponsePreflightSkipForFunctionCallOutput(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 141,
- Name: "openai-ingress-prev-preflight-skip-fco",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_preflight_skip_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_preflight_skip_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 1, captureDialer.DialCount())
- require.Len(t, captureConn.writes, 2)
- require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 场景不应预改写 previous_response_id")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachPreviousResponseID(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 143,
- Name: "openai-ingress-fco-auto-prev",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_auto_prev_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_1","output":"ok"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_auto_prev_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 1, captureDialer.DialCount())
- require.Len(t, captureConn.writes, 2)
- require.Equal(t, "resp_auto_prev_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 缺失 previous_response_id 时应回填上一轮响应 ID")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenLastResponseIDMissing(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 144,
- Name: "openai-ingress-fco-auto-prev-skip",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "response.completed", gjson.GetBytes(firstTurn, "type").String())
- require.Empty(t, gjson.GetBytes(firstTurn, "response.id").String(), "首轮响应不返回 response.id,模拟无法推导续链锚点")
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_skip_1","output":"ok"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_auto_prev_skip_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 1, captureDialer.DialCount())
- require.Len(t, captureConn.writes, 2)
- require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) {
- gin.SetMode(gin.TestMode)
- prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
- openAIWSIngressPreflightPingIdle = 0
- defer func() {
- openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
- }()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSPreflightFailConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 116,
- Name: "openai-ingress-preflight-ping",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_ping_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_ping_1"}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_ping_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
- require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 前 ping 失败应触发换连")
- require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续向旧连接发送第二轮 turn")
- require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应对旧连接执行 preflight ping")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreflightPingFailAutoRecoveryReconnects(t *testing.T) {
- gin.SetMode(gin.TestMode)
- prevPreflightPingIdle := openAIWSIngressPreflightPingIdle
- openAIWSIngressPreflightPingIdle = 0
- defer func() {
- openAIWSIngressPreflightPingIdle = prevPreflightPingIdle
- }()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSPreflightFailConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 121,
- Name: "openai-ingress-preflight-ping-strict-affinity",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_ping_strict_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_strict_1","input":[{"type":"input_text","text":"world"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_ping_strict_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 严格亲和自动恢复后结束超时")
- }
-
- require.Equal(t, 2, dialer.DialCount(), "严格亲和 preflight ping 失败后应自动降级并换连重放")
- require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续在旧连接写第二轮")
- require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping")
- secondConn.mu.Lock()
- secondWrites := append([]map[string]any(nil), secondConn.writes...)
- secondConn.mu.Unlock()
- require.Len(t, secondWrites, 1)
- secondWrite := requestToJSONString(secondWrites[0])
- require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "自动恢复重放应移除 previous_response_id")
- require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "自动恢复重放应使用完整 input 上下文")
- require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
- require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSWriteFailAfterFirstTurnConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 117,
- Name: "openai-ingress-write-retry",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
- var hooksMu sync.Mutex
- beforeTurnCalls := make(map[int]int)
- afterTurnCalls := make(map[int]int)
- hooks := &OpenAIWSIngressHooks{
- BeforeTurn: func(turn int) error {
- hooksMu.Lock()
- beforeTurnCalls[turn]++
- hooksMu.Unlock()
- return nil
- },
- AfterTurn: func(turn int, _ *OpenAIForwardResult, _ error) {
- hooksMu.Lock()
- afterTurnCalls[turn]++
- hooksMu.Unlock()
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_write_retry_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_write_retry_1"}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_write_retry_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
- require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 上游写失败且未写下游时应自动重试并换连")
- hooksMu.Lock()
- beforeTurn1 := beforeTurnCalls[1]
- beforeTurn2 := beforeTurnCalls[2]
- afterTurn1 := afterTurnCalls[1]
- afterTurn2 := afterTurnCalls[2]
- hooksMu.Unlock()
- require.Equal(t, 1, beforeTurn1, "首轮 turn BeforeTurn 应执行一次")
- require.Equal(t, 1, beforeTurn2, "同一 turn 重试不应重复触发 BeforeTurn")
- require.Equal(t, 1, afterTurn1, "首轮 turn AfterTurn 应执行一次")
- require.Equal(t, 1, afterTurn2, "第二轮 turn AfterTurn 应执行一次")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoversByDroppingPrevID(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":""}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 118,
- Name: "openai-ingress-prev-recovery",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_seed_anchor"}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_prev_recover_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_recover_1"}`)
- secondTurn := readMessage()
- require.Equal(t, "response.completed", gjson.GetBytes(secondTurn, "type").String())
- require.Equal(t, "resp_turn_prev_recover_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应触发换连重试")
-
- firstConn.mu.Lock()
- firstWrites := append([]map[string]any(nil), firstConn.writes...)
- firstConn.mu.Unlock()
- require.Len(t, firstWrites, 2, "首个连接应处理首轮与失败的第二轮请求")
- require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists(), "失败轮次首发请求应包含 previous_response_id")
-
- secondConn.mu.Lock()
- secondWrites := append([]map[string]any(nil), secondConn.writes...)
- secondConn.mu.Unlock()
- require.Len(t, secondWrites, 1, "恢复重试应在第二个连接发送一次请求")
- require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "恢复重试应移除 previous_response_id")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreviousResponseNotFoundLayer2Recovery(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"missing strict anchor"}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 122,
- Name: "openai-ingress-prev-strict-layer2",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","input":[{"type":"input_text","text":"hello"}]}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_prev_strict_recover_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","previous_response_id":"resp_turn_prev_strict_recover_1","input":[{"type":"input_text","text":"world"}]}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_prev_strict_recover_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 严格亲和 Layer2 恢复结束超时")
- }
-
- require.Equal(t, 2, dialer.DialCount(), "严格亲和链路命中 previous_response_not_found 应触发 Layer2 恢复重试")
-
- firstConn.mu.Lock()
- firstWrites := append([]map[string]any(nil), firstConn.writes...)
- firstConn.mu.Unlock()
- require.Len(t, firstWrites, 2, "首连接应收到首轮请求和失败的续链请求")
- require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists())
-
- secondConn.mu.Lock()
- secondWrites := append([]map[string]any(nil), secondConn.writes...)
- secondConn.mu.Unlock()
- require.Len(t, secondWrites, 1, "Layer2 恢复应仅重放一次")
- secondWrite := requestToJSONString(secondWrites[0])
- require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "Layer2 恢复重放应移除 previous_response_id")
- require.True(t, gjson.Get(secondWrite, "store").Exists(), "Layer2 恢复不应改变 store 标志")
- require.False(t, gjson.Get(secondWrite, "store").Bool())
- require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "Layer2 恢复应重放完整 input 上下文")
- require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String())
- require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String())
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoveryRemovesDuplicatePrevID(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- firstConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"first missing"}}`),
- },
- }
- secondConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- dialer := &openAIWSQueueDialer{
- conns: []openAIWSClientConn{firstConn, secondConn},
- }
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(dialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 120,
- Name: "openai-ingress-prev-recovery-once",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeMessage := func(payload string) {
- writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload)))
- }
- readMessage := func() []byte {
- readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
- defer cancel()
- msgType, message, readErr := clientConn.Read(readCtx)
- require.NoError(t, readErr)
- require.Equal(t, coderws.MessageText, msgType)
- return message
- }
-
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`)
- firstTurn := readMessage()
- require.Equal(t, "resp_turn_prev_once_1", gjson.GetBytes(firstTurn, "response.id").String())
-
- // duplicate previous_response_id: 恢复重试时应删除所有重复键,避免再次 previous_response_not_found。
- writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_once_1","input":[],"previous_response_id":"resp_turn_prev_duplicate"}`)
- secondTurn := readMessage()
- require.Equal(t, "resp_turn_prev_once_2", gjson.GetBytes(secondTurn, "response.id").String())
-
- require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done"))
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr)
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应只重试一次")
-
- firstConn.mu.Lock()
- firstWrites := append([]map[string]any(nil), firstConn.writes...)
- firstConn.mu.Unlock()
- require.Len(t, firstWrites, 2)
- require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists())
-
- secondConn.mu.Lock()
- secondWrites := append([]map[string]any(nil), secondConn.writes...)
- secondConn.mu.Unlock()
- require.Len(t, secondWrites, 1)
- require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "重复键场景恢复重试后不应保留 previous_response_id")
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_RejectsMessageIDAsPreviousResponseID(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 119,
- Name: "openai-ingress-prev-validation",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
- defer func() {
- _ = clientConn.CloseNow()
- }()
-
- writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
- err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456"}`))
- cancelWrite()
- require.NoError(t, err)
-
- select {
- case serverErr := <-serverErrCh:
- require.Error(t, serverErr)
- var closeErr *OpenAIWSClientCloseError
- require.ErrorAs(t, serverErr, &closeErr)
- require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode())
- require.Contains(t, closeErr.Reason(), "previous_response_id must be a response.id")
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-}
-
-type openAIWSQueueDialer struct {
- mu sync.Mutex
- conns []openAIWSClientConn
- dialCount int
-}
-
-func (d *openAIWSQueueDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = headers
- _ = proxyURL
- d.mu.Lock()
- defer d.mu.Unlock()
- d.dialCount++
- if len(d.conns) == 0 {
- return nil, 503, nil, errors.New("no test conn")
- }
- conn := d.conns[0]
- if len(d.conns) > 1 {
- d.conns = d.conns[1:]
- }
- return conn, 0, nil, nil
-}
-
-func (d *openAIWSQueueDialer) DialCount() int {
- d.mu.Lock()
- defer d.mu.Unlock()
- return d.dialCount
-}
-
-type openAIWSPreflightFailConn struct {
- mu sync.Mutex
- events [][]byte
- pingFails bool
- writeCount int
- pingCount int
-}
-
-func (c *openAIWSPreflightFailConn) WriteJSON(context.Context, any) error {
- c.mu.Lock()
- c.writeCount++
- c.mu.Unlock()
- return nil
-}
-
-func (c *openAIWSPreflightFailConn) ReadMessage(context.Context) ([]byte, error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if len(c.events) == 0 {
- return nil, io.EOF
- }
- event := c.events[0]
- c.events = c.events[1:]
- if len(c.events) == 0 {
- c.pingFails = true
- }
- return event, nil
-}
-
-func (c *openAIWSPreflightFailConn) Ping(context.Context) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.pingCount++
- if c.pingFails {
- return errors.New("preflight ping failed")
- }
- return nil
-}
-
-func (c *openAIWSPreflightFailConn) Close() error {
- return nil
-}
-
-func (c *openAIWSPreflightFailConn) WriteCount() int {
- c.mu.Lock()
- defer c.mu.Unlock()
- return c.writeCount
-}
-
-func (c *openAIWSPreflightFailConn) PingCount() int {
- c.mu.Lock()
- defer c.mu.Unlock()
- return c.pingCount
-}
-
-type openAIWSWriteFailAfterFirstTurnConn struct {
- mu sync.Mutex
- events [][]byte
- failOnWrite bool
-}
-
-func (c *openAIWSWriteFailAfterFirstTurnConn) WriteJSON(context.Context, any) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.failOnWrite {
- return errors.New("write failed on stale conn")
- }
- return nil
-}
-
-func (c *openAIWSWriteFailAfterFirstTurnConn) ReadMessage(context.Context) ([]byte, error) {
- c.mu.Lock()
- defer c.mu.Unlock()
- if len(c.events) == 0 {
- return nil, io.EOF
- }
- event := c.events[0]
- c.events = c.events[1:]
- if len(c.events) == 0 {
- c.failOnWrite = true
- }
- return event, nil
-}
-
-func (c *openAIWSWriteFailAfterFirstTurnConn) Ping(context.Context) error {
- return nil
-}
-
-func (c *openAIWSWriteFailAfterFirstTurnConn) Close() error {
- return nil
-}
-
-func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnectStillDrainsUpstream(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- // 多个上游事件:前几个为非 terminal 事件,最后一个为 terminal。
- // 第一个事件延迟 250ms 让客户端 RST 有时间传播,使 writeClientMessage 可靠失败。
- captureConn := &openAIWSCaptureConn{
- readDelays: []time.Duration{250 * time.Millisecond, 0, 0},
- events: [][]byte{
- []byte(`{"type":"response.created","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1"}}`),
- []byte(`{"type":"response.output_item.added","response":{"id":"resp_ingress_disconnect"}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 115,
- Name: "openai-ingress-client-disconnect",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "model_mapping": map[string]any{
- "custom-original-model": "gpt-5.1",
- },
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- serverErrCh := make(chan error, 1)
- resultCh := make(chan *OpenAIForwardResult, 1)
- hooks := &OpenAIWSIngressHooks{
- AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) {
- if turnErr == nil && result != nil {
- resultCh <- result
- }
- },
- }
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{
- CompressionMode: coderws.CompressionContextTakeover,
- })
- if err != nil {
- serverErrCh <- err
- return
- }
- defer func() {
- _ = conn.CloseNow()
- }()
-
- rec := httptest.NewRecorder()
- ginCtx, _ := gin.CreateTestContext(rec)
- req := r.Clone(r.Context())
- req.Header = req.Header.Clone()
- req.Header.Set("User-Agent", "unit-test-agent/1.0")
- ginCtx.Request = req
-
- readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second)
- msgType, firstMessage, readErr := conn.Read(readCtx)
- cancel()
- if readErr != nil {
- serverErrCh <- readErr
- return
- }
- if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
- serverErrCh <- errors.New("unsupported websocket client message type")
- return
- }
-
- serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks)
- }))
- defer wsServer.Close()
-
- dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second)
- clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil)
- cancelDial()
- require.NoError(t, err)
-
- writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second)
- err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`))
- cancelWrite()
- require.NoError(t, err)
- // 立即关闭客户端,模拟客户端在 relay 期间断连。
- require.NoError(t, clientConn.CloseNow(), "模拟 ingress 客户端提前断连")
-
- select {
- case serverErr := <-serverErrCh:
- require.NoError(t, serverErr, "客户端断连后应继续 drain 上游直到 terminal 或正常结束")
- case <-time.After(5 * time.Second):
- t.Fatal("等待 ingress websocket 结束超时")
- }
-
- select {
- case result := <-resultCh:
- require.Equal(t, "resp_ingress_disconnect", result.RequestID)
- require.Equal(t, 2, result.Usage.InputTokens)
- require.Equal(t, 1, result.Usage.OutputTokens)
- case <-time.After(2 * time.Second):
- t.Fatal("未收到断连后的 turn 结果回调")
- }
-}
diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go
index ff35cb01d..40514c57a 100644
--- a/backend/internal/service/openai_ws_forwarder_ingress_test.go
+++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go
@@ -6,10 +6,13 @@ import (
"errors"
"io"
"net"
+ "strings"
"testing"
+ "time"
"github.com/Wei-Shaw/sub2api/internal/config"
coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
@@ -34,6 +37,8 @@ func TestIsOpenAIWSClientDisconnectError(t *testing.T) {
{name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true},
{name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true},
{name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true},
+ {name: "blank_message", err: errors.New(" "), want: false},
+ {name: "unmatched_message", err: errors.New("tls handshake timeout"), want: false},
}
for _, tt := range tests {
@@ -45,6 +50,413 @@ func TestIsOpenAIWSClientDisconnectError(t *testing.T) {
}
}
+func TestOpenAIWSIngressFallbackSessionSeedFromContext(t *testing.T) {
+ t.Parallel()
+
+ require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(nil))
+
+ gin.SetMode(gin.TestMode)
+ c, _ := gin.CreateTestContext(nil)
+ require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(c))
+
+ c.Set("api_key", "not_api_key")
+ require.Empty(t, openAIWSIngressFallbackSessionSeedFromContext(c))
+
+ groupID := int64(99)
+ c.Set("api_key", &APIKey{
+ ID: 101,
+ GroupID: &groupID,
+ User: &User{ID: 202},
+ })
+ require.Equal(t, "openai_ws_ingress:99:202:101", openAIWSIngressFallbackSessionSeedFromContext(c))
+
+ c.Set("api_key", &APIKey{
+ ID: 303,
+ User: nil,
+ })
+ require.Equal(t, "openai_ws_ingress:0:0:303", openAIWSIngressFallbackSessionSeedFromContext(c))
+}
+
+func TestClassifyOpenAIWSIngressTurnAbortReason(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ wantReason openAIWSIngressTurnAbortReason
+ wantExpected bool
+ }{
+ {
+ name: "nil",
+ err: nil,
+ wantReason: openAIWSIngressTurnAbortReasonUnknown,
+ wantExpected: false,
+ },
+ {
+ name: "context canceled",
+ err: context.Canceled,
+ wantReason: openAIWSIngressTurnAbortReasonContextCanceled,
+ wantExpected: true,
+ },
+ {
+ name: "context deadline",
+ err: context.DeadlineExceeded,
+ wantReason: openAIWSIngressTurnAbortReasonContextDeadline,
+ wantExpected: false,
+ },
+ {
+ name: "client close",
+ err: coderws.CloseError{Code: coderws.StatusNormalClosure},
+ wantReason: openAIWSIngressTurnAbortReasonClientClosed,
+ wantExpected: true,
+ },
+ {
+ name: "client close by eof",
+ err: io.EOF,
+ wantReason: openAIWSIngressTurnAbortReasonClientClosed,
+ wantExpected: true,
+ },
+ {
+ name: "previous response not found",
+ err: wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonPreviousResponse,
+ wantExpected: true,
+ },
+ {
+ name: "tool output not found",
+ err: wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("no tool output found"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonToolOutput,
+ wantExpected: true,
+ },
+ {
+ name: "upstream error event",
+ err: wrapOpenAIWSIngressTurnError(
+ "upstream_error_event",
+ errors.New("upstream error event"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonUpstreamError,
+ wantExpected: true,
+ },
+ {
+ name: "write upstream",
+ err: wrapOpenAIWSIngressTurnError(
+ "write_upstream",
+ errors.New("write upstream fail"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonWriteUpstream,
+ wantExpected: false,
+ },
+ {
+ name: "read upstream",
+ err: wrapOpenAIWSIngressTurnError(
+ "read_upstream",
+ errors.New("read upstream fail"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonReadUpstream,
+ wantExpected: false,
+ },
+ {
+ name: "write client",
+ err: wrapOpenAIWSIngressTurnError(
+ "write_client",
+ errors.New("write client fail"),
+ true,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonWriteClient,
+ wantExpected: false,
+ },
+ {
+ name: "unknown turn stage",
+ err: wrapOpenAIWSIngressTurnError(
+ "some_unknown_stage",
+ errors.New("unknown stage fail"),
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonUnknown,
+ wantExpected: false,
+ },
+ {
+ name: "continuation unavailable close",
+ err: NewOpenAIWSClientCloseError(
+ coderws.StatusPolicyViolation,
+ openAIWSContinuationUnavailableReason,
+ nil,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonContinuationUnavailable,
+ wantExpected: true,
+ },
+ {
+ name: "upstream restart 1012",
+ err: wrapOpenAIWSIngressTurnError(
+ "read_upstream",
+ coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"},
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart,
+ wantExpected: true,
+ },
+ {
+ name: "upstream try again later 1013",
+ err: wrapOpenAIWSIngressTurnError(
+ "read_upstream",
+ coderws.CloseError{Code: coderws.StatusTryAgainLater, Reason: "try again later"},
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart,
+ wantExpected: true,
+ },
+ {
+ name: "upstream restart 1012 with wroteDownstream",
+ err: wrapOpenAIWSIngressTurnError(
+ "read_upstream",
+ coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"},
+ true,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonUpstreamRestart,
+ wantExpected: true,
+ },
+ {
+ name: "1012 on non-read_upstream stage should not match",
+ err: wrapOpenAIWSIngressTurnError(
+ "write_upstream",
+ coderws.CloseError{Code: coderws.StatusServiceRestart, Reason: "service restart"},
+ false,
+ ),
+ wantReason: openAIWSIngressTurnAbortReasonWriteUpstream,
+ wantExpected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(tt.err)
+ require.Equal(t, tt.wantReason, reason)
+ require.Equal(t, tt.wantExpected, expected)
+ })
+ }
+}
+
+func TestClassifyOpenAIWSIngressTurnAbortReason_ClientDisconnectedDrainTimeout(t *testing.T) {
+ t.Parallel()
+
+ err := wrapOpenAIWSIngressTurnError(
+ "client_disconnected_drain_timeout",
+ openAIWSIngressClientDisconnectedDrainTimeoutError(2*time.Second),
+ true,
+ )
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.Equal(t, openAIWSIngressTurnAbortReasonContextCanceled, reason)
+ require.True(t, expected)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionCloseGracefully, openAIWSIngressTurnAbortDispositionForReason(reason))
+}
+
+func TestOpenAIWSIngressPumpClosedTurnError_ClientDisconnected(t *testing.T) {
+ t.Parallel()
+
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_partial",
+ Usage: OpenAIUsage{
+ InputTokens: 12,
+ },
+ }
+ err := openAIWSIngressPumpClosedTurnError(true, true, partial)
+ require.Error(t, err)
+ require.ErrorIs(t, err, context.Canceled)
+
+ var turnErr *openAIWSIngressTurnError
+ require.ErrorAs(t, err, &turnErr)
+ require.Equal(t, "client_disconnected_drain_timeout", turnErr.stage)
+ require.True(t, turnErr.wroteDownstream)
+ require.NotNil(t, turnErr.partialResult)
+ require.Equal(t, partial.RequestID, turnErr.partialResult.RequestID)
+}
+
+func TestOpenAIWSIngressPumpClosedTurnError_ReadUpstream(t *testing.T) {
+ t.Parallel()
+
+ err := openAIWSIngressPumpClosedTurnError(false, false, nil)
+ require.Error(t, err)
+
+ var turnErr *openAIWSIngressTurnError
+ require.ErrorAs(t, err, &turnErr)
+ require.Equal(t, "read_upstream", turnErr.stage)
+ require.False(t, turnErr.wroteDownstream)
+ require.Nil(t, turnErr.partialResult)
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.Equal(t, openAIWSIngressTurnAbortReasonReadUpstream, reason)
+ require.False(t, expected)
+}
+
+func TestOpenAIWSIngressPumpClosedTurnError_ClonesPartialResult(t *testing.T) {
+ t.Parallel()
+
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_original",
+ PendingFunctionCallIDs: []string{"call_a"},
+ }
+ err := openAIWSIngressPumpClosedTurnError(true, true, partial)
+ require.Error(t, err)
+
+ partial.RequestID = "resp_mutated"
+ partial.PendingFunctionCallIDs[0] = "call_b"
+
+ var turnErr *openAIWSIngressTurnError
+ require.ErrorAs(t, err, &turnErr)
+ require.NotNil(t, turnErr.partialResult)
+ require.Equal(t, "resp_original", turnErr.partialResult.RequestID)
+ require.Equal(t, []string{"call_a"}, turnErr.partialResult.PendingFunctionCallIDs)
+}
+
+func TestOpenAIWSIngressClientDisconnectedDrainTimeoutError_DefaultTimeout(t *testing.T) {
+ t.Parallel()
+
+ err := openAIWSIngressClientDisconnectedDrainTimeoutError(0)
+ require.Error(t, err)
+ require.Contains(t, err.Error(), openAIWSIngressClientDisconnectDrainTimeout.String())
+ require.ErrorIs(t, err, context.Canceled)
+}
+
+func TestOpenAIWSIngressResolveDrainReadTimeout(t *testing.T) {
+ t.Parallel()
+
+ now := time.Now()
+ tests := []struct {
+ name string
+ base time.Duration
+ deadline time.Time
+ want time.Duration
+ wantExpire bool
+ }{
+ {
+ name: "no_deadline_uses_base",
+ base: 15 * time.Second,
+ deadline: time.Time{},
+ want: 15 * time.Second,
+ wantExpire: false,
+ },
+ {
+ name: "remaining_shorter_than_base",
+ base: 10 * time.Second,
+ deadline: now.Add(3 * time.Second),
+ want: 3 * time.Second,
+ wantExpire: false,
+ },
+ {
+ name: "base_shorter_than_remaining",
+ base: 2 * time.Second,
+ deadline: now.Add(8 * time.Second),
+ want: 2 * time.Second,
+ wantExpire: false,
+ },
+ {
+ name: "base_zero_uses_remaining",
+ base: 0,
+ deadline: now.Add(5 * time.Second),
+ want: 5 * time.Second,
+ wantExpire: false,
+ },
+ {
+ name: "expired_deadline",
+ base: 10 * time.Second,
+ deadline: now.Add(-time.Millisecond),
+ want: 0,
+ wantExpire: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got, expired := openAIWSIngressResolveDrainReadTimeout(tt.base, tt.deadline, now)
+ require.Equal(t, tt.want, got)
+ require.Equal(t, tt.wantExpire, expired)
+ })
+ }
+}
+
+func TestOpenAIWSIngressTurnAbortDispositionForReason(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ in openAIWSIngressTurnAbortReason
+ want openAIWSIngressTurnAbortDisposition
+ }{
+ {
+ name: "continue turn on previous response mismatch",
+ in: openAIWSIngressTurnAbortReasonPreviousResponse,
+ want: openAIWSIngressTurnAbortDispositionContinueTurn,
+ },
+ {
+ name: "continue turn on tool output mismatch",
+ in: openAIWSIngressTurnAbortReasonToolOutput,
+ want: openAIWSIngressTurnAbortDispositionContinueTurn,
+ },
+ {
+ name: "continue turn on upstream error event",
+ in: openAIWSIngressTurnAbortReasonUpstreamError,
+ want: openAIWSIngressTurnAbortDispositionContinueTurn,
+ },
+ {
+ name: "close gracefully on context canceled",
+ in: openAIWSIngressTurnAbortReasonContextCanceled,
+ want: openAIWSIngressTurnAbortDispositionCloseGracefully,
+ },
+ {
+ name: "close gracefully on client closed",
+ in: openAIWSIngressTurnAbortReasonClientClosed,
+ want: openAIWSIngressTurnAbortDispositionCloseGracefully,
+ },
+ {
+ name: "default fail request on unknown reason",
+ in: openAIWSIngressTurnAbortReasonUnknown,
+ want: openAIWSIngressTurnAbortDispositionFailRequest,
+ },
+ {
+ name: "default fail request on write upstream reason",
+ in: openAIWSIngressTurnAbortReasonWriteUpstream,
+ want: openAIWSIngressTurnAbortDispositionFailRequest,
+ },
+ {
+ name: "default fail request on read upstream reason",
+ in: openAIWSIngressTurnAbortReasonReadUpstream,
+ want: openAIWSIngressTurnAbortDispositionFailRequest,
+ },
+ {
+ name: "default fail request on write client reason",
+ in: openAIWSIngressTurnAbortReasonWriteClient,
+ want: openAIWSIngressTurnAbortDispositionFailRequest,
+ },
+ {
+ name: "continue turn on upstream restart",
+ in: openAIWSIngressTurnAbortReasonUpstreamRestart,
+ want: openAIWSIngressTurnAbortDispositionContinueTurn,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.want, openAIWSIngressTurnAbortDispositionForReason(tt.in))
+ })
+ }
+}
+
func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) {
t.Parallel()
@@ -254,12 +666,12 @@ func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) {
want: false,
},
{
- name: "skip_on_first_turn",
+ name: "infer_on_first_turn_when_expected_previous_exists",
storeDisabled: true,
turn: 1,
hasFunctionCallOutput: true,
expectedPrevious: "resp_1",
- want: false,
+ want: true,
},
{
name: "skip_without_function_call_output",
@@ -529,7 +941,7 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
}`)
t.Run("strict_incremental_keep", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false, nil, nil)
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "strict_incremental_ok", reason)
@@ -537,28 +949,28 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
t.Run("missing_previous_response_id", func(t *testing.T) {
payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_previous_response_id", reason)
})
t.Run("missing_last_turn_response_id", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false, nil, nil)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_last_turn_response_id", reason)
})
t.Run("previous_response_id_mismatch", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false, nil, nil)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "previous_response_id_mismatch", reason)
})
t.Run("missing_previous_turn_payload", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false, nil, nil)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "missing_previous_turn_payload", reason)
@@ -573,7 +985,7 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
"previous_response_id":"resp_turn_1",
"input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]
}`)
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil)
require.NoError(t, err)
require.False(t, keep)
require.Equal(t, "non_input_changed", reason)
@@ -588,7 +1000,7 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
"previous_response_id":"resp_turn_1",
"input":[{"type":"input_text","text":"different"}]
}`)
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false, nil, nil)
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "strict_incremental_ok", reason)
@@ -602,21 +1014,63 @@ func TestShouldKeepIngressPreviousResponseID(t *testing.T) {
"previous_response_id":"resp_external",
"input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
}`)
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true, nil, []string{"call_1"})
require.NoError(t, err)
require.True(t, keep)
require.Equal(t, "has_function_call_output", reason)
})
+ t.Run("function_call_output_pending_call_id_match_keeps_previous_response_id", func(t *testing.T) {
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "store":false,
+ "previous_response_id":"resp_turn_1",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(
+ previousPayload,
+ payload,
+ "resp_turn_1",
+ true,
+ []string{"call_1"},
+ []string{"call_1"},
+ )
+ require.NoError(t, err)
+ require.True(t, keep)
+ require.Equal(t, "function_call_output_call_id_match", reason)
+ })
+
+ t.Run("function_call_output_pending_call_id_mismatch_drops_previous_response_id", func(t *testing.T) {
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "store":false,
+ "previous_response_id":"resp_turn_1",
+ "input":[{"type":"function_call_output","call_id":"call_other","output":"ok"}]
+ }`)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(
+ previousPayload,
+ payload,
+ "resp_turn_1",
+ true,
+ []string{"call_1"},
+ []string{"call_other"},
+ )
+ require.NoError(t, err)
+ require.False(t, keep)
+ require.Equal(t, "function_call_output_call_id_mismatch", reason)
+ })
+
t.Run("non_input_compare_error", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false, nil, nil)
require.Error(t, err)
require.False(t, keep)
require.Equal(t, "non_input_compare_error", reason)
})
t.Run("current_payload_compare_error", func(t *testing.T) {
- keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false)
+ keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false, nil, nil)
require.Error(t, err)
require.False(t, keep)
require.Equal(t, "non_input_compare_error", reason)
@@ -670,6 +1124,171 @@ func TestBuildOpenAIWSReplayInputSequence(t *testing.T) {
require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String())
require.Equal(t, "world", gjson.GetBytes(items[1], "text").String())
})
+
+ t.Run("replay_input_limited_by_bytes_keeps_newest_items", func(t *testing.T) {
+ makeItem := func(text string) json.RawMessage {
+ raw, err := json.Marshal(map[string]any{
+ "type": "input_text",
+ "text": text,
+ })
+ require.NoError(t, err)
+ return json.RawMessage(raw)
+ }
+ largeA := strings.Repeat("a", openAIWSIngressReplayInputMaxBytes/2)
+ largeB := strings.Repeat("b", openAIWSIngressReplayInputMaxBytes/2)
+ largeC := strings.Repeat("c", openAIWSIngressReplayInputMaxBytes/2)
+ previousLarge := []json.RawMessage{
+ makeItem(largeA),
+ makeItem(largeB),
+ }
+ currentPayload, err := json.Marshal(map[string]any{
+ "previous_response_id": "resp_1",
+ "input": []map[string]any{
+ {"type": "input_text", "text": largeC},
+ },
+ })
+ require.NoError(t, err)
+
+ items, exists, err := buildOpenAIWSReplayInputSequence(
+ previousLarge,
+ true,
+ currentPayload,
+ true,
+ )
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.GreaterOrEqual(t, len(items), 1)
+ require.Equal(t, largeC, gjson.GetBytes(items[len(items)-1], "text").String(), "latest item should always be preserved")
+ require.Less(t, len(items), 3, "oversized replay input should be truncated")
+ })
+
+ t.Run("replay_input_limited_by_bytes_still_keeps_single_oversized_latest_item", func(t *testing.T) {
+ tooLargeText := strings.Repeat("z", openAIWSIngressReplayInputMaxBytes+1024)
+ currentPayload, err := json.Marshal(map[string]any{
+ "input": []map[string]any{
+ {"type": "input_text", "text": tooLargeText},
+ },
+ })
+ require.NoError(t, err)
+
+ items, exists, err := buildOpenAIWSReplayInputSequence(
+ nil,
+ false,
+ currentPayload,
+ false,
+ )
+ require.NoError(t, err)
+ require.True(t, exists)
+ require.Len(t, items, 1)
+ require.Equal(t, tooLargeText, gjson.GetBytes(items[0], "text").String())
+ })
+}
+
+func TestOpenAIWSInputAppearsEditedFromPreviousFullInput(t *testing.T) {
+ t.Parallel()
+
+ makeItems := func(values ...string) []json.RawMessage {
+ items := make([]json.RawMessage, 0, len(values))
+ for _, v := range values {
+ raw, err := json.Marshal(map[string]any{
+ "type": "input_text",
+ "text": v,
+ })
+ require.NoError(t, err)
+ items = append(items, json.RawMessage(raw))
+ }
+ return items
+ }
+
+ previous := makeItems("hello", "world")
+
+ t.Run("skip_when_no_previous_response_id", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`),
+ false,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("skip_when_previous_full_input_missing", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ nil,
+ false,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("error_when_current_payload_invalid", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[}`),
+ true,
+ )
+ require.Error(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("skip_when_current_input_missing", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"previous_response_id":"resp_1"}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("skip_when_previous_len_lt_2", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ makeItems("hello"),
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("skip_when_current_shorter_than_previous", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("skip_when_current_has_previous_prefix", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"},{"type":"input_text","text":"new"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.False(t, edited)
+ })
+
+ t.Run("detect_when_current_is_full_snapshot_edit", func(t *testing.T) {
+ edited, err := openAIWSInputAppearsEditedFromPreviousFullInput(
+ previous,
+ true,
+ []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"HELLO_EDITED"},{"type":"input_text","text":"world"}]}`),
+ true,
+ )
+ require.NoError(t, err)
+ require.True(t, edited)
+ })
}
func TestSetOpenAIWSPayloadInputSequence(t *testing.T) {
@@ -712,3 +1331,147 @@ func TestCloneOpenAIWSRawMessages(t *testing.T) {
require.Len(t, cloned, 0)
})
}
+
+// ---------------------------------------------------------------------------
+// TestInjectPreviousResponseIDForFunctionCallOutput
+// 端到端测试:当客户端发送 function_call_output 但未携带 previous_response_id 时,
+// Gateway 应主动注入 lastTurnResponseID,避免上游返回 tool_output_not_found 错误。
+// ---------------------------------------------------------------------------
+
+func TestInjectPreviousResponseIDForFunctionCallOutput(t *testing.T) {
+ t.Parallel()
+
+ // 辅助函数:模拟 forwarder 中的注入逻辑
+ // 返回 (注入后的 payload, 注入后的 previousResponseID, 是否执行了注入)
+ simulateInject := func(
+ storeDisabled bool,
+ turn int,
+ payload []byte,
+ expectedPrev string,
+ ) ([]byte, string, bool) {
+ currentPreviousResponseID := ""
+ prev := gjson.GetBytes(payload, "previous_response_id")
+ if prev.Exists() {
+ currentPreviousResponseID = strings.TrimSpace(prev.String())
+ }
+ hasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
+
+ if shouldInferIngressFunctionCallOutputPreviousResponseID(
+ storeDisabled, turn, hasFunctionCallOutput, currentPreviousResponseID, expectedPrev,
+ ) {
+ injected, err := setPreviousResponseIDToRawPayload(payload, expectedPrev)
+ if err != nil {
+ return payload, currentPreviousResponseID, false
+ }
+ return injected, expectedPrev, true
+ }
+ return payload, currentPreviousResponseID, false
+ }
+
+ t.Run("inject_when_function_call_output_without_prev_id", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"function_call_output","call_id":"call_abc123","output":"result"}]}`)
+ updated, prevID, injected := simulateInject(true, 2, payload, "resp_last_turn")
+
+ require.True(t, injected, "应该执行注入")
+ require.Equal(t, "resp_last_turn", prevID)
+ require.Equal(t, "resp_last_turn", gjson.GetBytes(updated, "previous_response_id").String())
+ // 验证原始 input 保持不变
+ require.Equal(t, "call_abc123", gjson.GetBytes(updated, `input.0.call_id`).String())
+ require.Equal(t, "function_call_output", gjson.GetBytes(updated, `input.0.type`).String())
+ require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String())
+ })
+
+ t.Run("skip_when_prev_id_already_present", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_client","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, prevID, injected := simulateInject(true, 2, payload, "resp_last_turn")
+
+ require.False(t, injected, "客户端已携带 previous_response_id,不应注入")
+ require.Equal(t, "resp_client", prevID)
+ })
+
+ t.Run("skip_when_store_enabled", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, _, injected := simulateInject(false, 2, payload, "resp_last_turn")
+
+ require.False(t, injected, "store 未禁用时不应注入")
+ })
+
+ t.Run("skip_when_no_function_call_output", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"input_text","text":"hello"}]}`)
+ _, _, injected := simulateInject(true, 2, payload, "resp_last_turn")
+
+ require.False(t, injected, "没有 function_call_output 时不应注入")
+ })
+
+ t.Run("skip_when_expected_prev_empty", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, _, injected := simulateInject(true, 2, payload, "")
+
+ require.False(t, injected, "没有 expectedPrev 时不应注入")
+ })
+
+ t.Run("inject_preserves_multiple_function_call_outputs", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"a"},{"type":"function_call_output","call_id":"call_2","output":"b"}]}`)
+ updated, prevID, injected := simulateInject(true, 5, payload, "resp_multi")
+
+ require.True(t, injected)
+ require.Equal(t, "resp_multi", prevID)
+ require.Equal(t, "resp_multi", gjson.GetBytes(updated, "previous_response_id").String())
+ outputs := gjson.GetBytes(updated, `input.#(type=="function_call_output")#.call_id`).Array()
+ require.Len(t, outputs, 2)
+ require.Equal(t, "call_1", outputs[0].String())
+ require.Equal(t, "call_2", outputs[1].String())
+ })
+
+ t.Run("inject_on_first_turn_with_expected_prev", func(t *testing.T) {
+ t.Parallel()
+ // turn=1 但有 expectedPrev(可能来自 session state store 恢复),应注入
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, _, injected := simulateInject(true, 1, payload, "resp_restored")
+
+ require.True(t, injected, "turn=1 且有 expectedPrev 时应注入")
+ })
+
+ t.Run("inject_updates_payload_bytes_correctly", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ updated, _, injected := simulateInject(true, 3, payload, "resp_check_size")
+
+ require.True(t, injected)
+ // 注入后 payload 长度应增加(包含了新的 previous_response_id 字段)
+ require.Greater(t, len(updated), len(payload))
+ // 验证 JSON 合法性
+ require.True(t, json.Valid(updated), "注入后的 payload 应为合法 JSON")
+ })
+
+ t.Run("skip_when_turn_zero", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, _, injected := simulateInject(true, 0, payload, "resp_1")
+
+ require.False(t, injected, "turn=0 时不应注入")
+ })
+
+ t.Run("inject_with_whitespace_expected_prev", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ // shouldInfer 内部会 trim,所以带空格的 expectedPrev 仍然有效
+ _, _, injected := simulateInject(true, 2, payload, " resp_trimmed ")
+
+ require.True(t, injected, "trim 后非空的 expectedPrev 应触发注入")
+ })
+
+ t.Run("skip_when_prev_id_is_whitespace_only", func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ _, _, injected := simulateInject(true, 2, payload, " ")
+
+ require.False(t, injected, "纯空白的 expectedPrev 不应触发注入")
+ })
+}
diff --git a/backend/internal/service/openai_ws_forwarder_panic_test.go b/backend/internal/service/openai_ws_forwarder_panic_test.go
new file mode 100644
index 000000000..0ceacb187
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_panic_test.go
@@ -0,0 +1,107 @@
+package service
+
+import (
+ "context"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ coderws "github.com/coder/websocket"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+type openAIWSPanicResolver struct{}
+
+func (openAIWSPanicResolver) Resolve(account *Account) OpenAIWSProtocolDecision {
+ panic("resolver panic")
+}
+
+type openAIWSPanicStateStore struct {
+ OpenAIWSStateStore
+}
+
+func (openAIWSPanicStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) {
+ panic("state_store panic")
+}
+
+func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PanicRecovered(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := buildIngressPolicyTestConfig()
+ cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
+ svc := buildIngressPolicyTestService(cfg)
+ svc.openaiWSResolver = openAIWSPanicResolver{}
+ account := buildIngressPolicyTestAccount(map[string]any{
+ "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
+ })
+
+ serverErr := runIngressProxyWithFirstPayload(t, svc, account, `{"type":"response.create","model":"gpt-5.1","stream":false}`)
+ var closeErr *OpenAIWSClientCloseError
+ require.ErrorAs(t, serverErr, &closeErr)
+ require.Equal(t, coderws.StatusInternalError, closeErr.StatusCode())
+ require.Equal(t, "internal websocket proxy panic", closeErr.Reason())
+}
+
+func TestOpenAIGatewayService_ForwardOpenAIWSV2_PanicRecovered(t *testing.T) {
+ gin.SetMode(gin.TestMode)
+
+ cfg := &config.Config{}
+ cfg.Security.URLAllowlist.Enabled = false
+ cfg.Security.URLAllowlist.AllowInsecureHTTP = true
+ cfg.Gateway.OpenAIWS.Enabled = true
+ cfg.Gateway.OpenAIWS.APIKeyEnabled = true
+ cfg.Gateway.OpenAIWS.OAuthEnabled = true
+ cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
+
+ svc := &OpenAIGatewayService{
+ cfg: cfg,
+ openaiWSStateStore: openAIWSPanicStateStore{},
+ openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
+ toolCorrector: NewCodexToolCorrector(),
+ cache: &stubGatewayCache{},
+ }
+
+ account := &Account{
+ ID: 445,
+ Name: "openai-forwarder-panic",
+ Platform: PlatformOpenAI,
+ Type: AccountTypeAPIKey,
+ Status: StatusActive,
+ Schedulable: true,
+ Concurrency: 1,
+ Credentials: map[string]any{
+ "api_key": "sk-test",
+ },
+ }
+
+ rec := httptest.NewRecorder()
+ ginCtx, _ := gin.CreateTestContext(rec)
+ req := httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
+ req.Header.Set("session_id", "sess-panic-check")
+ req.Header.Set("User-Agent", "unit-test-agent/1.0")
+ ginCtx.Request = req
+
+ _, err := svc.forwardOpenAIWSV2(
+ context.Background(),
+ ginCtx,
+ account,
+ map[string]any{
+ "model": "gpt-5.1",
+ "input": []any{},
+ },
+ "sk-test",
+ OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
+ true,
+ true,
+ "gpt-5.1",
+ "gpt-5.1",
+ time.Now(),
+ 1,
+ "",
+ )
+ require.Error(t, err)
+ require.ErrorContains(t, err, "panic recovered")
+}
diff --git a/backend/internal/service/openai_ws_forwarder_recovery_test.go b/backend/internal/service/openai_ws_forwarder_recovery_test.go
new file mode 100644
index 000000000..3cbf4842f
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_recovery_test.go
@@ -0,0 +1,691 @@
+package service
+
+import (
+ "encoding/json"
+ "errors"
+ "strconv"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+// ---------------------------------------------------------------------------
+// openAIWSIngressTurnWroteDownstream 辅助函数测试
+// ---------------------------------------------------------------------------
+
+func TestOpenAIWSIngressTurnWroteDownstream(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ err error
+ want bool
+ }{
+ {
+ name: "nil_error_returns_false",
+ err: nil,
+ want: false,
+ },
+ {
+ name: "plain_error_returns_false",
+ err: errors.New("some random error"),
+ want: false,
+ },
+ {
+ name: "turn_error_wrote_downstream_false",
+ err: wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ false,
+ ),
+ want: false,
+ },
+ {
+ name: "turn_error_wrote_downstream_true",
+ err: wrapOpenAIWSIngressTurnError(
+ "upstream_error_event",
+ errors.New("upstream error"),
+ true,
+ ),
+ want: true,
+ },
+ {
+ name: "turn_error_with_partial_result_wrote_downstream_true",
+ err: wrapOpenAIWSIngressTurnErrorWithPartial(
+ "read_upstream",
+ errors.New("connection reset"),
+ true,
+ &OpenAIForwardResult{RequestID: "resp_partial"},
+ ),
+ want: true,
+ },
+ {
+ name: "turn_error_with_partial_result_wrote_downstream_false",
+ err: wrapOpenAIWSIngressTurnErrorWithPartial(
+ "read_upstream",
+ errors.New("connection reset"),
+ false,
+ &OpenAIForwardResult{RequestID: "resp_partial"},
+ ),
+ want: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ require.Equal(t, tt.want, openAIWSIngressTurnWroteDownstream(tt.err))
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// previous_response_not_found 错误与 ContinueTurn 处置测试
+// ---------------------------------------------------------------------------
+
+func TestPreviousResponseNotFound_ClassifiesAsContinueTurn(t *testing.T) {
+ t.Parallel()
+
+ // previous_response_not_found(wroteDownstream=false)应被归类为 ContinueTurn
+ err := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ false,
+ )
+
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.Equal(t, openAIWSIngressTurnAbortReasonPreviousResponse, reason)
+ require.True(t, expected)
+
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+}
+
+func TestToolOutputNotFound_ClassifiesAsContinueTurn(t *testing.T) {
+ t.Parallel()
+
+ // tool_output_not_found(wroteDownstream=false)应被归类为 ContinueTurn
+ err := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("no tool call found for function call output"),
+ false,
+ )
+
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.Equal(t, openAIWSIngressTurnAbortReasonToolOutput, reason)
+ require.True(t, expected)
+
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+}
+
+// ---------------------------------------------------------------------------
+// function_call_output 与 previous_response_id 语义绑定测试
+// 验证核心修复:带 function_call_output 时不能 drop previous_response_id
+// ---------------------------------------------------------------------------
+
+func TestFunctionCallOutputPayload_HasFunctionCallOutput(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ payload string
+ want bool
+ }{
+ {
+ name: "payload_with_function_call_output",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_abc","output":"ok"}]}`,
+ want: true,
+ },
+ {
+ name: "payload_with_mixed_input_including_function_call_output",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"function_call_output","call_id":"call_abc","output":"ok"}]}`,
+ want: true,
+ },
+ {
+ name: "payload_without_function_call_output",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"}]}`,
+ want: false,
+ },
+ {
+ name: "payload_with_empty_input",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[]}`,
+ want: false,
+ },
+ {
+ name: "payload_without_input",
+ payload: `{"type":"response.create","model":"gpt-5.1"}`,
+ want: false,
+ },
+ {
+ name: "multiple_function_call_outputs",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"r1"},{"type":"function_call_output","call_id":"call_2","output":"r2"}]}`,
+ want: true,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ got := gjson.GetBytes([]byte(tt.payload), `input.#(type=="function_call_output")`).Exists()
+ require.Equal(t, tt.want, got)
+ })
+ }
+}
+
+func TestDropPreviousResponseID_BreaksFunctionCallOutput(t *testing.T) {
+ t.Parallel()
+
+ // 核心回归测试:验证 drop previous_response_id 后 function_call_output 会变成孤立引用
+ //
+ // 场景:客户端发送 {previous_response_id: "resp_1", input: [{type: "function_call_output", call_id: "call_abc"}]}
+ // 如果 drop 了 previous_response_id,上游会创建全新上下文,找不到 call_abc 对应的 tool call
+ // 结果:上游报 "No tool call found for function call output with call_id call_abc"
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale_or_lost",
+ "input":[
+ {"type":"function_call_output","call_id":"call_abc","output":"{\"result\":\"ok\"}"},
+ {"type":"function_call_output","call_id":"call_def","output":"{\"result\":\"done\"}"}
+ ]
+ }`)
+
+ // 1. 验证原始 payload 有 previous_response_id 和 function_call_output
+ require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists())
+ require.True(t, gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists())
+
+ // 2. drop previous_response_id
+ dropped, removed, err := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, err)
+ require.True(t, removed)
+
+ // 3. 验证 drop 后的状态:previous_response_id 被移除但 function_call_output 仍然存在
+ require.False(t, gjson.GetBytes(dropped, "previous_response_id").Exists(),
+ "previous_response_id 应该被移除")
+ require.True(t, gjson.GetBytes(dropped, `input.#(type=="function_call_output")`).Exists(),
+ "function_call_output 仍然存在,但此时它引用的 call_id 没有了上下文 — 这就是 bug 的根因")
+
+ // 4. 验证 call_id 仍然在 payload 中(说明 drop 不会清理 function_call_output)
+ callIDs := make([]string, 0)
+ gjson.GetBytes(dropped, "input").ForEach(func(_, item gjson.Result) bool {
+ if item.Get("type").String() == "function_call_output" {
+ callIDs = append(callIDs, item.Get("call_id").String())
+ }
+ return true
+ })
+ require.ElementsMatch(t, []string{"call_abc", "call_def"}, callIDs,
+ "function_call_output 的 call_id 未被清除,但上游已无法匹配")
+}
+
+func TestRecoveryStrategy_FunctionCallOutput_ShouldNotDrop(t *testing.T) {
+ t.Parallel()
+
+ // 此测试验证修复的核心逻辑:
+ // 当 hasFunctionCallOutput=true 且 set/align 策略均失败时,
+ // 正确行为是放弃恢复(return false),而非 drop previous_response_id
+ //
+ // 因为:function_call_output 语义绑定 previous_response_id
+ // drop previous_response_id 但保留 function_call_output → 上游报错
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_lost",
+ "input":[{"type":"function_call_output","call_id":"call_JDKR","output":"ok"}]
+ }`)
+
+ hasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
+ require.True(t, hasFunctionCallOutput, "payload 必须包含 function_call_output")
+
+ // 模拟 set 策略失败(currentPreviousResponseID 不为空,不满足 set 条件)
+ currentPreviousResponseID := "resp_lost"
+ expectedPrev := "resp_expected"
+ require.NotEmpty(t, currentPreviousResponseID, "set 策略需要 currentPreviousResponseID 为空")
+
+ // 模拟 align 策略失败
+ _, aligned, alignErr := alignStoreDisabledPreviousResponseID(payload, expectedPrev)
+ if alignErr == nil && aligned {
+ // align 成功了,更新 payload 中的 previous_response_id
+ t.Log("align 策略成功,此场景不触发 abort 路径")
+ }
+ // 注意:align 通常会成功(替换 resp_lost → resp_expected)。
+ // 但在真实场景中,如果 align 后的 previous_response_id 仍然在上游不存在,
+ // 上游会再次返回 previous_response_not_found,此时二次进入恢复函数,
+ // 但 turnPrevRecoveryTried=true 会阻止二次恢复,直接走 abort。
+
+ // 验证关键断言:即使 drop 技术上可行,也不应该执行
+ // 因为这会导致 "No tool call found for function call output" 错误
+ droppedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErr)
+ require.True(t, removed, "drop 操作本身可以成功")
+
+ // 但 drop 后的 payload 仍有 function_call_output —— 这就是为什么不能 drop
+ hasFCOAfterDrop := gjson.GetBytes(droppedPayload, `input.#(type=="function_call_output")`).Exists()
+ require.True(t, hasFCOAfterDrop,
+ "drop previous_response_id 不会移除 function_call_output,"+
+ "导致上游报 'No tool call found for function call output'")
+}
+
+// ---------------------------------------------------------------------------
+// ContinueTurn abort 路径错误通知测试
+// ---------------------------------------------------------------------------
+
+func TestContinueTurnAbort_ErrorEventFormat(t *testing.T) {
+ t.Parallel()
+
+ // 验证 ContinueTurn abort 时生成的 error 事件格式正确
+ abortReason := openAIWSIngressTurnAbortReasonPreviousResponse
+ abortMessage := "turn failed: " + string(abortReason)
+
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` +
+ string(abortReason) + `","message":` + strconv.Quote(abortMessage) + `}}`)
+
+ // 验证 JSON 格式有效
+ var parsed map[string]any
+ err := json.Unmarshal(errorEvent, &parsed)
+ require.NoError(t, err, "error 事件应为有效 JSON")
+
+ // 验证事件结构
+ require.Equal(t, "error", parsed["type"])
+ errorObj, ok := parsed["error"].(map[string]any)
+ require.True(t, ok)
+ require.Equal(t, "server_error", errorObj["type"])
+ require.Equal(t, string(abortReason), errorObj["code"])
+ require.Contains(t, errorObj["message"], string(abortReason))
+}
+
+func TestContinueTurnAbort_ErrorEventWithSpecialChars(t *testing.T) {
+ t.Parallel()
+
+ // 验证包含特殊字符的错误消息不会破坏 JSON 格式
+ specialMessages := []string{
+ `No tool call found for function call output with call_id call_JDKR0SzNTARIsGb0L3hofFWd.`,
+ `error with "quotes" and \backslash`,
+ "error with\nnewline",
+ `error with & entities`,
+ "", // 空消息
+ }
+
+ for i, msg := range specialMessages {
+ msg := msg
+ t.Run("special_message_"+strconv.Itoa(i), func(t *testing.T) {
+ t.Parallel()
+ abortReason := openAIWSIngressTurnAbortReasonToolOutput
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` +
+ string(abortReason) + `","message":` + strconv.Quote(msg) + `}}`)
+
+ var parsed map[string]any
+ err := json.Unmarshal(errorEvent, &parsed)
+ require.NoError(t, err, "error event with special chars should be valid JSON: %q", msg)
+
+ errorObj := parsed["error"].(map[string]any)
+ require.Equal(t, msg, errorObj["message"])
+ })
+ }
+}
+
+func TestContinueTurnAbort_WroteDownstreamDeterminesNotification(t *testing.T) {
+ t.Parallel()
+
+ // 验证 wroteDownstream 标志如何影响错误通知策略
+ tests := []struct {
+ name string
+ wroteDownstream bool
+ shouldSendErrorToClient bool
+ }{
+ {
+ name: "not_wrote_downstream_should_send_error",
+ wroteDownstream: false,
+ shouldSendErrorToClient: true,
+ },
+ {
+ name: "wrote_downstream_should_not_send_error",
+ wroteDownstream: true,
+ shouldSendErrorToClient: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ err := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ tt.wroteDownstream,
+ )
+ wroteDownstream := openAIWSIngressTurnWroteDownstream(err)
+ require.Equal(t, tt.wroteDownstream, wroteDownstream)
+
+ // 只有当 wroteDownstream=false 时才需要补发 error 事件
+ shouldNotify := !wroteDownstream
+ require.Equal(t, tt.shouldSendErrorToClient, shouldNotify)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// previous_response_id 恢复策略:set / align / abort 完整流程测试
+// ---------------------------------------------------------------------------
+
+func TestRecoveryStrategy_SetPreviousResponseID(t *testing.T) {
+ t.Parallel()
+
+ // 场景:客户端未发送 previous_response_id,但 session 中有记录
+ // 此时应该通过 set 策略注入 previous_response_id
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ expectedPrev := "resp_expected"
+
+ // set 策略:当 currentPreviousResponseID 为空时,注入 expectedPrev
+ updated, err := setPreviousResponseIDToRawPayload(payload, expectedPrev)
+ require.NoError(t, err)
+ require.Equal(t, expectedPrev, gjson.GetBytes(updated, "previous_response_id").String())
+
+ // function_call_output 保持不变
+ require.True(t, gjson.GetBytes(updated, `input.#(type=="function_call_output")`).Exists())
+ require.Equal(t, "call_1", gjson.GetBytes(updated, `input.#(type=="function_call_output").call_id`).String())
+}
+
+func TestRecoveryStrategy_AlignPreviousResponseID(t *testing.T) {
+ t.Parallel()
+
+ // 场景:客户端发送了过时的 previous_response_id,需要 align 到最新
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ expectedPrev := "resp_latest"
+
+ updated, changed, err := alignStoreDisabledPreviousResponseID(payload, expectedPrev)
+ require.NoError(t, err)
+ require.True(t, changed)
+ require.Equal(t, expectedPrev, gjson.GetBytes(updated, "previous_response_id").String())
+
+ // function_call_output 保持不变
+ require.True(t, gjson.GetBytes(updated, `input.#(type=="function_call_output")`).Exists())
+}
+
+func TestRecoveryStrategy_AlignFailsWhenNoExpectedPrev(t *testing.T) {
+ t.Parallel()
+
+ // 场景:没有预期的 previous_response_id,align 无法执行
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "")
+ require.NoError(t, err)
+ require.False(t, changed, "align 应该在 expectedPrev 为空时不执行")
+ require.Equal(t, string(payload), string(updated))
+}
+
+// ---------------------------------------------------------------------------
+// isOpenAIWSIngressPreviousResponseNotFound 边界条件测试
+// ---------------------------------------------------------------------------
+
+func TestIsOpenAIWSIngressPreviousResponseNotFound_WroteDownstreamBlocks(t *testing.T) {
+ t.Parallel()
+
+ // wroteDownstream=true 时,即使 stage 是 previous_response_not_found,
+ // 也不应被识别为可恢复的 previous_response_not_found
+ // (因为已经向客户端写入了数据,无法安全重试)
+ err := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ true, // wroteDownstream = true
+ )
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(err),
+ "wroteDownstream=true 时不应识别为可恢复的 previous_response_not_found")
+}
+
+func TestIsOpenAIWSIngressPreviousResponseNotFound_DifferentStageReturns_False(t *testing.T) {
+ t.Parallel()
+
+ stages := []string{
+ "read_upstream",
+ "write_upstream",
+ "upstream_error_event",
+ openAIWSIngressStageToolOutputNotFound,
+ "unknown",
+ "",
+ }
+
+ for _, stage := range stages {
+ stage := stage
+ t.Run("stage_"+stage, func(t *testing.T) {
+ t.Parallel()
+ err := wrapOpenAIWSIngressTurnError(stage, errors.New("some error"), false)
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(err),
+ "stage=%q 不应被识别为 previous_response_not_found", stage)
+ })
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 端到端场景测试:function_call_output 恢复链路
+// ---------------------------------------------------------------------------
+
+func TestEndToEnd_FunctionCallOutputRecoveryChain(t *testing.T) {
+ t.Parallel()
+
+ // 完整场景:
+ // 1. 客户端发送带 function_call_output 的请求
+ // 2. 上游返回 previous_response_not_found
+ // 3. 恢复策略尝试 set/align
+ // 4. 如果都失败,应该 abort(而非 drop previous_response_id)
+ // 5. 客户端收到 error 事件
+ // 6. 客户端重置并发送完整请求
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_lost",
+ "input":[
+ {"type":"function_call_output","call_id":"call_JDKR0SzNTARIsGb0L3hofFWd","output":"{\"ok\":true}"}
+ ]
+ }`)
+
+ // Step 1: 检测 function_call_output
+ hasFCO := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
+ require.True(t, hasFCO, "payload 包含 function_call_output")
+
+ // Step 2: 模拟 previous_response_not_found 错误
+ turnErr := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ false,
+ )
+ require.True(t, isOpenAIWSIngressPreviousResponseNotFound(turnErr))
+
+ // Step 3: 验证 ContinueTurn 处置
+ reason, _ := classifyOpenAIWSIngressTurnAbortReason(turnErr)
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition)
+
+ // Step 4: set 策略 — 失败(currentPreviousResponseID 不为空)
+ currentPrevID := gjson.GetBytes(payload, "previous_response_id").String()
+ require.NotEmpty(t, currentPrevID, "set 策略前提条件不满足(需要 currentPreviousResponseID 为空)")
+
+ // Step 5: align 策略 — 假设 expectedPrev 为空(session 中无记录)
+ expectedPrev := ""
+ _, aligned, alignErr := alignStoreDisabledPreviousResponseID(payload, expectedPrev)
+ require.NoError(t, alignErr)
+ require.False(t, aligned, "expectedPrev 为空时 align 应失败")
+
+ // Step 6: 此时应该 abort(return false)而非 drop
+ // 验证:如果错误地执行 drop,会导致 function_call_output 成为孤立引用
+ dropped, removed, _ := dropPreviousResponseIDFromRawPayload(payload)
+ if removed {
+ hasFCOAfterDrop := gjson.GetBytes(dropped, `input.#(type=="function_call_output")`).Exists()
+ require.True(t, hasFCOAfterDrop,
+ "drop 后 function_call_output 仍存在,上游会报 'No tool call found'")
+ }
+
+ // Step 7: 正确行为——abort 后生成 error 事件通知客户端
+ wroteDownstream := openAIWSIngressTurnWroteDownstream(turnErr)
+ require.False(t, wroteDownstream, "abort 前未向客户端写入数据")
+
+ abortMessage := "turn failed: " + string(reason)
+ errorEvent := []byte(`{"type":"error","error":{"type":"server_error","code":"` +
+ string(reason) + `","message":` + strconv.Quote(abortMessage) + `}}`)
+
+ var parsed map[string]any
+ require.NoError(t, json.Unmarshal(errorEvent, &parsed))
+ require.Equal(t, "error", parsed["type"])
+}
+
+func TestEndToEnd_NonFunctionCallOutput_CanDrop(t *testing.T) {
+ t.Parallel()
+
+ // 对照场景:没有 function_call_output 的 payload 可以安全 drop previous_response_id
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_old",
+ "input":[{"type":"input_text","text":"hello"}]
+ }`)
+
+ hasFCO := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists()
+ require.False(t, hasFCO, "此 payload 不包含 function_call_output")
+
+ // drop 是安全的
+ dropped, removed, err := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, err)
+ require.True(t, removed)
+ require.False(t, gjson.GetBytes(dropped, "previous_response_id").Exists())
+
+ // input 仍然有效(input_text 不依赖 previous_response_id)
+ require.Equal(t, "hello", gjson.GetBytes(dropped, "input.0.text").String())
+}
+
+// ---------------------------------------------------------------------------
+// shouldKeepIngressPreviousResponseID 与 function_call_output 的交互测试
+// ---------------------------------------------------------------------------
+
+func TestShouldKeepIngressPreviousResponseID_FunctionCallOutputCallIDMatch(t *testing.T) {
+ t.Parallel()
+
+ // 当 function_call_output 的 call_id 与 pending call_id 匹配时,应保留 previous_response_id
+ previousPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
+ currentPayload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_1",
+ "input":[{"type":"function_call_output","call_id":"call_match","output":"ok"}]
+ }`)
+
+ keep, reason, err := shouldKeepIngressPreviousResponseID(
+ previousPayload,
+ currentPayload,
+ "resp_1",
+ true, // hasFunctionCallOutput
+ []string{"call_match"}, // pendingCallIDs
+ []string{"call_match"}, // requestCallIDs
+ )
+ require.NoError(t, err)
+ require.True(t, keep, "call_id 匹配时应保留 previous_response_id")
+ require.Equal(t, "function_call_output_call_id_match", reason)
+}
+
+func TestShouldKeepIngressPreviousResponseID_FunctionCallOutputCallIDMismatch(t *testing.T) {
+ t.Parallel()
+
+ // 当 function_call_output 的 call_id 与 pending call_id 不匹配时,
+ // 应放弃 previous_response_id
+ previousPayload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
+ currentPayload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_1",
+ "input":[{"type":"function_call_output","call_id":"call_wrong","output":"ok"}]
+ }`)
+
+ keep, reason, err := shouldKeepIngressPreviousResponseID(
+ previousPayload,
+ currentPayload,
+ "resp_1",
+ true, // hasFunctionCallOutput
+ []string{"call_real"}, // pendingCallIDs
+ []string{"call_wrong"}, // requestCallIDs
+ )
+ require.NoError(t, err)
+ require.False(t, keep, "call_id 不匹配时应放弃 previous_response_id")
+ require.Equal(t, "function_call_output_call_id_mismatch", reason)
+}
+
+// ---------------------------------------------------------------------------
+// isOpenAIWSIngressTurnRetryable 与 function_call_output 场景的交互
+// ---------------------------------------------------------------------------
+
+func TestIsOpenAIWSIngressTurnRetryable_PreviousResponseNotFound(t *testing.T) {
+ t.Parallel()
+
+ // previous_response_not_found 不应被标记为 retryable(因为有专门的恢复路径)
+ err := wrapOpenAIWSIngressTurnError(
+ openAIWSIngressStagePreviousResponseNotFound,
+ errors.New("previous response not found"),
+ false,
+ )
+ require.False(t, isOpenAIWSIngressTurnRetryable(err),
+ "previous_response_not_found 有专门的恢复逻辑,不走通用重试")
+}
+
+func TestIsOpenAIWSIngressTurnRetryable_WroteDownstreamBlocksRetry(t *testing.T) {
+ t.Parallel()
+
+ // wroteDownstream=true 时,任何 stage 都不应 retryable
+ err := wrapOpenAIWSIngressTurnError(
+ "write_upstream",
+ errors.New("write failed"),
+ true, // wroteDownstream
+ )
+ require.False(t, isOpenAIWSIngressTurnRetryable(err),
+ "wroteDownstream=true 时不应重试")
+}
+
+// ---------------------------------------------------------------------------
+// normalizeOpenAIWSIngressPayloadBeforeSend 与恢复的集成测试
+// ---------------------------------------------------------------------------
+
+func TestNormalizePayload_FunctionCallOutputPassthrough(t *testing.T) {
+ t.Parallel()
+
+ // 透传模式:normalizer 不再注入 previous_response_id,原样传递
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 1,
+ turn: 2,
+ connID: "conn_test",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "",
+ expectedPreviousResponse: "resp_expected",
+ pendingExpectedCallIDs: []string{"call_1"},
+ })
+
+ // 透传模式:previous_response_id 保持客户端原值(空),由下游 recovery 处理
+ require.Empty(t, out.currentPreviousResponseID,
+ "透传模式不应注入 previous_response_id")
+ require.True(t, out.hasFunctionCallOutputCallID)
+ require.Equal(t, []string{"call_1"}, out.functionCallOutputCallIDs)
+}
diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go
deleted file mode 100644
index 592801f66..000000000
--- a/backend/internal/service/openai_ws_forwarder_success_test.go
+++ /dev/null
@@ -1,1306 +0,0 @@
-package service
-
-import (
- "context"
- "encoding/json"
- "errors"
- "io"
- "net/http"
- "net/http/httptest"
- "strconv"
- "strings"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/gin-gonic/gin"
- "github.com/gorilla/websocket"
- "github.com/stretchr/testify/require"
- "github.com/tidwall/gjson"
-)
-
-func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- type receivedPayload struct {
- Type string
- PreviousResponseID string
- StreamExists bool
- Stream bool
- }
- receivedCh := make(chan receivedPayload, 1)
-
- upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- conn, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- t.Errorf("upgrade websocket failed: %v", err)
- return
- }
- defer func() {
- _ = conn.Close()
- }()
-
- var request map[string]any
- if err := conn.ReadJSON(&request); err != nil {
- t.Errorf("read ws request failed: %v", err)
- return
- }
- requestJSON := requestToJSONString(request)
- receivedCh <- receivedPayload{
- Type: strings.TrimSpace(gjson.Get(requestJSON, "type").String()),
- PreviousResponseID: strings.TrimSpace(gjson.Get(requestJSON, "previous_response_id").String()),
- StreamExists: gjson.Get(requestJSON, "stream").Exists(),
- Stream: gjson.Get(requestJSON, "stream").Bool(),
- }
-
- if err := conn.WriteJSON(map[string]any{
- "type": "response.created",
- "response": map[string]any{
- "id": "resp_new_1",
- "model": "gpt-5.1",
- },
- }); err != nil {
- t.Errorf("write response.created failed: %v", err)
- return
- }
- if err := conn.WriteJSON(map[string]any{
- "type": "response.completed",
- "response": map[string]any{
- "id": "resp_new_1",
- "model": "gpt-5.1",
- "usage": map[string]any{
- "input_tokens": 12,
- "output_tokens": 7,
- "input_tokens_details": map[string]any{
- "cached_tokens": 3,
- },
- },
- },
- }); err != nil {
- t.Errorf("write response.completed failed: %v", err)
- return
- }
- }))
- defer wsServer.Close()
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "unit-test-agent/1.0")
- groupID := int64(1001)
- c.Set("api_key", &APIKey{GroupID: &groupID})
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
- cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600
-
- upstream := &httpUpstreamRecorder{
- resp: &http.Response{
- StatusCode: http.StatusOK,
- Header: http.Header{"Content-Type": []string{"application/json"}},
- Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
- },
- }
-
- cache := &stubGatewayCache{}
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: upstream,
- cache: cache,
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 9,
- Name: "openai-ws",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 2,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": wsServer.URL,
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_1","input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, 12, result.Usage.InputTokens)
- require.Equal(t, 7, result.Usage.OutputTokens)
- require.Equal(t, 3, result.Usage.CacheReadInputTokens)
- require.Equal(t, "resp_new_1", result.RequestID)
- require.True(t, result.OpenAIWSMode)
- require.False(t, gjson.GetBytes(upstream.lastBody, "model").Exists(), "WSv2 成功时不应回落 HTTP 上游")
-
- received := <-receivedCh
- require.Equal(t, "response.create", received.Type)
- require.Equal(t, "resp_prev_1", received.PreviousResponseID)
- require.True(t, received.StreamExists, "WS 请求应携带 stream 字段")
- require.False(t, received.Stream, "应保持客户端 stream=false 的原始语义")
-
- store := svc.getOpenAIWSStateStore()
- mappedAccountID, getErr := store.GetResponseAccount(context.Background(), groupID, "resp_new_1")
- require.NoError(t, getErr)
- require.Equal(t, account.ID, mappedAccountID)
- connID, ok := store.GetResponseConn("resp_new_1")
- require.True(t, ok)
- require.NotEmpty(t, connID)
-
- responseBody := rec.Body.Bytes()
- require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String())
-}
-
-func requestToJSONString(payload map[string]any) string {
- if len(payload) == 0 {
- return "{}"
- }
- b, err := json.Marshal(payload)
- if err != nil {
- return "{}"
- }
- return string(b)
-}
-
-func TestLogOpenAIWSBindResponseAccountWarn(t *testing.T) {
- require.NotPanics(t, func() {
- logOpenAIWSBindResponseAccountWarn(1, 2, "resp_ok", nil)
- })
- require.NotPanics(t, func() {
- logOpenAIWSBindResponseAccountWarn(1, 2, "resp_err", errors.New("bind failed"))
- })
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_RewriteModelAndToolCallsOnCompletedEvent(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
- groupID := int64(3001)
- c.Set("api_key", &APIKey{GroupID: &groupID})
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_model_tool_1","model":"gpt-5.1","tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}],"usage":{"input_tokens":2,"output_tokens":1}},"tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}]}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 1301,
- Name: "openai-rewrite",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "model_mapping": map[string]any{
- "custom-original-model": "gpt-5.1",
- },
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"custom-original-model","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "resp_model_tool_1", result.RequestID)
- require.Equal(t, "custom-original-model", gjson.GetBytes(rec.Body.Bytes(), "model").String(), "响应模型应回写为原始请求模型")
- require.Equal(t, "edit", gjson.GetBytes(rec.Body.Bytes(), "tool_calls.0.function.name").String(), "工具名称应被修正为 OpenCode 规范")
-}
-
-func TestOpenAIWSPayloadString_OnlyAcceptsStringValues(t *testing.T) {
- payload := map[string]any{
- "type": nil,
- "model": 123,
- "prompt_cache_key": " cache-key ",
- "previous_response_id": []byte(" resp_1 "),
- }
-
- require.Equal(t, "", openAIWSPayloadString(payload, "type"))
- require.Equal(t, "", openAIWSPayloadString(payload, "model"))
- require.Equal(t, "cache-key", openAIWSPayloadString(payload, "prompt_cache_key"))
- require.Equal(t, "resp_1", openAIWSPayloadString(payload, "previous_response_id"))
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- var upgradeCount atomic.Int64
- var sequence atomic.Int64
- upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- upgradeCount.Add(1)
- conn, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- t.Errorf("upgrade websocket failed: %v", err)
- return
- }
- defer func() {
- _ = conn.Close()
- }()
-
- for {
- var request map[string]any
- if err := conn.ReadJSON(&request); err != nil {
- return
- }
- idx := sequence.Add(1)
- responseID := "resp_reuse_" + strconv.FormatInt(idx, 10)
- if err := conn.WriteJSON(map[string]any{
- "type": "response.created",
- "response": map[string]any{
- "id": responseID,
- "model": "gpt-5.1",
- },
- }); err != nil {
- return
- }
- if err := conn.WriteJSON(map[string]any{
- "type": "response.completed",
- "response": map[string]any{
- "id": responseID,
- "model": "gpt-5.1",
- "usage": map[string]any{
- "input_tokens": 2,
- "output_tokens": 1,
- },
- },
- }); err != nil {
- return
- }
- }
- }))
- defer wsServer.Close()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
- account := &Account{
- ID: 19,
- Name: "openai-ws",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 2,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": wsServer.URL,
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- for i := 0; i < 2; i++ {
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
- groupID := int64(2001)
- c.Set("api_key", &APIKey{GroupID: &groupID})
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_reuse","input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_"))
- }
-
- require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链")
- metrics := svc.SnapshotOpenAIWSPoolMetrics()
- require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1))
- require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
- c.Request.Header.Set("session_id", "sess-oauth-1")
- c.Request.Header.Set("conversation_id", "conv-oauth-1")
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.AllowStoreRecovery = false
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_oauth_1","model":"gpt-5.1","usage":{"input_tokens":3,"output_tokens":2}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
- account := &Account{
- ID: 29,
- Name: "openai-oauth",
- Platform: PlatformOpenAI,
- Type: AccountTypeOAuth,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "oauth-token-1",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"store":true,"input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "resp_oauth_1", result.RequestID)
-
- require.NotNil(t, captureConn.lastWrite)
- requestJSON := requestToJSONString(captureConn.lastWrite)
- require.True(t, gjson.Get(requestJSON, "store").Exists(), "OAuth WSv2 应显式写入 store 字段")
- require.False(t, gjson.Get(requestJSON, "store").Bool(), "默认策略应将 OAuth store 置为 false")
- require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段")
- require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true")
- require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta"))
- require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id"))
- require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id"))
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_prompt_cache_key","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
- account := &Account{
- ID: 31,
- Name: "openai-oauth",
- Platform: PlatformOpenAI,
- Type: AccountTypeOAuth,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "access_token": "oauth-token-1",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":true,"prompt_cache_key":"pcache_123","input":[{"type":"input_text","text":"hi"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "resp_prompt_cache_key", result.RequestID)
-
- require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id"))
- require.Empty(t, captureDialer.lastHeaders.Get("conversation_id"))
- require.NotNil(t, captureConn.lastWrite)
- require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists())
-}
-
-func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsockets = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false
-
- upstream := &httpUpstreamRecorder{
- resp: &http.Response{
- StatusCode: http.StatusOK,
- Header: http.Header{"Content-Type": []string{"application/json"}},
- Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)),
- },
- }
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: upstream,
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 39,
- Name: "openai-ws-v1",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": "https://api.openai.com/v1/responses",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_v1","input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.Error(t, err)
- require.Nil(t, result)
- require.Contains(t, err.Error(), "ws v1")
- require.Equal(t, http.StatusBadRequest, rec.Code)
- require.Contains(t, rec.Body.String(), "WSv1")
- require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求")
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_TurnStateAndMetadataReplayOnReconnect(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- var connIndex atomic.Int64
- headersCh := make(chan http.Header, 4)
- upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- idx := connIndex.Add(1)
- headersCh <- cloneHeader(r.Header)
-
- respHeader := http.Header{}
- if idx == 1 {
- respHeader.Set("x-codex-turn-state", "turn_state_first")
- }
- conn, err := upgrader.Upgrade(w, r, respHeader)
- if err != nil {
- t.Errorf("upgrade websocket failed: %v", err)
- return
- }
- defer func() {
- _ = conn.Close()
- }()
-
- var request map[string]any
- if err := conn.ReadJSON(&request); err != nil {
- t.Errorf("read ws request failed: %v", err)
- return
- }
- responseID := "resp_turn_" + strconv.FormatInt(idx, 10)
- if err := conn.WriteJSON(map[string]any{
- "type": "response.completed",
- "response": map[string]any{
- "id": responseID,
- "model": "gpt-5.1",
- "usage": map[string]any{
- "input_tokens": 2,
- "output_tokens": 1,
- },
- },
- }); err != nil {
- t.Errorf("write response.completed failed: %v", err)
- return
- }
- }))
- defer wsServer.Close()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 49,
- Name: "openai-turn-state",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": wsServer.URL,
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- reqBody := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
- rec1 := httptest.NewRecorder()
- c1, _ := gin.CreateTestContext(rec1)
- c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c1.Request.Header.Set("session_id", "session_turn_state")
- c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_1")
- result1, err := svc.Forward(context.Background(), c1, account, reqBody)
- require.NoError(t, err)
- require.NotNil(t, result1)
-
- sessionHash := svc.GenerateSessionHash(c1, reqBody)
- store := svc.getOpenAIWSStateStore()
- turnState, ok := store.GetSessionTurnState(0, sessionHash)
- require.True(t, ok)
- require.Equal(t, "turn_state_first", turnState)
-
- // 主动淘汰连接,模拟下一次请求发生重连。
- connID, hasConn := store.GetResponseConn(result1.RequestID)
- require.True(t, hasConn)
- svc.getOpenAIWSConnPool().evictConn(account.ID, connID)
-
- rec2 := httptest.NewRecorder()
- c2, _ := gin.CreateTestContext(rec2)
- c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c2.Request.Header.Set("session_id", "session_turn_state")
- c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_2")
- result2, err := svc.Forward(context.Background(), c2, account, reqBody)
- require.NoError(t, err)
- require.NotNil(t, result2)
-
- firstHandshakeHeaders := <-headersCh
- secondHandshakeHeaders := <-headersCh
- require.Equal(t, "turn_meta_1", firstHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
- require.Equal(t, "turn_meta_2", secondHandshakeHeaders.Get("X-Codex-Turn-Metadata"))
- require.Equal(t, "turn_state_first", secondHandshakeHeaders.Get("X-Codex-Turn-State"))
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_GeneratePrewarm(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("session_id", "session-prewarm")
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_prewarm_1","model":"gpt-5.1","usage":{"input_tokens":0,"output_tokens":0}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_main_1","model":"gpt-5.1","usage":{"input_tokens":4,"output_tokens":2}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 59,
- Name: "openai-prewarm",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "resp_main_1", result.RequestID)
-
- require.Len(t, captureConn.writes, 2, "开启 generate=false 预热后应发送两次 WS 请求")
- firstWrite := requestToJSONString(captureConn.writes[0])
- secondWrite := requestToJSONString(captureConn.writes[1])
- require.True(t, gjson.Get(firstWrite, "generate").Exists())
- require.False(t, gjson.Get(firstWrite, "generate").Bool())
- require.False(t, gjson.Get(secondWrite, "generate").Exists())
-}
-
-func TestOpenAIGatewayService_PrewarmReadHonorsParentContext(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- toolCorrector: NewCodexToolCorrector(),
- }
- account := &Account{
- ID: 601,
- Name: "openai-prewarm-timeout",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- }
- conn := newOpenAIWSConn("prewarm_ctx_conn", account.ID, &openAIWSBlockingConn{
- readDelay: 200 * time.Millisecond,
- }, nil)
- lease := &openAIWSConnLease{
- accountID: account.ID,
- conn: conn,
- }
- payload := map[string]any{
- "type": "response.create",
- "model": "gpt-5.1",
- }
-
- ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond)
- defer cancel()
- start := time.Now()
- err := svc.performOpenAIWSGeneratePrewarm(
- ctx,
- lease,
- OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2},
- payload,
- "",
- map[string]any{"model": "gpt-5.1"},
- account,
- nil,
- 0,
- )
- elapsed := time.Since(start)
- require.Error(t, err)
- require.Contains(t, err.Error(), "prewarm_read_event")
- require.Less(t, elapsed, 180*time.Millisecond, "预热读取应受父 context 取消控制,不应阻塞到 read_timeout")
-}
-
-func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
-
- captureConn := &openAIWSCaptureConn{
- events: [][]byte{
- []byte(`{"type":"response.completed","response":{"id":"resp_meta_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_meta_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 69,
- Name: "openai-turn-metadata",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
-
- rec1 := httptest.NewRecorder()
- c1, _ := gin.CreateTestContext(rec1)
- c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c1.Request.Header.Set("session_id", "session-metadata-reuse")
- c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_1")
- result1, err := svc.Forward(context.Background(), c1, account, body)
- require.NoError(t, err)
- require.NotNil(t, result1)
- require.Equal(t, "resp_meta_1", result1.RequestID)
-
- rec2 := httptest.NewRecorder()
- c2, _ := gin.CreateTestContext(rec2)
- c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c2.Request.Header.Set("session_id", "session-metadata-reuse")
- c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_2")
- result2, err := svc.Forward(context.Background(), c2, account, body)
- require.NoError(t, err)
- require.NotNil(t, result2)
- require.Equal(t, "resp_meta_2", result2.RequestID)
-
- require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接")
- require.Len(t, captureConn.writes, 2)
-
- firstWrite := requestToJSONString(captureConn.writes[0])
- secondWrite := requestToJSONString(captureConn.writes[1])
- require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String())
- require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String())
-}
-
-func TestOpenAIGatewayService_Forward_WSv2StoreFalseSessionConnIsolation(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- var upgradeCount atomic.Int64
- var sequence atomic.Int64
- upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- upgradeCount.Add(1)
- conn, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- t.Errorf("upgrade websocket failed: %v", err)
- return
- }
- defer func() {
- _ = conn.Close()
- }()
-
- for {
- var request map[string]any
- if err := conn.ReadJSON(&request); err != nil {
- return
- }
- responseID := "resp_store_false_" + strconv.FormatInt(sequence.Add(1), 10)
- if err := conn.WriteJSON(map[string]any{
- "type": "response.completed",
- "response": map[string]any{
- "id": responseID,
- "model": "gpt-5.1",
- "usage": map[string]any{
- "input_tokens": 1,
- "output_tokens": 1,
- },
- },
- }); err != nil {
- return
- }
- }
- }))
- defer wsServer.Close()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
- cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 79,
- Name: "openai-store-false",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 2,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": wsServer.URL,
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
-
- rec1 := httptest.NewRecorder()
- c1, _ := gin.CreateTestContext(rec1)
- c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c1.Request.Header.Set("session_id", "session_store_false_a")
- result1, err := svc.Forward(context.Background(), c1, account, body)
- require.NoError(t, err)
- require.NotNil(t, result1)
- require.Equal(t, int64(1), upgradeCount.Load())
-
- rec2 := httptest.NewRecorder()
- c2, _ := gin.CreateTestContext(rec2)
- c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c2.Request.Header.Set("session_id", "session_store_false_a")
- result2, err := svc.Forward(context.Background(), c2, account, body)
- require.NoError(t, err)
- require.NotNil(t, result2)
- require.Equal(t, int64(1), upgradeCount.Load(), "同一 session(store=false) 应复用同一 WS 连接")
-
- rec3 := httptest.NewRecorder()
- c3, _ := gin.CreateTestContext(rec3)
- c3.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c3.Request.Header.Set("session_id", "session_store_false_b")
- result3, err := svc.Forward(context.Background(), c3, account, body)
- require.NoError(t, err)
- require.NotNil(t, result3)
- require.Equal(t, int64(2), upgradeCount.Load(), "不同 session(store=false) 应隔离连接,避免续链状态互相覆盖")
-}
-
-func TestOpenAIGatewayService_Forward_WSv2StoreFalseDisableForceNewConnAllowsReuse(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- var upgradeCount atomic.Int64
- var sequence atomic.Int64
- upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}
- wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
- upgradeCount.Add(1)
- conn, err := upgrader.Upgrade(w, r, nil)
- if err != nil {
- t.Errorf("upgrade websocket failed: %v", err)
- return
- }
- defer func() {
- _ = conn.Close()
- }()
-
- for {
- var request map[string]any
- if err := conn.ReadJSON(&request); err != nil {
- return
- }
- responseID := "resp_store_false_reuse_" + strconv.FormatInt(sequence.Add(1), 10)
- if err := conn.WriteJSON(map[string]any{
- "type": "response.completed",
- "response": map[string]any{
- "id": responseID,
- "model": "gpt-5.1",
- "usage": map[string]any{
- "input_tokens": 1,
- "output_tokens": 1,
- },
- },
- }); err != nil {
- return
- }
- }
- }))
- defer wsServer.Close()
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: &httpUpstreamRecorder{},
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- }
-
- account := &Account{
- ID: 80,
- Name: "openai-store-false-reuse",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 2,
- Credentials: map[string]any{
- "api_key": "sk-test",
- "base_url": wsServer.URL,
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`)
-
- rec1 := httptest.NewRecorder()
- c1, _ := gin.CreateTestContext(rec1)
- c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c1.Request.Header.Set("session_id", "session_store_false_reuse_a")
- result1, err := svc.Forward(context.Background(), c1, account, body)
- require.NoError(t, err)
- require.NotNil(t, result1)
- require.Equal(t, int64(1), upgradeCount.Load())
-
- rec2 := httptest.NewRecorder()
- c2, _ := gin.CreateTestContext(rec2)
- c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c2.Request.Header.Set("session_id", "session_store_false_reuse_b")
- result2, err := svc.Forward(context.Background(), c2, account, body)
- require.NoError(t, err)
- require.NotNil(t, result2)
- require.Equal(t, int64(1), upgradeCount.Load(), "关闭强制新连后,不同 session(store=false) 可复用连接")
-}
-
-func TestOpenAIGatewayService_Forward_WSv2ReadTimeoutAppliesPerRead(t *testing.T) {
- gin.SetMode(gin.TestMode)
-
- rec := httptest.NewRecorder()
- c, _ := gin.CreateTestContext(rec)
- c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
- c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0")
-
- cfg := &config.Config{}
- cfg.Security.URLAllowlist.Enabled = false
- cfg.Security.URLAllowlist.AllowInsecureHTTP = true
- cfg.Gateway.OpenAIWS.Enabled = true
- cfg.Gateway.OpenAIWS.OAuthEnabled = true
- cfg.Gateway.OpenAIWS.APIKeyEnabled = true
- cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3
- cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 1
- cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3
-
- captureConn := &openAIWSCaptureConn{
- readDelays: []time.Duration{
- 700 * time.Millisecond,
- 700 * time.Millisecond,
- },
- events: [][]byte{
- []byte(`{"type":"response.created","response":{"id":"resp_timeout_ok","model":"gpt-5.1"}}`),
- []byte(`{"type":"response.completed","response":{"id":"resp_timeout_ok","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`),
- },
- }
- captureDialer := &openAIWSCaptureDialer{conn: captureConn}
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(captureDialer)
-
- upstream := &httpUpstreamRecorder{
- resp: &http.Response{
- StatusCode: http.StatusOK,
- Header: http.Header{"Content-Type": []string{"application/json"}},
- Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)),
- },
- }
-
- svc := &OpenAIGatewayService{
- cfg: cfg,
- httpUpstream: upstream,
- cache: &stubGatewayCache{},
- openaiWSResolver: NewOpenAIWSProtocolResolver(cfg),
- toolCorrector: NewCodexToolCorrector(),
- openaiWSPool: pool,
- }
-
- account := &Account{
- ID: 81,
- Name: "openai-read-timeout",
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Status: StatusActive,
- Schedulable: true,
- Concurrency: 1,
- Credentials: map[string]any{
- "api_key": "sk-test",
- },
- Extra: map[string]any{
- "responses_websockets_v2_enabled": true,
- },
- }
-
- body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`)
- result, err := svc.Forward(context.Background(), c, account, body)
- require.NoError(t, err)
- require.NotNil(t, result)
- require.Equal(t, "resp_timeout_ok", result.RequestID)
- require.Nil(t, upstream.lastReq, "每次 Read 都应独立应用超时;总时长超过 read_timeout 不应误回退 HTTP")
-}
-
-type openAIWSCaptureDialer struct {
- mu sync.Mutex
- conn *openAIWSCaptureConn
- lastHeaders http.Header
- handshake http.Header
- dialCount int
-}
-
-func (d *openAIWSCaptureDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = proxyURL
- d.mu.Lock()
- d.lastHeaders = cloneHeader(headers)
- d.dialCount++
- respHeaders := cloneHeader(d.handshake)
- d.mu.Unlock()
- return d.conn, 0, respHeaders, nil
-}
-
-func (d *openAIWSCaptureDialer) DialCount() int {
- d.mu.Lock()
- defer d.mu.Unlock()
- return d.dialCount
-}
-
-type openAIWSCaptureConn struct {
- mu sync.Mutex
- readDelays []time.Duration
- events [][]byte
- lastWrite map[string]any
- writes []map[string]any
- closed bool
-}
-
-func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error {
- _ = ctx
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.closed {
- return errOpenAIWSConnClosed
- }
- switch payload := value.(type) {
- case map[string]any:
- c.lastWrite = cloneMapStringAny(payload)
- c.writes = append(c.writes, cloneMapStringAny(payload))
- case json.RawMessage:
- var parsed map[string]any
- if err := json.Unmarshal(payload, &parsed); err == nil {
- c.lastWrite = cloneMapStringAny(parsed)
- c.writes = append(c.writes, cloneMapStringAny(parsed))
- }
- case []byte:
- var parsed map[string]any
- if err := json.Unmarshal(payload, &parsed); err == nil {
- c.lastWrite = cloneMapStringAny(parsed)
- c.writes = append(c.writes, cloneMapStringAny(parsed))
- }
- }
- return nil
-}
-
-func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
- if ctx == nil {
- ctx = context.Background()
- }
- c.mu.Lock()
- if c.closed {
- c.mu.Unlock()
- return nil, errOpenAIWSConnClosed
- }
- if len(c.events) == 0 {
- c.mu.Unlock()
- return nil, io.EOF
- }
- delay := time.Duration(0)
- if len(c.readDelays) > 0 {
- delay = c.readDelays[0]
- c.readDelays = c.readDelays[1:]
- }
- event := c.events[0]
- c.events = c.events[1:]
- c.mu.Unlock()
- if delay > 0 {
- timer := time.NewTimer(delay)
- defer timer.Stop()
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-timer.C:
- }
- }
- return event, nil
-}
-
-func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
- _ = ctx
- return nil
-}
-
-func (c *openAIWSCaptureConn) Close() error {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.closed = true
- return nil
-}
-
-func cloneMapStringAny(src map[string]any) map[string]any {
- if src == nil {
- return nil
- }
- dst := make(map[string]any, len(src))
- for k, v := range src {
- dst[k] = v
- }
- return dst
-}
diff --git a/backend/internal/service/openai_ws_forwarder_turn_error_test.go b/backend/internal/service/openai_ws_forwarder_turn_error_test.go
new file mode 100644
index 000000000..d2e93d8be
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_turn_error_test.go
@@ -0,0 +1,53 @@
+package service
+
+import (
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIWSIngressTurnPartialResult_NotTurnError(t *testing.T) {
+ result, ok := OpenAIWSIngressTurnPartialResult(errors.New("plain error"))
+ require.False(t, ok)
+ require.Nil(t, result)
+}
+
+func TestOpenAIWSIngressTurnPartialResult_DeepCopy(t *testing.T) {
+ partial := &OpenAIForwardResult{
+ RequestID: "resp_partial",
+ Usage: OpenAIUsage{
+ InputTokens: 12,
+ OutputTokens: 34,
+ },
+ PendingFunctionCallIDs: []string{"call_1", "call_2"},
+ }
+ err := wrapOpenAIWSIngressTurnErrorWithPartial("read_upstream", errors.New("boom"), false, partial)
+
+ got, ok := OpenAIWSIngressTurnPartialResult(err)
+ require.True(t, ok)
+ require.NotNil(t, got)
+ require.Equal(t, partial.RequestID, got.RequestID)
+ require.Equal(t, partial.Usage, got.Usage)
+ require.Equal(t, partial.PendingFunctionCallIDs, got.PendingFunctionCallIDs)
+
+ // mutate returned copy should not affect stored partial result
+ got.PendingFunctionCallIDs[0] = "changed"
+ again, ok := OpenAIWSIngressTurnPartialResult(err)
+ require.True(t, ok)
+ require.Equal(t, "call_1", again.PendingFunctionCallIDs[0])
+}
+
+func TestOpenAIWSClientReadIdleTimeout_DefaultAndConfig(t *testing.T) {
+ svc := &OpenAIGatewayService{}
+ require.Equal(t, 30*time.Minute, svc.openAIWSClientReadIdleTimeout())
+
+ svc.cfg = &config.Config{}
+ svc.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 1800
+ require.Equal(t, 30*time.Minute, svc.openAIWSClientReadIdleTimeout())
+
+ svc.cfg.Gateway.OpenAIWS.ClientReadIdleTimeoutSeconds = 120
+ require.Equal(t, 120*time.Second, svc.openAIWSClientReadIdleTimeout())
+}
diff --git a/backend/internal/service/openai_ws_hotpath_perf_test.go b/backend/internal/service/openai_ws_hotpath_perf_test.go
new file mode 100644
index 000000000..81c5d5521
--- /dev/null
+++ b/backend/internal/service/openai_ws_hotpath_perf_test.go
@@ -0,0 +1,931 @@
+package service
+
+import (
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+ "unsafe"
+
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// Mock conn for hotpath performance tests
+// ---------------------------------------------------------------------------
+
+type openAIWSNoopConn struct{}
+
+func (c *openAIWSNoopConn) WriteJSON(context.Context, any) error { return nil }
+func (c *openAIWSNoopConn) ReadMessage(context.Context) ([]byte, error) { return nil, nil }
+func (c *openAIWSNoopConn) Ping(context.Context) error { return nil }
+func (c *openAIWSNoopConn) Close() error { return nil }
+
+// openAIWSIdentityConn is a distinct conn instance used to verify pointer identity.
+type openAIWSIdentityConn struct{ tag string }
+
+func (c *openAIWSIdentityConn) WriteJSON(context.Context, any) error { return nil }
+func (c *openAIWSIdentityConn) ReadMessage(context.Context) ([]byte, error) { return nil, nil }
+func (c *openAIWSIdentityConn) Ping(context.Context) error { return nil }
+func (c *openAIWSIdentityConn) Close() error { return nil }
+
+// ===================================================================
+// 1. maybeTouchLease throttle
+// ===================================================================
+
+func TestMaybeTouchLease_NilReceiverDoesNotPanic(t *testing.T) {
+ var c *openAIWSIngressContext
+ require.NotPanics(t, func() {
+ c.maybeTouchLease(time.Minute)
+ })
+}
+
+func TestMaybeTouchLease_FirstCallAlwaysTouches(t *testing.T) {
+ c := &openAIWSIngressContext{}
+ require.Zero(t, c.lastTouchUnixNano.Load(), "precondition: lastTouchUnixNano should be zero")
+
+ c.maybeTouchLease(5 * time.Minute)
+
+ require.NotZero(t, c.lastTouchUnixNano.Load(), "first maybeTouchLease must update lastTouchUnixNano")
+ require.False(t, c.expiresAt().IsZero(), "first maybeTouchLease must set expiresAt")
+}
+
+func TestMaybeTouchLease_SecondCallWithin1sIsSkipped(t *testing.T) {
+ c := &openAIWSIngressContext{}
+
+ // First touch
+ c.maybeTouchLease(5 * time.Minute)
+ firstExpiry := c.expiresAt()
+ firstTouch := c.lastTouchUnixNano.Load()
+ require.NotZero(t, firstTouch)
+
+ // Second touch immediately -- within 1s, should be skipped
+ c.maybeTouchLease(10 * time.Minute)
+ secondExpiry := c.expiresAt()
+ secondTouch := c.lastTouchUnixNano.Load()
+
+ require.Equal(t, firstTouch, secondTouch, "lastTouchUnixNano should NOT change within 1s")
+ require.Equal(t, firstExpiry, secondExpiry, "expiresAt should NOT change within 1s")
+}
+
+func TestMaybeTouchLease_CallAfter1sActuallyTouches(t *testing.T) {
+ c := &openAIWSIngressContext{}
+
+ c.maybeTouchLease(5 * time.Minute)
+ firstExpiry := c.expiresAt()
+
+ // Simulate 1s+ passing by backdating the lastTouchUnixNano
+ backdated := time.Now().Add(-2 * time.Second).UnixNano()
+ c.lastTouchUnixNano.Store(backdated)
+ // Also backdate expiresAt so we can observe the change
+ c.setExpiresAt(time.Now().Add(-time.Minute))
+ expiryAfterBackdate := c.expiresAt()
+ require.True(t, expiryAfterBackdate.Before(firstExpiry), "precondition: expiresAt should be backdated")
+
+ c.maybeTouchLease(5 * time.Minute)
+ touchAfter := c.lastTouchUnixNano.Load()
+ secondExpiry := c.expiresAt()
+
+ require.Greater(t, touchAfter, backdated, "lastTouchUnixNano should advance past the backdated value")
+ require.True(t, secondExpiry.After(expiryAfterBackdate), "expiresAt should advance after 1s+ gap")
+}
+
+func TestTouchLease_NilReceiverDoesNotPanic(t *testing.T) {
+ var c *openAIWSIngressContext
+ require.NotPanics(t, func() {
+ c.touchLease(time.Now(), 5*time.Minute)
+ })
+}
+
+func TestTouchLease_AlwaysUpdatesLastTouchUnixNano(t *testing.T) {
+ c := &openAIWSIngressContext{}
+
+ now := time.Now()
+ c.touchLease(now, 5*time.Minute)
+ first := c.lastTouchUnixNano.Load()
+ require.NotZero(t, first)
+
+ // touchLease (non-throttled) always updates, even if called again immediately.
+ time.Sleep(time.Millisecond) // ensure clock moves forward
+ now2 := time.Now()
+ c.touchLease(now2, 5*time.Minute)
+ second := c.lastTouchUnixNano.Load()
+ require.Greater(t, second, first, "touchLease must always update lastTouchUnixNano")
+}
+
+// ===================================================================
+// 2. activeConn cached connection
+// ===================================================================
+
+func TestActiveConn_NilLeaseReturnsError(t *testing.T) {
+ var lease *openAIWSIngressContextLease
+ conn, err := lease.activeConn()
+ require.Nil(t, conn)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestActiveConn_NilContextReturnsError(t *testing.T) {
+ lease := &openAIWSIngressContextLease{context: nil}
+ conn, err := lease.activeConn()
+ require.Nil(t, conn)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestActiveConn_ReleasedLeaseReturnsError(t *testing.T) {
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner",
+ upstream: &openAIWSNoopConn{},
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner",
+ }
+ lease.released.Store(true)
+
+ conn, err := lease.activeConn()
+ require.Nil(t, conn)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestActiveConn_FirstCallPopulatesCachedConn(t *testing.T) {
+ upstream := &openAIWSIdentityConn{tag: "primary"}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_1",
+ upstream: upstream,
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner_1",
+ }
+
+ require.Nil(t, lease.cachedConn, "precondition: cachedConn should be nil")
+
+ conn, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream, conn, "should return the upstream conn")
+ require.Equal(t, upstream, lease.cachedConn, "should populate cachedConn")
+}
+
+func TestActiveConn_SecondCallReturnsCachedDirectly(t *testing.T) {
+ upstream1 := &openAIWSIdentityConn{tag: "first"}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_cache",
+ upstream: upstream1,
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner_cache",
+ }
+
+ // First call populates cache
+ conn1, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream1, conn1)
+
+ // Swap the upstream -- cached path should NOT see the swap
+ upstream2 := &openAIWSIdentityConn{tag: "second"}
+ ctx.mu.Lock()
+ ctx.upstream = upstream2
+ ctx.mu.Unlock()
+
+ conn2, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream1, conn2, "second call should return cachedConn, not the swapped upstream")
+}
+
+func TestActiveConn_OwnerMismatchReturnsError(t *testing.T) {
+ ctx := &openAIWSIngressContext{
+ ownerID: "other_owner",
+ upstream: &openAIWSNoopConn{},
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "my_owner",
+ }
+
+ conn, err := lease.activeConn()
+ require.Nil(t, conn)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestActiveConn_NilUpstreamReturnsError(t *testing.T) {
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner",
+ upstream: nil,
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner",
+ }
+
+ conn, err := lease.activeConn()
+ require.Nil(t, conn)
+ require.ErrorIs(t, err, errOpenAIWSConnClosed)
+}
+
+func TestActiveConn_MarkBrokenClearsCachedConn(t *testing.T) {
+ upstream := &openAIWSNoopConn{}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_mb",
+ upstream: upstream,
+ }
+ pool := &openAIWSIngressContextPool{
+ idleTTL: 10 * time.Minute,
+ }
+ lease := &openAIWSIngressContextLease{
+ pool: pool,
+ context: ctx,
+ ownerID: "owner_mb",
+ }
+
+ // Populate cache
+ conn, err := lease.activeConn()
+ require.NoError(t, err)
+ require.NotNil(t, conn)
+ require.NotNil(t, lease.cachedConn)
+
+ lease.MarkBroken()
+ require.Nil(t, lease.cachedConn, "MarkBroken must clear cachedConn")
+}
+
+func TestActiveConn_ReleaseClearsCachedConn(t *testing.T) {
+ upstream := &openAIWSNoopConn{}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_rel",
+ upstream: upstream,
+ }
+ pool := &openAIWSIngressContextPool{
+ idleTTL: 10 * time.Minute,
+ }
+ lease := &openAIWSIngressContextLease{
+ pool: pool,
+ context: ctx,
+ ownerID: "owner_rel",
+ }
+
+ // Populate cache
+ conn, err := lease.activeConn()
+ require.NoError(t, err)
+ require.NotNil(t, conn)
+ require.NotNil(t, lease.cachedConn)
+
+ lease.Release()
+ require.Nil(t, lease.cachedConn, "Release must clear cachedConn")
+}
+
+func TestActiveConn_AfterClearCachedConn_ReacquiresViaMutex(t *testing.T) {
+ upstream1 := &openAIWSIdentityConn{tag: "v1"}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_reacq",
+ upstream: upstream1,
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner_reacq",
+ }
+
+ // Populate cache with upstream1
+ conn, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream1, conn)
+
+ // Simulate a cleared cache (e.g., after recovery)
+ lease.cachedConn = nil
+
+ // Swap upstream
+ upstream2 := &openAIWSIdentityConn{tag: "v2"}
+ ctx.mu.Lock()
+ ctx.upstream = upstream2
+ ctx.mu.Unlock()
+
+ // Should now re-acquire via mutex and return upstream2
+ conn2, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream2, conn2, "after clearing cachedConn, next call must re-acquire via mutex")
+ require.Equal(t, upstream2, lease.cachedConn, "should re-populate cachedConn with new upstream")
+}
+
+// ===================================================================
+// 3. Event type TrimSpace-free functions
+// ===================================================================
+
+func TestIsOpenAIWSTerminalEvent(t *testing.T) {
+ tests := []struct {
+ eventType string
+ want bool
+ }{
+ {"response.completed", true},
+ {"response.done", true},
+ {"response.failed", true},
+ {"response.incomplete", true},
+ {"response.cancelled", true},
+ {"response.canceled", true},
+ {"response.created", false},
+ {"response.in_progress", false},
+ {"response.output_text.delta", false},
+ {"", false},
+ {"unknown_event", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.eventType, func(t *testing.T) {
+ require.Equal(t, tt.want, isOpenAIWSTerminalEvent(tt.eventType))
+ })
+ }
+}
+
+func TestShouldPersistOpenAIWSLastResponseID_HotpathPerf(t *testing.T) {
+ tests := []struct {
+ eventType string
+ want bool
+ }{
+ {"response.completed", true},
+ {"response.done", true},
+ {"response.failed", false},
+ {"response.incomplete", false},
+ {"response.cancelled", false},
+ {"", false},
+ {"unknown_event", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.eventType, func(t *testing.T) {
+ require.Equal(t, tt.want, shouldPersistOpenAIWSLastResponseID(tt.eventType))
+ })
+ }
+}
+
+func TestIsOpenAIWSTokenEvent(t *testing.T) {
+ tests := []struct {
+ eventType string
+ want bool
+ }{
+ // Known false: structural events
+ {"response.created", false},
+ {"response.in_progress", false},
+ {"response.output_item.added", false},
+ {"response.output_item.done", false},
+ // Delta events
+ {"response.output_text.delta", true},
+ {"response.content_part.delta", true},
+ {"response.audio.delta", true},
+ {"response.function_call_arguments.delta", true},
+ // output_text prefix
+ {"response.output_text.done", true},
+ {"response.output_text.annotation.added", true},
+ // output prefix (but not output_item)
+ {"response.output.done", true},
+ // Terminal events that are also token events
+ {"response.completed", true},
+ {"response.done", true},
+ // Empty and unknown
+ {"", false},
+ {"unknown_event", false},
+ {"session.created", false},
+ {"session.updated", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.eventType, func(t *testing.T) {
+ require.Equal(t, tt.want, isOpenAIWSTokenEvent(tt.eventType))
+ })
+ }
+}
+
+func TestOpenAIWSEventShouldParseUsage(t *testing.T) {
+ tests := []struct {
+ eventType string
+ want bool
+ }{
+ {"response.completed", true},
+ {"response.done", true},
+ {"response.failed", true},
+ {"", false},
+ {"unknown", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.eventType, func(t *testing.T) {
+ require.Equal(t, tt.want, openAIWSEventShouldParseUsage(tt.eventType))
+ })
+ }
+}
+
+func TestOpenAIWSEventMayContainToolCalls(t *testing.T) {
+ tests := []struct {
+ eventType string
+ want bool
+ }{
+ // Explicit function_call / tool_call in name
+ {"response.function_call_arguments.delta", true},
+ {"response.function_call_arguments.done", true},
+ {"response.tool_call.delta", true},
+ // Structural events that may contain tool output items
+ {"response.output_item.added", true},
+ {"response.output_item.done", true},
+ {"response.completed", true},
+ {"response.done", true},
+ // Non-tool events
+ {"response.output_text.delta", false},
+ {"response.created", false},
+ {"response.in_progress", false},
+ {"", false},
+ {"unknown", false},
+ }
+ for _, tt := range tests {
+ t.Run(tt.eventType, func(t *testing.T) {
+ require.Equal(t, tt.want, openAIWSEventMayContainToolCalls(tt.eventType))
+ })
+ }
+}
+
+// ===================================================================
+// 4. parseOpenAIWSEventType (lightweight version)
+// ===================================================================
+
+func TestParseOpenAIWSEventType_EmptyMessage(t *testing.T) {
+ eventType, responseID := parseOpenAIWSEventType(nil)
+ require.Empty(t, eventType)
+ require.Empty(t, responseID)
+
+ eventType, responseID = parseOpenAIWSEventType([]byte{})
+ require.Empty(t, eventType)
+ require.Empty(t, responseID)
+}
+
+func TestParseOpenAIWSEventType_ResponseIDExtracted(t *testing.T) {
+ msg := []byte(`{"type":"response.completed","response":{"id":"resp_abc123"}}`)
+ eventType, responseID := parseOpenAIWSEventType(msg)
+ require.Equal(t, "response.completed", eventType)
+ require.Equal(t, "resp_abc123", responseID)
+}
+
+func TestParseOpenAIWSEventType_FallbackToID(t *testing.T) {
+ msg := []byte(`{"type":"response.output_text.delta","id":"evt_fallback_id"}`)
+ eventType, responseID := parseOpenAIWSEventType(msg)
+ require.Equal(t, "response.output_text.delta", eventType)
+ require.Equal(t, "evt_fallback_id", responseID)
+}
+
+func TestParseOpenAIWSEventType_ConsistentWithEnvelope(t *testing.T) {
+ testMessages := [][]byte{
+ []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`),
+ []byte(`{"type":"response.output_text.delta","id":"evt_2"}`),
+ []byte(`{"type":"response.created","response":{"id":"resp_3"}}`),
+ []byte(`{"type":"error","error":{"message":"bad request"}}`),
+ []byte(`{"type":"response.done","id":"resp_4","response":{"id":"resp_4_inner"}}`),
+ []byte(`{}`),
+ []byte(`{"type":"session.created"}`),
+ }
+ for i, msg := range testMessages {
+ t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) {
+ typeLight, idLight := parseOpenAIWSEventType(msg)
+ typeEnv, idEnv, _ := parseOpenAIWSEventEnvelope(msg)
+ require.Equal(t, typeEnv, typeLight, "eventType must match parseOpenAIWSEventEnvelope")
+ require.Equal(t, idEnv, idLight, "responseID must match parseOpenAIWSEventEnvelope")
+ })
+ }
+}
+
+// ===================================================================
+// 6. openAIWSResponseAccountCacheKey (xxhash, v2 prefix)
+// ===================================================================
+
+func TestOpenAIWSResponseAccountCacheKey_Deterministic(t *testing.T) {
+ key1 := openAIWSResponseAccountCacheKey("resp_deterministic_test")
+ key2 := openAIWSResponseAccountCacheKey("resp_deterministic_test")
+ require.Equal(t, key1, key2, "same responseID must produce the same key")
+}
+
+func TestOpenAIWSResponseAccountCacheKey_DifferentIDsDifferentKeys(t *testing.T) {
+ key1 := openAIWSResponseAccountCacheKey("resp_alpha")
+ key2 := openAIWSResponseAccountCacheKey("resp_beta")
+ require.NotEqual(t, key1, key2, "different responseIDs must produce different keys")
+}
+
+func TestOpenAIWSResponseAccountCacheKey_V2Prefix(t *testing.T) {
+ key := openAIWSResponseAccountCacheKey("resp_v2_check")
+ require.True(t, strings.Contains(key, "v2:"), "key must contain v2: prefix for version compatibility")
+}
+
+func TestOpenAIWSResponseAccountCacheKey_StartsWithCachePrefix(t *testing.T) {
+ key := openAIWSResponseAccountCacheKey("resp_prefix_check")
+ require.True(t, strings.HasPrefix(key, openAIWSResponseAccountCachePrefix),
+ "key must start with the standard cache prefix %q, got %q", openAIWSResponseAccountCachePrefix, key)
+}
+
+func TestOpenAIWSResponseAccountCacheKey_HexLength(t *testing.T) {
+ key := openAIWSResponseAccountCacheKey("resp_hex_length")
+ // Expected format: "openai:response:v2:<16 hex chars>"
+ prefix := openAIWSResponseAccountCachePrefix + "v2:"
+ require.True(t, strings.HasPrefix(key, prefix))
+ hexPart := strings.TrimPrefix(key, prefix)
+ require.Len(t, hexPart, 16, "xxhash hex digest should be zero-padded to 16 chars, got %q", hexPart)
+}
+
+func TestOpenAIWSResponseAccountCacheKey_ManyInputs_AllPaddedTo16(t *testing.T) {
+ // Verify that all inputs produce exactly 16-char hex, testing many variations.
+ prefix := openAIWSResponseAccountCachePrefix + "v2:"
+ for i := 0; i < 1000; i++ {
+ responseID := fmt.Sprintf("resp_%d", i)
+ key := openAIWSResponseAccountCacheKey(responseID)
+ hexPart := strings.TrimPrefix(key, prefix)
+ require.Len(t, hexPart, 16, "responseID=%q produced hex %q (len %d)", responseID, hexPart, len(hexPart))
+ }
+}
+
+// ===================================================================
+// 7. openAIWSSessionTurnStateKey uses strconv
+// ===================================================================
+
+func TestOpenAIWSSessionTurnStateKey_NormalCase(t *testing.T) {
+ key := openAIWSSessionTurnStateKey(123, "abc_hash")
+ require.Equal(t, "123:abc_hash", key)
+}
+
+func TestOpenAIWSSessionTurnStateKey_EmptySessionHash(t *testing.T) {
+ key := openAIWSSessionTurnStateKey(123, "")
+ require.Equal(t, "", key)
+}
+
+func TestOpenAIWSSessionTurnStateKey_WhitespaceOnlySessionHash(t *testing.T) {
+ key := openAIWSSessionTurnStateKey(123, " ")
+ require.Equal(t, "", key)
+}
+
+func TestOpenAIWSSessionTurnStateKey_NegativeGroupID(t *testing.T) {
+ key := openAIWSSessionTurnStateKey(-1, "hash")
+ require.Equal(t, "-1:hash", key)
+}
+
+func TestOpenAIWSSessionTurnStateKey_ZeroGroupID(t *testing.T) {
+ key := openAIWSSessionTurnStateKey(0, "hash")
+ require.Equal(t, "0:hash", key)
+}
+
+// ===================================================================
+// 8. openAIWSIngressContextSessionKey uses strconv
+// ===================================================================
+
+func TestOpenAIWSIngressContextSessionKey_NormalCase(t *testing.T) {
+ key := openAIWSIngressContextSessionKey(456, "session_xyz")
+ require.Equal(t, "456:session_xyz", key)
+}
+
+func TestOpenAIWSIngressContextSessionKey_EmptySessionHash(t *testing.T) {
+ key := openAIWSIngressContextSessionKey(456, "")
+ require.Equal(t, "", key)
+}
+
+func TestOpenAIWSIngressContextSessionKey_WhitespaceOnlySessionHash(t *testing.T) {
+ key := openAIWSIngressContextSessionKey(456, " \t ")
+ require.Equal(t, "", key)
+}
+
+func TestOpenAIWSIngressContextSessionKey_LargeGroupID(t *testing.T) {
+ key := openAIWSIngressContextSessionKey(9223372036854775807, "h")
+ require.Equal(t, "9223372036854775807:h", key)
+}
+
+// ===================================================================
+// 9. deriveOpenAISessionHash and deriveOpenAILegacySessionHash
+// ===================================================================
+
+func TestDeriveOpenAISessionHash_EmptyReturnsEmpty(t *testing.T) {
+ require.Equal(t, "", deriveOpenAISessionHash(""))
+ require.Equal(t, "", deriveOpenAISessionHash(" "))
+}
+
+func TestDeriveOpenAISessionHash_ProducesXXHash16Chars(t *testing.T) {
+ hash := deriveOpenAISessionHash("test_session_id")
+ require.Len(t, hash, 16, "xxhash hex should be exactly 16 chars, got %q", hash)
+ // Verify it's valid hex
+ for _, ch := range hash {
+ require.True(t, (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f'),
+ "hash should be lowercase hex, got char %c in %q", ch, hash)
+ }
+}
+
+func TestDeriveOpenAISessionHash_Deterministic(t *testing.T) {
+ h1 := deriveOpenAISessionHash("session_abc")
+ h2 := deriveOpenAISessionHash("session_abc")
+ require.Equal(t, h1, h2)
+}
+
+func TestDeriveOpenAILegacySessionHash_EmptyReturnsEmpty(t *testing.T) {
+ require.Equal(t, "", deriveOpenAILegacySessionHash(""))
+ require.Equal(t, "", deriveOpenAILegacySessionHash(" "))
+}
+
+func TestDeriveOpenAILegacySessionHash_ProducesSHA256_64Chars(t *testing.T) {
+ hash := deriveOpenAILegacySessionHash("test_session_id")
+ require.Len(t, hash, 64, "SHA-256 hex should be exactly 64 chars, got %q", hash)
+ for _, ch := range hash {
+ require.True(t, (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f'),
+ "hash should be lowercase hex, got char %c in %q", ch, hash)
+ }
+}
+
+func TestDeriveOpenAILegacySessionHash_Deterministic(t *testing.T) {
+ h1 := deriveOpenAILegacySessionHash("session_xyz")
+ h2 := deriveOpenAILegacySessionHash("session_xyz")
+ require.Equal(t, h1, h2)
+}
+
+func TestDeriveOpenAISessionHashes_MatchesIndividualFunctions(t *testing.T) {
+ sessionID := "test_combined_session"
+ currentHash, legacyHash := deriveOpenAISessionHashes(sessionID)
+
+ require.Equal(t, deriveOpenAISessionHash(sessionID), currentHash)
+ require.Equal(t, deriveOpenAILegacySessionHash(sessionID), legacyHash)
+}
+
+func TestDeriveOpenAISessionHashes_EmptyReturnsEmpty(t *testing.T) {
+ currentHash, legacyHash := deriveOpenAISessionHashes("")
+ require.Equal(t, "", currentHash)
+ require.Equal(t, "", legacyHash)
+}
+
+func TestDeriveOpenAISessionHashes_DifferentInputsDifferentOutputs(t *testing.T) {
+ h1Current, h1Legacy := deriveOpenAISessionHashes("session_A")
+ h2Current, h2Legacy := deriveOpenAISessionHashes("session_B")
+ require.NotEqual(t, h1Current, h2Current)
+ require.NotEqual(t, h1Legacy, h2Legacy)
+}
+
+func TestDeriveOpenAISessionHash_DifferentFromLegacy(t *testing.T) {
+ // xxhash and SHA-256 produce completely different outputs for the same input
+ currentHash := deriveOpenAISessionHash("same_input")
+ legacyHash := deriveOpenAILegacySessionHash("same_input")
+ require.NotEqual(t, currentHash, legacyHash, "xxhash and SHA-256 should produce different results")
+ require.Len(t, currentHash, 16)
+ require.Len(t, legacyHash, 64)
+}
+
+// ===================================================================
+// 10. State store sharded lock (responseToConn)
+// ===================================================================
+
+func TestConnShard_DistributesAcrossShards(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil).(*defaultOpenAIWSStateStore)
+
+ shardHits := make(map[int]int)
+ for i := 0; i < 256; i++ {
+ key := fmt.Sprintf("resp_%d", i)
+ shard := store.connShard(key)
+ // Find which shard index this is
+ for j := 0; j < openAIWSStateStoreConnShards; j++ {
+ if shard == &store.responseToConnShards[j] {
+ shardHits[j]++
+ break
+ }
+ }
+ }
+
+ // With 256 keys and 16 shards, each shard should get some keys.
+ // We don't require perfect uniformity, just that keys aren't all in one shard.
+ require.Greater(t, len(shardHits), 1, "keys must be distributed across multiple shards, got %d shards used", len(shardHits))
+ require.GreaterOrEqual(t, len(shardHits), openAIWSStateStoreConnShards/2,
+ "keys should hit at least half the shards for reasonable distribution")
+}
+
+func TestStateStore_ShardedBindGetDelete(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+
+ store.BindResponseConn("resp_shard_1", "conn_a", time.Minute)
+ store.BindResponseConn("resp_shard_2", "conn_b", time.Minute)
+
+ conn1, ok1 := store.GetResponseConn("resp_shard_1")
+ require.True(t, ok1)
+ require.Equal(t, "conn_a", conn1)
+
+ conn2, ok2 := store.GetResponseConn("resp_shard_2")
+ require.True(t, ok2)
+ require.Equal(t, "conn_b", conn2)
+
+ store.DeleteResponseConn("resp_shard_1")
+ _, ok1After := store.GetResponseConn("resp_shard_1")
+ require.False(t, ok1After)
+
+ // resp_shard_2 should still be accessible
+ conn2After, ok2After := store.GetResponseConn("resp_shard_2")
+ require.True(t, ok2After)
+ require.Equal(t, "conn_b", conn2After)
+}
+
+func TestStateStore_ShardedConcurrentAccessNoRace(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+ const goroutines = 32
+ const opsPerGoroutine = 200
+
+ var wg sync.WaitGroup
+ wg.Add(goroutines)
+
+ for g := 0; g < goroutines; g++ {
+ g := g
+ go func() {
+ defer wg.Done()
+ for i := 0; i < opsPerGoroutine; i++ {
+ key := fmt.Sprintf("resp_conc_%d_%d", g, i)
+ connID := fmt.Sprintf("conn_%d_%d", g, i)
+
+ store.BindResponseConn(key, connID, time.Minute)
+ got, ok := store.GetResponseConn(key)
+ if ok {
+ _ = got
+ }
+ store.DeleteResponseConn(key)
+ }
+ }()
+ }
+
+ wg.Wait()
+}
+
+// ===================================================================
+// 11. State store: Get paths don't call maybeCleanup
+// ===================================================================
+
+func TestStateStore_GetPaths_DoNotTriggerCleanup(t *testing.T) {
+ raw := NewOpenAIWSStateStore(nil)
+ store := raw.(*defaultOpenAIWSStateStore)
+
+ // Seed some data so Get paths have something to read
+ store.BindResponseConn("resp_get_noclean", "conn_1", time.Minute)
+ store.BindResponsePendingToolCalls(0, "resp_get_noclean", []string{"call_1"}, time.Minute)
+ store.BindSessionTurnState(1, "session_get_noclean", "state_1", time.Minute)
+ store.BindSessionConn(1, "session_get_noclean", "conn_1", time.Minute)
+
+ // Record the lastCleanupUnixNano after the Bind calls
+ cleanupBefore := store.lastCleanupUnixNano.Load()
+
+ // Set lastCleanup to the future to ensure no cleanup triggers from Binds
+ store.lastCleanupUnixNano.Store(time.Now().Add(time.Hour).UnixNano())
+ cleanupFrozen := store.lastCleanupUnixNano.Load()
+
+ // Perform many Get calls
+ for i := 0; i < 100; i++ {
+ store.GetResponseConn("resp_get_noclean")
+ store.GetResponsePendingToolCalls(0, "resp_get_noclean")
+ store.GetSessionTurnState(1, "session_get_noclean")
+ store.GetSessionConn(1, "session_get_noclean")
+ }
+
+ cleanupAfterGets := store.lastCleanupUnixNano.Load()
+ require.Equal(t, cleanupFrozen, cleanupAfterGets,
+ "Get paths must NOT change lastCleanupUnixNano (was %d before, %d after)", cleanupBefore, cleanupAfterGets)
+}
+
+func TestStateStore_MaybeCleanup_NilReceiverDoesNotPanic(t *testing.T) {
+ var store *defaultOpenAIWSStateStore
+ require.NotPanics(t, func() {
+ store.maybeCleanup()
+ })
+}
+
+func TestStateStore_BindPaths_MayTriggerCleanup(t *testing.T) {
+ raw := NewOpenAIWSStateStore(nil)
+ store := raw.(*defaultOpenAIWSStateStore)
+
+ // Set lastCleanup to long ago to ensure cleanup triggers on next Bind
+ pastNano := time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano()
+ store.lastCleanupUnixNano.Store(pastNano)
+
+ store.BindResponseConn("resp_bind_trigger", "conn_trigger", time.Minute)
+
+ cleanupAfterBind := store.lastCleanupUnixNano.Load()
+ require.NotEqual(t, pastNano, cleanupAfterBind,
+ "Bind paths should trigger maybeCleanup when interval has elapsed")
+}
+
+// ===================================================================
+// 12. GetResponsePendingToolCalls returns internal slice directly
+// ===================================================================
+
+func TestGetResponsePendingToolCalls_ReturnsInternalSlice(t *testing.T) {
+ raw := NewOpenAIWSStateStore(nil)
+ store := raw.(*defaultOpenAIWSStateStore)
+
+ store.BindResponsePendingToolCalls(0, "resp_slice_identity", []string{"call_x", "call_y"}, time.Minute)
+
+ callIDs, ok := store.GetResponsePendingToolCalls(0, "resp_slice_identity")
+ require.True(t, ok)
+ require.Equal(t, []string{"call_x", "call_y"}, callIDs)
+
+ // Verify it's the same underlying slice as stored in the binding (pointer equality).
+ // The binding stores callIDs as a copied slice at bind time, but Get returns it directly.
+ id := openAIWSResponsePendingToolCallsBindingKey(0, "resp_slice_identity")
+ store.responsePendingToolMu.RLock()
+ binding, exists := store.responsePendingTool[id]
+ store.responsePendingToolMu.RUnlock()
+ require.True(t, exists)
+
+ // Check pointer equality of the underlying array via unsafe
+ gotHeader := (*[3]uintptr)(unsafe.Pointer(&callIDs))
+ internalHeader := (*[3]uintptr)(unsafe.Pointer(&binding.callIDs))
+ require.Equal(t, gotHeader[0], internalHeader[0],
+ "returned slice should share the same underlying array pointer as the internal binding (zero-copy)")
+}
+
+// ===================================================================
+// Additional edge-case tests for completeness
+// ===================================================================
+
+func TestParseOpenAIWSEventType_MalformedJSON(t *testing.T) {
+ // Should not panic on malformed JSON
+ eventType, responseID := parseOpenAIWSEventType([]byte(`{not valid json`))
+ // gjson returns empty for invalid JSON
+ _ = eventType
+ _ = responseID
+}
+
+func TestOpenAIWSResponseAccountCacheKey_EmptyInput(t *testing.T) {
+ // Even empty string should produce a valid key
+ key := openAIWSResponseAccountCacheKey("")
+ require.True(t, strings.HasPrefix(key, openAIWSResponseAccountCachePrefix+"v2:"))
+ hexPart := strings.TrimPrefix(key, openAIWSResponseAccountCachePrefix+"v2:")
+ require.Len(t, hexPart, 16)
+}
+
+func TestConnShard_SameKeyAlwaysSameShard(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil).(*defaultOpenAIWSStateStore)
+ shard1 := store.connShard("resp_stable_key")
+ shard2 := store.connShard("resp_stable_key")
+ require.Equal(t, shard1, shard2, "same key must always map to the same shard")
+}
+
+func TestMaybeTouchLease_ConcurrentSafe(t *testing.T) {
+ c := &openAIWSIngressContext{}
+ var wg sync.WaitGroup
+ const goroutines = 16
+
+ wg.Add(goroutines)
+ for i := 0; i < goroutines; i++ {
+ go func() {
+ defer wg.Done()
+ for j := 0; j < 100; j++ {
+ c.maybeTouchLease(5 * time.Minute)
+ }
+ }()
+ }
+ wg.Wait()
+
+ require.NotZero(t, c.lastTouchUnixNano.Load())
+ require.False(t, c.expiresAt().IsZero())
+}
+
+func TestActiveConn_SingleOwnerSequentialAccess(t *testing.T) {
+ // activeConn uses a non-synchronized cachedConn field by design.
+ // A lease is only used by a single goroutine (the forwarding loop).
+ // This test verifies sequential repeated calls from the same goroutine
+ // always return the same cached conn without error.
+ upstream := &openAIWSNoopConn{}
+ ctx := &openAIWSIngressContext{
+ ownerID: "owner_seq",
+ upstream: upstream,
+ }
+ lease := &openAIWSIngressContextLease{
+ context: ctx,
+ ownerID: "owner_seq",
+ }
+
+ for i := 0; i < 1000; i++ {
+ conn, err := lease.activeConn()
+ require.NoError(t, err)
+ require.Equal(t, upstream, conn)
+ }
+}
+
+func TestOpenAIWSIngressContextSessionKey_ConsistentWithTurnStateKey(t *testing.T) {
+ // Both functions use the same pattern: strconv.FormatInt(groupID, 10) + ":" + hash
+ groupID := int64(42)
+ sessionHash := "test_hash"
+
+ sessionKey := openAIWSIngressContextSessionKey(groupID, sessionHash)
+ turnStateKey := openAIWSSessionTurnStateKey(groupID, sessionHash)
+
+ require.Equal(t, sessionKey, turnStateKey,
+ "openAIWSIngressContextSessionKey and openAIWSSessionTurnStateKey should produce identical keys for the same inputs")
+}
+
+func TestStateStore_ShardedBindOverwrite(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+
+ store.BindResponseConn("resp_overwrite", "conn_old", time.Minute)
+ store.BindResponseConn("resp_overwrite", "conn_new", time.Minute)
+
+ conn, ok := store.GetResponseConn("resp_overwrite")
+ require.True(t, ok)
+ require.Equal(t, "conn_new", conn, "later bind should overwrite earlier bind")
+}
+
+func TestStateStore_ShardedTTLExpiry(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+
+ store.BindResponseConn("resp_ttl_shard", "conn_ttl", 30*time.Millisecond)
+ conn, ok := store.GetResponseConn("resp_ttl_shard")
+ require.True(t, ok)
+ require.Equal(t, "conn_ttl", conn)
+
+ time.Sleep(60 * time.Millisecond)
+ _, ok = store.GetResponseConn("resp_ttl_shard")
+ require.False(t, ok, "entry should be expired after TTL")
+}
diff --git a/backend/internal/service/openai_ws_ingress_context_pool.go b/backend/internal/service/openai_ws_ingress_context_pool.go
new file mode 100644
index 000000000..4b0daf13d
--- /dev/null
+++ b/backend/internal/service/openai_ws_ingress_context_pool.go
@@ -0,0 +1,1586 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "net/http"
+ "strconv"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+var (
+ errOpenAIWSIngressContextBusy = errors.New("openai ws ingress context is busy")
+)
+
+const (
+ openAIWSIngressScheduleLayerExact = "l0_exact"
+ openAIWSIngressScheduleLayerNew = "l1_new_context"
+ openAIWSIngressScheduleLayerMigration = "l2_migration"
+ openAIWSIngressAcquireMaxWaitRetries = 4096
+ openAIWSIngressAcquireMaxQueueWait = 30 * time.Minute
+
+ // openAIWSUpstreamConnMaxAge 是上游 WebSocket 连接的默认最大存活时间。
+ // OpenAI 在 60 分钟后强制关闭连接,此处默认 55 分钟主动轮换以避免中途断连。
+ openAIWSUpstreamConnMaxAge = 55 * time.Minute
+
+ // openAIWSIngressDelayedPingAfterYield 是 yield 后延迟 Ping 探测的等待时间。
+ // 在会话暂时空闲后提前发现死连接,避免下次 Acquire 时才发现。
+ openAIWSIngressDelayedPingAfterYield = 5 * time.Second
+
+ // openAIWSIngressPingTimeout 是后台 Ping 探测的超时时间。
+ openAIWSIngressPingTimeout = 5 * time.Second
+)
+
+const (
+ openAIWSIngressStickinessWeak = "weak"
+ openAIWSIngressStickinessBalanced = "balanced"
+ openAIWSIngressStickinessStrong = "strong"
+)
+
+type openAIWSIngressContextAcquireRequest struct {
+ Account *Account
+ GroupID int64
+ SessionHash string
+ OwnerID string
+ WSURL string
+ Headers http.Header
+ ProxyURL string
+ Turn int
+
+ HasPreviousResponseID bool
+ StrictAffinity bool
+ StoreDisabled bool
+}
+
+type openAIWSIngressContextPool struct {
+ cfg *config.Config
+ dialer openAIWSClientDialer
+
+ idleTTL time.Duration
+ sweepInterval time.Duration
+ upstreamMaxAge time.Duration
+
+ seq atomic.Uint64
+
+ // schedulerStats provides load-aware signals (error rate, circuit breaker
+ // state) for migration scoring. When nil, scoring falls back to the
+ // existing idle-time + failure-streak heuristic.
+ schedulerStats *openAIAccountRuntimeStats
+
+ mu sync.Mutex
+ accounts map[int64]*openAIWSIngressAccountPool
+
+ stopCh chan struct{}
+ stopOnce sync.Once
+ workerWg sync.WaitGroup
+ closeOnce sync.Once
+}
+
+type openAIWSIngressAccountPool struct {
+ mu sync.Mutex
+
+ refs atomic.Int64
+
+ // dynamicCap 动态容量:初始 1,按需增长(L1 新建时 +1),空闲超时后缩减。
+ // 实际容量为 min(dynamicCap, effectiveContextCapacity)。
+ dynamicCap atomic.Int32
+
+ contexts map[string]*openAIWSIngressContext
+ bySession map[string]string
+}
+
+type openAIWSIngressContext struct {
+ id string
+ groupID int64
+ accountID int64
+ sessionHash string
+ sessionKey string
+
+ mu sync.Mutex
+ dialing bool
+ dialDone chan struct{}
+ releaseDone chan struct{} // ownerID 释放时发送单信号,唤醒一个等待者
+ ownerID string
+ lastUsedAtUnix atomic.Int64
+ expiresAtUnix atomic.Int64
+ lastTouchUnixNano atomic.Int64 // throttle: skip touchLease if < 1s since last
+ broken bool
+ failureStreak int
+ lastFailureAt time.Time
+ migrationCount int
+ lastMigrationAt time.Time
+ upstream openAIWSClientConn
+ upstreamConnID string
+ upstreamConnCreatedAt atomic.Int64 // UnixNano; 0 表示未设置
+ handshakeHeaders http.Header
+ prewarmed atomic.Bool
+ pendingPingTimer *time.Timer // 延迟 Ping 去重:同一 context 仅保留一个 pending ping
+}
+
+type openAIWSIngressContextLease struct {
+ pool *openAIWSIngressContextPool
+ context *openAIWSIngressContext
+ ownerID string
+ queueWait time.Duration
+ connPick time.Duration
+ reused bool
+ scheduleLayer string
+ stickiness string
+ migrationUsed bool
+ released atomic.Bool
+ cachedConnMu sync.RWMutex
+ cachedConn openAIWSClientConn // fast path: avoid mutex on every activeConn() call
+}
+
+func openAIWSTimeToUnixNano(ts time.Time) int64 {
+ if ts.IsZero() {
+ return 0
+ }
+ return ts.UnixNano()
+}
+
+func openAIWSUnixNanoToTime(ns int64) time.Time {
+ if ns <= 0 {
+ return time.Time{}
+ }
+ return time.Unix(0, ns)
+}
+
+func (c *openAIWSIngressContext) setLastUsedAt(ts time.Time) {
+ if c == nil {
+ return
+ }
+ c.lastUsedAtUnix.Store(openAIWSTimeToUnixNano(ts))
+}
+
+func (c *openAIWSIngressContext) lastUsedAt() time.Time {
+ if c == nil {
+ return time.Time{}
+ }
+ return openAIWSUnixNanoToTime(c.lastUsedAtUnix.Load())
+}
+
+func (c *openAIWSIngressContext) setExpiresAt(ts time.Time) {
+ if c == nil {
+ return
+ }
+ c.expiresAtUnix.Store(openAIWSTimeToUnixNano(ts))
+}
+
+func (c *openAIWSIngressContext) expiresAt() time.Time {
+ if c == nil {
+ return time.Time{}
+ }
+ return openAIWSUnixNanoToTime(c.expiresAtUnix.Load())
+}
+
+// upstreamConnAge 返回上游连接已存活的时长。
+// 若 createdAt 未设置(零值)或 now 早于 createdAt(时钟回拨),返回 0。
+func (c *openAIWSIngressContext) upstreamConnAge(now time.Time) time.Duration {
+ if c == nil {
+ return 0
+ }
+ ns := c.upstreamConnCreatedAt.Load()
+ if ns <= 0 {
+ return 0
+ }
+ age := now.Sub(time.Unix(0, ns))
+ if age < 0 {
+ return 0
+ }
+ return age
+}
+
+func (c *openAIWSIngressContext) touchLease(now time.Time, ttl time.Duration) {
+ if c == nil {
+ return
+ }
+ nowUnix := openAIWSTimeToUnixNano(now)
+ c.lastUsedAtUnix.Store(nowUnix)
+ c.expiresAtUnix.Store(openAIWSTimeToUnixNano(now.Add(ttl)))
+ c.lastTouchUnixNano.Store(nowUnix)
+}
+
+// maybeTouchLease is a throttled version of touchLease.
+// It skips the update if less than 1 second has passed since the last touch,
+// avoiding redundant time.Now() + atomic stores on every hot-path message.
+func (c *openAIWSIngressContext) maybeTouchLease(ttl time.Duration) {
+ if c == nil {
+ return
+ }
+ now := time.Now()
+ nowNano := now.UnixNano()
+ lastNano := c.lastTouchUnixNano.Load()
+ if lastNano > 0 && nowNano-lastNano < int64(time.Second) {
+ return
+ }
+ c.touchLease(now, ttl)
+}
+
+func newOpenAIWSIngressContextPool(cfg *config.Config) *openAIWSIngressContextPool {
+ pool := &openAIWSIngressContextPool{
+ cfg: cfg,
+ dialer: newDefaultOpenAIWSClientDialer(),
+ idleTTL: 10 * time.Minute,
+ sweepInterval: 30 * time.Second,
+ upstreamMaxAge: openAIWSUpstreamConnMaxAge,
+ accounts: make(map[int64]*openAIWSIngressAccountPool),
+ stopCh: make(chan struct{}),
+ }
+ if cfg != nil && cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 {
+ pool.idleTTL = time.Duration(cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second
+ }
+ if cfg != nil && cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds >= 0 {
+ // 配置语义:0 表示禁用超龄轮换。
+ pool.upstreamMaxAge = time.Duration(cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds) * time.Second
+ }
+ pool.startWorker()
+ return pool
+}
+
+func (p *openAIWSIngressContextPool) setClientDialerForTest(dialer openAIWSClientDialer) {
+ if p == nil || dialer == nil {
+ return
+ }
+ p.dialer = dialer
+}
+
+func (p *openAIWSIngressContextPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
+ if p == nil {
+ return OpenAIWSTransportMetricsSnapshot{}
+ }
+ if dialer, ok := p.dialer.(openAIWSTransportMetricsDialer); ok {
+ return dialer.SnapshotTransportMetrics()
+ }
+ return OpenAIWSTransportMetricsSnapshot{}
+}
+
+func (p *openAIWSIngressContextPool) maxConnsHardCap() int {
+ if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 {
+ return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount
+ }
+ return 8
+}
+
+func (p *openAIWSIngressContextPool) effectiveContextCapacity(account *Account) int {
+ if account == nil || account.Concurrency <= 0 {
+ return 0
+ }
+ capacity := account.Concurrency
+ hardCap := p.maxConnsHardCap()
+ if hardCap > 0 && capacity > hardCap {
+ return hardCap
+ }
+ return capacity
+}
+
+func (p *openAIWSIngressContextPool) Close() {
+ if p == nil {
+ return
+ }
+ p.closeOnce.Do(func() {
+ p.stopOnce.Do(func() {
+ close(p.stopCh)
+ })
+ p.workerWg.Wait()
+
+ var toClose []openAIWSClientConn
+ p.mu.Lock()
+ accountPools := make([]*openAIWSIngressAccountPool, 0, len(p.accounts))
+ for _, ap := range p.accounts {
+ if ap != nil {
+ accountPools = append(accountPools, ap)
+ }
+ }
+ p.accounts = make(map[int64]*openAIWSIngressAccountPool)
+ p.mu.Unlock()
+
+ for _, ap := range accountPools {
+ ap.mu.Lock()
+ for _, ctx := range ap.contexts {
+ if ctx == nil {
+ continue
+ }
+ ctx.mu.Lock()
+ if ctx.upstream != nil {
+ toClose = append(toClose, ctx.upstream)
+ }
+ ctx.upstream = nil
+ ctx.upstreamConnCreatedAt.Store(0)
+ ctx.broken = true
+ ctx.ownerID = ""
+ ctx.handshakeHeaders = nil
+ ctx.mu.Unlock()
+ }
+ ap.contexts = make(map[string]*openAIWSIngressContext)
+ ap.bySession = make(map[string]string)
+ ap.mu.Unlock()
+ }
+
+ for _, conn := range toClose {
+ if conn != nil {
+ _ = conn.Close()
+ }
+ }
+ })
+}
+
+func (p *openAIWSIngressContextPool) startWorker() {
+ if p == nil {
+ return
+ }
+ interval := p.sweepInterval
+ if interval <= 0 {
+ interval = 30 * time.Second
+ }
+ p.workerWg.Add(1)
+ go func() {
+ defer p.workerWg.Done()
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-p.stopCh:
+ return
+ case <-ticker.C:
+ p.sweepExpiredIdleContexts()
+ }
+ }
+ }()
+}
+
+func (p *openAIWSIngressContextPool) Acquire(
+ ctx context.Context,
+ req openAIWSIngressContextAcquireRequest,
+) (*openAIWSIngressContextLease, error) {
+ if p == nil {
+ return nil, errors.New("openai ws ingress context pool is nil")
+ }
+ if req.Account == nil || req.Account.ID <= 0 {
+ return nil, errors.New("invalid account in ingress context acquire request")
+ }
+ ownerID := strings.TrimSpace(req.OwnerID)
+ if ownerID == "" {
+ return nil, errors.New("owner id is empty")
+ }
+ if strings.TrimSpace(req.WSURL) == "" {
+ return nil, errors.New("ws url is empty")
+ }
+ capacity := p.effectiveContextCapacity(req.Account)
+ if capacity <= 0 {
+ return nil, errOpenAIWSConnQueueFull
+ }
+
+ sessionHash := strings.TrimSpace(req.SessionHash)
+ if sessionHash == "" {
+ // 无会话信号时退化为连接级上下文,避免跨连接串会话。
+ sessionHash = "conn:" + ownerID
+ }
+ sessionKey := openAIWSIngressContextSessionKey(req.GroupID, sessionHash)
+ accountID := req.Account.ID
+
+ start := time.Now()
+ queueWait := time.Duration(0)
+ waitRetries := 0
+
+ p.mu.Lock()
+ ap := p.getOrCreateAccountPoolLocked(accountID)
+ ap.refs.Add(1)
+ p.mu.Unlock()
+ defer ap.refs.Add(-1)
+
+ calcConnPick := func() time.Duration {
+ connPick := time.Since(start) - queueWait
+ if connPick < 0 {
+ return 0
+ }
+ return connPick
+ }
+
+ for {
+ now := time.Now()
+ var (
+ selected *openAIWSIngressContext
+ reusedContext bool
+ newlyCreated bool
+ ownerAssigned bool
+ migrationUsed bool
+ scheduleLayer string
+ oldUpstream openAIWSClientConn
+ deferredClose []openAIWSClientConn
+ )
+
+ ap.mu.Lock()
+
+ stickiness := p.resolveStickinessLevelLocked(ap, sessionKey, req, now)
+ allowMigration, minMigrationScore := openAIWSIngressMigrationPolicyByStickiness(stickiness)
+
+ if existingID, ok := ap.bySession[sessionKey]; ok {
+ if existing := ap.contexts[existingID]; existing != nil {
+ existing.mu.Lock()
+ switch existing.ownerID {
+ case "":
+ if existing.releaseDone != nil {
+ select {
+ case <-existing.releaseDone:
+ default:
+ }
+ }
+ existing.ownerID = ownerID
+ ownerAssigned = true
+ existing.touchLease(now, p.idleTTL)
+ selected = existing
+ reusedContext = true
+ scheduleLayer = openAIWSIngressScheduleLayerExact
+ case ownerID:
+ existing.touchLease(now, p.idleTTL)
+ selected = existing
+ reusedContext = true
+ scheduleLayer = openAIWSIngressScheduleLayerExact
+ default:
+ // 当前 context 被其他 owner 占用,等待其释放后重试(循环重试替代递归)。
+ if existing.releaseDone == nil {
+ existing.releaseDone = make(chan struct{}, 1)
+ }
+ releaseDone := existing.releaseDone
+ existing.mu.Unlock()
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(deferredClose)
+
+ waitStart := time.Now()
+ select {
+ case <-releaseDone:
+ queueWait += time.Since(waitStart)
+ waitRetries++
+ if waitRetries >= openAIWSIngressAcquireMaxWaitRetries || queueWait >= openAIWSIngressAcquireMaxQueueWait {
+ logOpenAIWSModeInfo(
+ "ctx_pool_owner_wait_exhausted account_id=%d ctx_id=%s owner_id=%s wait_retries=%d queue_wait_ms=%d",
+ accountID, existing.id, ownerID, waitRetries, queueWait.Milliseconds(),
+ )
+ return nil, errOpenAIWSIngressContextBusy
+ }
+ continue
+ case <-ctx.Done():
+ queueWait += time.Since(waitStart)
+ logOpenAIWSModeInfo(
+ "ctx_pool_owner_wait_canceled account_id=%d ctx_id=%s owner_id=%s wait_retries=%d queue_wait_ms=%d",
+ accountID, existing.id, ownerID, waitRetries, queueWait.Milliseconds(),
+ )
+ return nil, errOpenAIWSIngressContextBusy
+ }
+ }
+ existing.mu.Unlock()
+ }
+ }
+
+ if selected == nil {
+ dynCap := p.effectiveDynamicCapacity(ap, capacity)
+ if len(ap.contexts) >= dynCap {
+ deferredClose = append(deferredClose, p.evictExpiredIdleLocked(ap, now)...)
+ }
+ if len(ap.contexts) >= dynCap {
+ if dynCap < capacity {
+ // 动态扩容:尚未达到硬上限,增长 1 后创建新 context
+ ap.dynamicCap.Add(1)
+ } else if !allowMigration {
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(deferredClose)
+ logOpenAIWSModeInfo(
+ "ctx_pool_full_no_migration account_id=%d capacity=%d stickiness=%s",
+ accountID, capacity, stickiness,
+ )
+ return nil, errOpenAIWSConnQueueFull
+ } else {
+ recycle := p.pickMigrationCandidateLocked(ap, minMigrationScore, now)
+ if recycle == nil {
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(deferredClose)
+ logOpenAIWSModeInfo(
+ "ctx_pool_no_migration_candidate account_id=%d capacity=%d min_score=%.1f",
+ accountID, capacity, minMigrationScore,
+ )
+ return nil, errOpenAIWSConnQueueFull
+ }
+ recycle.mu.Lock()
+ oldSessionKey := recycle.sessionKey
+ oldUpstream = recycle.upstream
+ recycle.sessionHash = sessionHash
+ recycle.sessionKey = sessionKey
+ recycle.groupID = req.GroupID
+ if recycle.releaseDone != nil {
+ select {
+ case <-recycle.releaseDone:
+ default:
+ }
+ }
+ recycle.ownerID = ownerID
+ recycle.touchLease(now, p.idleTTL)
+ // 会话被回收复用时关闭旧上游,避免跨会话污染。
+ recycle.upstream = nil
+ recycle.upstreamConnID = ""
+ recycle.upstreamConnCreatedAt.Store(0)
+ recycle.handshakeHeaders = nil
+ recycle.broken = false
+ recycle.migrationCount++
+ recycle.lastMigrationAt = now
+ recycle.mu.Unlock()
+
+ if oldSessionKey != "" {
+ if mapped, ok := ap.bySession[oldSessionKey]; ok && mapped == recycle.id {
+ delete(ap.bySession, oldSessionKey)
+ }
+ }
+ ap.bySession[sessionKey] = recycle.id
+ selected = recycle
+ reusedContext = true
+ migrationUsed = true
+ scheduleLayer = openAIWSIngressScheduleLayerMigration
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(deferredClose)
+ if oldUpstream != nil {
+ _ = oldUpstream.Close()
+ }
+ reusedConn, ensureErr := p.ensureContextUpstream(ctx, selected, req)
+ if ensureErr != nil {
+ p.releaseContext(selected, ownerID)
+ return nil, ensureErr
+ }
+ logOpenAIWSModeInfo(
+ "ctx_pool_migration account_id=%d ctx_id=%s old_session=%s new_session=%s migration_count=%d",
+ accountID, selected.id, truncateOpenAIWSLogValue(oldSessionKey, openAIWSIDValueMaxLen),
+ truncateOpenAIWSLogValue(sessionKey, openAIWSIDValueMaxLen), selected.migrationCount,
+ )
+ return &openAIWSIngressContextLease{
+ pool: p,
+ context: selected,
+ ownerID: ownerID,
+ queueWait: queueWait,
+ connPick: calcConnPick(),
+ reused: reusedContext && reusedConn,
+ scheduleLayer: scheduleLayer,
+ stickiness: stickiness,
+ migrationUsed: migrationUsed,
+ }, nil
+ }
+ }
+
+ ctxID := fmt.Sprintf("ctx_%d_%d", accountID, p.seq.Add(1))
+ created := &openAIWSIngressContext{
+ id: ctxID,
+ groupID: req.GroupID,
+ accountID: accountID,
+ sessionHash: sessionHash,
+ sessionKey: sessionKey,
+ ownerID: ownerID,
+ }
+ created.touchLease(now, p.idleTTL)
+ ap.contexts[ctxID] = created
+ ap.bySession[sessionKey] = ctxID
+ selected = created
+ newlyCreated = true
+ ownerAssigned = true
+ scheduleLayer = openAIWSIngressScheduleLayerNew
+ }
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(deferredClose)
+
+ reusedConn, ensureErr := p.ensureContextUpstream(ctx, selected, req)
+ if ensureErr != nil {
+ if newlyCreated {
+ ap.mu.Lock()
+ delete(ap.contexts, selected.id)
+ if mapped, ok := ap.bySession[sessionKey]; ok && mapped == selected.id {
+ delete(ap.bySession, sessionKey)
+ }
+ ap.mu.Unlock()
+ } else if ownerAssigned {
+ p.releaseContext(selected, ownerID)
+ }
+ return nil, ensureErr
+ }
+
+ return &openAIWSIngressContextLease{
+ pool: p,
+ context: selected,
+ ownerID: ownerID,
+ queueWait: queueWait,
+ connPick: calcConnPick(),
+ reused: reusedContext && reusedConn,
+ scheduleLayer: scheduleLayer,
+ stickiness: stickiness,
+ migrationUsed: migrationUsed,
+ }, nil
+ }
+}
+
+func (p *openAIWSIngressContextPool) resolveStickinessLevelLocked(
+ ap *openAIWSIngressAccountPool,
+ sessionKey string,
+ req openAIWSIngressContextAcquireRequest,
+ now time.Time,
+) string {
+ if req.StrictAffinity {
+ return openAIWSIngressStickinessStrong
+ }
+
+ level := openAIWSIngressStickinessWeak
+ switch {
+ case req.HasPreviousResponseID:
+ level = openAIWSIngressStickinessStrong
+ case req.StoreDisabled || req.Turn > 1:
+ level = openAIWSIngressStickinessBalanced
+ }
+
+ if ap == nil {
+ return level
+ }
+ ctxID, ok := ap.bySession[sessionKey]
+ if !ok {
+ return level
+ }
+ existing := ap.contexts[ctxID]
+ if existing == nil {
+ return level
+ }
+
+ existing.mu.Lock()
+ broken := existing.broken
+ failureStreak := existing.failureStreak
+ lastFailureAt := existing.lastFailureAt
+ lastUsedAt := existing.lastUsedAt()
+ existing.mu.Unlock()
+
+ recentFailure := failureStreak > 0 && !lastFailureAt.IsZero() && now.Sub(lastFailureAt) <= 2*time.Minute
+ if broken || recentFailure {
+ return openAIWSIngressStickinessDowngrade(level)
+ }
+ if failureStreak == 0 && !lastUsedAt.IsZero() && now.Sub(lastUsedAt) <= 20*time.Second {
+ return openAIWSIngressStickinessUpgrade(level)
+ }
+ return level
+}
+
+func openAIWSIngressMigrationPolicyByStickiness(stickiness string) (bool, float64) {
+ switch stickiness {
+ case openAIWSIngressStickinessStrong:
+ return false, 80 // was 85; lowered to allow migration away from degraded accounts
+ case openAIWSIngressStickinessBalanced:
+ return true, 65 // was 68; lowered to allow more aggressive migration to healthier accounts
+ default:
+ return true, 40 // was 45; lowered for weak stickiness
+ }
+}
+
+func openAIWSIngressStickinessDowngrade(level string) string {
+ switch level {
+ case openAIWSIngressStickinessStrong:
+ return openAIWSIngressStickinessBalanced
+ case openAIWSIngressStickinessBalanced:
+ return openAIWSIngressStickinessWeak
+ default:
+ return openAIWSIngressStickinessWeak
+ }
+}
+
+func openAIWSIngressStickinessUpgrade(level string) string {
+ switch level {
+ case openAIWSIngressStickinessWeak:
+ return openAIWSIngressStickinessBalanced
+ case openAIWSIngressStickinessBalanced:
+ return openAIWSIngressStickinessStrong
+ default:
+ return openAIWSIngressStickinessStrong
+ }
+}
+
+func (p *openAIWSIngressContextPool) pickMigrationCandidateLocked(
+ ap *openAIWSIngressAccountPool,
+ minScore float64,
+ now time.Time,
+) *openAIWSIngressContext {
+ if ap == nil {
+ return nil
+ }
+ var (
+ selected *openAIWSIngressContext
+ selectedScore float64
+ selectedAt time.Time
+ hasSelected bool
+ )
+
+ for _, ctx := range ap.contexts {
+ if ctx == nil {
+ continue
+ }
+ score, lastUsed, ok := scoreOpenAIWSIngressMigrationCandidate(ctx, now, p.schedulerStats)
+ if !ok || score < minScore {
+ continue
+ }
+ if !hasSelected || score > selectedScore || (score == selectedScore && lastUsed.Before(selectedAt)) {
+ selected = ctx
+ selectedScore = score
+ selectedAt = lastUsed
+ hasSelected = true
+ }
+ }
+ return selected
+}
+
+func scoreOpenAIWSIngressMigrationCandidate(c *openAIWSIngressContext, now time.Time, stats *openAIAccountRuntimeStats) (float64, time.Time, bool) {
+ if c == nil {
+ return 0, time.Time{}, false
+ }
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if strings.TrimSpace(c.ownerID) != "" {
+ return 0, time.Time{}, false
+ }
+
+ score := 100.0
+ if c.broken {
+ score -= 30
+ }
+ if c.failureStreak > 0 {
+ score -= float64(minInt(c.failureStreak*12, 40))
+ }
+ if !c.lastFailureAt.IsZero() && now.Sub(c.lastFailureAt) <= 2*time.Minute {
+ score -= 18
+ }
+ if !c.lastMigrationAt.IsZero() && now.Sub(c.lastMigrationAt) <= time.Minute {
+ score -= 10
+ }
+ if c.migrationCount > 0 {
+ score -= float64(minInt(c.migrationCount*4, 20))
+ }
+
+ lastUsedAt := c.lastUsedAt()
+ idleDuration := now.Sub(lastUsedAt)
+ switch {
+ case idleDuration <= 15*time.Second:
+ score -= 15
+ case idleDuration >= 3*time.Minute:
+ score += 16
+ default:
+ score += idleDuration.Seconds() / 12.0
+ }
+
+ // Load-aware factors: penalize contexts bound to accounts that the
+ // scheduler has flagged as degraded or circuit-open. When stats is nil
+ // (e.g. during tests or before scheduler init), these adjustments are
+ // silently skipped so existing behaviour is preserved.
+ if stats != nil && c.accountID > 0 {
+ errorRate, _, _ := stats.snapshot(c.accountID)
+ // errorRate is in [0,1]; a fully-erroring account subtracts up to 30
+ // points, making it significantly harder for a migration to land on
+ // an unhealthy account.
+ score -= errorRate * 30
+
+ // Circuit-open accounts receive a harsh penalty (-50) that in
+ // practice drops the score below any reasonable minimum threshold,
+ // effectively blocking migration to that account.
+ if stats.isCircuitOpen(c.accountID) {
+ score -= 50
+ }
+ }
+
+ return score, lastUsedAt, true
+}
+
+func minInt(a, b int) int {
+ if a <= b {
+ return a
+ }
+ return b
+}
+
+func closeOpenAIWSClientConns(conns []openAIWSClientConn) {
+ for _, conn := range conns {
+ if conn != nil {
+ _ = conn.Close()
+ }
+ }
+}
+
+func (p *openAIWSIngressContextPool) ensureContextUpstream(
+ ctx context.Context,
+ c *openAIWSIngressContext,
+ req openAIWSIngressContextAcquireRequest,
+) (bool, error) {
+ if p == nil || c == nil {
+ return false, errOpenAIWSConnClosed
+ }
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ for {
+ c.mu.Lock()
+ if c.upstream != nil && !c.broken {
+ now := time.Now()
+ connAge := c.upstreamConnAge(now)
+ if p.upstreamMaxAge > 0 && connAge > 0 && connAge >= p.upstreamMaxAge {
+ // 主动轮换:关闭旧连接,不设 broken、不增 failureStreak
+ oldUpstream, oldConnID := c.upstream, c.upstreamConnID
+ c.upstream = nil
+ c.upstreamConnID = ""
+ c.upstreamConnCreatedAt.Store(0)
+ c.handshakeHeaders = nil
+ c.prewarmed.Store(false)
+ c.mu.Unlock()
+ _ = oldUpstream.Close()
+ logOpenAIWSModeInfo(
+ "ctx_pool_upstream_max_age_rotate account_id=%d ctx_id=%s conn_id=%s conn_age_min=%.1f max_age_min=%.1f",
+ c.accountID, c.id, oldConnID,
+ connAge.Minutes(), p.upstreamMaxAge.Minutes(),
+ )
+ continue // 回到 for 循环走 dialing 路径
+ }
+ c.touchLease(now, p.idleTTL)
+ c.mu.Unlock()
+ return true, nil
+ }
+ if c.dialing {
+ dialDone := c.dialDone
+ c.mu.Unlock()
+ if dialDone == nil {
+ if err := ctx.Err(); err != nil {
+ return false, err
+ }
+ continue
+ }
+ select {
+ case <-dialDone:
+ continue
+ case <-ctx.Done():
+ return false, ctx.Err()
+ }
+ }
+ oldUpstream := c.upstream
+ c.upstream = nil
+ c.upstreamConnCreatedAt.Store(0)
+ c.handshakeHeaders = nil
+ c.upstreamConnID = ""
+ c.prewarmed.Store(false)
+ c.broken = false
+ c.dialing = true
+ dialDone := make(chan struct{})
+ c.dialDone = dialDone
+ c.mu.Unlock()
+
+ if oldUpstream != nil {
+ _ = oldUpstream.Close()
+ }
+
+ dialer := p.dialer
+ if dialer == nil {
+ c.mu.Lock()
+ c.broken = true
+ c.failureStreak++
+ c.lastFailureAt = time.Now()
+ c.dialing = false
+ if c.dialDone == dialDone {
+ c.dialDone = nil
+ }
+ close(dialDone)
+ c.mu.Unlock()
+ return false, errors.New("openai ws ingress context dialer is nil")
+ }
+ conn, statusCode, handshakeHeaders, err := dialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL)
+ if err != nil {
+ wrappedErr := err
+ var dialErr *openAIWSDialError
+ if !errors.As(err, &dialErr) {
+ wrappedErr = &openAIWSDialError{
+ StatusCode: statusCode,
+ ResponseHeaders: cloneHeader(handshakeHeaders),
+ Err: err,
+ }
+ }
+ c.mu.Lock()
+ c.broken = true
+ c.failureStreak++
+ c.lastFailureAt = time.Now()
+ c.dialing = false
+ if c.dialDone == dialDone {
+ c.dialDone = nil
+ }
+ close(dialDone)
+ failureStreak := c.failureStreak
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_dial_fail account_id=%d ctx_id=%s status_code=%d failure_streak=%d cause=%s",
+ c.accountID, c.id, statusCode, failureStreak, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ )
+ return false, wrappedErr
+ }
+
+ c.mu.Lock()
+ now := time.Now()
+ c.upstream = conn
+ c.upstreamConnID = fmt.Sprintf("ctxws_%d_%d", c.accountID, p.seq.Add(1))
+ c.upstreamConnCreatedAt.Store(now.UnixNano())
+ c.handshakeHeaders = cloneHeader(handshakeHeaders)
+ c.prewarmed.Store(false)
+ c.touchLease(now, p.idleTTL)
+ c.broken = false
+ c.failureStreak = 0
+ c.lastFailureAt = time.Time{}
+ c.dialing = false
+ if c.dialDone == dialDone {
+ c.dialDone = nil
+ }
+ close(dialDone)
+ connID := c.upstreamConnID
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_dial_ok account_id=%d ctx_id=%s conn_id=%s",
+ c.accountID, c.id, connID,
+ )
+ return false, nil
+ }
+}
+
+func (p *openAIWSIngressContextPool) yieldContext(c *openAIWSIngressContext, ownerID string) {
+ p.releaseContextWithPolicy(c, ownerID, false)
+ // yield 后延迟 Ping,提前发现死连接
+ p.scheduleDelayedPing(c, openAIWSIngressDelayedPingAfterYield)
+}
+
+func (p *openAIWSIngressContextPool) releaseContext(c *openAIWSIngressContext, ownerID string) {
+ p.releaseContextWithPolicy(c, ownerID, true)
+}
+
+func (p *openAIWSIngressContextPool) releaseContextWithPolicy(
+ c *openAIWSIngressContext,
+ ownerID string,
+ closeUpstream bool,
+) {
+ if p == nil || c == nil {
+ return
+ }
+ var upstream openAIWSClientConn
+ c.mu.Lock()
+ if c.ownerID == ownerID {
+ if closeUpstream {
+ // 会话结束或链路损坏时销毁上游连接,避免污染后续请求。
+ upstream = c.upstream
+ c.upstream = nil
+ c.upstreamConnCreatedAt.Store(0)
+ c.handshakeHeaders = nil
+ c.upstreamConnID = ""
+ c.prewarmed.Store(false)
+ }
+ c.ownerID = ""
+ // 通知一个等待中的 Acquire 请求,避免 close 广播导致惊群。
+ if c.releaseDone != nil {
+ select {
+ case c.releaseDone <- struct{}{}:
+ default:
+ }
+ }
+ now := time.Now()
+ c.setLastUsedAt(now)
+ c.setExpiresAt(now.Add(p.idleTTL))
+ c.broken = false
+ }
+ c.mu.Unlock()
+ if upstream != nil {
+ _ = upstream.Close()
+ }
+}
+
+func (p *openAIWSIngressContextPool) markContextBroken(c *openAIWSIngressContext) {
+ if c == nil {
+ return
+ }
+ c.mu.Lock()
+ upstream := c.upstream
+ c.upstream = nil
+ c.upstreamConnCreatedAt.Store(0)
+ c.handshakeHeaders = nil
+ c.upstreamConnID = ""
+ c.prewarmed.Store(false)
+ c.broken = true
+ c.failureStreak++
+ c.lastFailureAt = time.Now()
+ // 注意:此处不发送 releaseDone 信号。ownerID 仍被占用,等待者被唤醒后
+ // 会发现 owner 未释放而重新阻塞,造成信号浪费。实际释放由 Release() 完成。
+ failureStreak := c.failureStreak
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_mark_broken account_id=%d ctx_id=%s failure_streak=%d",
+ c.accountID, c.id, failureStreak,
+ )
+ if upstream != nil {
+ _ = upstream.Close()
+ }
+}
+
+// markContextBrokenIfConnMatch 仅在连接代次(connID)匹配时标记 broken。
+// 后台 Ping 在解锁期间执行,期间连接可能已被重建为新连接;
+// 若 connID 已变则说明旧连接已被替换,放弃标记以避免误杀新连接。
+func (p *openAIWSIngressContextPool) markContextBrokenIfConnMatch(c *openAIWSIngressContext, expectedConnID string) {
+ if c == nil {
+ return
+ }
+ c.mu.Lock()
+ actualConnID := c.upstreamConnID
+ if actualConnID != expectedConnID {
+ // 连接已被重建(connID 变了),放弃标记
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_bg_ping_skip_stale account_id=%d ctx_id=%s expected_conn_id=%s actual_conn_id=%s",
+ c.accountID, c.id, expectedConnID, actualConnID,
+ )
+ return
+ }
+ ownerID := c.ownerID
+ dialing := c.dialing
+ if ownerID != "" || dialing {
+ // Ping 期间 context 可能被重新占用或进入建连流程,不应由后台探测路径误杀活跃连接。
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_bg_ping_skip_busy account_id=%d ctx_id=%s conn_id=%s owner_id=%s dialing=%v",
+ c.accountID,
+ c.id,
+ actualConnID,
+ truncateOpenAIWSLogValue(ownerID, openAIWSIDValueMaxLen),
+ dialing,
+ )
+ return
+ }
+ upstream := c.upstream
+ c.upstream = nil
+ c.upstreamConnCreatedAt.Store(0)
+ c.handshakeHeaders = nil
+ c.upstreamConnID = ""
+ c.prewarmed.Store(false)
+ c.broken = true
+ c.failureStreak++
+ c.lastFailureAt = time.Now()
+ failureStreak := c.failureStreak
+ c.mu.Unlock()
+ logOpenAIWSModeInfo(
+ "ctx_pool_mark_broken account_id=%d ctx_id=%s failure_streak=%d",
+ c.accountID, c.id, failureStreak,
+ )
+ if upstream != nil {
+ _ = upstream.Close()
+ }
+}
+
+func (p *openAIWSIngressContextPool) getOrCreateAccountPoolLocked(accountID int64) *openAIWSIngressAccountPool {
+ if ap, ok := p.accounts[accountID]; ok && ap != nil {
+ return ap
+ }
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ ap.dynamicCap.Store(1)
+ p.accounts[accountID] = ap
+ return ap
+}
+
+// effectiveDynamicCapacity 返回 min(dynamicCap, hardCap)。
+// dynamicCap 从 1 开始,按需增长,空闲时缩减;hardCap 由账户并发度和全局上限决定。
+func (p *openAIWSIngressContextPool) effectiveDynamicCapacity(ap *openAIWSIngressAccountPool, hardCap int) int {
+ if ap == nil || hardCap <= 0 {
+ return hardCap
+ }
+ dynCap := int(ap.dynamicCap.Load())
+ if dynCap < 1 {
+ dynCap = 1
+ ap.dynamicCap.Store(1)
+ }
+ if dynCap > hardCap {
+ return hardCap
+ }
+ return dynCap
+}
+
+func (p *openAIWSIngressContextPool) evictExpiredIdleLocked(
+ ap *openAIWSIngressAccountPool,
+ now time.Time,
+) []openAIWSClientConn {
+ if ap == nil {
+ return nil
+ }
+ var toClose []openAIWSClientConn
+ for id, ctx := range ap.contexts {
+ if ctx == nil {
+ delete(ap.contexts, id)
+ continue
+ }
+ ctx.mu.Lock()
+ expiresAt := ctx.expiresAt()
+ expired := ctx.ownerID == "" && !expiresAt.IsZero() && now.After(expiresAt)
+ upstream := ctx.upstream
+ if expired {
+ ctx.upstream = nil
+ ctx.upstreamConnCreatedAt.Store(0)
+ ctx.handshakeHeaders = nil
+ ctx.upstreamConnID = ""
+ }
+ ctx.mu.Unlock()
+ if !expired {
+ continue
+ }
+ delete(ap.contexts, id)
+ if mappedID, ok := ap.bySession[ctx.sessionKey]; ok && mappedID == id {
+ delete(ap.bySession, ctx.sessionKey)
+ }
+ if upstream != nil {
+ toClose = append(toClose, upstream)
+ }
+ }
+ return toClose
+}
+
+func (p *openAIWSIngressContextPool) pickOldestIdleContextLocked(ap *openAIWSIngressAccountPool) *openAIWSIngressContext {
+ if ap == nil {
+ return nil
+ }
+ var (
+ selected *openAIWSIngressContext
+ selectedAt time.Time
+ )
+ for _, ctx := range ap.contexts {
+ if ctx == nil {
+ continue
+ }
+ ctx.mu.Lock()
+ idle := strings.TrimSpace(ctx.ownerID) == ""
+ lastUsed := ctx.lastUsedAt()
+ ctx.mu.Unlock()
+ if !idle {
+ continue
+ }
+ if selected == nil || lastUsed.Before(selectedAt) {
+ selected = ctx
+ selectedAt = lastUsed
+ }
+ }
+ return selected
+}
+
+// closeAgedIdleUpstreamsLocked 关闭空闲且超龄的上游连接。
+// 只清理 upstream,保留 context 槽位(不删 context、不清 bySession)。
+// 不设 broken、不增 failureStreak。
+// 调用方必须持有 ap.mu。
+func (p *openAIWSIngressContextPool) closeAgedIdleUpstreamsLocked(
+ ap *openAIWSIngressAccountPool,
+ now time.Time,
+) []openAIWSClientConn {
+ if ap == nil || p.upstreamMaxAge <= 0 {
+ return nil
+ }
+ var toClose []openAIWSClientConn
+ for _, ctx := range ap.contexts {
+ if ctx == nil {
+ continue
+ }
+ ctx.mu.Lock()
+ idle := ctx.ownerID == ""
+ hasUpstream := ctx.upstream != nil
+ connAge := ctx.upstreamConnAge(now)
+ aged := connAge > 0 && connAge >= p.upstreamMaxAge
+ if idle && hasUpstream && aged {
+ toClose = append(toClose, ctx.upstream)
+ ctx.upstream = nil
+ ctx.upstreamConnCreatedAt.Store(0)
+ ctx.upstreamConnID = ""
+ ctx.handshakeHeaders = nil
+ ctx.prewarmed.Store(false)
+ }
+ ctx.mu.Unlock()
+ }
+ return toClose
+}
+
+// pingContextUpstream 对空闲 context 的上游连接发送 Ping 探测。
+// 若 Ping 失败则标记 context 为 broken,让后续 Acquire 走重建路径。
+// 调用方不需要持有任何锁。
+//
+// 使用 connID 代次守卫:Ping 期间连接可能被重建,仅当 connID 未变时才标记 broken,
+// 避免旧连接 Ping 失败误杀新连接。
+func (p *openAIWSIngressContextPool) pingContextUpstream(c *openAIWSIngressContext) {
+ if p == nil || c == nil {
+ return
+ }
+ c.mu.Lock()
+ idle := c.ownerID == ""
+ hasUpstream := c.upstream != nil
+ broken := c.broken
+ dialing := c.dialing
+ upstream := c.upstream
+ connID := c.upstreamConnID // 快照连接代次
+ c.mu.Unlock()
+ if !idle || !hasUpstream || broken || dialing || upstream == nil {
+ return
+ }
+
+ pingCtx, cancel := context.WithTimeout(context.Background(), openAIWSIngressPingTimeout)
+ defer cancel()
+ if err := upstream.Ping(pingCtx); err != nil {
+ p.markContextBrokenIfConnMatch(c, connID)
+ logOpenAIWSModeInfo(
+ "ctx_pool_bg_ping_fail account_id=%d ctx_id=%s conn_id=%s cause=%s",
+ c.accountID, c.id, connID, truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+}
+
+// pingIdleUpstreams 对账户池内所有空闲且有上游连接的 context 发起 Ping 探测。
+// 先在锁内收集候选列表,再在锁外逐个 Ping,避免阻塞其他操作。
+func (p *openAIWSIngressContextPool) pingIdleUpstreams(ap *openAIWSIngressAccountPool) {
+ if p == nil || ap == nil {
+ return
+ }
+ ap.mu.Lock()
+ targets := make([]*openAIWSIngressContext, 0, len(ap.contexts))
+ for _, ctx := range ap.contexts {
+ if ctx == nil {
+ continue
+ }
+ ctx.mu.Lock()
+ eligible := ctx.ownerID == "" && ctx.upstream != nil && !ctx.broken && !ctx.dialing
+ ctx.mu.Unlock()
+ if eligible {
+ targets = append(targets, ctx)
+ }
+ }
+ ap.mu.Unlock()
+
+ for _, ctx := range targets {
+ p.pingContextUpstream(ctx)
+ }
+}
+
+// scheduleDelayedPing 在 yield 后延迟一段时间对 context 发送 Ping 探测。
+// 通过 pendingPingTimer 去重:同一 context 同时只保留一个延迟 Ping,
+// 连续 yield 只 Reset timer 而不创建新 goroutine,避免高并发下 goroutine 堆积。
+func (p *openAIWSIngressContextPool) scheduleDelayedPing(c *openAIWSIngressContext, delay time.Duration) {
+ if p == nil || c == nil || delay <= 0 {
+ return
+ }
+ c.mu.Lock()
+ if c.pendingPingTimer != nil {
+ // 已有 pending ping,只需 Reset timer 延迟窗口
+ c.pendingPingTimer.Reset(delay)
+ c.mu.Unlock()
+ return
+ }
+ timer := time.NewTimer(delay)
+ c.pendingPingTimer = timer
+ c.mu.Unlock()
+
+ go func() {
+ select {
+ case <-p.stopCh:
+ timer.Stop()
+ case <-timer.C:
+ p.pingContextUpstream(c)
+ }
+ c.mu.Lock()
+ if c.pendingPingTimer == timer {
+ c.pendingPingTimer = nil
+ }
+ c.mu.Unlock()
+ }()
+}
+
+func (p *openAIWSIngressContextPool) sweepExpiredIdleContexts() {
+ if p == nil {
+ return
+ }
+ now := time.Now()
+
+ type accountSnapshot struct {
+ accountID int64
+ ap *openAIWSIngressAccountPool
+ }
+
+ snapshots := make([]accountSnapshot, 0, len(p.accounts))
+ p.mu.Lock()
+ for accountID, ap := range p.accounts {
+ if ap == nil {
+ delete(p.accounts, accountID)
+ continue
+ }
+ snapshots = append(snapshots, accountSnapshot{accountID: accountID, ap: ap})
+ }
+ p.mu.Unlock()
+
+ removable := make([]accountSnapshot, 0)
+ for _, item := range snapshots {
+ ap := item.ap
+ ap.mu.Lock()
+ toClose := p.evictExpiredIdleLocked(ap, now)
+ agedClose := p.closeAgedIdleUpstreamsLocked(ap, now)
+ empty := len(ap.contexts) == 0
+ // 动态缩容:将 dynamicCap 收缩到 max(1, 当前 context 数量)
+ shrinkTarget := int32(len(ap.contexts))
+ if shrinkTarget < 1 {
+ shrinkTarget = 1
+ }
+ if ap.dynamicCap.Load() > shrinkTarget {
+ ap.dynamicCap.Store(shrinkTarget)
+ }
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(toClose)
+ closeOpenAIWSClientConns(agedClose)
+ // 后台 Ping 探测:对剩余空闲连接发送 Ping,及时剔除死连接
+ if !empty {
+ p.pingIdleUpstreams(ap)
+ }
+ if empty && ap.refs.Load() == 0 {
+ removable = append(removable, item)
+ }
+ }
+
+ if len(removable) == 0 {
+ return
+ }
+
+ p.mu.Lock()
+ for _, item := range removable {
+ existing := p.accounts[item.accountID]
+ if existing != item.ap || existing == nil {
+ continue
+ }
+ if existing.refs.Load() != 0 {
+ continue
+ }
+ delete(p.accounts, item.accountID)
+ }
+ p.mu.Unlock()
+}
+
+func openAIWSIngressContextSessionKey(groupID int64, sessionHash string) string {
+ hash := strings.TrimSpace(sessionHash)
+ if hash == "" {
+ return ""
+ }
+ return strconv.FormatInt(groupID, 10) + ":" + hash
+}
+
+func (l *openAIWSIngressContextLease) ConnID() string {
+ if l == nil || l.context == nil {
+ return ""
+ }
+ l.context.mu.Lock()
+ defer l.context.mu.Unlock()
+ return strings.TrimSpace(l.context.upstreamConnID)
+}
+
+func (l *openAIWSIngressContextLease) QueueWaitDuration() time.Duration {
+ if l == nil {
+ return 0
+ }
+ return l.queueWait
+}
+
+func (l *openAIWSIngressContextLease) ConnPickDuration() time.Duration {
+ if l == nil {
+ return 0
+ }
+ return l.connPick
+}
+
+func (l *openAIWSIngressContextLease) Reused() bool {
+ if l == nil {
+ return false
+ }
+ return l.reused
+}
+
+func (l *openAIWSIngressContextLease) ScheduleLayer() string {
+ if l == nil {
+ return ""
+ }
+ return strings.TrimSpace(l.scheduleLayer)
+}
+
+func (l *openAIWSIngressContextLease) StickinessLevel() string {
+ if l == nil {
+ return ""
+ }
+ return strings.TrimSpace(l.stickiness)
+}
+
+func (l *openAIWSIngressContextLease) MigrationUsed() bool {
+ if l == nil {
+ return false
+ }
+ return l.migrationUsed
+}
+
+func (l *openAIWSIngressContextLease) HandshakeHeader(name string) string {
+ if l == nil || l.context == nil {
+ return ""
+ }
+ l.context.mu.Lock()
+ defer l.context.mu.Unlock()
+ if l.context.handshakeHeaders == nil {
+ return ""
+ }
+ return strings.TrimSpace(l.context.handshakeHeaders.Get(strings.TrimSpace(name)))
+}
+
+func (l *openAIWSIngressContextLease) IsPrewarmed() bool {
+ if l == nil || l.context == nil {
+ return false
+ }
+ return l.context.prewarmed.Load()
+}
+
+func (l *openAIWSIngressContextLease) MarkPrewarmed() {
+ if l == nil || l.context == nil {
+ return
+ }
+ l.context.prewarmed.Store(true)
+}
+
+func (l *openAIWSIngressContextLease) activeConn() (openAIWSClientConn, error) {
+ if l == nil || l.context == nil || l.released.Load() {
+ return nil, errOpenAIWSConnClosed
+ }
+ // Fast path: return cached conn without mutex if lease is still valid.
+ l.cachedConnMu.RLock()
+ cc := l.cachedConn
+ l.cachedConnMu.RUnlock()
+ if cc != nil {
+ return cc, nil
+ }
+ // Slow path: acquire mutex, validate ownership, cache result.
+ l.context.mu.Lock()
+ defer l.context.mu.Unlock()
+ if l.context.ownerID != l.ownerID {
+ return nil, errOpenAIWSConnClosed
+ }
+ if l.context.upstream == nil {
+ return nil, errOpenAIWSConnClosed
+ }
+ l.cachedConnMu.Lock()
+ l.cachedConn = l.context.upstream
+ l.cachedConnMu.Unlock()
+ return l.context.upstream, nil
+}
+
+func (l *openAIWSIngressContextLease) invalidateCachedConnOnIOError(err error) {
+ if l == nil || err == nil {
+ return
+ }
+ l.cachedConnMu.Lock()
+ l.cachedConn = nil
+ l.cachedConnMu.Unlock()
+ if l.pool != nil && l.context != nil && isOpenAIWSClientDisconnectError(err) {
+ l.pool.markContextBroken(l.context)
+ }
+}
+
+func (l *openAIWSIngressContextLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error {
+ conn, err := l.activeConn()
+ if err != nil {
+ return err
+ }
+ writeCtx := ctx
+ if writeCtx == nil {
+ writeCtx = context.Background()
+ }
+ if timeout > 0 {
+ var cancel context.CancelFunc
+ writeCtx, cancel = context.WithTimeout(writeCtx, timeout)
+ defer cancel()
+ }
+ if err := conn.WriteJSON(writeCtx, value); err != nil {
+ l.invalidateCachedConnOnIOError(err)
+ return err
+ }
+ l.context.maybeTouchLease(l.pool.idleTTL)
+ return nil
+}
+
+func (l *openAIWSIngressContextLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) {
+ conn, err := l.activeConn()
+ if err != nil {
+ return nil, err
+ }
+ readCtx := ctx
+ if readCtx == nil {
+ readCtx = context.Background()
+ }
+ if timeout > 0 {
+ var cancel context.CancelFunc
+ readCtx, cancel = context.WithTimeout(readCtx, timeout)
+ defer cancel()
+ }
+ payload, err := conn.ReadMessage(readCtx)
+ if err != nil {
+ l.invalidateCachedConnOnIOError(err)
+ return nil, err
+ }
+ l.context.maybeTouchLease(l.pool.idleTTL)
+ return payload, nil
+}
+
+func (l *openAIWSIngressContextLease) PingWithTimeout(timeout time.Duration) error {
+ conn, err := l.activeConn()
+ if err != nil {
+ return err
+ }
+ pingTimeout := timeout
+ if pingTimeout <= 0 {
+ pingTimeout = openAIWSConnHealthCheckTO
+ }
+ pingCtx, cancel := context.WithTimeout(context.Background(), pingTimeout)
+ defer cancel()
+ if err := conn.Ping(pingCtx); err != nil {
+ l.invalidateCachedConnOnIOError(err)
+ return err
+ }
+ l.context.maybeTouchLease(l.pool.idleTTL)
+ return nil
+}
+
+func (l *openAIWSIngressContextLease) MarkBroken() {
+ if l == nil || l.pool == nil || l.context == nil || l.released.Load() {
+ return
+ }
+ l.cachedConnMu.Lock()
+ l.cachedConn = nil
+ l.cachedConnMu.Unlock()
+ l.pool.markContextBroken(l.context)
+}
+
+func (l *openAIWSIngressContextLease) Release() {
+ if l == nil || l.context == nil || l.pool == nil {
+ return
+ }
+ if !l.released.CompareAndSwap(false, true) {
+ return
+ }
+ l.cachedConnMu.Lock()
+ l.cachedConn = nil
+ l.cachedConnMu.Unlock()
+ l.pool.releaseContext(l.context, l.ownerID)
+}
+
+func (l *openAIWSIngressContextLease) Yield() {
+ if l == nil || l.context == nil || l.pool == nil {
+ return
+ }
+ if !l.released.CompareAndSwap(false, true) {
+ return
+ }
+ l.cachedConnMu.Lock()
+ l.cachedConn = nil
+ l.cachedConnMu.Unlock()
+ l.pool.yieldContext(l.context, l.ownerID)
+}
diff --git a/backend/internal/service/openai_ws_ingress_context_pool_test.go b/backend/internal/service/openai_ws_ingress_context_pool_test.go
new file mode 100644
index 000000000..f5a6802c5
--- /dev/null
+++ b/backend/internal/service/openai_ws_ingress_context_pool_test.go
@@ -0,0 +1,2496 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "net"
+ "net/http"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+func TestOpenAIWSIngressContextPool_Acquire_HardCapacityEqualsConcurrency(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 801, Concurrency: 1}
+
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 2,
+ SessionHash: "session_a",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+
+ _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 2,
+ SessionHash: "session_b",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "并发=1 时第二个并发 owner 不应获取到 context")
+
+ lease1.Release()
+
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 2,
+ SessionHash: "session_b",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, openAIWSIngressScheduleLayerMigration, lease2.ScheduleLayer())
+ require.Equal(t, openAIWSIngressStickinessWeak, lease2.StickinessLevel())
+ require.True(t, lease2.MigrationUsed())
+ lease2.Release()
+
+ require.Equal(t, 2, dialer.DialCount(), "会话回收复用 context 后应重建上游连接,避免跨会话污染")
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_RespectsGlobalHardCap(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+ cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 802, Concurrency: 10}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 3,
+ SessionHash: "session_a",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ HasPreviousResponseID: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 3,
+ SessionHash: "session_b",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ HasPreviousResponseID: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+
+ _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 3,
+ SessionHash: "session_c",
+ OwnerID: "owner_c",
+ WSURL: "ws://test-upstream",
+ HasPreviousResponseID: true,
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "账号并发高于全局硬上限时,context pool 仍应被硬上限约束")
+
+ lease1.Release()
+ lease2.Release()
+ require.Equal(t, 2, dialer.DialCount())
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_DoesNotCrossAccount(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ accountA := &Account{ID: 901, Concurrency: 1}
+ accountB := &Account{ID: 902, Concurrency: 1}
+
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ leaseA, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: accountA,
+ GroupID: 5,
+ SessionHash: "same_session_hash",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, leaseA)
+ leaseA.Release()
+
+ leaseB, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: accountB,
+ GroupID: 5,
+ SessionHash: "same_session_hash",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, leaseB)
+ leaseB.Release()
+
+ require.Equal(t, 2, dialer.DialCount(), "相同 session_hash 在不同账号下必须使用不同 context,不允许跨账号复用")
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_StrongStickinessDisablesMigration(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 1001, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 9,
+ SessionHash: "session_keep_strong_a",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Release()
+
+ _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 9,
+ SessionHash: "session_keep_strong_b",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ HasPreviousResponseID: true,
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "strong 粘连不应迁移其它 session 的 context")
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_AdaptiveStickinessDowngradesAfterFailure(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 1002, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 11,
+ SessionHash: "session_adaptive",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.MarkBroken()
+ lease1.Release()
+
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 11,
+ SessionHash: "session_adaptive",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ HasPreviousResponseID: true,
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, openAIWSIngressScheduleLayerExact, lease2.ScheduleLayer())
+ require.Equal(t, openAIWSIngressStickinessBalanced, lease2.StickinessLevel(), "失败后应从 strong 自适应降级到 balanced")
+ lease2.Release()
+ require.Equal(t, 2, dialer.DialCount(), "故障后重连同一 context 应重新建立上游连接")
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_EnsureFailureReleasesOwner(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ initialDialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(initialDialer)
+
+ account := &Account{ID: 1101, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 12,
+ SessionHash: "session_owner_release",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Release()
+
+ failDialer := &openAIWSAlwaysFailDialer{}
+ pool.setClientDialerForTest(failDialer)
+ _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 12,
+ SessionHash: "session_owner_release",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.Error(t, err)
+ require.NotErrorIs(t, err, errOpenAIWSIngressContextBusy, "ensure 上游失败后不应遗留 owner 导致 context 长时间 busy")
+
+ recoverDialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(recoverDialer)
+
+ lease3, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 12,
+ SessionHash: "session_owner_release",
+ OwnerID: "owner_c",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err, "owner 回滚后应允许后续会话重新获取同一 context")
+ require.NotNil(t, lease3)
+ lease3.Release()
+ require.Equal(t, 1, failDialer.DialCount())
+ require.Equal(t, 1, recoverDialer.DialCount())
+}
+
+func TestOpenAIWSIngressContextPool_Release_ClosesUpstreamAndForcesRedial(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstreamConn1 := &openAIWSCaptureConn{}
+ upstreamConn2 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ upstreamConn1,
+ upstreamConn2,
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 1102, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 13,
+ SessionHash: "session_same",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ connID1 := lease1.ConnID()
+ require.NotEmpty(t, connID1)
+ lease1.Release()
+
+ upstreamConn1.mu.Lock()
+ closed1 := upstreamConn1.closed
+ upstreamConn1.mu.Unlock()
+ require.True(t, closed1, "客户端会话结束后应关闭对应上游连接,防止复用污染")
+
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 13,
+ SessionHash: "session_same",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ connID2 := lease2.ConnID()
+ require.NotEmpty(t, connID2)
+ require.NotEqual(t, connID1, connID2, "下一次会话必须重新建立上游连接")
+ lease2.Release()
+
+ upstreamConn2.mu.Lock()
+ closed2 := upstreamConn2.closed
+ upstreamConn2.mu.Unlock()
+ require.True(t, closed2)
+ require.Equal(t, 2, dialer.DialCount())
+}
+
+func TestOpenAIWSIngressContextPool_Yield_ReleasesOwnerKeepsUpstream(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstreamConn := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{upstreamConn},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 1103, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 14,
+ SessionHash: "session_yield",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ connID1 := lease1.ConnID()
+ require.NotEmpty(t, connID1)
+
+ lease1.Yield()
+ upstreamConn.mu.Lock()
+ closedAfterYield := upstreamConn.closed
+ upstreamConn.mu.Unlock()
+ require.False(t, closedAfterYield, "yield 只应释放 owner,不应关闭上游连接")
+
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 14,
+ SessionHash: "session_yield",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, connID1, lease2.ConnID(), "yield 后应复用同一上游连接")
+ require.Equal(t, 1, dialer.DialCount(), "yield 后重新获取不应触发重拨号")
+
+ lease2.Release()
+ upstreamConn.mu.Lock()
+ closedAfterRelease := upstreamConn.closed
+ upstreamConn.mu.Unlock()
+ require.True(t, closedAfterRelease, "release 仍需关闭上游连接")
+}
+
+func TestOpenAIWSIngressContextPool_EvictExpiredIdleLocked_ClosesUpstream(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstreamConn := &openAIWSCaptureConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ expiredCtx := &openAIWSIngressContext{
+ id: "ctx_expired_1",
+ groupID: 21,
+ accountID: 1201,
+ sessionHash: "session_expired",
+ sessionKey: openAIWSIngressContextSessionKey(21, "session_expired"),
+ upstream: upstreamConn,
+ upstreamConnID: "ctxws_1201_1",
+ handshakeHeaders: map[string][]string{"x-test": {"ok"}},
+ }
+ expiredCtx.setExpiresAt(time.Now().Add(-2 * time.Second))
+ ap.contexts[expiredCtx.id] = expiredCtx
+ ap.bySession[expiredCtx.sessionKey] = expiredCtx.id
+
+ ap.mu.Lock()
+ toClose := pool.evictExpiredIdleLocked(ap, time.Now())
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(toClose)
+
+ require.Empty(t, ap.contexts, "过期 idle context 应被清理")
+ require.Empty(t, ap.bySession, "过期 context 的 session 索引应同步清理")
+ upstreamConn.mu.Lock()
+ closed := upstreamConn.closed
+ upstreamConn.mu.Unlock()
+ require.True(t, closed, "清理过期 context 时应关闭残留上游连接,避免泄漏")
+}
+
+func TestOpenAIWSIngressContextPool_ScoreAndStickinessHelpers(t *testing.T) {
+ now := time.Now()
+
+ require.Equal(t, 1, minInt(1, 2))
+ require.Equal(t, 2, minInt(3, 2))
+
+ require.Equal(t, openAIWSIngressStickinessBalanced, openAIWSIngressStickinessDowngrade(openAIWSIngressStickinessStrong))
+ require.Equal(t, openAIWSIngressStickinessWeak, openAIWSIngressStickinessDowngrade(openAIWSIngressStickinessBalanced))
+ require.Equal(t, openAIWSIngressStickinessWeak, openAIWSIngressStickinessDowngrade("unknown"))
+
+ require.Equal(t, openAIWSIngressStickinessBalanced, openAIWSIngressStickinessUpgrade(openAIWSIngressStickinessWeak))
+ require.Equal(t, openAIWSIngressStickinessStrong, openAIWSIngressStickinessUpgrade(openAIWSIngressStickinessBalanced))
+ require.Equal(t, openAIWSIngressStickinessStrong, openAIWSIngressStickinessUpgrade("unknown"))
+
+ allowStrong, scoreStrong := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessStrong)
+ require.False(t, allowStrong)
+ require.Equal(t, 80.0, scoreStrong)
+ allowBalanced, scoreBalanced := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessBalanced)
+ require.True(t, allowBalanced)
+ require.Equal(t, 65.0, scoreBalanced)
+ allowWeak, scoreWeak := openAIWSIngressMigrationPolicyByStickiness("weak_or_unknown")
+ require.True(t, allowWeak)
+ require.Equal(t, 40.0, scoreWeak)
+
+ busyCtx := &openAIWSIngressContext{ownerID: "owner_busy"}
+ _, _, ok := scoreOpenAIWSIngressMigrationCandidate(busyCtx, now, nil)
+ require.False(t, ok, "owner 占用中的 context 不应作为迁移候选")
+
+ oldIdle := &openAIWSIngressContext{}
+ oldIdle.setLastUsedAt(now.Add(-5 * time.Minute))
+ recentIdle := &openAIWSIngressContext{}
+ recentIdle.setLastUsedAt(now.Add(-10 * time.Second))
+ scoreOld, _, okOld := scoreOpenAIWSIngressMigrationCandidate(oldIdle, now, nil)
+ scoreRecent, _, okRecent := scoreOpenAIWSIngressMigrationCandidate(recentIdle, now, nil)
+ require.True(t, okOld)
+ require.True(t, okRecent)
+ require.Greater(t, scoreOld, scoreRecent, "更久未使用的空闲 context 应该更易被迁移")
+
+ penalized := &openAIWSIngressContext{
+ broken: true,
+ failureStreak: 2,
+ lastFailureAt: now.Add(-30 * time.Second),
+ migrationCount: 2,
+ lastMigrationAt: now.Add(-10 * time.Second),
+ }
+ penalized.setLastUsedAt(now.Add(-5 * time.Minute))
+ scorePenalized, _, okPenalized := scoreOpenAIWSIngressMigrationCandidate(penalized, now, nil)
+ require.True(t, okPenalized)
+ require.Less(t, scorePenalized, scoreOld, "近期失败和频繁迁移应降低迁移分数")
+}
+
+func TestOpenAIWSIngressContextPool_EvictPickAndSweep(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ now := time.Now()
+ expiredConn := &openAIWSCaptureConn{}
+ expiredCtx := &openAIWSIngressContext{
+ id: "ctx_expired",
+ sessionKey: "1:expired",
+ upstream: expiredConn,
+ upstreamConnID: "ctxws_expired",
+ }
+ expiredCtx.setLastUsedAt(now.Add(-20 * time.Minute))
+ expiredCtx.setExpiresAt(now.Add(-time.Minute))
+
+ idleNewCtx := &openAIWSIngressContext{
+ id: "ctx_idle_new",
+ sessionKey: "1:idle_new",
+ }
+ idleNewCtx.setLastUsedAt(now.Add(-30 * time.Second))
+ idleNewCtx.setExpiresAt(now.Add(time.Minute))
+
+ busyCtx := &openAIWSIngressContext{
+ id: "ctx_busy",
+ sessionKey: "1:busy",
+ ownerID: "active_owner",
+ }
+ busyCtx.setLastUsedAt(now.Add(-40 * time.Minute))
+ busyCtx.setExpiresAt(now.Add(-time.Minute))
+
+ ap := &openAIWSIngressAccountPool{
+ contexts: map[string]*openAIWSIngressContext{
+ "ctx_expired": expiredCtx,
+ "ctx_idle_new": idleNewCtx,
+ "ctx_busy": busyCtx,
+ },
+ bySession: map[string]string{
+ "1:expired": "ctx_expired",
+ "1:idle_new": "ctx_idle_new",
+ "1:busy": "ctx_busy",
+ },
+ }
+
+ ap.mu.Lock()
+ oldestIdle := pool.pickOldestIdleContextLocked(ap)
+ ap.mu.Unlock()
+ require.NotNil(t, oldestIdle)
+ require.Equal(t, "ctx_expired", oldestIdle.id, "应选择最旧的空闲 context")
+
+ ap.mu.Lock()
+ toClose := pool.evictExpiredIdleLocked(ap, now)
+ ap.mu.Unlock()
+ closeOpenAIWSClientConns(toClose)
+ require.NotContains(t, ap.contexts, "ctx_expired")
+ require.NotContains(t, ap.bySession, "1:expired")
+ require.Contains(t, ap.contexts, "ctx_idle_new", "未过期空闲 context 应保留")
+ require.Contains(t, ap.contexts, "ctx_busy", "有 owner 的 context 不应被 idle 过期清理")
+ expiredConn.mu.Lock()
+ expiredClosed := expiredConn.closed
+ expiredConn.mu.Unlock()
+ require.True(t, expiredClosed, "清理过期 idle context 时应关闭上游连接")
+
+ expiredInPoolConn := &openAIWSCaptureConn{}
+ pool.mu.Lock()
+ pool.accounts[5001] = ap
+ poolExpiredCtx := &openAIWSIngressContext{
+ id: "ctx_pool_expired",
+ sessionKey: "2:expired",
+ upstream: expiredInPoolConn,
+ }
+ poolExpiredCtx.setExpiresAt(now.Add(-time.Minute))
+ pool.accounts[5002] = &openAIWSIngressAccountPool{
+ contexts: map[string]*openAIWSIngressContext{
+ "ctx_pool_expired": poolExpiredCtx,
+ },
+ bySession: map[string]string{
+ "2:expired": "ctx_pool_expired",
+ },
+ }
+ pool.mu.Unlock()
+
+ pool.sweepExpiredIdleContexts()
+
+ pool.mu.Lock()
+ _, account2Exists := pool.accounts[5002]
+ account1 := pool.accounts[5001]
+ pool.mu.Unlock()
+ require.False(t, account2Exists, "sweep 后空账号应被移除")
+ require.NotNil(t, account1, "非空账号应保留")
+ expiredInPoolConn.mu.Lock()
+ sweptClosed := expiredInPoolConn.closed
+ expiredInPoolConn.mu.Unlock()
+ require.True(t, sweptClosed)
+}
+
+func TestOpenAIWSIngressContextLease_AccessorsAndPingGuards(t *testing.T) {
+ var nilLease *openAIWSIngressContextLease
+ require.Equal(t, "", nilLease.ConnID())
+ require.Zero(t, nilLease.QueueWaitDuration())
+ require.Zero(t, nilLease.ConnPickDuration())
+ require.False(t, nilLease.Reused())
+ require.Equal(t, "", nilLease.ScheduleLayer())
+ require.Equal(t, "", nilLease.StickinessLevel())
+ require.False(t, nilLease.MigrationUsed())
+ require.Equal(t, "", nilLease.HandshakeHeader("x-test"))
+ require.ErrorIs(t, nilLease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed)
+
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ ctxItem := &openAIWSIngressContext{
+ id: "ctx_lease",
+ ownerID: "owner_ok",
+ upstream: &openAIWSFakeConn{},
+ handshakeHeaders: http.Header{"X-Test": []string{"ok"}},
+ }
+ lease := &openAIWSIngressContextLease{
+ pool: pool,
+ context: ctxItem,
+ ownerID: "owner_ok",
+ queueWait: 5 * time.Millisecond,
+ connPick: 8 * time.Millisecond,
+ reused: true,
+ scheduleLayer: openAIWSIngressScheduleLayerExact,
+ stickiness: openAIWSIngressStickinessBalanced,
+ migrationUsed: true,
+ }
+
+ require.Equal(t, "ok", lease.HandshakeHeader("x-test"))
+ require.Equal(t, 5*time.Millisecond, lease.QueueWaitDuration())
+ require.Equal(t, 8*time.Millisecond, lease.ConnPickDuration())
+ require.True(t, lease.Reused())
+ require.Equal(t, openAIWSIngressScheduleLayerExact, lease.ScheduleLayer())
+ require.Equal(t, openAIWSIngressStickinessBalanced, lease.StickinessLevel())
+ require.True(t, lease.MigrationUsed())
+ require.NoError(t, lease.PingWithTimeout(0), "timeout=0 应回退默认 ping 超时")
+
+ lease.released.Store(true)
+ require.ErrorIs(t, lease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed)
+ lease.released.Store(false)
+
+ ctxItem.mu.Lock()
+ ctxItem.ownerID = "other_owner"
+ ctxItem.mu.Unlock()
+ lease.cachedConn = nil // clear cache to force re-validation (simulates migration)
+ require.ErrorIs(t, lease.PingWithTimeout(time.Millisecond), errOpenAIWSConnClosed, "owner 不匹配时应拒绝访问")
+
+ ctxItem.mu.Lock()
+ ctxItem.ownerID = "owner_ok"
+ ctxItem.upstream = &openAIWSPingFailConn{}
+ ctxItem.mu.Unlock()
+ lease.cachedConn = nil // clear cache to pick up new upstream
+ require.Error(t, lease.PingWithTimeout(time.Millisecond), "上游 ping 失败应透传错误")
+
+ lease.Release()
+ lease.Release()
+ require.Equal(t, "", lease.context.ownerID, "重复 Release 应幂等且不会 panic")
+}
+
+func TestOpenAIWSIngressContextPool_EnsureContextUpstreamBranches(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ ctxItem := &openAIWSIngressContext{
+ id: "ctx_ensure",
+ accountID: 1,
+ ownerID: "owner",
+ upstream: &openAIWSFakeConn{},
+ }
+
+ reused, err := pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{
+ WSURL: "ws://test",
+ })
+ require.NoError(t, err)
+ require.True(t, reused, "已有可用 upstream 时应直接复用")
+
+ pool.dialer = nil
+ ctxItem.mu.Lock()
+ ctxItem.broken = true
+ ctxItem.mu.Unlock()
+ _, err = pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{
+ WSURL: "ws://test",
+ })
+ require.ErrorContains(t, err, "dialer is nil")
+
+ failDialer := &openAIWSAlwaysFailDialer{}
+ pool.setClientDialerForTest(failDialer)
+ _, err = pool.ensureContextUpstream(context.Background(), ctxItem, openAIWSIngressContextAcquireRequest{
+ WSURL: "ws://test",
+ })
+ require.Error(t, err)
+ var dialErr *openAIWSDialError
+ require.ErrorAs(t, err, &dialErr, "dial 失败应包装为 openAIWSDialError")
+ require.Equal(t, 503, dialErr.StatusCode)
+ ctxItem.mu.Lock()
+ broken := ctxItem.broken
+ failureStreak := ctxItem.failureStreak
+ ctxItem.mu.Unlock()
+ require.True(t, broken)
+ require.GreaterOrEqual(t, failureStreak, 1, "dial 失败后应累计 failure_streak")
+}
+
+func TestOpenAIWSIngressContextPool_MarkBrokenDoesNotSignalWaiterBeforeRelease(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstream := &openAIWSCaptureConn{}
+ ctxItem := &openAIWSIngressContext{
+ id: "ctx_mark_broken",
+ ownerID: "owner_broken",
+ releaseDone: make(chan struct{}, 1),
+ upstream: upstream,
+ }
+
+ pool.markContextBroken(ctxItem)
+
+ select {
+ case <-ctxItem.releaseDone:
+ t.Fatal("markContextBroken should not wake waiters before owner is released")
+ default:
+ }
+
+ ctxItem.mu.Lock()
+ require.True(t, ctxItem.broken)
+ require.Equal(t, "owner_broken", ctxItem.ownerID)
+ require.Nil(t, ctxItem.upstream)
+ ctxItem.mu.Unlock()
+
+ upstream.mu.Lock()
+ require.True(t, upstream.closed, "mark broken should close current upstream connection")
+ upstream.mu.Unlock()
+
+ pool.releaseContext(ctxItem, "owner_broken")
+
+ select {
+ case <-ctxItem.releaseDone:
+ case <-time.After(200 * time.Millisecond):
+ t.Fatal("releaseContext should signal one waiting acquire after owner is released")
+ }
+
+ ctxItem.mu.Lock()
+ require.Equal(t, "", ctxItem.ownerID)
+ require.False(t, ctxItem.broken)
+ ctxItem.mu.Unlock()
+}
+
+type openAIWSWriteDisconnectConn struct{}
+
+func (c *openAIWSWriteDisconnectConn) WriteJSON(context.Context, any) error {
+ return net.ErrClosed
+}
+
+func (c *openAIWSWriteDisconnectConn) ReadMessage(context.Context) ([]byte, error) {
+ return nil, net.ErrClosed
+}
+
+func (c *openAIWSWriteDisconnectConn) Ping(context.Context) error {
+ return net.ErrClosed
+}
+
+func (c *openAIWSWriteDisconnectConn) Close() error {
+ return nil
+}
+
+type openAIWSWriteGenericFailConn struct{}
+
+func (c *openAIWSWriteGenericFailConn) WriteJSON(context.Context, any) error {
+ return errors.New("writer failed")
+}
+
+func (c *openAIWSWriteGenericFailConn) ReadMessage(context.Context) ([]byte, error) {
+ return nil, errors.New("reader failed")
+}
+
+func (c *openAIWSWriteGenericFailConn) Ping(context.Context) error {
+ return errors.New("ping failed")
+}
+
+func (c *openAIWSWriteGenericFailConn) Close() error {
+ return nil
+}
+
+func TestOpenAIWSIngressContextLease_IOErrorInvalidatesCachedConn(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstream := &openAIWSWriteDisconnectConn{}
+ ctxItem := &openAIWSIngressContext{
+ id: "ctx_io_err",
+ accountID: 7,
+ ownerID: "owner_io",
+ upstream: upstream,
+ }
+ lease := &openAIWSIngressContextLease{
+ pool: pool,
+ context: ctxItem,
+ ownerID: "owner_io",
+ }
+ lease.cachedConn = upstream
+
+ err := lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
+ require.Error(t, err)
+ require.ErrorIs(t, err, net.ErrClosed)
+
+ lease.cachedConnMu.RLock()
+ cached := lease.cachedConn
+ lease.cachedConnMu.RUnlock()
+ require.Nil(t, cached, "write failure should invalidate cachedConn")
+
+ ctxItem.mu.Lock()
+ require.True(t, ctxItem.broken, "disconnect-style IO failure should mark context broken")
+ require.Nil(t, ctxItem.upstream, "broken context should drop upstream reference")
+ ctxItem.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextLease_GenericIOErrorKeepsContextButInvalidatesCache(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ upstream := &openAIWSWriteGenericFailConn{}
+ ctxItem := &openAIWSIngressContext{
+ id: "ctx_generic_err",
+ accountID: 8,
+ ownerID: "owner_generic",
+ upstream: upstream,
+ }
+ lease := &openAIWSIngressContextLease{
+ pool: pool,
+ context: ctxItem,
+ ownerID: "owner_generic",
+ }
+ lease.cachedConn = upstream
+
+ err := lease.PingWithTimeout(time.Second)
+ require.Error(t, err)
+
+ lease.cachedConnMu.RLock()
+ cached := lease.cachedConn
+ lease.cachedConnMu.RUnlock()
+ require.Nil(t, cached, "generic IO failure should still invalidate cachedConn")
+
+ ctxItem.mu.Lock()
+ require.False(t, ctxItem.broken, "non-disconnect IO failure should not force-broken context")
+ require.Equal(t, upstream, ctxItem.upstream, "upstream should remain for non-disconnect errors")
+ ctxItem.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextPool_EnsureContextUpstream_SerializesConcurrentDial(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ releaseDial := make(chan struct{})
+ blockingDialer := &openAIWSBlockingDialer{
+ release: releaseDial,
+ dialStarted: make(chan struct{}, 4),
+ }
+ pool.setClientDialerForTest(blockingDialer)
+
+ account := &Account{ID: 1301, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ type acquireResult struct {
+ lease *openAIWSIngressContextLease
+ err error
+ }
+ resultCh := make(chan acquireResult, 2)
+ acquireOnce := func() {
+ lease, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 23,
+ SessionHash: "session_same_owner",
+ OwnerID: "owner_same",
+ WSURL: "ws://test-upstream",
+ })
+ resultCh <- acquireResult{lease: lease, err: err}
+ }
+
+ go acquireOnce()
+ select {
+ case <-blockingDialer.dialStarted:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("首个 dial 未按预期启动")
+ }
+ go acquireOnce()
+
+ select {
+ case <-blockingDialer.dialStarted:
+ t.Fatal("同一 context 并发 acquire 不应触发第二次 dial")
+ case <-time.After(120 * time.Millisecond):
+ }
+
+ close(releaseDial)
+
+ results := make([]acquireResult, 0, 2)
+ for i := 0; i < 2; i++ {
+ select {
+ case result := <-resultCh:
+ require.NoError(t, result.err)
+ require.NotNil(t, result.lease)
+ results = append(results, result)
+ case <-time.After(2 * time.Second):
+ t.Fatal("等待并发 acquire 结果超时")
+ }
+ }
+
+ for _, result := range results {
+ result.lease.Release()
+ }
+ require.Equal(t, 1, blockingDialer.DialCount(), "同一 context 并发获取应只发生一次上游拨号")
+}
+
+func TestOpenAIWSIngressContextPool_EnsureContextUpstream_WaiterTimeoutDoesNotReleaseOwner(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ releaseDial := make(chan struct{})
+ blockingDialer := &openAIWSBlockingDialer{
+ release: releaseDial,
+ dialStarted: make(chan struct{}, 4),
+ }
+ pool.setClientDialerForTest(blockingDialer)
+
+ account := &Account{ID: 1302, Concurrency: 1}
+ baseReq := openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 24,
+ SessionHash: "session_waiter_timeout",
+ OwnerID: "owner_same",
+ WSURL: "ws://test-upstream",
+ }
+
+ longCtx, longCancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer longCancel()
+ type acquireResult struct {
+ lease *openAIWSIngressContextLease
+ err error
+ }
+ firstResultCh := make(chan acquireResult, 1)
+ go func() {
+ lease, err := pool.Acquire(longCtx, baseReq)
+ firstResultCh <- acquireResult{lease: lease, err: err}
+ }()
+
+ select {
+ case <-blockingDialer.dialStarted:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("首个 dial 未按预期启动")
+ }
+
+ shortCtx, shortCancel := context.WithTimeout(context.Background(), 60*time.Millisecond)
+ defer shortCancel()
+ _, waiterErr := pool.Acquire(shortCtx, baseReq)
+ require.ErrorIs(t, waiterErr, context.DeadlineExceeded, "等待中的 acquire 超时应返回 context deadline exceeded")
+
+ close(releaseDial)
+
+ select {
+ case first := <-firstResultCh:
+ require.NoError(t, first.err)
+ require.NotNil(t, first.lease)
+ require.NoError(t, first.lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "ping"}, time.Second), "等待方超时不应释放已建连 owner")
+ first.lease.Release()
+ case <-time.After(2 * time.Second):
+ t.Fatal("等待首个 acquire 结果超时")
+ }
+
+ require.Equal(t, 1, blockingDialer.DialCount())
+}
+
+func TestOpenAIWSIngressContextPool_Acquire_QueueWaitDurationRecorded(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 1303, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 25,
+ SessionHash: "session_queue_wait",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+
+ type acquireResult struct {
+ lease *openAIWSIngressContextLease
+ err error
+ }
+ waiterCh := make(chan acquireResult, 1)
+ go func() {
+ lease, acquireErr := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 25,
+ SessionHash: "session_queue_wait",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ waiterCh <- acquireResult{lease: lease, err: acquireErr}
+ }()
+
+ time.Sleep(120 * time.Millisecond)
+ lease1.Release()
+
+ select {
+ case waiter := <-waiterCh:
+ require.NoError(t, waiter.err)
+ require.NotNil(t, waiter.lease)
+ require.GreaterOrEqual(t, waiter.lease.QueueWaitDuration(), 100*time.Millisecond)
+ waiter.lease.Release()
+ case <-time.After(2 * time.Second):
+ t.Fatal("等待第二个 acquire 结果超时")
+ }
+}
+
+type openAIWSBlockingDialer struct {
+ mu sync.Mutex
+ release <-chan struct{}
+ dialStarted chan struct{}
+ dialCount int
+}
+
+func (d *openAIWSBlockingDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = wsURL
+ _ = headers
+ _ = proxyURL
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ d.mu.Lock()
+ d.dialCount++
+ d.mu.Unlock()
+ select {
+ case d.dialStarted <- struct{}{}:
+ default:
+ }
+ if d.release != nil {
+ select {
+ case <-d.release:
+ case <-ctx.Done():
+ return nil, 0, nil, ctx.Err()
+ }
+ }
+ return &openAIWSCaptureConn{}, 0, nil, nil
+}
+
+func (d *openAIWSBlockingDialer) DialCount() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dialCount
+}
+
+// ---------------------------------------------------------------------------
+// Load-aware migration scoring tests
+// ---------------------------------------------------------------------------
+
+func TestScoreOpenAIWSIngressMigrationCandidate_HighErrorRatePenalty(t *testing.T) {
+ now := time.Now()
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(9001)
+
+ // Report a pattern of failures interspersed with occasional successes.
+ // This pushes the error rate high without tripping the circuit breaker
+ // (consecutive failures stay below the default threshold of 5).
+ for i := 0; i < 6; i++ {
+ stats.report(accountID, false, nil, "", 0)
+ stats.report(accountID, false, nil, "", 0)
+ stats.report(accountID, false, nil, "", 0)
+ stats.report(accountID, true, nil, "", 0) // reset consecutive fail counter
+ }
+ require.False(t, stats.isCircuitOpen(accountID), "circuit breaker should remain closed for this test")
+
+ ctx := &openAIWSIngressContext{accountID: accountID}
+ ctx.setLastUsedAt(now.Add(-5 * time.Minute))
+
+ scoreWithStats, _, okStats := scoreOpenAIWSIngressMigrationCandidate(ctx, now, stats)
+ require.True(t, okStats)
+
+ // Score the same context without stats (nil) for comparison.
+ scoreWithout, _, okWithout := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil)
+ require.True(t, okWithout)
+
+ require.Less(t, scoreWithStats, scoreWithout,
+ "a context on a high-error-rate account should receive a lower migration score")
+
+ // The error rate penalty should be approximately errorRate * 30.
+ // Since the circuit breaker is not open, the only load-aware penalty is
+ // errorRate * 30.
+ errorRate, _, _ := stats.snapshot(accountID)
+ expectedPenalty := errorRate * 30
+ require.InDelta(t, expectedPenalty, scoreWithout-scoreWithStats, 1.0,
+ "penalty should be approximately errorRate * 30")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_CircuitOpenPenalty(t *testing.T) {
+ now := time.Now()
+ stats := newOpenAIAccountRuntimeStats()
+ accountID := int64(9002)
+
+ // Trip the circuit breaker by reporting consecutive failures beyond the
+ // default threshold (5).
+ for i := 0; i < defaultCircuitBreakerFailThreshold+1; i++ {
+ stats.report(accountID, false, nil, "", 0)
+ }
+ require.True(t, stats.isCircuitOpen(accountID), "circuit breaker should be open after many failures")
+
+ ctx := &openAIWSIngressContext{accountID: accountID}
+ ctx.setLastUsedAt(now.Add(-5 * time.Minute))
+
+ scoreCircuitOpen, _, ok := scoreOpenAIWSIngressMigrationCandidate(ctx, now, stats)
+ require.True(t, ok)
+
+ // Score without stats for comparison.
+ scoreBaseline, _, okBase := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil)
+ require.True(t, okBase)
+
+ // The circuit-open penalty is -50, plus errorRate*30, so the score should
+ // be substantially lower.
+ require.Less(t, scoreCircuitOpen, scoreBaseline-45,
+ "a context on a circuit-open account should have a very low migration score")
+
+ // In practice, the combined penalty should bring the score below any
+ // reasonable minimum migration threshold (weakest = 40).
+ _, weakMin := openAIWSIngressMigrationPolicyByStickiness(openAIWSIngressStickinessWeak)
+ require.Less(t, scoreCircuitOpen, weakMin,
+ "circuit-open accounts should score below even the weakest migration threshold")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_NilStatsFallback(t *testing.T) {
+ now := time.Now()
+
+ ctx := &openAIWSIngressContext{accountID: 9003}
+ ctx.setLastUsedAt(now.Add(-5 * time.Minute))
+
+ scoreNil, _, okNil := scoreOpenAIWSIngressMigrationCandidate(ctx, now, nil)
+ require.True(t, okNil)
+
+ // Create stats but report nothing for this account -- snapshot returns 0.
+ emptyStats := newOpenAIAccountRuntimeStats()
+ scoreEmpty, _, okEmpty := scoreOpenAIWSIngressMigrationCandidate(ctx, now, emptyStats)
+ require.True(t, okEmpty)
+
+ // With no data for the account, the load-aware path should add zero
+ // penalty, yielding the same score as nil stats.
+ require.InDelta(t, scoreNil, scoreEmpty, 0.001,
+ "when scheduler stats have no data for the account, score should match nil-stats baseline")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_NilContext(t *testing.T) {
+ now := time.Now()
+ score, _, ok := scoreOpenAIWSIngressMigrationCandidate(nil, now, nil)
+ require.False(t, ok)
+ require.Equal(t, 0.0, score)
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_IdleDurationBranches(t *testing.T) {
+ now := time.Now()
+
+ // Very recently used (≤15s): penalty of -15
+ recentCtx := &openAIWSIngressContext{}
+ recentCtx.setLastUsedAt(now.Add(-5 * time.Second))
+ scoreRecent, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, 100.0-15.0, scoreRecent, 0.5, "very recently used should get -15 penalty")
+
+ // Medium idle (between 15s and 3min): bonus = idleDuration.Seconds() / 12
+ mediumCtx := &openAIWSIngressContext{}
+ mediumCtx.setLastUsedAt(now.Add(-90 * time.Second)) // 90s idle
+ scoreMedium, _, ok := scoreOpenAIWSIngressMigrationCandidate(mediumCtx, now, nil)
+ require.True(t, ok)
+ expectedBonus := 90.0 / 12.0 // 7.5
+ require.InDelta(t, 100.0+expectedBonus, scoreMedium, 0.5, "medium idle should get proportional bonus")
+
+ // Long idle (≥3min): bonus of +16
+ longCtx := &openAIWSIngressContext{}
+ longCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreLong, _, ok := scoreOpenAIWSIngressMigrationCandidate(longCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, 100.0+16.0, scoreLong, 0.5, "long idle should get +16 bonus")
+
+ // Verify ordering: long > medium > recent
+ require.Greater(t, scoreLong, scoreMedium, "longer idle should score higher than medium")
+ require.Greater(t, scoreMedium, scoreRecent, "medium idle should score higher than very recent")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_BrokenAndFailures(t *testing.T) {
+ now := time.Now()
+
+ // Broken context: -30
+ brokenCtx := &openAIWSIngressContext{broken: true}
+ brokenCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreBroken, _, ok := scoreOpenAIWSIngressMigrationCandidate(brokenCtx, now, nil)
+ require.True(t, ok)
+
+ cleanCtx := &openAIWSIngressContext{}
+ cleanCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreClean, _, ok := scoreOpenAIWSIngressMigrationCandidate(cleanCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, 30.0, scoreClean-scoreBroken, 0.5, "broken should subtract 30")
+
+ // High failure streak (capped at 40)
+ highFailCtx := &openAIWSIngressContext{failureStreak: 5}
+ highFailCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreHighFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(highFailCtx, now, nil)
+ require.True(t, ok)
+ // 5*12=60 but capped at 40
+ require.InDelta(t, 40.0, scoreClean-scoreHighFail, 0.5, "failure streak penalty should cap at 40")
+
+ // Recent failure (within 2 min): -18
+ recentFailCtx := &openAIWSIngressContext{lastFailureAt: now.Add(-30 * time.Second)}
+ recentFailCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreRecentFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentFailCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, 18.0, scoreClean-scoreRecentFail, 0.5, "recent failure should subtract 18")
+
+ // Old failure (>2 min): no penalty
+ oldFailCtx := &openAIWSIngressContext{lastFailureAt: now.Add(-5 * time.Minute)}
+ oldFailCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreOldFail, _, ok := scoreOpenAIWSIngressMigrationCandidate(oldFailCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, scoreClean, scoreOldFail, 0.5, "old failure should have no penalty")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_MigrationPenalties(t *testing.T) {
+ now := time.Now()
+
+ cleanCtx := &openAIWSIngressContext{}
+ cleanCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreClean, _, _ := scoreOpenAIWSIngressMigrationCandidate(cleanCtx, now, nil)
+
+ // Recent migration (within 1 min): -10
+ recentMigCtx := &openAIWSIngressContext{lastMigrationAt: now.Add(-30 * time.Second)}
+ recentMigCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreRecentMig, _, ok := scoreOpenAIWSIngressMigrationCandidate(recentMigCtx, now, nil)
+ require.True(t, ok)
+ require.InDelta(t, 10.0, scoreClean-scoreRecentMig, 0.5, "recent migration should subtract 10")
+
+ // High migration count (capped at 20)
+ highMigCtx := &openAIWSIngressContext{migrationCount: 6}
+ highMigCtx.setLastUsedAt(now.Add(-5 * time.Minute))
+ scoreHighMig, _, ok := scoreOpenAIWSIngressMigrationCandidate(highMigCtx, now, nil)
+ require.True(t, ok)
+ // 6*4=24 but capped at 20
+ require.InDelta(t, 20.0, scoreClean-scoreHighMig, 0.5, "migration count penalty should cap at 20")
+}
+
+func TestScoreOpenAIWSIngressMigrationCandidate_CombinedPenalties(t *testing.T) {
+ now := time.Now()
+ // All penalties combined: broken(-30) + failStreak 1*12(-12) + recentFail(-18) + recentMig(-10) + migCount 1*4(-4) + recentIdle(-15)
+ worstCtx := &openAIWSIngressContext{
+ broken: true,
+ failureStreak: 1,
+ lastFailureAt: now.Add(-30 * time.Second),
+ migrationCount: 1,
+ lastMigrationAt: now.Add(-30 * time.Second),
+ }
+ worstCtx.setLastUsedAt(now.Add(-5 * time.Second))
+ score, _, ok := scoreOpenAIWSIngressMigrationCandidate(worstCtx, now, nil)
+ require.True(t, ok)
+ expected := 100.0 - 30.0 - 12.0 - 18.0 - 10.0 - 4.0 - 15.0 // = 11.0
+ require.InDelta(t, expected, score, 0.5, "all penalties should stack correctly")
+}
+
+func TestOpenAIWSIngressContextPool_MigrationBlockedByCircuitBreaker(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ stats := newOpenAIAccountRuntimeStats()
+ pool.schedulerStats = stats
+
+ accountID := int64(9004)
+
+ // Trip circuit breaker for this account.
+ for i := 0; i < defaultCircuitBreakerFailThreshold+1; i++ {
+ stats.report(accountID, false, nil, "", 0)
+ }
+ require.True(t, stats.isCircuitOpen(accountID))
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: accountID, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ // Acquire the only slot.
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 30,
+ SessionHash: "session_cb_a",
+ OwnerID: "owner_a",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Release()
+
+ // Now try a different session -- migration should fail because the only
+ // candidate context is on a circuit-open account, whose score will be
+ // below the minimum threshold.
+ _, err = pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 30,
+ SessionHash: "session_cb_b",
+ OwnerID: "owner_b",
+ WSURL: "ws://test-upstream",
+ })
+ require.ErrorIs(t, err, errOpenAIWSConnQueueFull,
+ "migration to a circuit-open account should be blocked")
+}
+
+// ---------- 连接生命周期管理(超龄轮换)测试 ----------
+
+func TestOpenAIWSIngressContext_UpstreamConnAge_ZeroValue(t *testing.T) {
+ ctx := &openAIWSIngressContext{}
+ // 未设置 createdAt 时,connAge 应返回 0
+ require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now()))
+}
+
+func TestOpenAIWSIngressContext_UpstreamConnAge_Normal(t *testing.T) {
+ ctx := &openAIWSIngressContext{}
+ past := time.Now().Add(-10 * time.Minute)
+ ctx.upstreamConnCreatedAt.Store(past.UnixNano())
+ age := ctx.upstreamConnAge(time.Now())
+ require.True(t, age >= 10*time.Minute-time.Second, "connAge 应约为 10 分钟,实际=%v", age)
+ require.True(t, age < 11*time.Minute, "connAge 不应过大,实际=%v", age)
+}
+
+func TestOpenAIWSIngressContext_UpstreamConnAge_NilSafe(t *testing.T) {
+ var ctx *openAIWSIngressContext
+ require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now()))
+}
+
+func TestOpenAIWSIngressContext_UpstreamConnAge_ClockSkew(t *testing.T) {
+ ctx := &openAIWSIngressContext{}
+ future := time.Now().Add(10 * time.Minute)
+ ctx.upstreamConnCreatedAt.Store(future.UnixNano())
+ // now 早于 createdAt(时钟回拨),应返回 0
+ require.Equal(t, time.Duration(0), ctx.upstreamConnAge(time.Now()))
+}
+
+func TestNewOpenAIWSIngressContextPool_UpstreamMaxAge_ZeroDisablesRotation(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.UpstreamConnMaxAgeSeconds = 0
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ require.Equal(t, time.Duration(0), pool.upstreamMaxAge)
+}
+
+func TestOpenAIWSIngressContextPool_EnsureUpstream_MaxAgeRotate(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 1 * time.Second // 设置极短的 maxAge 以便测试
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ conn2 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{conn1, conn2},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 901, Concurrency: 2}
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ // 第一次 Acquire:建连
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_age",
+ OwnerID: "owner_age_1",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ require.Equal(t, 1, dialer.DialCount(), "首次 Acquire 应 dial 一次")
+
+ // Yield 保留连接
+ lease1.Yield()
+
+ // 等待超过 maxAge
+ time.Sleep(1200 * time.Millisecond)
+
+ // 第二次 Acquire:应触发超龄轮换
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_age",
+ OwnerID: "owner_age_2",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, 2, dialer.DialCount(), "超龄轮换应触发重新 dial")
+ require.True(t, conn1.closed, "旧连接应被关闭")
+ lease2.Release()
+}
+
+func TestOpenAIWSIngressContextPool_EnsureUpstream_YoungConnNotRotated(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 10 * time.Minute // 远大于测试时间
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{conn1},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 902, Concurrency: 2}
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_young",
+ OwnerID: "owner_young_1",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ lease1.Yield()
+
+ // 立即重新 Acquire:连接年轻,不应轮换
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_young",
+ OwnerID: "owner_young_2",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, 1, dialer.DialCount(), "年轻连接不应触发重新 dial")
+ require.True(t, lease2.Reused(), "年轻连接应复用")
+ require.False(t, conn1.closed, "年轻连接不应被关闭")
+ lease2.Release()
+}
+
+func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_ClosesAgedIdle(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 1 * time.Second
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+
+ ctx := &openAIWSIngressContext{
+ id: "ctx_aged_1",
+ accountID: 903,
+ upstream: conn1,
+ }
+ // 设 createdAt 为 2 秒前
+ ctx.upstreamConnCreatedAt.Store(time.Now().Add(-2 * time.Second).UnixNano())
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_aged_1"] = ctx
+
+ now := time.Now()
+ ap.mu.Lock()
+ toClose := pool.closeAgedIdleUpstreamsLocked(ap, now)
+ ap.mu.Unlock()
+
+ require.Len(t, toClose, 1, "应关闭超龄空闲连接")
+ closeOpenAIWSClientConns(toClose)
+ require.True(t, conn1.closed)
+
+ // upstream 应已清空
+ ctx.mu.Lock()
+ require.Nil(t, ctx.upstream)
+ require.Equal(t, int64(0), ctx.upstreamConnCreatedAt.Load())
+ ctx.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_SkipsOwnedContext(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 1 * time.Second
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+
+ ctx := &openAIWSIngressContext{
+ id: "ctx_owned_1",
+ accountID: 904,
+ ownerID: "active_owner",
+ upstream: conn1,
+ }
+ ctx.upstreamConnCreatedAt.Store(time.Now().Add(-2 * time.Second).UnixNano())
+ ap.contexts["ctx_owned_1"] = ctx
+
+ now := time.Now()
+ ap.mu.Lock()
+ toClose := pool.closeAgedIdleUpstreamsLocked(ap, now)
+ ap.mu.Unlock()
+
+ require.Len(t, toClose, 0, "有 owner 的超龄连接不应被关闭")
+ require.False(t, conn1.closed)
+}
+
+func TestOpenAIWSIngressContextPool_CloseAgedIdleUpstreams_SkipsYoungConn(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 10 * time.Minute
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+
+ ctx := &openAIWSIngressContext{
+ id: "ctx_young_1",
+ accountID: 905,
+ upstream: conn1,
+ }
+ ctx.upstreamConnCreatedAt.Store(time.Now().UnixNano())
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_young_1"] = ctx
+
+ now := time.Now()
+ ap.mu.Lock()
+ toClose := pool.closeAgedIdleUpstreamsLocked(ap, now)
+ ap.mu.Unlock()
+
+ require.Len(t, toClose, 0, "年轻连接不应被关闭")
+ require.False(t, conn1.closed)
+}
+
+func TestOpenAIWSIngressContextPool_E2E_AcquireYieldAgedReconnect(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ pool.upstreamMaxAge = 55 * time.Minute // 使用实际默认值
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ conn2 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{conn1, conn2},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 906, Concurrency: 2}
+ ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+ defer cancel()
+
+ // Acquire → Yield
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_e2e",
+ OwnerID: "owner_e2e_1",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Yield()
+
+ // 手动设置 createdAt 为 56 分钟前以模拟超龄
+ pool.mu.Lock()
+ ap := pool.accounts[account.ID]
+ pool.mu.Unlock()
+ require.NotNil(t, ap)
+
+ ap.mu.Lock()
+ for _, c := range ap.contexts {
+ c.upstreamConnCreatedAt.Store(time.Now().Add(-56 * time.Minute).UnixNano())
+ }
+ ap.mu.Unlock()
+
+ // 重新 Acquire:应检测到超龄并重连
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_e2e",
+ OwnerID: "owner_e2e_2",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease2)
+ require.Equal(t, 2, dialer.DialCount(), "超龄 56 分钟的连接应触发重连")
+ require.True(t, conn1.closed, "旧的超龄连接应被关闭")
+ lease2.Release()
+}
+
+// 回归测试:容量满 + 存在过期 context 时,Acquire 仍能通过 evict 腾出空间后正常分配。
+func TestOpenAIWSIngressContextPool_Acquire_EvictsExpiredWhenFull(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ conn2 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{conn1, conn2},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 901, Concurrency: 1}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ // 获取一个 lease 占满容量
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 3,
+ SessionHash: "session_old",
+ OwnerID: "owner_old",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Release()
+
+ // 手动令该 context 过期
+ pool.mu.Lock()
+ ap := pool.accounts[account.ID]
+ pool.mu.Unlock()
+ ap.mu.Lock()
+ for _, c := range ap.contexts {
+ c.setExpiresAt(time.Now().Add(-2 * time.Second))
+ }
+ ap.mu.Unlock()
+
+ // 容量满(1个过期 context),新 session 的 Acquire 应通过 evict 成功
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 3,
+ SessionHash: "session_new",
+ OwnerID: "owner_new",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err, "容量满但有过期 context 时 Acquire 应成功")
+ require.NotNil(t, lease2)
+ lease2.Release()
+}
+
+// 回归测试:Acquire 找到已过期但仍在 bySession 映射中的 context 时,能正确取得所有权并刷新租约。
+func TestOpenAIWSIngressContextPool_Acquire_ReusesExpiredContextBySession(t *testing.T) {
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 60
+
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ conn1 := &openAIWSCaptureConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{conn1},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 902, Concurrency: 2}
+ ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+ defer cancel()
+
+ // 第一次获取 context
+ lease1, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 4,
+ SessionHash: "session_reuse",
+ OwnerID: "owner_1",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+ require.NotNil(t, lease1)
+ lease1.Release()
+
+ // 令 context 过期(但不清理,模拟热路径不再清理的行为)
+ pool.mu.Lock()
+ ap := pool.accounts[account.ID]
+ pool.mu.Unlock()
+ ap.mu.Lock()
+ for _, c := range ap.contexts {
+ c.setExpiresAt(time.Now().Add(-1 * time.Second))
+ }
+ ap.mu.Unlock()
+
+ // 同 session 再次 Acquire,应能复用过期但未清理的 context
+ lease2, err := pool.Acquire(ctx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 4,
+ SessionHash: "session_reuse",
+ OwnerID: "owner_2",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err, "过期但未清理的 context 应被同 session 的 Acquire 复用")
+ require.NotNil(t, lease2)
+
+ // 验证租约已刷新(expiresAt 应在未来)
+ pool.mu.Lock()
+ ap2 := pool.accounts[account.ID]
+ pool.mu.Unlock()
+ ap2.mu.Lock()
+ for _, c := range ap2.contexts {
+ c.mu.Lock()
+ ea := c.expiresAt()
+ c.mu.Unlock()
+ require.True(t, ea.After(time.Now()), "复用后租约应被刷新到未来")
+ }
+ ap2.mu.Unlock()
+
+ lease2.Release()
+}
+
+// === P3: 后台主动 Ping 检测测试 ===
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_MarksBrokenOnPingFailure(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ ctx := &openAIWSIngressContext{
+ id: "ctx_ping_fail_1",
+ accountID: 2001,
+ upstream: failConn,
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_ping_fail_1"] = ctx
+
+ pool.pingIdleUpstreams(ap)
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ streak := ctx.failureStreak
+ ctx.mu.Unlock()
+ require.True(t, broken, "Ping 失败应标记 context 为 broken")
+ require.Equal(t, 1, streak, "Ping 失败应增加 failureStreak")
+}
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsOwnedContext(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ ctx := &openAIWSIngressContext{
+ id: "ctx_ping_owned_1",
+ accountID: 2002,
+ ownerID: "active_owner",
+ upstream: failConn,
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_ping_owned_1"] = ctx
+
+ pool.pingIdleUpstreams(ap)
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ ctx.mu.Unlock()
+ require.False(t, broken, "有 owner 的 context 不应被 Ping 探测")
+}
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsBrokenContext(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ ctx := &openAIWSIngressContext{
+ id: "ctx_ping_broken_1",
+ accountID: 2003,
+ upstream: failConn,
+ broken: true,
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_ping_broken_1"] = ctx
+
+ pool.pingIdleUpstreams(ap)
+
+ ctx.mu.Lock()
+ streak := ctx.failureStreak
+ ctx.mu.Unlock()
+ require.Equal(t, 0, streak, "已 broken 的 context 不应被再次 Ping")
+}
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_HealthyConnStaysHealthy(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ healthyConn := &openAIWSCaptureConn{}
+ ap := &openAIWSIngressAccountPool{
+ contexts: make(map[string]*openAIWSIngressContext),
+ bySession: make(map[string]string),
+ }
+ ctx := &openAIWSIngressContext{
+ id: "ctx_ping_healthy_1",
+ accountID: 2004,
+ upstream: healthyConn,
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_ping_healthy_1"] = ctx
+
+ pool.pingIdleUpstreams(ap)
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ hasUpstream := ctx.upstream != nil
+ ctx.mu.Unlock()
+ require.False(t, broken, "Ping 成功的 context 不应被标记 broken")
+ require.True(t, hasUpstream, "Ping 成功的 upstream 应保持")
+}
+
+func TestOpenAIWSIngressContextPool_SweepTriggersPingOnIdleContexts(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ healthyConn := &openAIWSCaptureConn{}
+
+ pool.mu.Lock()
+ ap := pool.getOrCreateAccountPoolLocked(3001)
+ pool.mu.Unlock()
+
+ ap.mu.Lock()
+ ctxFail := &openAIWSIngressContext{
+ id: "ctx_sweep_ping_fail",
+ accountID: 3001,
+ upstream: failConn,
+ }
+ ctxFail.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_sweep_ping_fail"] = ctxFail
+
+ ctxOk := &openAIWSIngressContext{
+ id: "ctx_sweep_ping_ok",
+ accountID: 3001,
+ upstream: healthyConn,
+ }
+ ctxOk.touchLease(time.Now(), pool.idleTTL)
+ ap.contexts["ctx_sweep_ping_ok"] = ctxOk
+ ap.dynamicCap.Store(2)
+ ap.mu.Unlock()
+
+ // 手动触发 sweep
+ pool.sweepExpiredIdleContexts()
+
+ ctxFail.mu.Lock()
+ failBroken := ctxFail.broken
+ ctxFail.mu.Unlock()
+ require.True(t, failBroken, "sweep 后 Ping 失败的空闲 context 应被标记 broken")
+
+ ctxOk.mu.Lock()
+ okBroken := ctxOk.broken
+ ctxOk.mu.Unlock()
+ require.False(t, okBroken, "sweep 后 Ping 成功的空闲 context 应保持健康")
+}
+
+func TestOpenAIWSIngressContextPool_YieldSchedulesDelayedPing(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{failConn},
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 4001, Concurrency: 2}
+ bCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+ defer cancel()
+
+ lease, err := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{
+ Account: account,
+ GroupID: 1,
+ SessionHash: "session_yield_ping",
+ OwnerID: "owner_yield_ping",
+ WSURL: "ws://test-upstream",
+ })
+ require.NoError(t, err)
+
+ ingressCtx := lease.context
+ // Yield 触发延迟 Ping
+ lease.Yield()
+
+ // 等待延迟 Ping 执行完毕(默认 5s + 余量)
+ time.Sleep(6 * time.Second)
+
+ ingressCtx.mu.Lock()
+ broken := ingressCtx.broken
+ ingressCtx.mu.Unlock()
+ require.True(t, broken, "Yield 后延迟 Ping 应检测到 PingFailConn 并标记 broken")
+}
+
+func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_CancelledOnPoolClose(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_delayed_cancel_1",
+ accountID: 5001,
+ upstream: failConn,
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ // 安排 10s 延迟(远大于测试等待时间)
+ pool.scheduleDelayedPing(ctx, 10*time.Second)
+
+ // 立刻关闭 pool,应取消延迟 Ping
+ pool.Close()
+ time.Sleep(200 * time.Millisecond)
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ ctx.mu.Unlock()
+ require.False(t, broken, "pool 关闭后延迟 Ping 不应执行")
+}
+
+// === effectiveDynamicCapacity 边界测试 ===
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_NilAccountPool(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ require.Equal(t, 4, pool.effectiveDynamicCapacity(nil, 4), "ap==nil 时应返回 hardCap")
+ require.Equal(t, 0, pool.effectiveDynamicCapacity(nil, 0), "ap==nil && hardCap==0")
+}
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_ZeroHardCap(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ ap := &openAIWSIngressAccountPool{}
+ ap.dynamicCap.Store(3)
+ require.Equal(t, 0, pool.effectiveDynamicCapacity(ap, 0), "hardCap<=0 应返回 hardCap")
+ require.Equal(t, -1, pool.effectiveDynamicCapacity(ap, -1), "hardCap<0 应返回 hardCap")
+}
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapBelowOne(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ ap := &openAIWSIngressAccountPool{}
+ ap.dynamicCap.Store(0) // 异常值
+ result := pool.effectiveDynamicCapacity(ap, 4)
+ require.Equal(t, 1, result, "dynCap<1 应自动修复为 1")
+ require.Equal(t, int32(1), ap.dynamicCap.Load(), "dynCap 应被修复为 1")
+}
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapExceedsHardCap(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ ap := &openAIWSIngressAccountPool{}
+ ap.dynamicCap.Store(10)
+ require.Equal(t, 4, pool.effectiveDynamicCapacity(ap, 4), "dynCap>hardCap 应 clamp 到 hardCap")
+}
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_DynCapEqualsHardCap(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ ap := &openAIWSIngressAccountPool{}
+ ap.dynamicCap.Store(4)
+ require.Equal(t, 4, pool.effectiveDynamicCapacity(ap, 4), "dynCap==hardCap 应返回 hardCap")
+}
+
+func TestOpenAIWSIngressContextPool_EffectiveDynamicCapacity_NormalPath(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{}
+ ap := &openAIWSIngressAccountPool{}
+ ap.dynamicCap.Store(2)
+ require.Equal(t, 2, pool.effectiveDynamicCapacity(ap, 8), "正常 dynCap= 2, "第二次 Acquire 应触发 dynamicCap 增长")
+
+ lease1.Release()
+ lease2.Release()
+}
+
+func TestOpenAIWSIngressContextPool_Sweeper_ShrinksDynamicCap(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1 // 1 秒 TTL 让 context 快速过期
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ dialer := &openAIWSQueueDialer{
+ conns: []openAIWSClientConn{
+ &openAIWSCaptureConn{}, &openAIWSCaptureConn{}, &openAIWSCaptureConn{},
+ },
+ }
+ pool.setClientDialerForTest(dialer)
+
+ account := &Account{ID: 6002, Concurrency: 4}
+ bCtx := context.Background()
+
+ // 创建两个 context
+ lease1, _ := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{
+ Account: account, GroupID: 1, SessionHash: "s1", OwnerID: "o1", WSURL: "ws://t",
+ })
+ lease2, _ := pool.Acquire(bCtx, openAIWSIngressContextAcquireRequest{
+ Account: account, GroupID: 1, SessionHash: "s2", OwnerID: "o2", WSURL: "ws://t",
+ })
+ lease1.Release()
+ lease2.Release()
+
+ pool.mu.Lock()
+ ap := pool.accounts[account.ID]
+ pool.mu.Unlock()
+ require.True(t, ap.dynamicCap.Load() >= 2)
+
+ // 等待 context 过期
+ time.Sleep(2 * time.Second)
+
+ // 手动 sweep
+ pool.sweepExpiredIdleContexts()
+
+ // sweep 后 dynamicCap 应缩减
+ ap.mu.Lock()
+ ctxCount := len(ap.contexts)
+ ap.mu.Unlock()
+ dynCap := ap.dynamicCap.Load()
+ // 如果所有 context 都被 evict,dynamicCap 应缩到 1(min)
+ if ctxCount == 0 {
+ require.Equal(t, int32(1), dynCap, "context 全部 evict 后 dynamicCap 应缩至 1")
+ } else {
+ require.LessOrEqual(t, dynCap, int32(ctxCount), "dynamicCap 应缩至当前 context 数")
+ }
+}
+
+// === Ping 额外边界测试 ===
+
+func TestOpenAIWSIngressContextPool_PingContextUpstream_NilPoolOrContext(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})}
+ // nil context 不应 panic
+ pool.pingContextUpstream(nil)
+ // nil pool 不应 panic
+ var nilPool *openAIWSIngressContextPool
+ nilPool.pingContextUpstream(&openAIWSIngressContext{})
+}
+
+func TestOpenAIWSIngressContextPool_PingContextUpstream_SkipsDialingContext(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_dialing_1", accountID: 7001,
+ upstream: failConn, dialing: true,
+ }
+ pool.pingContextUpstream(ctx)
+ ctx.mu.Lock()
+ require.False(t, ctx.broken, "dialing 中的 context 不应被 Ping")
+ ctx.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextPool_PingContextUpstream_SkipsNoUpstream(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ ctx := &openAIWSIngressContext{
+ id: "ctx_no_upstream", accountID: 7002, upstream: nil,
+ }
+ pool.pingContextUpstream(ctx)
+ ctx.mu.Lock()
+ require.False(t, ctx.broken, "无 upstream 的 context 不应被 Ping")
+ ctx.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_NilPoolOrAP(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})}
+ pool.pingIdleUpstreams(nil)
+ var nilPool *openAIWSIngressContextPool
+ nilPool.pingIdleUpstreams(&openAIWSIngressAccountPool{})
+}
+
+func TestOpenAIWSIngressContextPool_PingIdleUpstreams_SkipsNilContext(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ ap := &openAIWSIngressAccountPool{
+ contexts: map[string]*openAIWSIngressContext{"nil_ctx": nil},
+ bySession: make(map[string]string),
+ }
+ // 不应 panic
+ pool.pingIdleUpstreams(ap)
+}
+
+func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_ZeroDelay(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_zero_delay", accountID: 8001, upstream: failConn,
+ }
+ // delay <= 0 应为 no-op
+ pool.scheduleDelayedPing(ctx, 0)
+ pool.scheduleDelayedPing(ctx, -1*time.Second)
+ time.Sleep(100 * time.Millisecond)
+ ctx.mu.Lock()
+ require.False(t, ctx.broken, "delay<=0 不应触发 Ping")
+ ctx.mu.Unlock()
+}
+
+func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_NilParams(t *testing.T) {
+ t.Parallel()
+ pool := &openAIWSIngressContextPool{stopCh: make(chan struct{})}
+ // nil context
+ pool.scheduleDelayedPing(nil, 5*time.Second)
+ // nil pool
+ var nilPool *openAIWSIngressContextPool
+ nilPool.scheduleDelayedPing(&openAIWSIngressContext{}, 5*time.Second)
+}
+
+// === P1 并发回归:旧连接 Ping 失败不应误杀新连接 ===
+
+func TestOpenAIWSIngressContextPool_PingFailDoesNotKillNewConn(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ // 旧连接:Ping 带 200ms 延迟后失败
+ oldConn := newOpenAIWSDelayedPingFailConn(200 * time.Millisecond)
+ // 新连接:正常的 Ping
+ newConn := &openAIWSCaptureConn{}
+
+ ctx := &openAIWSIngressContext{
+ id: "ctx_race_test",
+ accountID: 9001,
+ upstream: oldConn,
+ upstreamConnID: "old_conn_1",
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ // 在后台对旧连接发起 Ping 探测
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ pool.pingContextUpstream(ctx)
+ }()
+
+ // 等待 Ping 开始执行
+ <-oldConn.pingDone
+
+ // 模拟连接重建:在 Ping 执行期间将 upstream 替换为新连接
+ ctx.mu.Lock()
+ ctx.upstream = newConn
+ ctx.upstreamConnID = "new_conn_2"
+ ctx.broken = false
+ ctx.failureStreak = 0
+ ctx.mu.Unlock()
+
+ // 等待 Ping goroutine 完成
+ wg.Wait()
+
+ // 验证:新连接不应被标记为 broken
+ ctx.mu.Lock()
+ broken := ctx.broken
+ streak := ctx.failureStreak
+ upstream := ctx.upstream
+ connID := ctx.upstreamConnID
+ ctx.mu.Unlock()
+
+ require.False(t, broken, "新连接不应被旧 Ping 失败标记为 broken")
+ require.Equal(t, 0, streak, "failureStreak 不应增加")
+ require.Equal(t, newConn, upstream, "upstream 应仍是新连接")
+ require.Equal(t, "new_conn_2", connID, "upstreamConnID 应仍是新连接的 ID")
+ require.False(t, newConn.Closed(), "新连接不应被关闭")
+}
+
+func TestOpenAIWSIngressContextPool_PingFailKillsSameConn(t *testing.T) {
+ // 对照测试:connID 未变时应正常标记 broken
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_same_conn",
+ accountID: 9002,
+ upstream: failConn,
+ upstreamConnID: "conn_1",
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ pool.pingContextUpstream(ctx)
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ streak := ctx.failureStreak
+ connID := ctx.upstreamConnID
+ ctx.mu.Unlock()
+
+ require.True(t, broken, "同一连接 Ping 失败应标记 broken")
+ require.Equal(t, 1, streak, "failureStreak 应为 1")
+ require.Empty(t, connID, "upstreamConnID 应被清空")
+}
+
+func TestOpenAIWSIngressContextPool_PingFailDoesNotKillBusyConn(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := newOpenAIWSDelayedPingFailConn(200 * time.Millisecond)
+ ctx := &openAIWSIngressContext{
+ id: "ctx_busy_conn",
+ accountID: 9005,
+ upstream: failConn,
+ upstreamConnID: "conn_busy_1",
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ pool.pingContextUpstream(ctx)
+ }()
+
+ <-failConn.pingDone
+
+ ctx.mu.Lock()
+ ctx.ownerID = "active_owner"
+ ctx.mu.Unlock()
+
+ wg.Wait()
+
+ ctx.mu.Lock()
+ broken := ctx.broken
+ streak := ctx.failureStreak
+ upstream := ctx.upstream
+ connID := ctx.upstreamConnID
+ ownerID := ctx.ownerID
+ ctx.mu.Unlock()
+
+ require.False(t, broken, "busy context 不应被后台 Ping 标记 broken")
+ require.Equal(t, 0, streak, "failureStreak 不应增加")
+ require.Equal(t, failConn, upstream, "busy context 的 upstream 不应被替换")
+ require.Equal(t, "conn_busy_1", connID, "busy context 的 connID 不应变化")
+ require.Equal(t, "active_owner", ownerID, "owner 应保持不变")
+ require.False(t, failConn.Closed(), "busy context 的连接不应被后台 Ping 关闭")
+}
+
+// === P2a 去重:连续多次 Yield 仅触发一次延迟 Ping ===
+
+func TestOpenAIWSIngressContextPool_ConsecutiveYieldsOnlyOnePing(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 600
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ // 使用 Ping 失败连接以便观察是否被标记 broken
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_yield_dedup",
+ accountID: 9003,
+ upstream: failConn,
+ upstreamConnID: "conn_dedup_1",
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ // 连续调用 5 次 scheduleDelayedPing
+ for i := 0; i < 5; i++ {
+ pool.scheduleDelayedPing(ctx, 100*time.Millisecond)
+ }
+
+ // 验证:只有一个 pendingPingTimer(不应堆积多个 goroutine)
+ ctx.mu.Lock()
+ hasPending := ctx.pendingPingTimer != nil
+ ctx.mu.Unlock()
+ require.True(t, hasPending, "应有一个 pendingPingTimer")
+
+ // 等待 timer 到期并执行 Ping
+ time.Sleep(300 * time.Millisecond)
+
+ // Ping 失败应标记 broken(证明延迟 Ping 确实执行了)
+ ctx.mu.Lock()
+ broken := ctx.broken
+ pendingTimer := ctx.pendingPingTimer
+ ctx.mu.Unlock()
+
+ require.True(t, broken, "延迟 Ping 应已执行并标记 broken")
+ require.Nil(t, pendingTimer, "Ping 执行后 pendingPingTimer 应被清理")
+}
+
+func TestOpenAIWSIngressContextPool_ScheduleDelayedPing_ResetExtendsDelay(t *testing.T) {
+ t.Parallel()
+ cfg := &config.Config{}
+ pool := newOpenAIWSIngressContextPool(cfg)
+ defer pool.Close()
+
+ failConn := &openAIWSPingFailConn{}
+ ctx := &openAIWSIngressContext{
+ id: "ctx_reset_delay",
+ accountID: 9004,
+ upstream: failConn,
+ upstreamConnID: "conn_reset_1",
+ }
+ ctx.touchLease(time.Now(), pool.idleTTL)
+
+ // 第一次调度 200ms 延迟
+ pool.scheduleDelayedPing(ctx, 200*time.Millisecond)
+
+ // 100ms 后再次调度 200ms(应 Reset timer,从此刻起再等 200ms)
+ time.Sleep(100 * time.Millisecond)
+ pool.scheduleDelayedPing(ctx, 200*time.Millisecond)
+
+ // 150ms 后(距第一次 250ms,距 Reset 150ms)应未执行
+ time.Sleep(150 * time.Millisecond)
+ ctx.mu.Lock()
+ broken := ctx.broken
+ ctx.mu.Unlock()
+ require.False(t, broken, "Reset 后 150ms 不应触发 Ping")
+
+ // 再等 100ms(距 Reset 250ms)应已执行
+ time.Sleep(100 * time.Millisecond)
+ ctx.mu.Lock()
+ broken = ctx.broken
+ ctx.mu.Unlock()
+ require.True(t, broken, "Reset 后 250ms 应已触发 Ping")
+}
diff --git a/backend/internal/service/openai_ws_ingress_normalizer.go b/backend/internal/service/openai_ws_ingress_normalizer.go
new file mode 100644
index 000000000..1815023cd
--- /dev/null
+++ b/backend/internal/service/openai_ws_ingress_normalizer.go
@@ -0,0 +1,39 @@
+package service
+
+type openAIWSIngressPreSendNormalizeInput struct {
+ accountID int64
+ turn int
+ connID string
+
+ currentPayload []byte
+ currentPayloadBytes int
+ currentPreviousResponseID string
+ expectedPreviousResponse string
+ pendingExpectedCallIDs []string
+}
+
+type openAIWSIngressPreSendNormalizeOutput struct {
+ currentPayload []byte
+ currentPayloadBytes int
+ currentPreviousResponseID string
+ expectedPreviousResponseID string
+ pendingExpectedCallIDs []string
+ functionCallOutputCallIDs []string
+ hasFunctionCallOutputCallID bool
+}
+
+// normalizeOpenAIWSIngressPayloadBeforeSend 纯透传 + callID 提取。
+// proxy 只负责转发、认证替换、计费,所有边缘场景由 recoverIngressPrevResponseNotFound 兜底。
+func normalizeOpenAIWSIngressPayloadBeforeSend(input openAIWSIngressPreSendNormalizeInput) openAIWSIngressPreSendNormalizeOutput {
+ callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(input.currentPayload)
+
+ return openAIWSIngressPreSendNormalizeOutput{
+ currentPayload: input.currentPayload,
+ currentPayloadBytes: input.currentPayloadBytes,
+ currentPreviousResponseID: input.currentPreviousResponseID,
+ expectedPreviousResponseID: input.expectedPreviousResponse,
+ pendingExpectedCallIDs: input.pendingExpectedCallIDs,
+ functionCallOutputCallIDs: callIDs,
+ hasFunctionCallOutputCallID: len(callIDs) > 0,
+ }
+}
diff --git a/backend/internal/service/openai_ws_ingress_normalizer_test.go b/backend/internal/service/openai_ws_ingress_normalizer_test.go
new file mode 100644
index 000000000..ff68f4b5b
--- /dev/null
+++ b/backend/internal/service/openai_ws_ingress_normalizer_test.go
@@ -0,0 +1,193 @@
+package service
+
+import (
+ "testing"
+
+ "github.com/stretchr/testify/require"
+)
+
+// === 纯透传 normalizer 行为测试 ===
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_BasicPassthrough(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_prev",
+ "input":[{"type":"input_text","text":"hello"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 1,
+ turn: 2,
+ connID: "conn_1",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "resp_prev",
+ expectedPreviousResponse: "resp_expected",
+ pendingExpectedCallIDs: []string{"call_1"},
+ })
+
+ require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传")
+ require.Equal(t, len(payload), out.currentPayloadBytes)
+ require.Equal(t, "resp_prev", out.currentPreviousResponseID, "currentPreviousResponseID 应原样透传")
+ require.Equal(t, "resp_expected", out.expectedPreviousResponseID, "expectedPreviousResponseID 应原样透传")
+ require.Equal(t, []string{"call_1"}, out.pendingExpectedCallIDs, "pendingExpectedCallIDs 应原样透传")
+ require.False(t, out.hasFunctionCallOutputCallID, "无 FCO 时应为 false")
+ require.Empty(t, out.functionCallOutputCallIDs)
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_ExtractsCallID(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_prev",
+ "input":[{"type":"function_call_output","call_id":"call_abc","output":"{}"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 2,
+ turn: 3,
+ connID: "conn_2",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "resp_prev",
+ expectedPreviousResponse: "resp_prev",
+ })
+
+ require.True(t, out.hasFunctionCallOutputCallID, "有 FCO 时应为 true")
+ require.Equal(t, []string{"call_abc"}, out.functionCallOutputCallIDs, "应正确提取 call_id")
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_NoFCO(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"input_text","text":"hello"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 3,
+ turn: 1,
+ connID: "conn_3",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "",
+ expectedPreviousResponse: "",
+ })
+
+ require.False(t, out.hasFunctionCallOutputCallID)
+ require.Empty(t, out.functionCallOutputCallIDs)
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_MultipleFCO(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_ok",
+ "input":[
+ {"type":"function_call_output","call_id":"call_a","output":"{}"},
+ {"type":"function_call_output","call_id":"call_b","output":"{}"},
+ {"type":"function_call_output","call_id":"call_c","output":"{}"}
+ ]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 4,
+ turn: 2,
+ connID: "conn_4",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "resp_ok",
+ expectedPreviousResponse: "resp_ok",
+ })
+
+ require.True(t, out.hasFunctionCallOutputCallID)
+ require.ElementsMatch(t, []string{"call_a", "call_b", "call_c"}, out.functionCallOutputCallIDs, "应提取所有 call_id")
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_EmptyInput(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 5,
+ turn: 1,
+ connID: "conn_5",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "",
+ expectedPreviousResponse: "",
+ })
+
+ require.JSONEq(t, string(payload), string(out.currentPayload), "空 input 不应 panic")
+ require.False(t, out.hasFunctionCallOutputCallID)
+ require.Empty(t, out.functionCallOutputCallIDs)
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_ESCInterruptPassthrough(t *testing.T) {
+ t.Parallel()
+
+ // 场景:ESC 中断后客户端有意不传 previous_response_id,有 pendingCallIDs。
+ // 透传不补 prev、不注入 aborted output。
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"input_text","text":"new task after ESC"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 6,
+ turn: 5,
+ connID: "conn_esc",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "",
+ expectedPreviousResponse: "resp_prev_turn4",
+ pendingExpectedCallIDs: []string{"call_pending_1", "call_pending_2"},
+ })
+
+ require.Empty(t, out.currentPreviousResponseID, "透传不应补 previous_response_id")
+ require.Equal(t, "resp_prev_turn4", out.expectedPreviousResponseID)
+ require.False(t, out.hasFunctionCallOutputCallID, "透传不应注入 function_call_output")
+ require.Empty(t, out.functionCallOutputCallIDs)
+ require.Equal(t, []string{"call_pending_1", "call_pending_2"}, out.pendingExpectedCallIDs, "pendingExpectedCallIDs 应原样传递")
+ require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传")
+}
+
+func TestNormalizeOpenAIWSIngressPayloadBeforeSend_StalePrevPassthrough(t *testing.T) {
+ t.Parallel()
+
+ // 场景:客户端传了过期 previous_response_id,透传不对齐。
+ // 由下游 recoverIngressPrevResponseNotFound 处理。
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"{}"}]
+ }`)
+
+ out := normalizeOpenAIWSIngressPayloadBeforeSend(openAIWSIngressPreSendNormalizeInput{
+ accountID: 7,
+ turn: 4,
+ connID: "conn_stale",
+ currentPayload: payload,
+ currentPayloadBytes: len(payload),
+ currentPreviousResponseID: "resp_stale",
+ expectedPreviousResponse: "resp_latest",
+ })
+
+ require.Equal(t, "resp_stale", out.currentPreviousResponseID, "透传不应对齐 previous_response_id")
+ require.Equal(t, "resp_latest", out.expectedPreviousResponseID)
+ require.JSONEq(t, string(payload), string(out.currentPayload), "payload 应原样透传")
+ require.True(t, out.hasFunctionCallOutputCallID)
+ require.Equal(t, []string{"call_1"}, out.functionCallOutputCallIDs)
+}
diff --git a/backend/internal/service/openai_ws_pool.go b/backend/internal/service/openai_ws_pool.go
deleted file mode 100644
index db6a96a7d..000000000
--- a/backend/internal/service/openai_ws_pool.go
+++ /dev/null
@@ -1,1706 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "fmt"
- "math"
- "net/http"
- "sort"
- "strconv"
- "strings"
- "sync"
- "sync/atomic"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "golang.org/x/sync/errgroup"
-)
-
-const (
- openAIWSConnMaxAge = 60 * time.Minute
- openAIWSConnHealthCheckIdle = 90 * time.Second
- openAIWSConnHealthCheckTO = 2 * time.Second
- openAIWSConnPrewarmExtraDelay = 2 * time.Second
- openAIWSAcquireCleanupInterval = 3 * time.Second
- openAIWSBackgroundPingInterval = 30 * time.Second
- openAIWSBackgroundSweepTicker = 30 * time.Second
-
- openAIWSPrewarmFailureWindow = 30 * time.Second
- openAIWSPrewarmFailureSuppress = 2
-)
-
-var (
- errOpenAIWSConnClosed = errors.New("openai ws connection closed")
- errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full")
- errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable")
-)
-
-type openAIWSDialError struct {
- StatusCode int
- ResponseHeaders http.Header
- Err error
-}
-
-func (e *openAIWSDialError) Error() string {
- if e == nil {
- return ""
- }
- if e.StatusCode > 0 {
- return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err)
- }
- return fmt.Sprintf("openai ws dial failed: %v", e.Err)
-}
-
-func (e *openAIWSDialError) Unwrap() error {
- if e == nil {
- return nil
- }
- return e.Err
-}
-
-type openAIWSAcquireRequest struct {
- Account *Account
- WSURL string
- Headers http.Header
- ProxyURL string
- PreferredConnID string
- // ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。
- ForceNewConn bool
- // ForcePreferredConn: 强制本次只使用 PreferredConnID,禁止漂移到其它连接。
- ForcePreferredConn bool
-}
-
-type openAIWSConnLease struct {
- pool *openAIWSConnPool
- accountID int64
- conn *openAIWSConn
- queueWait time.Duration
- connPick time.Duration
- reused bool
- released atomic.Bool
-}
-
-func (l *openAIWSConnLease) activeConn() (*openAIWSConn, error) {
- if l == nil || l.conn == nil {
- return nil, errOpenAIWSConnClosed
- }
- if l.released.Load() {
- return nil, errOpenAIWSConnClosed
- }
- return l.conn, nil
-}
-
-func (l *openAIWSConnLease) ConnID() string {
- if l == nil || l.conn == nil {
- return ""
- }
- return l.conn.id
-}
-
-func (l *openAIWSConnLease) QueueWaitDuration() time.Duration {
- if l == nil {
- return 0
- }
- return l.queueWait
-}
-
-func (l *openAIWSConnLease) ConnPickDuration() time.Duration {
- if l == nil {
- return 0
- }
- return l.connPick
-}
-
-func (l *openAIWSConnLease) Reused() bool {
- if l == nil {
- return false
- }
- return l.reused
-}
-
-func (l *openAIWSConnLease) HandshakeHeader(name string) string {
- if l == nil || l.conn == nil {
- return ""
- }
- return l.conn.handshakeHeader(name)
-}
-
-func (l *openAIWSConnLease) IsPrewarmed() bool {
- if l == nil || l.conn == nil {
- return false
- }
- return l.conn.isPrewarmed()
-}
-
-func (l *openAIWSConnLease) MarkPrewarmed() {
- if l == nil || l.conn == nil {
- return
- }
- l.conn.markPrewarmed()
-}
-
-func (l *openAIWSConnLease) WriteJSON(value any, timeout time.Duration) error {
- conn, err := l.activeConn()
- if err != nil {
- return err
- }
- return conn.writeJSONWithTimeout(context.Background(), value, timeout)
-}
-
-func (l *openAIWSConnLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error {
- conn, err := l.activeConn()
- if err != nil {
- return err
- }
- return conn.writeJSONWithTimeout(ctx, value, timeout)
-}
-
-func (l *openAIWSConnLease) WriteJSONContext(ctx context.Context, value any) error {
- conn, err := l.activeConn()
- if err != nil {
- return err
- }
- return conn.writeJSON(value, ctx)
-}
-
-func (l *openAIWSConnLease) ReadMessage(timeout time.Duration) ([]byte, error) {
- conn, err := l.activeConn()
- if err != nil {
- return nil, err
- }
- return conn.readMessageWithTimeout(timeout)
-}
-
-func (l *openAIWSConnLease) ReadMessageContext(ctx context.Context) ([]byte, error) {
- conn, err := l.activeConn()
- if err != nil {
- return nil, err
- }
- return conn.readMessage(ctx)
-}
-
-func (l *openAIWSConnLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) {
- conn, err := l.activeConn()
- if err != nil {
- return nil, err
- }
- return conn.readMessageWithContextTimeout(ctx, timeout)
-}
-
-func (l *openAIWSConnLease) PingWithTimeout(timeout time.Duration) error {
- conn, err := l.activeConn()
- if err != nil {
- return err
- }
- return conn.pingWithTimeout(timeout)
-}
-
-func (l *openAIWSConnLease) MarkBroken() {
- if l == nil || l.pool == nil || l.conn == nil || l.released.Load() {
- return
- }
- l.pool.evictConn(l.accountID, l.conn.id)
-}
-
-func (l *openAIWSConnLease) Release() {
- if l == nil || l.conn == nil {
- return
- }
- if !l.released.CompareAndSwap(false, true) {
- return
- }
- l.conn.release()
-}
-
-type openAIWSConn struct {
- id string
- ws openAIWSClientConn
-
- handshakeHeaders http.Header
-
- leaseCh chan struct{}
- closedCh chan struct{}
- closeOnce sync.Once
-
- readMu sync.Mutex
- writeMu sync.Mutex
-
- waiters atomic.Int32
- createdAtNano atomic.Int64
- lastUsedNano atomic.Int64
- prewarmed atomic.Bool
-}
-
-func newOpenAIWSConn(id string, _ int64, ws openAIWSClientConn, handshakeHeaders http.Header) *openAIWSConn {
- now := time.Now()
- conn := &openAIWSConn{
- id: id,
- ws: ws,
- handshakeHeaders: cloneHeader(handshakeHeaders),
- leaseCh: make(chan struct{}, 1),
- closedCh: make(chan struct{}),
- }
- conn.leaseCh <- struct{}{}
- conn.createdAtNano.Store(now.UnixNano())
- conn.lastUsedNano.Store(now.UnixNano())
- return conn
-}
-
-func (c *openAIWSConn) tryAcquire() bool {
- if c == nil {
- return false
- }
- select {
- case <-c.closedCh:
- return false
- default:
- }
- select {
- case <-c.leaseCh:
- select {
- case <-c.closedCh:
- c.release()
- return false
- default:
- }
- return true
- default:
- return false
- }
-}
-
-func (c *openAIWSConn) acquire(ctx context.Context) error {
- if c == nil {
- return errOpenAIWSConnClosed
- }
- for {
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-c.closedCh:
- return errOpenAIWSConnClosed
- case <-c.leaseCh:
- select {
- case <-c.closedCh:
- c.release()
- return errOpenAIWSConnClosed
- default:
- }
- return nil
- }
- }
-}
-
-func (c *openAIWSConn) release() {
- if c == nil {
- return
- }
- select {
- case c.leaseCh <- struct{}{}:
- default:
- }
- c.touch()
-}
-
-func (c *openAIWSConn) close() {
- if c == nil {
- return
- }
- c.closeOnce.Do(func() {
- close(c.closedCh)
- if c.ws != nil {
- _ = c.ws.Close()
- }
- select {
- case c.leaseCh <- struct{}{}:
- default:
- }
- })
-}
-
-func (c *openAIWSConn) writeJSONWithTimeout(parent context.Context, value any, timeout time.Duration) error {
- if c == nil {
- return errOpenAIWSConnClosed
- }
- select {
- case <-c.closedCh:
- return errOpenAIWSConnClosed
- default:
- }
-
- writeCtx := parent
- if writeCtx == nil {
- writeCtx = context.Background()
- }
- if timeout <= 0 {
- return c.writeJSON(value, writeCtx)
- }
- var cancel context.CancelFunc
- writeCtx, cancel = context.WithTimeout(writeCtx, timeout)
- defer cancel()
- return c.writeJSON(value, writeCtx)
-}
-
-func (c *openAIWSConn) writeJSON(value any, writeCtx context.Context) error {
- c.writeMu.Lock()
- defer c.writeMu.Unlock()
- if c.ws == nil {
- return errOpenAIWSConnClosed
- }
- if writeCtx == nil {
- writeCtx = context.Background()
- }
- if err := c.ws.WriteJSON(writeCtx, value); err != nil {
- return err
- }
- c.touch()
- return nil
-}
-
-func (c *openAIWSConn) readMessageWithTimeout(timeout time.Duration) ([]byte, error) {
- return c.readMessageWithContextTimeout(context.Background(), timeout)
-}
-
-func (c *openAIWSConn) readMessageWithContextTimeout(parent context.Context, timeout time.Duration) ([]byte, error) {
- if c == nil {
- return nil, errOpenAIWSConnClosed
- }
- select {
- case <-c.closedCh:
- return nil, errOpenAIWSConnClosed
- default:
- }
-
- if parent == nil {
- parent = context.Background()
- }
- if timeout <= 0 {
- return c.readMessage(parent)
- }
- readCtx, cancel := context.WithTimeout(parent, timeout)
- defer cancel()
- return c.readMessage(readCtx)
-}
-
-func (c *openAIWSConn) readMessage(readCtx context.Context) ([]byte, error) {
- c.readMu.Lock()
- defer c.readMu.Unlock()
- if c.ws == nil {
- return nil, errOpenAIWSConnClosed
- }
- if readCtx == nil {
- readCtx = context.Background()
- }
- payload, err := c.ws.ReadMessage(readCtx)
- if err != nil {
- return nil, err
- }
- c.touch()
- return payload, nil
-}
-
-func (c *openAIWSConn) pingWithTimeout(timeout time.Duration) error {
- if c == nil {
- return errOpenAIWSConnClosed
- }
- select {
- case <-c.closedCh:
- return errOpenAIWSConnClosed
- default:
- }
-
- c.writeMu.Lock()
- defer c.writeMu.Unlock()
- if c.ws == nil {
- return errOpenAIWSConnClosed
- }
- if timeout <= 0 {
- timeout = openAIWSConnHealthCheckTO
- }
- pingCtx, cancel := context.WithTimeout(context.Background(), timeout)
- defer cancel()
- if err := c.ws.Ping(pingCtx); err != nil {
- return err
- }
- return nil
-}
-
-func (c *openAIWSConn) touch() {
- if c == nil {
- return
- }
- c.lastUsedNano.Store(time.Now().UnixNano())
-}
-
-func (c *openAIWSConn) createdAt() time.Time {
- if c == nil {
- return time.Time{}
- }
- nano := c.createdAtNano.Load()
- if nano <= 0 {
- return time.Time{}
- }
- return time.Unix(0, nano)
-}
-
-func (c *openAIWSConn) lastUsedAt() time.Time {
- if c == nil {
- return time.Time{}
- }
- nano := c.lastUsedNano.Load()
- if nano <= 0 {
- return time.Time{}
- }
- return time.Unix(0, nano)
-}
-
-func (c *openAIWSConn) idleDuration(now time.Time) time.Duration {
- if c == nil {
- return 0
- }
- last := c.lastUsedAt()
- if last.IsZero() {
- return 0
- }
- return now.Sub(last)
-}
-
-func (c *openAIWSConn) age(now time.Time) time.Duration {
- if c == nil {
- return 0
- }
- created := c.createdAt()
- if created.IsZero() {
- return 0
- }
- return now.Sub(created)
-}
-
-func (c *openAIWSConn) isLeased() bool {
- if c == nil {
- return false
- }
- return len(c.leaseCh) == 0
-}
-
-func (c *openAIWSConn) handshakeHeader(name string) string {
- if c == nil || c.handshakeHeaders == nil {
- return ""
- }
- return strings.TrimSpace(c.handshakeHeaders.Get(strings.TrimSpace(name)))
-}
-
-func (c *openAIWSConn) isPrewarmed() bool {
- if c == nil {
- return false
- }
- return c.prewarmed.Load()
-}
-
-func (c *openAIWSConn) markPrewarmed() {
- if c == nil {
- return
- }
- c.prewarmed.Store(true)
-}
-
-type openAIWSAccountPool struct {
- mu sync.Mutex
- conns map[string]*openAIWSConn
- pinnedConns map[string]int
- creating int
- lastCleanupAt time.Time
- lastAcquire *openAIWSAcquireRequest
- prewarmActive bool
- prewarmUntil time.Time
- prewarmFails int
- prewarmFailAt time.Time
-}
-
-type OpenAIWSPoolMetricsSnapshot struct {
- AcquireTotal int64
- AcquireReuseTotal int64
- AcquireCreateTotal int64
- AcquireQueueWaitTotal int64
- AcquireQueueWaitMsTotal int64
- ConnPickTotal int64
- ConnPickMsTotal int64
- ScaleUpTotal int64
- ScaleDownTotal int64
-}
-
-type openAIWSPoolMetrics struct {
- acquireTotal atomic.Int64
- acquireReuseTotal atomic.Int64
- acquireCreateTotal atomic.Int64
- acquireQueueWaitTotal atomic.Int64
- acquireQueueWaitMs atomic.Int64
- connPickTotal atomic.Int64
- connPickMs atomic.Int64
- scaleUpTotal atomic.Int64
- scaleDownTotal atomic.Int64
-}
-
-type openAIWSConnPool struct {
- cfg *config.Config
- // 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。
- clientDialer openAIWSClientDialer
-
- accounts sync.Map // key: int64(accountID), value: *openAIWSAccountPool
- seq atomic.Uint64
-
- metrics openAIWSPoolMetrics
-
- workerStopCh chan struct{}
- workerWg sync.WaitGroup
- closeOnce sync.Once
-}
-
-func newOpenAIWSConnPool(cfg *config.Config) *openAIWSConnPool {
- pool := &openAIWSConnPool{
- cfg: cfg,
- clientDialer: newDefaultOpenAIWSClientDialer(),
- workerStopCh: make(chan struct{}),
- }
- pool.startBackgroundWorkers()
- return pool
-}
-
-func (p *openAIWSConnPool) SnapshotMetrics() OpenAIWSPoolMetricsSnapshot {
- if p == nil {
- return OpenAIWSPoolMetricsSnapshot{}
- }
- return OpenAIWSPoolMetricsSnapshot{
- AcquireTotal: p.metrics.acquireTotal.Load(),
- AcquireReuseTotal: p.metrics.acquireReuseTotal.Load(),
- AcquireCreateTotal: p.metrics.acquireCreateTotal.Load(),
- AcquireQueueWaitTotal: p.metrics.acquireQueueWaitTotal.Load(),
- AcquireQueueWaitMsTotal: p.metrics.acquireQueueWaitMs.Load(),
- ConnPickTotal: p.metrics.connPickTotal.Load(),
- ConnPickMsTotal: p.metrics.connPickMs.Load(),
- ScaleUpTotal: p.metrics.scaleUpTotal.Load(),
- ScaleDownTotal: p.metrics.scaleDownTotal.Load(),
- }
-}
-
-func (p *openAIWSConnPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot {
- if p == nil {
- return OpenAIWSTransportMetricsSnapshot{}
- }
- if dialer, ok := p.clientDialer.(openAIWSTransportMetricsDialer); ok {
- return dialer.SnapshotTransportMetrics()
- }
- return OpenAIWSTransportMetricsSnapshot{}
-}
-
-func (p *openAIWSConnPool) setClientDialerForTest(dialer openAIWSClientDialer) {
- if p == nil || dialer == nil {
- return
- }
- p.clientDialer = dialer
-}
-
-// Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。
-func (p *openAIWSConnPool) Close() {
- if p == nil {
- return
- }
- p.closeOnce.Do(func() {
- if p.workerStopCh != nil {
- close(p.workerStopCh)
- }
- p.workerWg.Wait()
- // 遍历所有账户池,关闭全部空闲连接。
- p.accounts.Range(func(key, value any) bool {
- ap, ok := value.(*openAIWSAccountPool)
- if !ok || ap == nil {
- return true
- }
- ap.mu.Lock()
- for _, conn := range ap.conns {
- if conn != nil && !conn.isLeased() {
- conn.close()
- }
- }
- ap.mu.Unlock()
- return true
- })
- })
-}
-
-func (p *openAIWSConnPool) startBackgroundWorkers() {
- if p == nil || p.workerStopCh == nil {
- return
- }
- p.workerWg.Add(2)
- go func() {
- defer p.workerWg.Done()
- p.runBackgroundPingWorker()
- }()
- go func() {
- defer p.workerWg.Done()
- p.runBackgroundCleanupWorker()
- }()
-}
-
-type openAIWSIdlePingCandidate struct {
- accountID int64
- conn *openAIWSConn
-}
-
-func (p *openAIWSConnPool) runBackgroundPingWorker() {
- if p == nil {
- return
- }
- ticker := time.NewTicker(openAIWSBackgroundPingInterval)
- defer ticker.Stop()
- for {
- select {
- case <-ticker.C:
- p.runBackgroundPingSweep()
- case <-p.workerStopCh:
- return
- }
- }
-}
-
-func (p *openAIWSConnPool) runBackgroundPingSweep() {
- if p == nil {
- return
- }
- candidates := p.snapshotIdleConnsForPing()
- var g errgroup.Group
- g.SetLimit(10)
- for _, item := range candidates {
- item := item
- if item.conn == nil || item.conn.isLeased() || item.conn.waiters.Load() > 0 {
- continue
- }
- g.Go(func() error {
- if err := item.conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- p.evictConn(item.accountID, item.conn.id)
- }
- return nil
- })
- }
- _ = g.Wait()
-}
-
-func (p *openAIWSConnPool) snapshotIdleConnsForPing() []openAIWSIdlePingCandidate {
- if p == nil {
- return nil
- }
- candidates := make([]openAIWSIdlePingCandidate, 0)
- p.accounts.Range(func(key, value any) bool {
- accountID, ok := key.(int64)
- if !ok || accountID <= 0 {
- return true
- }
- ap, ok := value.(*openAIWSAccountPool)
- if !ok || ap == nil {
- return true
- }
- ap.mu.Lock()
- for _, conn := range ap.conns {
- if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 {
- continue
- }
- candidates = append(candidates, openAIWSIdlePingCandidate{
- accountID: accountID,
- conn: conn,
- })
- }
- ap.mu.Unlock()
- return true
- })
- return candidates
-}
-
-func (p *openAIWSConnPool) runBackgroundCleanupWorker() {
- if p == nil {
- return
- }
- ticker := time.NewTicker(openAIWSBackgroundSweepTicker)
- defer ticker.Stop()
- for {
- select {
- case <-ticker.C:
- p.runBackgroundCleanupSweep(time.Now())
- case <-p.workerStopCh:
- return
- }
- }
-}
-
-func (p *openAIWSConnPool) runBackgroundCleanupSweep(now time.Time) {
- if p == nil {
- return
- }
- type cleanupResult struct {
- evicted []*openAIWSConn
- }
- results := make([]cleanupResult, 0)
- p.accounts.Range(func(_ any, value any) bool {
- ap, ok := value.(*openAIWSAccountPool)
- if !ok || ap == nil {
- return true
- }
- maxConns := p.maxConnsHardCap()
- ap.mu.Lock()
- if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
- maxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
- }
- evicted := p.cleanupAccountLocked(ap, now, maxConns)
- ap.lastCleanupAt = now
- ap.mu.Unlock()
- if len(evicted) > 0 {
- results = append(results, cleanupResult{evicted: evicted})
- }
- return true
- })
- for _, result := range results {
- closeOpenAIWSConns(result.evicted)
- }
-}
-
-func (p *openAIWSConnPool) Acquire(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConnLease, error) {
- if p != nil {
- p.metrics.acquireTotal.Add(1)
- }
- return p.acquire(ctx, cloneOpenAIWSAcquireRequest(req), 0)
-}
-
-func (p *openAIWSConnPool) acquire(ctx context.Context, req openAIWSAcquireRequest, retry int) (*openAIWSConnLease, error) {
- if p == nil || req.Account == nil || req.Account.ID <= 0 {
- return nil, errors.New("invalid ws acquire request")
- }
- if stringsTrim(req.WSURL) == "" {
- return nil, errors.New("ws url is empty")
- }
-
- accountID := req.Account.ID
- effectiveMaxConns := p.effectiveMaxConnsByAccount(req.Account)
- if effectiveMaxConns <= 0 {
- return nil, errOpenAIWSConnQueueFull
- }
- var evicted []*openAIWSConn
- ap := p.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.lastAcquire = cloneOpenAIWSAcquireRequestPtr(&req)
- now := time.Now()
- if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval {
- evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns)
- ap.lastCleanupAt = now
- }
- pickStartedAt := time.Now()
- allowReuse := !req.ForceNewConn
- preferredConnID := stringsTrim(req.PreferredConnID)
- forcePreferredConn := allowReuse && req.ForcePreferredConn
-
- if allowReuse {
- if forcePreferredConn {
- if preferredConnID == "" {
- p.recordConnPickDuration(time.Since(pickStartedAt))
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSPreferredConnUnavailable
- }
- preferredConn, ok := ap.conns[preferredConnID]
- if !ok || preferredConn == nil {
- p.recordConnPickDuration(time.Since(pickStartedAt))
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSPreferredConnUnavailable
- }
- if preferredConn.tryAcquire() {
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- if p.shouldHealthCheckConn(preferredConn) {
- if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- preferredConn.close()
- p.evictConn(accountID, preferredConn.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
- lease := &openAIWSConnLease{
- pool: p,
- accountID: accountID,
- conn: preferredConn,
- connPick: connPick,
- reused: true,
- }
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
-
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- if int(preferredConn.waiters.Load()) >= p.queueLimitPerConn() {
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSConnQueueFull
- }
- preferredConn.waiters.Add(1)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- defer preferredConn.waiters.Add(-1)
- waitStart := time.Now()
- p.metrics.acquireQueueWaitTotal.Add(1)
-
- if err := preferredConn.acquire(ctx); err != nil {
- if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- if p.shouldHealthCheckConn(preferredConn) {
- if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- preferredConn.release()
- preferredConn.close()
- p.evictConn(accountID, preferredConn.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
-
- queueWait := time.Since(waitStart)
- p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
- lease := &openAIWSConnLease{
- pool: p,
- accountID: accountID,
- conn: preferredConn,
- queueWait: queueWait,
- connPick: connPick,
- reused: true,
- }
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
-
- if preferredConnID != "" {
- if conn, ok := ap.conns[preferredConnID]; ok && conn.tryAcquire() {
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- if p.shouldHealthCheckConn(conn) {
- if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- conn.close()
- p.evictConn(accountID, conn.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
- lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
- }
-
- best := p.pickLeastBusyConnLocked(ap, "")
- if best != nil && best.tryAcquire() {
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- if p.shouldHealthCheckConn(best) {
- if err := best.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- best.close()
- p.evictConn(accountID, best.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
- lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: best, connPick: connPick, reused: true}
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
- for _, conn := range ap.conns {
- if conn == nil || conn == best {
- continue
- }
- if conn.tryAcquire() {
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- if p.shouldHealthCheckConn(conn) {
- if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- conn.close()
- p.evictConn(accountID, conn.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
- lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true}
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
- }
- }
-
- if req.ForceNewConn && len(ap.conns)+ap.creating >= effectiveMaxConns {
- if idle := p.pickOldestIdleConnLocked(ap); idle != nil {
- delete(ap.conns, idle.id)
- evicted = append(evicted, idle)
- p.metrics.scaleDownTotal.Add(1)
- }
- }
-
- if len(ap.conns)+ap.creating < effectiveMaxConns {
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- ap.creating++
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
-
- conn, dialErr := p.dialConn(ctx, req)
-
- ap = p.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.creating--
- if dialErr != nil {
- ap.prewarmFails++
- ap.prewarmFailAt = time.Now()
- ap.mu.Unlock()
- return nil, dialErr
- }
- ap.conns[conn.id] = conn
- ap.prewarmFails = 0
- ap.prewarmFailAt = time.Time{}
- ap.mu.Unlock()
- p.metrics.acquireCreateTotal.Add(1)
-
- if !conn.tryAcquire() {
- if err := conn.acquire(ctx); err != nil {
- conn.close()
- p.evictConn(accountID, conn.id)
- return nil, err
- }
- }
- lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick}
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
- }
-
- if req.ForceNewConn {
- p.recordConnPickDuration(time.Since(pickStartedAt))
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSConnQueueFull
- }
-
- target := p.pickLeastBusyConnLocked(ap, req.PreferredConnID)
- connPick := time.Since(pickStartedAt)
- p.recordConnPickDuration(connPick)
- if target == nil {
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSConnClosed
- }
- if int(target.waiters.Load()) >= p.queueLimitPerConn() {
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- return nil, errOpenAIWSConnQueueFull
- }
- target.waiters.Add(1)
- ap.mu.Unlock()
- closeOpenAIWSConns(evicted)
- defer target.waiters.Add(-1)
- waitStart := time.Now()
- p.metrics.acquireQueueWaitTotal.Add(1)
-
- if err := target.acquire(ctx); err != nil {
- if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- if p.shouldHealthCheckConn(target) {
- if err := target.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil {
- target.release()
- target.close()
- p.evictConn(accountID, target.id)
- if retry < 1 {
- return p.acquire(ctx, req, retry+1)
- }
- return nil, err
- }
- }
-
- queueWait := time.Since(waitStart)
- p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds())
- lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: target, queueWait: queueWait, connPick: connPick, reused: true}
- p.metrics.acquireReuseTotal.Add(1)
- p.ensureTargetIdleAsync(accountID)
- return lease, nil
-}
-
-func (p *openAIWSConnPool) recordConnPickDuration(duration time.Duration) {
- if p == nil {
- return
- }
- if duration < 0 {
- duration = 0
- }
- p.metrics.connPickTotal.Add(1)
- p.metrics.connPickMs.Add(duration.Milliseconds())
-}
-
-func (p *openAIWSConnPool) pickOldestIdleConnLocked(ap *openAIWSAccountPool) *openAIWSConn {
- if ap == nil || len(ap.conns) == 0 {
- return nil
- }
- var oldest *openAIWSConn
- for _, conn := range ap.conns {
- if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
- continue
- }
- if oldest == nil || conn.lastUsedAt().Before(oldest.lastUsedAt()) {
- oldest = conn
- }
- }
- return oldest
-}
-
-func (p *openAIWSConnPool) getOrCreateAccountPool(accountID int64) *openAIWSAccountPool {
- if p == nil || accountID <= 0 {
- return nil
- }
- if existing, ok := p.accounts.Load(accountID); ok {
- if ap, typed := existing.(*openAIWSAccountPool); typed && ap != nil {
- return ap
- }
- }
- ap := &openAIWSAccountPool{
- conns: make(map[string]*openAIWSConn),
- pinnedConns: make(map[string]int),
- }
- actual, _ := p.accounts.LoadOrStore(accountID, ap)
- if typed, ok := actual.(*openAIWSAccountPool); ok && typed != nil {
- return typed
- }
- return ap
-}
-
-// ensureAccountPoolLocked 兼容旧调用。
-func (p *openAIWSConnPool) ensureAccountPoolLocked(accountID int64) *openAIWSAccountPool {
- return p.getOrCreateAccountPool(accountID)
-}
-
-func (p *openAIWSConnPool) getAccountPool(accountID int64) (*openAIWSAccountPool, bool) {
- if p == nil || accountID <= 0 {
- return nil, false
- }
- value, ok := p.accounts.Load(accountID)
- if !ok || value == nil {
- return nil, false
- }
- ap, typed := value.(*openAIWSAccountPool)
- return ap, typed && ap != nil
-}
-
-func (p *openAIWSConnPool) isConnPinnedLocked(ap *openAIWSAccountPool, connID string) bool {
- if ap == nil || connID == "" || len(ap.pinnedConns) == 0 {
- return false
- }
- return ap.pinnedConns[connID] > 0
-}
-
-func (p *openAIWSConnPool) cleanupAccountLocked(ap *openAIWSAccountPool, now time.Time, maxConns int) []*openAIWSConn {
- if ap == nil {
- return nil
- }
- maxAge := p.maxConnAge()
-
- evicted := make([]*openAIWSConn, 0)
- for id, conn := range ap.conns {
- if conn == nil {
- delete(ap.conns, id)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, id)
- }
- continue
- }
- select {
- case <-conn.closedCh:
- delete(ap.conns, id)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, id)
- }
- evicted = append(evicted, conn)
- continue
- default:
- }
- if p.isConnPinnedLocked(ap, id) {
- continue
- }
- if maxAge > 0 && !conn.isLeased() && conn.age(now) > maxAge {
- delete(ap.conns, id)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, id)
- }
- evicted = append(evicted, conn)
- }
- }
-
- if maxConns <= 0 {
- maxConns = p.maxConnsHardCap()
- }
- maxIdle := p.maxIdlePerAccount()
- if maxIdle < 0 || maxIdle > maxConns {
- maxIdle = maxConns
- }
- if maxIdle >= 0 && len(ap.conns) > maxIdle {
- idleConns := make([]*openAIWSConn, 0, len(ap.conns))
- for id, conn := range ap.conns {
- if conn == nil {
- delete(ap.conns, id)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, id)
- }
- continue
- }
- // 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。
- if conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) {
- continue
- }
- idleConns = append(idleConns, conn)
- }
- sort.SliceStable(idleConns, func(i, j int) bool {
- return idleConns[i].lastUsedAt().Before(idleConns[j].lastUsedAt())
- })
- redundant := len(ap.conns) - maxIdle
- if redundant > len(idleConns) {
- redundant = len(idleConns)
- }
- for i := 0; i < redundant; i++ {
- conn := idleConns[i]
- delete(ap.conns, conn.id)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, conn.id)
- }
- evicted = append(evicted, conn)
- }
- if redundant > 0 {
- p.metrics.scaleDownTotal.Add(int64(redundant))
- }
- }
-
- return evicted
-}
-
-func (p *openAIWSConnPool) pickLeastBusyConnLocked(ap *openAIWSAccountPool, preferredConnID string) *openAIWSConn {
- if ap == nil || len(ap.conns) == 0 {
- return nil
- }
- preferredConnID = stringsTrim(preferredConnID)
- if preferredConnID != "" {
- if conn, ok := ap.conns[preferredConnID]; ok {
- return conn
- }
- }
- var best *openAIWSConn
- var bestWaiters int32
- var bestLastUsed time.Time
- for _, conn := range ap.conns {
- if conn == nil {
- continue
- }
- waiters := conn.waiters.Load()
- lastUsed := conn.lastUsedAt()
- if best == nil ||
- waiters < bestWaiters ||
- (waiters == bestWaiters && lastUsed.Before(bestLastUsed)) {
- best = conn
- bestWaiters = waiters
- bestLastUsed = lastUsed
- }
- }
- return best
-}
-
-func accountPoolLoadLocked(ap *openAIWSAccountPool) (inflight int, waiters int) {
- if ap == nil {
- return 0, 0
- }
- for _, conn := range ap.conns {
- if conn == nil {
- continue
- }
- if conn.isLeased() {
- inflight++
- }
- waiters += int(conn.waiters.Load())
- }
- return inflight, waiters
-}
-
-// AccountPoolLoad 返回指定账号连接池的并发与排队快照。
-func (p *openAIWSConnPool) AccountPoolLoad(accountID int64) (inflight int, waiters int, conns int) {
- if p == nil || accountID <= 0 {
- return 0, 0, 0
- }
- ap, ok := p.getAccountPool(accountID)
- if !ok || ap == nil {
- return 0, 0, 0
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- inflight, waiters = accountPoolLoadLocked(ap)
- return inflight, waiters, len(ap.conns)
-}
-
-func (p *openAIWSConnPool) ensureTargetIdleAsync(accountID int64) {
- if p == nil || accountID <= 0 {
- return
- }
-
- var req openAIWSAcquireRequest
- need := 0
- ap, ok := p.getAccountPool(accountID)
- if !ok || ap == nil {
- return
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- if ap.lastAcquire == nil {
- return
- }
- if ap.prewarmActive {
- return
- }
- now := time.Now()
- if !ap.prewarmUntil.IsZero() && now.Before(ap.prewarmUntil) {
- return
- }
- if p.shouldSuppressPrewarmLocked(ap, now) {
- return
- }
- effectiveMaxConns := p.maxConnsHardCap()
- if ap.lastAcquire != nil && ap.lastAcquire.Account != nil {
- effectiveMaxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account)
- }
- target := p.targetConnCountLocked(ap, effectiveMaxConns)
- current := len(ap.conns) + ap.creating
- if current >= target {
- return
- }
- need = target - current
- if need <= 0 {
- return
- }
- req = cloneOpenAIWSAcquireRequest(*ap.lastAcquire)
- ap.prewarmActive = true
- if cooldown := p.prewarmCooldown(); cooldown > 0 {
- ap.prewarmUntil = now.Add(cooldown)
- }
- ap.creating += need
- p.metrics.scaleUpTotal.Add(int64(need))
-
- go p.prewarmConns(accountID, req, need)
-}
-
-func (p *openAIWSConnPool) targetConnCountLocked(ap *openAIWSAccountPool, maxConns int) int {
- if ap == nil {
- return 0
- }
-
- if maxConns <= 0 {
- return 0
- }
-
- minIdle := p.minIdlePerAccount()
- if minIdle < 0 {
- minIdle = 0
- }
- if minIdle > maxConns {
- minIdle = maxConns
- }
-
- inflight, waiters := accountPoolLoadLocked(ap)
- utilization := p.targetUtilization()
- demand := inflight + waiters
- if demand <= 0 {
- return minIdle
- }
-
- target := 1
- if demand > 1 {
- target = int(math.Ceil(float64(demand) / utilization))
- }
- if waiters > 0 && target < len(ap.conns)+1 {
- target = len(ap.conns) + 1
- }
- if target < minIdle {
- target = minIdle
- }
- if target > maxConns {
- target = maxConns
- }
- return target
-}
-
-func (p *openAIWSConnPool) prewarmConns(accountID int64, req openAIWSAcquireRequest, total int) {
- defer func() {
- if ap, ok := p.getAccountPool(accountID); ok && ap != nil {
- ap.mu.Lock()
- ap.prewarmActive = false
- ap.mu.Unlock()
- }
- }()
-
- for i := 0; i < total; i++ {
- ctx, cancel := context.WithTimeout(context.Background(), p.dialTimeout()+openAIWSConnPrewarmExtraDelay)
- conn, err := p.dialConn(ctx, req)
- cancel()
-
- ap, ok := p.getAccountPool(accountID)
- if !ok || ap == nil {
- if conn != nil {
- conn.close()
- }
- return
- }
- ap.mu.Lock()
- if ap.creating > 0 {
- ap.creating--
- }
- if err != nil {
- ap.prewarmFails++
- ap.prewarmFailAt = time.Now()
- ap.mu.Unlock()
- continue
- }
- if len(ap.conns) >= p.effectiveMaxConnsByAccount(req.Account) {
- ap.mu.Unlock()
- conn.close()
- continue
- }
- ap.conns[conn.id] = conn
- ap.prewarmFails = 0
- ap.prewarmFailAt = time.Time{}
- ap.mu.Unlock()
- }
-}
-
-func (p *openAIWSConnPool) evictConn(accountID int64, connID string) {
- if p == nil || accountID <= 0 || stringsTrim(connID) == "" {
- return
- }
- var conn *openAIWSConn
- ap, ok := p.getAccountPool(accountID)
- if ok && ap != nil {
- ap.mu.Lock()
- if c, exists := ap.conns[connID]; exists {
- conn = c
- delete(ap.conns, connID)
- if len(ap.pinnedConns) > 0 {
- delete(ap.pinnedConns, connID)
- }
- }
- ap.mu.Unlock()
- }
- if conn != nil {
- conn.close()
- }
-}
-
-func (p *openAIWSConnPool) PinConn(accountID int64, connID string) bool {
- if p == nil || accountID <= 0 {
- return false
- }
- connID = stringsTrim(connID)
- if connID == "" {
- return false
- }
- ap, ok := p.getAccountPool(accountID)
- if !ok || ap == nil {
- return false
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- if _, exists := ap.conns[connID]; !exists {
- return false
- }
- if ap.pinnedConns == nil {
- ap.pinnedConns = make(map[string]int)
- }
- ap.pinnedConns[connID]++
- return true
-}
-
-func (p *openAIWSConnPool) UnpinConn(accountID int64, connID string) {
- if p == nil || accountID <= 0 {
- return
- }
- connID = stringsTrim(connID)
- if connID == "" {
- return
- }
- ap, ok := p.getAccountPool(accountID)
- if !ok || ap == nil {
- return
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- if len(ap.pinnedConns) == 0 {
- return
- }
- count := ap.pinnedConns[connID]
- if count <= 1 {
- delete(ap.pinnedConns, connID)
- return
- }
- ap.pinnedConns[connID] = count - 1
-}
-
-func (p *openAIWSConnPool) dialConn(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConn, error) {
- if p == nil || p.clientDialer == nil {
- return nil, errors.New("openai ws client dialer is nil")
- }
- conn, status, handshakeHeaders, err := p.clientDialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL)
- if err != nil {
- return nil, &openAIWSDialError{
- StatusCode: status,
- ResponseHeaders: cloneHeader(handshakeHeaders),
- Err: err,
- }
- }
- if conn == nil {
- return nil, &openAIWSDialError{
- StatusCode: status,
- ResponseHeaders: cloneHeader(handshakeHeaders),
- Err: errors.New("openai ws dialer returned nil connection"),
- }
- }
- id := p.nextConnID(req.Account.ID)
- return newOpenAIWSConn(id, req.Account.ID, conn, handshakeHeaders), nil
-}
-
-func (p *openAIWSConnPool) nextConnID(accountID int64) string {
- seq := p.seq.Add(1)
- buf := make([]byte, 0, 32)
- buf = append(buf, "oa_ws_"...)
- buf = strconv.AppendInt(buf, accountID, 10)
- buf = append(buf, '_')
- buf = strconv.AppendUint(buf, seq, 10)
- return string(buf)
-}
-
-func (p *openAIWSConnPool) shouldHealthCheckConn(conn *openAIWSConn) bool {
- if conn == nil {
- return false
- }
- return conn.idleDuration(time.Now()) >= openAIWSConnHealthCheckIdle
-}
-
-func (p *openAIWSConnPool) maxConnsHardCap() int {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 {
- return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount
- }
- return 8
-}
-
-func (p *openAIWSConnPool) dynamicMaxConnsEnabled() bool {
- if p != nil && p.cfg != nil {
- return p.cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled
- }
- return false
-}
-
-func (p *openAIWSConnPool) modeRouterV2Enabled() bool {
- if p != nil && p.cfg != nil {
- return p.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled
- }
- return false
-}
-
-func (p *openAIWSConnPool) maxConnsFactorByAccount(account *Account) float64 {
- if p == nil || p.cfg == nil || account == nil {
- return 1.0
- }
- switch account.Type {
- case AccountTypeOAuth:
- if p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor > 0 {
- return p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor
- }
- case AccountTypeAPIKey:
- if p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor > 0 {
- return p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor
- }
- }
- return 1.0
-}
-
-func (p *openAIWSConnPool) effectiveMaxConnsByAccount(account *Account) int {
- hardCap := p.maxConnsHardCap()
- if hardCap <= 0 {
- return 0
- }
- if p.modeRouterV2Enabled() {
- if account == nil {
- return hardCap
- }
- if account.Concurrency <= 0 {
- return 0
- }
- return account.Concurrency
- }
- if account == nil || !p.dynamicMaxConnsEnabled() {
- return hardCap
- }
- if account.Concurrency <= 0 {
- // 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。
- return hardCap
- }
- factor := p.maxConnsFactorByAccount(account)
- if factor <= 0 {
- factor = 1.0
- }
- effective := int(math.Ceil(float64(account.Concurrency) * factor))
- if effective < 1 {
- effective = 1
- }
- if effective > hardCap {
- effective = hardCap
- }
- return effective
-}
-
-func (p *openAIWSConnPool) minIdlePerAccount() int {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MinIdlePerAccount >= 0 {
- return p.cfg.Gateway.OpenAIWS.MinIdlePerAccount
- }
- return 0
-}
-
-func (p *openAIWSConnPool) maxIdlePerAccount() int {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount >= 0 {
- return p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount
- }
- return 4
-}
-
-func (p *openAIWSConnPool) maxConnAge() time.Duration {
- return openAIWSConnMaxAge
-}
-
-func (p *openAIWSConnPool) queueLimitPerConn() int {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.QueueLimitPerConn > 0 {
- return p.cfg.Gateway.OpenAIWS.QueueLimitPerConn
- }
- return 256
-}
-
-func (p *openAIWSConnPool) targetUtilization() float64 {
- if p != nil && p.cfg != nil {
- ratio := p.cfg.Gateway.OpenAIWS.PoolTargetUtilization
- if ratio > 0 && ratio <= 1 {
- return ratio
- }
- }
- return 0.7
-}
-
-func (p *openAIWSConnPool) prewarmCooldown() time.Duration {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS > 0 {
- return time.Duration(p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS) * time.Millisecond
- }
- return 0
-}
-
-func (p *openAIWSConnPool) shouldSuppressPrewarmLocked(ap *openAIWSAccountPool, now time.Time) bool {
- if ap == nil {
- return true
- }
- if ap.prewarmFails <= 0 {
- return false
- }
- if ap.prewarmFailAt.IsZero() {
- ap.prewarmFails = 0
- return false
- }
- if now.Sub(ap.prewarmFailAt) > openAIWSPrewarmFailureWindow {
- ap.prewarmFails = 0
- ap.prewarmFailAt = time.Time{}
- return false
- }
- return ap.prewarmFails >= openAIWSPrewarmFailureSuppress
-}
-
-func (p *openAIWSConnPool) dialTimeout() time.Duration {
- if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 {
- return time.Duration(p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second
- }
- return 10 * time.Second
-}
-
-func cloneOpenAIWSAcquireRequest(req openAIWSAcquireRequest) openAIWSAcquireRequest {
- copied := req
- copied.Headers = cloneHeader(req.Headers)
- copied.WSURL = stringsTrim(req.WSURL)
- copied.ProxyURL = stringsTrim(req.ProxyURL)
- copied.PreferredConnID = stringsTrim(req.PreferredConnID)
- return copied
-}
-
-func cloneOpenAIWSAcquireRequestPtr(req *openAIWSAcquireRequest) *openAIWSAcquireRequest {
- if req == nil {
- return nil
- }
- copied := cloneOpenAIWSAcquireRequest(*req)
- return &copied
-}
-
-func cloneHeader(src http.Header) http.Header {
- if src == nil {
- return nil
- }
- dst := make(http.Header, len(src))
- for k, vals := range src {
- if len(vals) == 0 {
- dst[k] = nil
- continue
- }
- copied := make([]string, len(vals))
- copy(copied, vals)
- dst[k] = copied
- }
- return dst
-}
-
-func closeOpenAIWSConns(conns []*openAIWSConn) {
- if len(conns) == 0 {
- return
- }
- for _, conn := range conns {
- if conn == nil {
- continue
- }
- conn.close()
- }
-}
-
-func stringsTrim(value string) string {
- return strings.TrimSpace(value)
-}
diff --git a/backend/internal/service/openai_ws_pool_benchmark_test.go b/backend/internal/service/openai_ws_pool_benchmark_test.go
deleted file mode 100644
index bff74b626..000000000
--- a/backend/internal/service/openai_ws_pool_benchmark_test.go
+++ /dev/null
@@ -1,58 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "testing"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
-)
-
-func BenchmarkOpenAIWSPoolAcquire(b *testing.B) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
-
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(&openAIWSCountingDialer{})
-
- account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- req := openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- }
- ctx := context.Background()
-
- lease, err := pool.Acquire(ctx, req)
- if err != nil {
- b.Fatalf("warm acquire failed: %v", err)
- }
- lease.Release()
-
- b.ReportAllocs()
- b.ResetTimer()
- b.RunParallel(func(pb *testing.PB) {
- for pb.Next() {
- var (
- got *openAIWSConnLease
- acquireErr error
- )
- for retry := 0; retry < 3; retry++ {
- got, acquireErr = pool.Acquire(ctx, req)
- if acquireErr == nil {
- break
- }
- if !errors.Is(acquireErr, errOpenAIWSConnClosed) {
- break
- }
- }
- if acquireErr != nil {
- b.Fatalf("acquire failed: %v", acquireErr)
- }
- got.Release()
- }
- })
-}
diff --git a/backend/internal/service/openai_ws_pool_test.go b/backend/internal/service/openai_ws_pool_test.go
deleted file mode 100644
index b2683ee04..000000000
--- a/backend/internal/service/openai_ws_pool_test.go
+++ /dev/null
@@ -1,1709 +0,0 @@
-package service
-
-import (
- "context"
- "errors"
- "net/http"
- "strings"
- "sync"
- "sync/atomic"
- "testing"
- "time"
-
- "github.com/Wei-Shaw/sub2api/internal/config"
- "github.com/stretchr/testify/require"
-)
-
-func TestOpenAIWSConnPool_CleanupStaleAndTrimIdle(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1
- pool := newOpenAIWSConnPool(cfg)
-
- accountID := int64(10)
- ap := pool.getOrCreateAccountPool(accountID)
-
- stale := newOpenAIWSConn("stale", accountID, nil, nil)
- stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
- stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
-
- idleOld := newOpenAIWSConn("idle_old", accountID, nil, nil)
- idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
-
- idleNew := newOpenAIWSConn("idle_new", accountID, nil, nil)
- idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
-
- ap.conns[stale.id] = stale
- ap.conns[idleOld.id] = idleOld
- ap.conns[idleNew.id] = idleNew
-
- evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
- closeOpenAIWSConns(evicted)
-
- require.Nil(t, ap.conns["stale"], "stale connection should be rotated")
- require.Nil(t, ap.conns["idle_old"], "old idle should be trimmed by max_idle")
- require.NotNil(t, ap.conns["idle_new"], "newer idle should be kept")
-}
-
-func TestOpenAIWSConnPool_NextConnIDFormat(t *testing.T) {
- pool := newOpenAIWSConnPool(&config.Config{})
- id1 := pool.nextConnID(42)
- id2 := pool.nextConnID(42)
-
- require.True(t, strings.HasPrefix(id1, "oa_ws_42_"))
- require.True(t, strings.HasPrefix(id2, "oa_ws_42_"))
- require.NotEqual(t, id1, id2)
- require.Equal(t, "oa_ws_42_1", id1)
- require.Equal(t, "oa_ws_42_2", id2)
-}
-
-func TestOpenAIWSConnPool_AcquireCleanupInterval(t *testing.T) {
- require.Equal(t, 3*time.Second, openAIWSAcquireCleanupInterval)
- require.Less(t, openAIWSAcquireCleanupInterval, openAIWSBackgroundSweepTicker)
-}
-
-func TestOpenAIWSConnLease_WriteJSONAndGuards(t *testing.T) {
- conn := newOpenAIWSConn("lease_write", 1, &openAIWSFakeConn{}, nil)
- lease := &openAIWSConnLease{conn: conn}
- require.NoError(t, lease.WriteJSON(map[string]any{"type": "response.create"}, 0))
-
- var nilLease *openAIWSConnLease
- err := nilLease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-
- err = (&openAIWSConnLease{}).WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-}
-
-func TestOpenAIWSConn_WriteJSONWithTimeout_NilParentContextUsesBackground(t *testing.T) {
- probe := &openAIWSContextProbeConn{}
- conn := newOpenAIWSConn("ctx_probe", 1, probe, nil)
- require.NoError(t, conn.writeJSONWithTimeout(context.Background(), map[string]any{"type": "response.create"}, 0))
- require.NotNil(t, probe.lastWriteCtx)
-}
-
-func TestOpenAIWSConnPool_TargetConnCountAdaptive(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 6
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.5
-
- pool := newOpenAIWSConnPool(cfg)
- ap := pool.getOrCreateAccountPool(88)
-
- conn1 := newOpenAIWSConn("c1", 88, nil, nil)
- conn2 := newOpenAIWSConn("c2", 88, nil, nil)
- require.True(t, conn1.tryAcquire())
- require.True(t, conn2.tryAcquire())
- conn1.waiters.Store(1)
- conn2.waiters.Store(1)
-
- ap.conns[conn1.id] = conn1
- ap.conns[conn2.id] = conn2
-
- target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
- require.Equal(t, 6, target, "应按 inflight+waiters 与 target_utilization 自适应扩容到上限")
-
- conn1.release()
- conn2.release()
- conn1.waiters.Store(0)
- conn2.waiters.Store(0)
- target = pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
- require.Equal(t, 1, target, "低负载时应缩回到最小空闲连接")
-}
-
-func TestOpenAIWSConnPool_TargetConnCountMinIdleZero(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
-
- pool := newOpenAIWSConnPool(cfg)
- ap := pool.getOrCreateAccountPool(66)
-
- target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap())
- require.Equal(t, 0, target, "min_idle=0 且无负载时应允许缩容到 0")
-}
-
-func TestOpenAIWSConnPool_EnsureTargetIdleAsync(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
-
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(&openAIWSFakeDialer{})
-
- accountID := int64(77)
- account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.lastAcquire = &openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- }
- ap.mu.Unlock()
-
- pool.ensureTargetIdleAsync(accountID)
-
- require.Eventually(t, func() bool {
- ap, ok := pool.getAccountPool(accountID)
- if !ok || ap == nil {
- return false
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- return len(ap.conns) >= 2
- }, 2*time.Second, 20*time.Millisecond)
-
- metrics := pool.SnapshotMetrics()
- require.GreaterOrEqual(t, metrics.ScaleUpTotal, int64(2))
-}
-
-func TestOpenAIWSConnPool_EnsureTargetIdleAsyncCooldown(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
- cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 500
-
- pool := newOpenAIWSConnPool(cfg)
- dialer := &openAIWSCountingDialer{}
- pool.setClientDialerForTest(dialer)
-
- accountID := int64(178)
- account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.lastAcquire = &openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- }
- ap.mu.Unlock()
-
- pool.ensureTargetIdleAsync(accountID)
- require.Eventually(t, func() bool {
- ap, ok := pool.getAccountPool(accountID)
- if !ok || ap == nil {
- return false
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- return len(ap.conns) >= 2 && !ap.prewarmActive
- }, 2*time.Second, 20*time.Millisecond)
- firstDialCount := dialer.DialCount()
- require.GreaterOrEqual(t, firstDialCount, 2)
-
- // 人工制造缺口触发新一轮预热需求。
- ap, ok := pool.getAccountPool(accountID)
- require.True(t, ok)
- require.NotNil(t, ap)
- ap.mu.Lock()
- for id := range ap.conns {
- delete(ap.conns, id)
- break
- }
- ap.mu.Unlock()
-
- pool.ensureTargetIdleAsync(accountID)
- time.Sleep(120 * time.Millisecond)
- require.Equal(t, firstDialCount, dialer.DialCount(), "cooldown 窗口内不应再次触发预热")
-
- time.Sleep(450 * time.Millisecond)
- pool.ensureTargetIdleAsync(accountID)
- require.Eventually(t, func() bool {
- return dialer.DialCount() > firstDialCount
- }, 2*time.Second, 20*time.Millisecond)
-}
-
-func TestOpenAIWSConnPool_EnsureTargetIdleAsyncFailureSuppress(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
- cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 0
-
- pool := newOpenAIWSConnPool(cfg)
- dialer := &openAIWSAlwaysFailDialer{}
- pool.setClientDialerForTest(dialer)
-
- accountID := int64(279)
- account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.lastAcquire = &openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- }
- ap.mu.Unlock()
-
- pool.ensureTargetIdleAsync(accountID)
- require.Eventually(t, func() bool {
- ap, ok := pool.getAccountPool(accountID)
- if !ok || ap == nil {
- return false
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- return !ap.prewarmActive
- }, 2*time.Second, 20*time.Millisecond)
-
- pool.ensureTargetIdleAsync(accountID)
- require.Eventually(t, func() bool {
- ap, ok := pool.getAccountPool(accountID)
- if !ok || ap == nil {
- return false
- }
- ap.mu.Lock()
- defer ap.mu.Unlock()
- return !ap.prewarmActive
- }, 2*time.Second, 20*time.Millisecond)
- require.Equal(t, 2, dialer.DialCount())
-
- // 连续失败达到阈值后,新的预热触发应被抑制,不再继续拨号。
- pool.ensureTargetIdleAsync(accountID)
- time.Sleep(120 * time.Millisecond)
- require.Equal(t, 2, dialer.DialCount())
-}
-
-func TestOpenAIWSConnPool_AcquireQueueWaitMetrics(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
-
- pool := newOpenAIWSConnPool(cfg)
- accountID := int64(99)
- account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- conn := newOpenAIWSConn("busy", accountID, &openAIWSFakeConn{}, nil)
- require.True(t, conn.tryAcquire()) // 占用连接,触发后续排队
-
- ap := pool.ensureAccountPoolLocked(accountID)
- ap.mu.Lock()
- ap.conns[conn.id] = conn
- ap.lastAcquire = &openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- }
- ap.mu.Unlock()
-
- go func() {
- time.Sleep(60 * time.Millisecond)
- conn.release()
- }()
-
- lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- })
- require.NoError(t, err)
- require.NotNil(t, lease)
- require.True(t, lease.Reused())
- require.GreaterOrEqual(t, lease.QueueWaitDuration(), 50*time.Millisecond)
- lease.Release()
-
- metrics := pool.SnapshotMetrics()
- require.GreaterOrEqual(t, metrics.AcquireQueueWaitTotal, int64(1))
- require.Greater(t, metrics.AcquireQueueWaitMsTotal, int64(0))
- require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1))
-}
-
-func TestOpenAIWSConnPool_ForceNewConnSkipsReuse(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
-
- pool := newOpenAIWSConnPool(cfg)
- dialer := &openAIWSCountingDialer{}
- pool.setClientDialerForTest(dialer)
-
- account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
-
- lease1, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- })
- require.NoError(t, err)
- require.NotNil(t, lease1)
- lease1.Release()
-
- lease2, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- ForceNewConn: true,
- })
- require.NoError(t, err)
- require.NotNil(t, lease2)
- lease2.Release()
-
- require.Equal(t, 2, dialer.DialCount(), "ForceNewConn=true 时应跳过空闲连接复用并新建连接")
-}
-
-func TestOpenAIWSConnPool_AcquireForcePreferredConnUnavailable(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
-
- pool := newOpenAIWSConnPool(cfg)
- account := &Account{ID: 124, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(account.ID)
- otherConn := newOpenAIWSConn("other_conn", account.ID, &openAIWSFakeConn{}, nil)
- ap.mu.Lock()
- ap.conns[otherConn.id] = otherConn
- ap.mu.Unlock()
-
- _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- ForcePreferredConn: true,
- })
- require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
-
- _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- PreferredConnID: "missing_conn",
- ForcePreferredConn: true,
- })
- require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable)
-}
-
-func TestOpenAIWSConnPool_AcquireForcePreferredConnQueuesOnPreferredOnly(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4
-
- pool := newOpenAIWSConnPool(cfg)
- account := &Account{ID: 125, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(account.ID)
- preferredConn := newOpenAIWSConn("preferred_conn", account.ID, &openAIWSFakeConn{}, nil)
- otherConn := newOpenAIWSConn("other_conn_idle", account.ID, &openAIWSFakeConn{}, nil)
- require.True(t, preferredConn.tryAcquire(), "先占用 preferred 连接,触发排队获取")
- ap.mu.Lock()
- ap.conns[preferredConn.id] = preferredConn
- ap.conns[otherConn.id] = otherConn
- ap.lastCleanupAt = time.Now()
- ap.mu.Unlock()
-
- go func() {
- time.Sleep(60 * time.Millisecond)
- preferredConn.release()
- }()
-
- ctx, cancel := context.WithTimeout(context.Background(), time.Second)
- defer cancel()
- lease, err := pool.Acquire(ctx, openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- PreferredConnID: preferredConn.id,
- ForcePreferredConn: true,
- })
- require.NoError(t, err)
- require.NotNil(t, lease)
- require.Equal(t, preferredConn.id, lease.ConnID(), "严格模式应只等待并复用 preferred 连接,不可漂移")
- require.GreaterOrEqual(t, lease.QueueWaitDuration(), 40*time.Millisecond)
- lease.Release()
- require.True(t, otherConn.tryAcquire(), "other 连接不应被严格模式抢占")
- otherConn.release()
-}
-
-func TestOpenAIWSConnPool_AcquireForcePreferredConnDirectAndQueueFull(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
-
- pool := newOpenAIWSConnPool(cfg)
- account := &Account{ID: 127, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := pool.getOrCreateAccountPool(account.ID)
- preferredConn := newOpenAIWSConn("preferred_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
- otherConn := newOpenAIWSConn("other_conn_direct", account.ID, &openAIWSFakeConn{}, nil)
- ap.mu.Lock()
- ap.conns[preferredConn.id] = preferredConn
- ap.conns[otherConn.id] = otherConn
- ap.lastCleanupAt = time.Now()
- ap.mu.Unlock()
-
- lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- PreferredConnID: preferredConn.id,
- ForcePreferredConn: true,
- })
- require.NoError(t, err)
- require.Equal(t, preferredConn.id, lease.ConnID(), "preferred 空闲时应直接命中")
- lease.Release()
-
- require.True(t, preferredConn.tryAcquire())
- preferredConn.waiters.Store(1)
- _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- PreferredConnID: preferredConn.id,
- ForcePreferredConn: true,
- })
- require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "严格模式下队列满应直接失败,不得漂移")
- preferredConn.waiters.Store(0)
- preferredConn.release()
-}
-
-func TestOpenAIWSConnPool_CleanupSkipsPinnedConn(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0
-
- pool := newOpenAIWSConnPool(cfg)
- accountID := int64(126)
- ap := pool.getOrCreateAccountPool(accountID)
- pinnedConn := newOpenAIWSConn("pinned_conn", accountID, &openAIWSFakeConn{}, nil)
- idleConn := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
- ap.mu.Lock()
- ap.conns[pinnedConn.id] = pinnedConn
- ap.conns[idleConn.id] = idleConn
- ap.mu.Unlock()
-
- require.True(t, pool.PinConn(accountID, pinnedConn.id))
- evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
- closeOpenAIWSConns(evicted)
-
- ap.mu.Lock()
- _, pinnedExists := ap.conns[pinnedConn.id]
- _, idleExists := ap.conns[idleConn.id]
- ap.mu.Unlock()
- require.True(t, pinnedExists, "被 active ingress 绑定的连接不应被 cleanup 回收")
- require.False(t, idleExists, "非绑定的空闲连接应被回收")
-
- pool.UnpinConn(accountID, pinnedConn.id)
- evicted = pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap())
- closeOpenAIWSConns(evicted)
- ap.mu.Lock()
- _, pinnedExists = ap.conns[pinnedConn.id]
- ap.mu.Unlock()
- require.False(t, pinnedExists, "解绑后连接应可被正常回收")
-}
-
-func TestOpenAIWSConnPool_PinUnpinConnBranches(t *testing.T) {
- var nilPool *openAIWSConnPool
- require.False(t, nilPool.PinConn(1, "x"))
- nilPool.UnpinConn(1, "x")
-
- cfg := &config.Config{}
- pool := newOpenAIWSConnPool(cfg)
- accountID := int64(128)
- ap := &openAIWSAccountPool{
- conns: map[string]*openAIWSConn{},
- }
- pool.accounts.Store(accountID, ap)
-
- require.False(t, pool.PinConn(0, "x"))
- require.False(t, pool.PinConn(999, "x"))
- require.False(t, pool.PinConn(accountID, ""))
- require.False(t, pool.PinConn(accountID, "missing"))
-
- conn := newOpenAIWSConn("pin_refcount", accountID, &openAIWSFakeConn{}, nil)
- ap.mu.Lock()
- ap.conns[conn.id] = conn
- ap.mu.Unlock()
- require.True(t, pool.PinConn(accountID, conn.id))
- require.True(t, pool.PinConn(accountID, conn.id))
-
- ap.mu.Lock()
- require.Equal(t, 2, ap.pinnedConns[conn.id])
- ap.mu.Unlock()
-
- pool.UnpinConn(accountID, conn.id)
- ap.mu.Lock()
- require.Equal(t, 1, ap.pinnedConns[conn.id])
- ap.mu.Unlock()
-
- pool.UnpinConn(accountID, conn.id)
- ap.mu.Lock()
- _, exists := ap.pinnedConns[conn.id]
- ap.mu.Unlock()
- require.False(t, exists)
-
- pool.UnpinConn(accountID, conn.id)
- pool.UnpinConn(accountID, "")
- pool.UnpinConn(0, conn.id)
- pool.UnpinConn(999, conn.id)
-}
-
-func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
- cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
- cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
- cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
-
- pool := newOpenAIWSConnPool(cfg)
-
- oauthHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 10}
- require.Equal(t, 8, pool.effectiveMaxConnsByAccount(oauthHigh), "应受全局硬上限约束")
-
- oauthLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 3}
- require.Equal(t, 3, pool.effectiveMaxConnsByAccount(oauthLow))
-
- apiKeyHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 10}
- require.Equal(t, 6, pool.effectiveMaxConnsByAccount(apiKeyHigh), "API Key 应按系数缩放")
-
- apiKeyLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1}
- require.Equal(t, 1, pool.effectiveMaxConnsByAccount(apiKeyLow), "最小值应保持为 1")
-
- unlimited := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
- require.Equal(t, 8, pool.effectiveMaxConnsByAccount(unlimited), "无限并发应回退到全局硬上限")
-
- require.Equal(t, 8, pool.effectiveMaxConnsByAccount(nil), "缺少账号上下文应回退到全局硬上限")
-}
-
-func TestOpenAIWSConnPool_EffectiveMaxConnsDisabledFallbackHardCap(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
- cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false
- cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0
- cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 1.0
-
- pool := newOpenAIWSConnPool(cfg)
- account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 2}
- require.Equal(t, 8, pool.effectiveMaxConnsByAccount(account), "关闭动态模式后应保持旧行为")
-}
-
-func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount_ModeRouterV2UsesAccountConcurrency(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
- cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
- cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0.3
- cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6
-
- pool := newOpenAIWSConnPool(cfg)
-
- high := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 20}
- require.Equal(t, 20, pool.effectiveMaxConnsByAccount(high), "v2 路径应直接使用账号并发数作为池上限")
-
- nonPositive := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 0}
- require.Equal(t, 0, pool.effectiveMaxConnsByAccount(nonPositive), "并发数<=0 时应不可调度")
-}
-
-func TestOpenAIWSConnPool_AcquireRejectsWhenEffectiveMaxConnsIsZero(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8
- pool := newOpenAIWSConnPool(cfg)
-
- account := &Account{ID: 901, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0}
- _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- })
- require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
-}
-
-func TestOpenAIWSConnLease_ReadMessageWithContextTimeout_PerRead(t *testing.T) {
- conn := newOpenAIWSConn("timeout", 1, &openAIWSBlockingConn{readDelay: 80 * time.Millisecond}, nil)
- lease := &openAIWSConnLease{conn: conn}
-
- _, err := lease.ReadMessageWithContextTimeout(context.Background(), 20*time.Millisecond)
- require.Error(t, err)
- require.ErrorIs(t, err, context.DeadlineExceeded)
-
- payload, err := lease.ReadMessageWithContextTimeout(context.Background(), 150*time.Millisecond)
- require.NoError(t, err)
- require.Contains(t, string(payload), "response.completed")
-
- parentCtx, cancel := context.WithCancel(context.Background())
- cancel()
- _, err = lease.ReadMessageWithContextTimeout(parentCtx, 150*time.Millisecond)
- require.Error(t, err)
- require.ErrorIs(t, err, context.Canceled)
-}
-
-func TestOpenAIWSConnLease_WriteJSONWithContextTimeout_RespectsParentContext(t *testing.T) {
- conn := newOpenAIWSConn("write_timeout_ctx", 1, &openAIWSWriteBlockingConn{}, nil)
- lease := &openAIWSConnLease{conn: conn}
-
- parentCtx, cancel := context.WithCancel(context.Background())
- go func() {
- time.Sleep(20 * time.Millisecond)
- cancel()
- }()
-
- start := time.Now()
- err := lease.WriteJSONWithContextTimeout(parentCtx, map[string]any{"type": "response.create"}, 2*time.Minute)
- elapsed := time.Since(start)
-
- require.Error(t, err)
- require.ErrorIs(t, err, context.Canceled)
- require.Less(t, elapsed, 200*time.Millisecond)
-}
-
-func TestOpenAIWSConnLease_PingWithTimeout(t *testing.T) {
- conn := newOpenAIWSConn("ping_ok", 1, &openAIWSFakeConn{}, nil)
- lease := &openAIWSConnLease{conn: conn}
- require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
-
- var nilLease *openAIWSConnLease
- err := nilLease.PingWithTimeout(50 * time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-}
-
-func TestOpenAIWSConn_ReadAndWriteCanProceedConcurrently(t *testing.T) {
- conn := newOpenAIWSConn("full_duplex", 1, &openAIWSBlockingConn{readDelay: 120 * time.Millisecond}, nil)
-
- readDone := make(chan error, 1)
- go func() {
- _, err := conn.readMessageWithContextTimeout(context.Background(), 200*time.Millisecond)
- readDone <- err
- }()
-
- // 让读取先占用 readMu。
- time.Sleep(20 * time.Millisecond)
-
- start := time.Now()
- err := conn.pingWithTimeout(50 * time.Millisecond)
- elapsed := time.Since(start)
-
- require.NoError(t, err)
- require.Less(t, elapsed, 80*time.Millisecond, "写路径不应被读锁长期阻塞")
- require.NoError(t, <-readDone)
-}
-
-func TestOpenAIWSConnPool_BackgroundPingSweep_EvictsDeadIdleConn(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- pool := newOpenAIWSConnPool(cfg)
-
- accountID := int64(301)
- ap := pool.getOrCreateAccountPool(accountID)
- conn := newOpenAIWSConn("dead_idle", accountID, &openAIWSPingFailConn{}, nil)
- ap.mu.Lock()
- ap.conns[conn.id] = conn
- ap.mu.Unlock()
-
- pool.runBackgroundPingSweep()
-
- ap.mu.Lock()
- _, exists := ap.conns[conn.id]
- ap.mu.Unlock()
- require.False(t, exists, "后台 ping 失败的空闲连接应被回收")
-}
-
-func TestOpenAIWSConnPool_BackgroundCleanupSweep_WithoutAcquire(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2
- pool := newOpenAIWSConnPool(cfg)
-
- accountID := int64(302)
- ap := pool.getOrCreateAccountPool(accountID)
- stale := newOpenAIWSConn("stale_bg", accountID, &openAIWSFakeConn{}, nil)
- stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
- stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
- ap.mu.Lock()
- ap.conns[stale.id] = stale
- ap.mu.Unlock()
-
- pool.runBackgroundCleanupSweep(time.Now())
-
- ap.mu.Lock()
- _, exists := ap.conns[stale.id]
- ap.mu.Unlock()
- require.False(t, exists, "后台清理应在无新 acquire 时也回收过期连接")
-}
-
-func TestOpenAIWSConnPool_BackgroundWorkerGuardBranches(t *testing.T) {
- var nilPool *openAIWSConnPool
- require.NotPanics(t, func() {
- nilPool.startBackgroundWorkers()
- nilPool.runBackgroundPingWorker()
- nilPool.runBackgroundPingSweep()
- _ = nilPool.snapshotIdleConnsForPing()
- nilPool.runBackgroundCleanupWorker()
- nilPool.runBackgroundCleanupSweep(time.Now())
- })
-
- poolNoStop := &openAIWSConnPool{}
- require.NotPanics(t, func() {
- poolNoStop.startBackgroundWorkers()
- })
-
- poolStopPing := &openAIWSConnPool{workerStopCh: make(chan struct{})}
- pingDone := make(chan struct{})
- go func() {
- poolStopPing.runBackgroundPingWorker()
- close(pingDone)
- }()
- close(poolStopPing.workerStopCh)
- select {
- case <-pingDone:
- case <-time.After(500 * time.Millisecond):
- t.Fatal("runBackgroundPingWorker 未在 stop 信号后退出")
- }
-
- poolStopCleanup := &openAIWSConnPool{workerStopCh: make(chan struct{})}
- cleanupDone := make(chan struct{})
- go func() {
- poolStopCleanup.runBackgroundCleanupWorker()
- close(cleanupDone)
- }()
- close(poolStopCleanup.workerStopCh)
- select {
- case <-cleanupDone:
- case <-time.After(500 * time.Millisecond):
- t.Fatal("runBackgroundCleanupWorker 未在 stop 信号后退出")
- }
-}
-
-func TestOpenAIWSConnPool_SnapshotIdleConnsForPing_SkipsInvalidEntries(t *testing.T) {
- pool := &openAIWSConnPool{}
- pool.accounts.Store("invalid-key", &openAIWSAccountPool{})
- pool.accounts.Store(int64(123), "invalid-value")
-
- accountID := int64(123)
- ap := &openAIWSAccountPool{
- conns: make(map[string]*openAIWSConn),
- }
- ap.conns["nil_conn"] = nil
-
- leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
- require.True(t, leased.tryAcquire())
- ap.conns[leased.id] = leased
-
- waiting := newOpenAIWSConn("waiting", accountID, &openAIWSFakeConn{}, nil)
- waiting.waiters.Store(1)
- ap.conns[waiting.id] = waiting
-
- idle := newOpenAIWSConn("idle", accountID, &openAIWSFakeConn{}, nil)
- ap.conns[idle.id] = idle
-
- pool.accounts.Store(accountID, ap)
- candidates := pool.snapshotIdleConnsForPing()
- require.Len(t, candidates, 1)
- require.Equal(t, idle.id, candidates[0].conn.id)
-}
-
-func TestOpenAIWSConnPool_RunBackgroundCleanupSweep_SkipsInvalidAndUsesAccountCap(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4
- cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true
-
- pool := &openAIWSConnPool{cfg: cfg}
- pool.accounts.Store("bad-key", "bad-value")
-
- accountID := int64(2026)
- ap := &openAIWSAccountPool{
- conns: make(map[string]*openAIWSConn),
- }
- ap.conns["nil_conn"] = nil
- stale := newOpenAIWSConn("stale_bg_cleanup", accountID, &openAIWSFakeConn{}, nil)
- stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
- stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano())
- ap.conns[stale.id] = stale
- ap.lastAcquire = &openAIWSAcquireRequest{
- Account: &Account{
- ID: accountID,
- Platform: PlatformOpenAI,
- Type: AccountTypeAPIKey,
- Concurrency: 1,
- },
- }
- pool.accounts.Store(accountID, ap)
-
- now := time.Now()
- require.NotPanics(t, func() {
- pool.runBackgroundCleanupSweep(now)
- })
-
- ap.mu.Lock()
- _, nilConnExists := ap.conns["nil_conn"]
- _, exists := ap.conns[stale.id]
- lastCleanupAt := ap.lastCleanupAt
- ap.mu.Unlock()
-
- require.False(t, nilConnExists, "后台清理应移除无效 nil 连接条目")
- require.False(t, exists, "后台清理应清理过期连接")
- require.Equal(t, now, lastCleanupAt)
-}
-
-func TestOpenAIWSConnPool_QueueLimitPerConn_DefaultAndConfigured(t *testing.T) {
- var nilPool *openAIWSConnPool
- require.Equal(t, 256, nilPool.queueLimitPerConn())
-
- pool := &openAIWSConnPool{cfg: &config.Config{}}
- require.Equal(t, 256, pool.queueLimitPerConn())
-
- pool.cfg.Gateway.OpenAIWS.QueueLimitPerConn = 9
- require.Equal(t, 9, pool.queueLimitPerConn())
-}
-
-func TestOpenAIWSConnPool_Close(t *testing.T) {
- cfg := &config.Config{}
- pool := newOpenAIWSConnPool(cfg)
-
- // Close 应该可以安全调用
- pool.Close()
-
- // workerStopCh 应已关闭
- select {
- case <-pool.workerStopCh:
- // 预期:channel 已关闭
- default:
- t.Fatal("Close 后 workerStopCh 应已关闭")
- }
-
- // 多次调用 Close 不应 panic
- pool.Close()
-
- // nil pool 调用 Close 不应 panic
- var nilPool *openAIWSConnPool
- nilPool.Close()
-}
-
-func TestOpenAIWSDialError_ErrorAndUnwrap(t *testing.T) {
- baseErr := errors.New("boom")
- dialErr := &openAIWSDialError{StatusCode: 502, Err: baseErr}
- require.Contains(t, dialErr.Error(), "status=502")
- require.ErrorIs(t, dialErr.Unwrap(), baseErr)
-
- noStatus := &openAIWSDialError{Err: baseErr}
- require.Contains(t, noStatus.Error(), "boom")
-
- var nilDialErr *openAIWSDialError
- require.Equal(t, "", nilDialErr.Error())
- require.NoError(t, nilDialErr.Unwrap())
-}
-
-func TestOpenAIWSConnLease_ReadWriteHelpersAndConnStats(t *testing.T) {
- conn := newOpenAIWSConn("helper_conn", 1, &openAIWSFakeConn{}, http.Header{
- "X-Test": []string{" value "},
- })
- lease := &openAIWSConnLease{conn: conn}
-
- require.NoError(t, lease.WriteJSONContext(context.Background(), map[string]any{"type": "response.create"}))
- payload, err := lease.ReadMessage(100 * time.Millisecond)
- require.NoError(t, err)
- require.Contains(t, string(payload), "response.completed")
-
- payload, err = lease.ReadMessageContext(context.Background())
- require.NoError(t, err)
- require.Contains(t, string(payload), "response.completed")
-
- payload, err = conn.readMessageWithTimeout(100 * time.Millisecond)
- require.NoError(t, err)
- require.Contains(t, string(payload), "response.completed")
-
- require.Equal(t, "value", conn.handshakeHeader(" X-Test "))
- require.NotZero(t, conn.createdAt())
- require.NotZero(t, conn.lastUsedAt())
- require.GreaterOrEqual(t, conn.age(time.Now()), time.Duration(0))
- require.GreaterOrEqual(t, conn.idleDuration(time.Now()), time.Duration(0))
- require.False(t, conn.isLeased())
-
- // 覆盖空上下文路径
- _, err = conn.readMessage(context.Background())
- require.NoError(t, err)
-
- // 覆盖 nil 保护分支
- var nilConn *openAIWSConn
- require.ErrorIs(t, nilConn.writeJSONWithTimeout(context.Background(), map[string]any{}, time.Second), errOpenAIWSConnClosed)
- _, err = nilConn.readMessageWithTimeout(10 * time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- _, err = nilConn.readMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-}
-
-func TestOpenAIWSConnPool_PickOldestIdleAndAccountPoolLoad(t *testing.T) {
- pool := &openAIWSConnPool{}
- accountID := int64(404)
- ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
-
- idleOld := newOpenAIWSConn("idle_old", accountID, &openAIWSFakeConn{}, nil)
- idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano())
- idleNew := newOpenAIWSConn("idle_new", accountID, &openAIWSFakeConn{}, nil)
- idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano())
- leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil)
- require.True(t, leased.tryAcquire())
- leased.waiters.Store(2)
-
- ap.conns[idleOld.id] = idleOld
- ap.conns[idleNew.id] = idleNew
- ap.conns[leased.id] = leased
-
- oldest := pool.pickOldestIdleConnLocked(ap)
- require.NotNil(t, oldest)
- require.Equal(t, idleOld.id, oldest.id)
-
- inflight, waiters := accountPoolLoadLocked(ap)
- require.Equal(t, 1, inflight)
- require.Equal(t, 2, waiters)
-
- pool.accounts.Store(accountID, ap)
- loadInflight, loadWaiters, conns := pool.AccountPoolLoad(accountID)
- require.Equal(t, 1, loadInflight)
- require.Equal(t, 2, loadWaiters)
- require.Equal(t, 3, conns)
-
- zeroInflight, zeroWaiters, zeroConns := pool.AccountPoolLoad(0)
- require.Equal(t, 0, zeroInflight)
- require.Equal(t, 0, zeroWaiters)
- require.Equal(t, 0, zeroConns)
-}
-
-func TestOpenAIWSConnPool_Close_WaitsWorkerGroupAndNilStopChannel(t *testing.T) {
- pool := &openAIWSConnPool{}
- release := make(chan struct{})
- pool.workerWg.Add(1)
- go func() {
- defer pool.workerWg.Done()
- <-release
- }()
-
- closed := make(chan struct{})
- go func() {
- pool.Close()
- close(closed)
- }()
-
- select {
- case <-closed:
- t.Fatal("Close 不应在 WaitGroup 未完成时提前返回")
- case <-time.After(30 * time.Millisecond):
- }
-
- close(release)
- select {
- case <-closed:
- case <-time.After(time.Second):
- t.Fatal("Close 未等待 workerWg 完成")
- }
-}
-
-func TestOpenAIWSConnPool_Close_ClosesOnlyIdleConnections(t *testing.T) {
- pool := &openAIWSConnPool{
- workerStopCh: make(chan struct{}),
- }
-
- accountID := int64(606)
- ap := &openAIWSAccountPool{
- conns: map[string]*openAIWSConn{},
- }
- idle := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil)
- leased := newOpenAIWSConn("leased_conn", accountID, &openAIWSFakeConn{}, nil)
- require.True(t, leased.tryAcquire())
-
- ap.conns[idle.id] = idle
- ap.conns[leased.id] = leased
- pool.accounts.Store(accountID, ap)
- pool.accounts.Store("invalid-key", "invalid-value")
-
- pool.Close()
-
- select {
- case <-idle.closedCh:
- // idle should be closed
- default:
- t.Fatal("空闲连接应在 Close 时被关闭")
- }
-
- select {
- case <-leased.closedCh:
- t.Fatal("已租赁连接不应在 Close 时被关闭")
- default:
- }
-
- leased.release()
- pool.Close()
-}
-
-func TestOpenAIWSConnPool_RunBackgroundPingSweep_ConcurrencyLimit(t *testing.T) {
- cfg := &config.Config{}
- pool := newOpenAIWSConnPool(cfg)
- accountID := int64(505)
- ap := pool.getOrCreateAccountPool(accountID)
-
- var current atomic.Int32
- var maxConcurrent atomic.Int32
- release := make(chan struct{})
- for i := 0; i < 25; i++ {
- conn := newOpenAIWSConn(pool.nextConnID(accountID), accountID, &openAIWSPingBlockingConn{
- current: ¤t,
- maxConcurrent: &maxConcurrent,
- release: release,
- }, nil)
- ap.mu.Lock()
- ap.conns[conn.id] = conn
- ap.mu.Unlock()
- }
-
- done := make(chan struct{})
- go func() {
- pool.runBackgroundPingSweep()
- close(done)
- }()
-
- require.Eventually(t, func() bool {
- return maxConcurrent.Load() >= 10
- }, time.Second, 10*time.Millisecond)
-
- close(release)
- select {
- case <-done:
- case <-time.After(2 * time.Second):
- t.Fatal("runBackgroundPingSweep 未在释放后完成")
- }
-
- require.LessOrEqual(t, maxConcurrent.Load(), int32(10))
-}
-
-func TestOpenAIWSConnLease_BasicGetterBranches(t *testing.T) {
- var nilLease *openAIWSConnLease
- require.Equal(t, "", nilLease.ConnID())
- require.Equal(t, time.Duration(0), nilLease.QueueWaitDuration())
- require.Equal(t, time.Duration(0), nilLease.ConnPickDuration())
- require.False(t, nilLease.Reused())
- require.Equal(t, "", nilLease.HandshakeHeader("x-test"))
- require.False(t, nilLease.IsPrewarmed())
- nilLease.MarkPrewarmed()
- nilLease.Release()
-
- conn := newOpenAIWSConn("getter_conn", 1, &openAIWSFakeConn{}, http.Header{"X-Test": []string{"ok"}})
- lease := &openAIWSConnLease{
- conn: conn,
- queueWait: 3 * time.Millisecond,
- connPick: 4 * time.Millisecond,
- reused: true,
- }
- require.Equal(t, "getter_conn", lease.ConnID())
- require.Equal(t, 3*time.Millisecond, lease.QueueWaitDuration())
- require.Equal(t, 4*time.Millisecond, lease.ConnPickDuration())
- require.True(t, lease.Reused())
- require.Equal(t, "ok", lease.HandshakeHeader("x-test"))
- require.False(t, lease.IsPrewarmed())
- lease.MarkPrewarmed()
- require.True(t, lease.IsPrewarmed())
- lease.Release()
-}
-
-func TestOpenAIWSConnPool_UtilityBranches(t *testing.T) {
- var nilPool *openAIWSConnPool
- require.Equal(t, OpenAIWSPoolMetricsSnapshot{}, nilPool.SnapshotMetrics())
- require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, nilPool.SnapshotTransportMetrics())
-
- pool := &openAIWSConnPool{cfg: &config.Config{}}
- pool.metrics.acquireTotal.Store(7)
- pool.metrics.acquireReuseTotal.Store(3)
- metrics := pool.SnapshotMetrics()
- require.Equal(t, int64(7), metrics.AcquireTotal)
- require.Equal(t, int64(3), metrics.AcquireReuseTotal)
-
- // 非 transport metrics dialer 路径
- pool.clientDialer = &openAIWSFakeDialer{}
- require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, pool.SnapshotTransportMetrics())
- pool.setClientDialerForTest(nil)
- require.NotNil(t, pool.clientDialer)
-
- require.Equal(t, 8, nilPool.maxConnsHardCap())
- require.False(t, nilPool.dynamicMaxConnsEnabled())
- require.Equal(t, 1.0, nilPool.maxConnsFactorByAccount(nil))
- require.Equal(t, 0, nilPool.minIdlePerAccount())
- require.Equal(t, 4, nilPool.maxIdlePerAccount())
- require.Equal(t, 256, nilPool.queueLimitPerConn())
- require.Equal(t, 0.7, nilPool.targetUtilization())
- require.Equal(t, time.Duration(0), nilPool.prewarmCooldown())
- require.Equal(t, 10*time.Second, nilPool.dialTimeout())
-
- // shouldSuppressPrewarmLocked 覆盖 3 条分支
- now := time.Now()
- apNilFail := &openAIWSAccountPool{prewarmFails: 1}
- require.False(t, pool.shouldSuppressPrewarmLocked(apNilFail, now))
- apZeroTime := &openAIWSAccountPool{prewarmFails: 2}
- require.False(t, pool.shouldSuppressPrewarmLocked(apZeroTime, now))
- require.Equal(t, 0, apZeroTime.prewarmFails)
- apOldFail := &openAIWSAccountPool{prewarmFails: 2, prewarmFailAt: now.Add(-openAIWSPrewarmFailureWindow - time.Second)}
- require.False(t, pool.shouldSuppressPrewarmLocked(apOldFail, now))
- apRecentFail := &openAIWSAccountPool{prewarmFails: openAIWSPrewarmFailureSuppress, prewarmFailAt: now}
- require.True(t, pool.shouldSuppressPrewarmLocked(apRecentFail, now))
-
- // recordConnPickDuration 的保护分支
- nilPool.recordConnPickDuration(10 * time.Millisecond)
- pool.recordConnPickDuration(-10 * time.Millisecond)
- require.Equal(t, int64(1), pool.metrics.connPickTotal.Load())
-
- // account pool 读写分支
- require.Nil(t, nilPool.getOrCreateAccountPool(1))
- require.Nil(t, pool.getOrCreateAccountPool(0))
- pool.accounts.Store(int64(7), "invalid")
- ap := pool.getOrCreateAccountPool(7)
- require.NotNil(t, ap)
- _, ok := pool.getAccountPool(0)
- require.False(t, ok)
- _, ok = pool.getAccountPool(12345)
- require.False(t, ok)
- pool.accounts.Store(int64(8), "bad-type")
- _, ok = pool.getAccountPool(8)
- require.False(t, ok)
-
- // health check 条件
- require.False(t, pool.shouldHealthCheckConn(nil))
- conn := newOpenAIWSConn("health", 1, &openAIWSFakeConn{}, nil)
- conn.lastUsedNano.Store(time.Now().Add(-openAIWSConnHealthCheckIdle - time.Second).UnixNano())
- require.True(t, pool.shouldHealthCheckConn(conn))
-}
-
-func TestOpenAIWSConn_LeaseAndTimeHelpers_NilAndClosedBranches(t *testing.T) {
- var nilConn *openAIWSConn
- nilConn.touch()
- require.Equal(t, time.Time{}, nilConn.createdAt())
- require.Equal(t, time.Time{}, nilConn.lastUsedAt())
- require.Equal(t, time.Duration(0), nilConn.idleDuration(time.Now()))
- require.Equal(t, time.Duration(0), nilConn.age(time.Now()))
- require.False(t, nilConn.isLeased())
- require.False(t, nilConn.isPrewarmed())
- nilConn.markPrewarmed()
-
- conn := newOpenAIWSConn("lease_state", 1, &openAIWSFakeConn{}, nil)
- require.True(t, conn.tryAcquire())
- require.True(t, conn.isLeased())
- conn.release()
- require.False(t, conn.isLeased())
- conn.close()
- require.False(t, conn.tryAcquire())
-
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- err := conn.acquire(ctx)
- require.Error(t, err)
-}
-
-func TestOpenAIWSConnLease_ReadWriteNilConnBranches(t *testing.T) {
- lease := &openAIWSConnLease{}
- require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
- require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
- _, err := lease.ReadMessage(10 * time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- _, err = lease.ReadMessageContext(context.Background())
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-}
-
-func TestOpenAIWSConnLease_ReleasedLeaseGuards(t *testing.T) {
- conn := newOpenAIWSConn("released_guard", 1, &openAIWSFakeConn{}, nil)
- lease := &openAIWSConnLease{conn: conn}
-
- require.NoError(t, lease.PingWithTimeout(50*time.Millisecond))
-
- lease.Release()
- lease.Release() // idempotent
-
- require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
- require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed)
- require.ErrorIs(t, lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed)
-
- _, err := lease.ReadMessage(10 * time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- _, err = lease.ReadMessageContext(context.Background())
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-
- require.ErrorIs(t, lease.PingWithTimeout(50*time.Millisecond), errOpenAIWSConnClosed)
-}
-
-func TestOpenAIWSConnLease_MarkBrokenAfterRelease_NoEviction(t *testing.T) {
- conn := newOpenAIWSConn("released_markbroken", 7, &openAIWSFakeConn{}, nil)
- ap := &openAIWSAccountPool{
- conns: map[string]*openAIWSConn{
- conn.id: conn,
- },
- }
- pool := &openAIWSConnPool{}
- pool.accounts.Store(int64(7), ap)
-
- lease := &openAIWSConnLease{
- pool: pool,
- accountID: 7,
- conn: conn,
- }
-
- lease.Release()
- lease.MarkBroken()
-
- ap.mu.Lock()
- _, exists := ap.conns[conn.id]
- ap.mu.Unlock()
- require.True(t, exists, "released lease should not evict active pool connection")
-}
-
-func TestOpenAIWSConn_AdditionalGuardBranches(t *testing.T) {
- var nilConn *openAIWSConn
- require.False(t, nilConn.tryAcquire())
- require.ErrorIs(t, nilConn.acquire(context.Background()), errOpenAIWSConnClosed)
- nilConn.release()
- nilConn.close()
- require.Equal(t, "", nilConn.handshakeHeader("x-test"))
-
- connBusy := newOpenAIWSConn("busy_ctx", 1, &openAIWSFakeConn{}, nil)
- require.True(t, connBusy.tryAcquire())
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- require.ErrorIs(t, connBusy.acquire(ctx), context.Canceled)
- connBusy.release()
-
- connClosed := newOpenAIWSConn("closed_guard", 1, &openAIWSFakeConn{}, nil)
- connClosed.close()
- require.ErrorIs(
- t,
- connClosed.writeJSONWithTimeout(context.Background(), map[string]any{"k": "v"}, time.Second),
- errOpenAIWSConnClosed,
- )
- _, err := connClosed.readMessageWithContextTimeout(context.Background(), time.Second)
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- require.ErrorIs(t, connClosed.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
-
- connNoWS := newOpenAIWSConn("no_ws", 1, nil, nil)
- require.ErrorIs(t, connNoWS.writeJSON(map[string]any{"k": "v"}, context.Background()), errOpenAIWSConnClosed)
- _, err = connNoWS.readMessage(context.Background())
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
- require.ErrorIs(t, connNoWS.pingWithTimeout(time.Second), errOpenAIWSConnClosed)
- require.Equal(t, "", connNoWS.handshakeHeader("x-test"))
-
- connOK := newOpenAIWSConn("ok", 1, &openAIWSFakeConn{}, nil)
- require.NoError(t, connOK.writeJSON(map[string]any{"k": "v"}, nil))
- _, err = connOK.readMessageWithContextTimeout(context.Background(), 0)
- require.NoError(t, err)
- require.NoError(t, connOK.pingWithTimeout(0))
-
- connZero := newOpenAIWSConn("zero_ts", 1, &openAIWSFakeConn{}, nil)
- connZero.createdAtNano.Store(0)
- connZero.lastUsedNano.Store(0)
- require.True(t, connZero.createdAt().IsZero())
- require.True(t, connZero.lastUsedAt().IsZero())
- require.Equal(t, time.Duration(0), connZero.idleDuration(time.Now()))
- require.Equal(t, time.Duration(0), connZero.age(time.Now()))
-
- require.Nil(t, cloneOpenAIWSAcquireRequestPtr(nil))
- copied := cloneHeader(http.Header{
- "X-Empty": []string{},
- "X-Test": []string{"v1"},
- })
- require.Contains(t, copied, "X-Empty")
- require.Nil(t, copied["X-Empty"])
- require.Equal(t, "v1", copied.Get("X-Test"))
-
- closeOpenAIWSConns([]*openAIWSConn{nil, connOK})
-}
-
-func TestOpenAIWSConnLease_MarkBrokenEvictsConn(t *testing.T) {
- pool := newOpenAIWSConnPool(&config.Config{})
- accountID := int64(5001)
- conn := newOpenAIWSConn("broken_me", accountID, &openAIWSFakeConn{}, nil)
- ap := pool.getOrCreateAccountPool(accountID)
- ap.mu.Lock()
- ap.conns[conn.id] = conn
- ap.mu.Unlock()
-
- lease := &openAIWSConnLease{
- pool: pool,
- accountID: accountID,
- conn: conn,
- }
- lease.MarkBroken()
-
- ap.mu.Lock()
- _, exists := ap.conns[conn.id]
- ap.mu.Unlock()
- require.False(t, exists)
- require.False(t, conn.tryAcquire(), "被标记为 broken 的连接应被关闭")
-}
-
-func TestOpenAIWSConnPool_TargetConnCountAndPrewarmBranches(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- pool := newOpenAIWSConnPool(cfg)
-
- require.Equal(t, 0, pool.targetConnCountLocked(nil, 1))
- ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}}
- require.Equal(t, 0, pool.targetConnCountLocked(ap, 0))
-
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 3
- require.Equal(t, 1, pool.targetConnCountLocked(ap, 1), "minIdle 应被 maxConns 截断")
-
- // 覆盖 waiters>0 且 target 需要至少 len(conns)+1 的分支
- cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0
- cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.9
- busy := newOpenAIWSConn("busy_target", 2, &openAIWSFakeConn{}, nil)
- require.True(t, busy.tryAcquire())
- busy.waiters.Store(1)
- ap.conns[busy.id] = busy
- target := pool.targetConnCountLocked(ap, 4)
- require.GreaterOrEqual(t, target, len(ap.conns)+1)
-
- // prewarm: account pool 缺失时,拨号后的连接应被关闭并提前返回
- req := openAIWSAcquireRequest{
- Account: &Account{ID: 999, Platform: PlatformOpenAI, Type: AccountTypeAPIKey},
- WSURL: "wss://example.com/v1/responses",
- }
- pool.prewarmConns(999, req, 1)
-
- // prewarm: 拨号失败分支(prewarmFails 累加)
- accountID := int64(1000)
- failPool := newOpenAIWSConnPool(cfg)
- failPool.setClientDialerForTest(&openAIWSAlwaysFailDialer{})
- apFail := failPool.getOrCreateAccountPool(accountID)
- apFail.mu.Lock()
- apFail.creating = 1
- apFail.mu.Unlock()
- req.Account.ID = accountID
- failPool.prewarmConns(accountID, req, 1)
- apFail.mu.Lock()
- require.GreaterOrEqual(t, apFail.prewarmFails, 1)
- apFail.mu.Unlock()
-}
-
-func TestOpenAIWSConnPool_Acquire_ErrorBranches(t *testing.T) {
- var nilPool *openAIWSConnPool
- _, err := nilPool.Acquire(context.Background(), openAIWSAcquireRequest{})
- require.Error(t, err)
-
- pool := newOpenAIWSConnPool(&config.Config{})
- _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: &Account{ID: 1},
- WSURL: " ",
- })
- require.Error(t, err)
- require.Contains(t, err.Error(), "ws url is empty")
-
- // target=nil 分支:池满且仅有 nil 连接
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1
- cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1
- fullPool := newOpenAIWSConnPool(cfg)
- account := &Account{ID: 2001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap := fullPool.getOrCreateAccountPool(account.ID)
- ap.mu.Lock()
- ap.conns["nil"] = nil
- ap.lastCleanupAt = time.Now()
- ap.mu.Unlock()
- _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- })
- require.ErrorIs(t, err, errOpenAIWSConnClosed)
-
- // queue full 分支:waiters 达上限
- account2 := &Account{ID: 2002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
- ap2 := fullPool.getOrCreateAccountPool(account2.ID)
- conn := newOpenAIWSConn("queue_full", account2.ID, &openAIWSFakeConn{}, nil)
- require.True(t, conn.tryAcquire())
- conn.waiters.Store(1)
- ap2.mu.Lock()
- ap2.conns[conn.id] = conn
- ap2.lastCleanupAt = time.Now()
- ap2.mu.Unlock()
- _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account2,
- WSURL: "wss://example.com/v1/responses",
- })
- require.ErrorIs(t, err, errOpenAIWSConnQueueFull)
-}
-
-type openAIWSFakeDialer struct{}
-
-func (d *openAIWSFakeDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = headers
- _ = proxyURL
- return &openAIWSFakeConn{}, 0, nil, nil
-}
-
-type openAIWSCountingDialer struct {
- mu sync.Mutex
- dialCount int
-}
-
-type openAIWSAlwaysFailDialer struct {
- mu sync.Mutex
- dialCount int
-}
-
-type openAIWSPingBlockingConn struct {
- current *atomic.Int32
- maxConcurrent *atomic.Int32
- release <-chan struct{}
-}
-
-func (c *openAIWSPingBlockingConn) WriteJSON(context.Context, any) error {
- return nil
-}
-
-func (c *openAIWSPingBlockingConn) ReadMessage(context.Context) ([]byte, error) {
- return []byte(`{"type":"response.completed","response":{"id":"resp_blocking_ping"}}`), nil
-}
-
-func (c *openAIWSPingBlockingConn) Ping(ctx context.Context) error {
- if c.current == nil || c.maxConcurrent == nil {
- return nil
- }
-
- now := c.current.Add(1)
- for {
- prev := c.maxConcurrent.Load()
- if now <= prev || c.maxConcurrent.CompareAndSwap(prev, now) {
- break
- }
- }
- defer c.current.Add(-1)
-
- select {
- case <-ctx.Done():
- return ctx.Err()
- case <-c.release:
- return nil
- }
-}
-
-func (c *openAIWSPingBlockingConn) Close() error {
- return nil
-}
-
-func (d *openAIWSCountingDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = headers
- _ = proxyURL
- d.mu.Lock()
- d.dialCount++
- d.mu.Unlock()
- return &openAIWSFakeConn{}, 0, nil, nil
-}
-
-func (d *openAIWSCountingDialer) DialCount() int {
- d.mu.Lock()
- defer d.mu.Unlock()
- return d.dialCount
-}
-
-func (d *openAIWSAlwaysFailDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = headers
- _ = proxyURL
- d.mu.Lock()
- d.dialCount++
- d.mu.Unlock()
- return nil, 503, nil, errors.New("dial failed")
-}
-
-func (d *openAIWSAlwaysFailDialer) DialCount() int {
- d.mu.Lock()
- defer d.mu.Unlock()
- return d.dialCount
-}
-
-type openAIWSFakeConn struct {
- mu sync.Mutex
- closed bool
- payload [][]byte
-}
-
-func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error {
- _ = ctx
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.closed {
- return errors.New("closed")
- }
- c.payload = append(c.payload, []byte("ok"))
- _ = value
- return nil
-}
-
-func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) {
- _ = ctx
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.closed {
- return nil, errors.New("closed")
- }
- return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil
-}
-
-func (c *openAIWSFakeConn) Ping(ctx context.Context) error {
- _ = ctx
- return nil
-}
-
-func (c *openAIWSFakeConn) Close() error {
- c.mu.Lock()
- defer c.mu.Unlock()
- c.closed = true
- return nil
-}
-
-type openAIWSBlockingConn struct {
- readDelay time.Duration
-}
-
-func (c *openAIWSBlockingConn) WriteJSON(ctx context.Context, value any) error {
- _ = ctx
- _ = value
- return nil
-}
-
-func (c *openAIWSBlockingConn) ReadMessage(ctx context.Context) ([]byte, error) {
- delay := c.readDelay
- if delay <= 0 {
- delay = 10 * time.Millisecond
- }
- timer := time.NewTimer(delay)
- defer timer.Stop()
-
- select {
- case <-ctx.Done():
- return nil, ctx.Err()
- case <-timer.C:
- return []byte(`{"type":"response.completed","response":{"id":"resp_blocking"}}`), nil
- }
-}
-
-func (c *openAIWSBlockingConn) Ping(ctx context.Context) error {
- _ = ctx
- return nil
-}
-
-func (c *openAIWSBlockingConn) Close() error {
- return nil
-}
-
-type openAIWSWriteBlockingConn struct{}
-
-func (c *openAIWSWriteBlockingConn) WriteJSON(ctx context.Context, _ any) error {
- <-ctx.Done()
- return ctx.Err()
-}
-
-func (c *openAIWSWriteBlockingConn) ReadMessage(context.Context) ([]byte, error) {
- return []byte(`{"type":"response.completed","response":{"id":"resp_write_block"}}`), nil
-}
-
-func (c *openAIWSWriteBlockingConn) Ping(context.Context) error {
- return nil
-}
-
-func (c *openAIWSWriteBlockingConn) Close() error {
- return nil
-}
-
-type openAIWSPingFailConn struct{}
-
-func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error {
- return nil
-}
-
-func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) {
- return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil
-}
-
-func (c *openAIWSPingFailConn) Ping(context.Context) error {
- return errors.New("ping failed")
-}
-
-func (c *openAIWSPingFailConn) Close() error {
- return nil
-}
-
-type openAIWSContextProbeConn struct {
- lastWriteCtx context.Context
-}
-
-func (c *openAIWSContextProbeConn) WriteJSON(ctx context.Context, _ any) error {
- c.lastWriteCtx = ctx
- return nil
-}
-
-func (c *openAIWSContextProbeConn) ReadMessage(context.Context) ([]byte, error) {
- return []byte(`{"type":"response.completed","response":{"id":"resp_ctx_probe"}}`), nil
-}
-
-func (c *openAIWSContextProbeConn) Ping(context.Context) error {
- return nil
-}
-
-func (c *openAIWSContextProbeConn) Close() error {
- return nil
-}
-
-type openAIWSNilConnDialer struct{}
-
-func (d *openAIWSNilConnDialer) Dial(
- ctx context.Context,
- wsURL string,
- headers http.Header,
- proxyURL string,
-) (openAIWSClientConn, int, http.Header, error) {
- _ = ctx
- _ = wsURL
- _ = headers
- _ = proxyURL
- return nil, 200, nil, nil
-}
-
-func TestOpenAIWSConnPool_DialConnNilConnection(t *testing.T) {
- cfg := &config.Config{}
- cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2
- cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1
-
- pool := newOpenAIWSConnPool(cfg)
- pool.setClientDialerForTest(&openAIWSNilConnDialer{})
- account := &Account{ID: 91, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}
-
- _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{
- Account: account,
- WSURL: "wss://example.com/v1/responses",
- })
- require.Error(t, err)
- require.Contains(t, err.Error(), "nil connection")
-}
-
-func TestOpenAIWSConnPool_SnapshotTransportMetrics(t *testing.T) {
- cfg := &config.Config{}
- pool := newOpenAIWSConnPool(cfg)
-
- dialer, ok := pool.clientDialer.(*coderOpenAIWSClientDialer)
- require.True(t, ok)
-
- _, err := dialer.proxyHTTPClient("http://127.0.0.1:28080")
- require.NoError(t, err)
- _, err = dialer.proxyHTTPClient("http://127.0.0.1:28080")
- require.NoError(t, err)
- _, err = dialer.proxyHTTPClient("http://127.0.0.1:28081")
- require.NoError(t, err)
-
- snapshot := pool.SnapshotTransportMetrics()
- require.Equal(t, int64(1), snapshot.ProxyClientCacheHits)
- require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses)
- require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001)
-}
diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go
index df4d4871c..69aaeaec0 100644
--- a/backend/internal/service/openai_ws_protocol_forward_test.go
+++ b/backend/internal/service/openai_ws_protocol_forward_test.go
@@ -30,6 +30,7 @@ func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -201,7 +202,9 @@ func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *
func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) {
gin.SetMode(gin.TestMode)
+ var wsAttempts atomic.Int32
ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ wsAttempts.Add(1)
w.WriteHeader(http.StatusUpgradeRequired)
_, _ = w.Write([]byte(`upgrade required`))
}))
@@ -211,6 +214,7 @@ func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) {
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -260,6 +264,7 @@ func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) {
require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP")
require.Equal(t, http.StatusUpgradeRequired, rec.Code)
require.Contains(t, rec.Body.String(), "426")
+ require.Equal(t, int32(1), wsAttempts.Load(), "426 upgrade_required 应快速失败,不应进行 WS 重试")
}
func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) {
@@ -273,6 +278,7 @@ func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) {
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -332,6 +338,7 @@ func TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled(t *testing.
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -419,6 +426,7 @@ func TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenRetu
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
c.String(http.StatusAccepted, "already-written")
upstream := &httpUpstreamRecorder{
@@ -508,6 +516,7 @@ func TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP(t *testin
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -590,6 +599,7 @@ func TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP(t *test
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -672,6 +682,7 @@ func TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP(t *tes
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -759,6 +770,7 @@ func TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbac
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -866,6 +878,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDrop
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -966,6 +979,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryF
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -1064,6 +1078,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryW
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
@@ -1161,6 +1176,7 @@ func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOn
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil)
c.Request.Header.Set("User-Agent", "custom-client/1.0")
+ SetOpenAIClientTransport(c, OpenAIClientTransportWS)
upstream := &httpUpstreamRecorder{
resp: &http.Response{
diff --git a/backend/internal/service/openai_ws_protocol_resolver.go b/backend/internal/service/openai_ws_protocol_resolver.go
index 368643bea..11c7baa81 100644
--- a/backend/internal/service/openai_ws_protocol_resolver.go
+++ b/backend/internal/service/openai_ws_protocol_resolver.go
@@ -69,7 +69,7 @@ func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProt
switch mode {
case OpenAIWSIngressModeOff:
return openAIWSHTTPDecision("account_mode_off")
- case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated:
+ case OpenAIWSIngressModeCtxPool:
// continue
default:
return openAIWSHTTPDecision("account_mode_off")
diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go
index 5be76e28f..cdd3ef07a 100644
--- a/backend/internal/service/openai_ws_protocol_resolver_test.go
+++ b/backend/internal/service/openai_ws_protocol_resolver_test.go
@@ -143,21 +143,35 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
cfg.Gateway.OpenAIWS.APIKeyEnabled = true
cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true
cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true
- cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared
+ cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeOff
- account := &Account{
- Platform: PlatformOpenAI,
- Type: AccountTypeOAuth,
- Concurrency: 1,
- Extra: map[string]any{
- "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
- },
- }
-
- t.Run("dedicated mode routes to ws v2", func(t *testing.T) {
+ t.Run("dedicated mode is blocked and routes to http", func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
+ },
+ }
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account)
+ // dedicated is now mapped to ctx_pool for backward compatibility
+ require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
+ require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
+ })
+
+ t.Run("ctx_pool mode routes to ws v2", func(t *testing.T) {
+ ctxPoolAccount := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Concurrency: 1,
+ Extra: map[string]any{
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
+ },
+ }
+ decision := NewOpenAIWSProtocolResolver(cfg).Resolve(ctxPoolAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
- require.Equal(t, "ws_v2_mode_dedicated", decision.Reason)
+ require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
})
t.Run("off mode routes to http", func(t *testing.T) {
@@ -174,7 +188,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
require.Equal(t, "account_mode_off", decision.Reason)
})
- t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) {
+ t.Run("legacy boolean maps to ctx_pool in v2 router", func(t *testing.T) {
legacyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
@@ -185,7 +199,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount)
require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport)
- require.Equal(t, "ws_v2_mode_shared", decision.Reason)
+ require.Equal(t, "ws_v2_mode_ctx_pool", decision.Reason)
})
t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) {
@@ -193,7 +207,7 @@ func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) {
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
- "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
+ "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeCtxPool,
},
}
decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency)
diff --git a/backend/internal/service/openai_ws_recovery.go b/backend/internal/service/openai_ws_recovery.go
new file mode 100644
index 000000000..b49329140
--- /dev/null
+++ b/backend/internal/service/openai_ws_recovery.go
@@ -0,0 +1,758 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "io"
+ "net"
+ "net/http"
+ "strings"
+ "time"
+
+ coderws "github.com/coder/websocket"
+)
+
+// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。
+type openAIWSFallbackError struct {
+ Reason string
+ Err error
+}
+
+func (e *openAIWSFallbackError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.Err == nil {
+ return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason))
+ }
+ return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err)
+}
+
+func (e *openAIWSFallbackError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.Err
+}
+
+func wrapOpenAIWSFallback(reason string, err error) error {
+ return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err}
+}
+
+// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。
+type OpenAIWSClientCloseError struct {
+ statusCode coderws.StatusCode
+ reason string
+ err error
+}
+
+type openAIWSIngressTurnError struct {
+ stage string
+ cause error
+ wroteDownstream bool
+ partialResult *OpenAIForwardResult
+}
+
+type openAIWSIngressUpstreamLease interface {
+ ConnID() string
+ QueueWaitDuration() time.Duration
+ ConnPickDuration() time.Duration
+ Reused() bool
+ ScheduleLayer() string
+ StickinessLevel() string
+ MigrationUsed() bool
+ HandshakeHeader(name string) string
+ IsPrewarmed() bool
+ MarkPrewarmed()
+ WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error
+ ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error)
+ PingWithTimeout(timeout time.Duration) error
+ MarkBroken()
+ Yield()
+ Release()
+}
+
+func (e *openAIWSIngressTurnError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.cause == nil {
+ return strings.TrimSpace(e.stage)
+ }
+ return e.cause.Error()
+}
+
+func (e *openAIWSIngressTurnError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.cause
+}
+
+func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error {
+ return wrapOpenAIWSIngressTurnErrorWithPartial(stage, cause, wroteDownstream, nil)
+}
+
+func cloneOpenAIForwardResult(result *OpenAIForwardResult) *OpenAIForwardResult {
+ if result == nil {
+ return nil
+ }
+ cloned := *result
+ if result.PendingFunctionCallIDs != nil {
+ cloned.PendingFunctionCallIDs = make([]string, len(result.PendingFunctionCallIDs))
+ copy(cloned.PendingFunctionCallIDs, result.PendingFunctionCallIDs)
+ }
+ return &cloned
+}
+
+func wrapOpenAIWSIngressTurnErrorWithPartial(stage string, cause error, wroteDownstream bool, partialResult *OpenAIForwardResult) error {
+ if cause == nil {
+ return nil
+ }
+ return &openAIWSIngressTurnError{
+ stage: strings.TrimSpace(stage),
+ cause: cause,
+ wroteDownstream: wroteDownstream,
+ partialResult: cloneOpenAIForwardResult(partialResult),
+ }
+}
+
+// OpenAIWSIngressTurnPartialResult returns usage-bearing partial turn result
+// when WS ingress turn aborts after receiving upstream events.
+func OpenAIWSIngressTurnPartialResult(err error) (*OpenAIForwardResult, bool) {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil || turnErr.partialResult == nil {
+ return nil, false
+ }
+ return cloneOpenAIForwardResult(turnErr.partialResult), true
+}
+
+func isOpenAIWSIngressTurnRetryable(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) {
+ return false
+ }
+ if turnErr.wroteDownstream {
+ return false
+ }
+ switch turnErr.stage {
+ case "write_upstream", "read_upstream":
+ return true
+ default:
+ return false
+ }
+}
+
+func openAIWSIngressTurnRetryReason(err error) string {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return "unknown"
+ }
+ if turnErr.stage == "" {
+ return "unknown"
+ }
+ return turnErr.stage
+}
+
+func isOpenAIWSIngressPreviousResponseNotFound(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound {
+ return false
+ }
+ return !turnErr.wroteDownstream
+}
+
+func isOpenAIWSIngressToolOutputNotFound(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ if strings.TrimSpace(turnErr.stage) != openAIWSIngressStageToolOutputNotFound {
+ return false
+ }
+ return !turnErr.wroteDownstream
+}
+
+// openAIWSIngressTurnWroteDownstream 返回本次 turn 是否已向客户端写入过数据。
+// 用于 ContinueTurn abort 时判断是否需要补发 error 事件。
+func openAIWSIngressTurnWroteDownstream(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ return turnErr.wroteDownstream
+}
+
+func isOpenAIWSIngressUpstreamErrorEvent(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ return strings.TrimSpace(turnErr.stage) == "upstream_error_event"
+}
+
+func isOpenAIWSContinuationUnavailableCloseError(err error) bool {
+ var closeErr *OpenAIWSClientCloseError
+ if !errors.As(err, &closeErr) || closeErr == nil {
+ return false
+ }
+ if closeErr.StatusCode() != coderws.StatusPolicyViolation {
+ return false
+ }
+ return strings.Contains(closeErr.Reason(), openAIWSContinuationUnavailableReason)
+}
+
+// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。
+func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error {
+ return &OpenAIWSClientCloseError{
+ statusCode: statusCode,
+ reason: strings.TrimSpace(reason),
+ err: err,
+ }
+}
+
+func (e *OpenAIWSClientCloseError) Error() string {
+ if e == nil {
+ return ""
+ }
+ if e.err == nil {
+ return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason))
+ }
+ return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err)
+}
+
+func (e *OpenAIWSClientCloseError) Unwrap() error {
+ if e == nil {
+ return nil
+ }
+ return e.err
+}
+
+func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode {
+ if e == nil {
+ return coderws.StatusInternalError
+ }
+ return e.statusCode
+}
+
+func (e *OpenAIWSClientCloseError) Reason() string {
+ if e == nil {
+ return ""
+ }
+ return strings.TrimSpace(e.reason)
+}
+
+func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) {
+ if err == nil {
+ return "-", "-"
+ }
+ statusCode := coderws.CloseStatus(err)
+ if statusCode == -1 {
+ return "-", "-"
+ }
+ closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
+ closeReason := "-"
+ var closeErr coderws.CloseError
+ if errors.As(err, &closeErr) {
+ reasonText := strings.TrimSpace(closeErr.Reason)
+ if reasonText != "" {
+ closeReason = normalizeOpenAIWSLogValue(reasonText)
+ }
+ }
+ return normalizeOpenAIWSLogValue(closeStatus), closeReason
+}
+
+func unwrapOpenAIWSDialBaseError(err error) error {
+ if err == nil {
+ return nil
+ }
+ var dialErr *openAIWSDialError
+ if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil {
+ return dialErr.Err
+ }
+ return err
+}
+
+func openAIWSDialRespHeaderForLog(err error, key string) string {
+ var dialErr *openAIWSDialError
+ if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil {
+ return "-"
+ }
+ return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen)
+}
+
+func classifyOpenAIWSDialError(err error) string {
+ if err == nil {
+ return "-"
+ }
+ baseErr := unwrapOpenAIWSDialBaseError(err)
+ if baseErr == nil {
+ return "-"
+ }
+ if errors.Is(baseErr, context.DeadlineExceeded) {
+ return "ctx_deadline_exceeded"
+ }
+ if errors.Is(baseErr, context.Canceled) {
+ return "ctx_canceled"
+ }
+ var netErr net.Error
+ if errors.As(baseErr, &netErr) && netErr.Timeout() {
+ return "net_timeout"
+ }
+ if status := coderws.CloseStatus(baseErr); status != -1 {
+ return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status)))
+ }
+ message := strings.ToLower(strings.TrimSpace(baseErr.Error()))
+ switch {
+ case strings.Contains(message, "handshake not finished"):
+ return "handshake_not_finished"
+ case strings.Contains(message, "bad handshake"):
+ return "bad_handshake"
+ case strings.Contains(message, "connection refused"):
+ return "connection_refused"
+ case strings.Contains(message, "no such host"):
+ return "dns_not_found"
+ case strings.Contains(message, "tls"):
+ return "tls_error"
+ case strings.Contains(message, "i/o timeout"):
+ return "io_timeout"
+ case strings.Contains(message, "context deadline exceeded"):
+ return "ctx_deadline_exceeded"
+ default:
+ return "dial_error"
+ }
+}
+
+func summarizeOpenAIWSDialError(err error) (
+ statusCode int,
+ dialClass string,
+ closeStatus string,
+ closeReason string,
+ respServer string,
+ respVia string,
+ respCFRay string,
+ respRequestID string,
+) {
+ dialClass = "-"
+ closeStatus = "-"
+ closeReason = "-"
+ respServer = "-"
+ respVia = "-"
+ respCFRay = "-"
+ respRequestID = "-"
+ if err == nil {
+ return
+ }
+ var dialErr *openAIWSDialError
+ if errors.As(err, &dialErr) && dialErr != nil {
+ statusCode = dialErr.StatusCode
+ respServer = openAIWSDialRespHeaderForLog(err, "server")
+ respVia = openAIWSDialRespHeaderForLog(err, "via")
+ respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray")
+ respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id")
+ }
+ dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err))
+ closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err))
+ return
+}
+
+func isOpenAIWSClientDisconnectError(err error) bool {
+ if err == nil {
+ return false
+ }
+ if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) {
+ return true
+ }
+ switch coderws.CloseStatus(err) {
+ case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure:
+ return true
+ }
+ message := strings.ToLower(strings.TrimSpace(err.Error()))
+ if message == "" {
+ return false
+ }
+ return strings.Contains(message, "failed to read frame header: eof") ||
+ strings.Contains(message, "unexpected eof") ||
+ strings.Contains(message, "use of closed network connection") ||
+ strings.Contains(message, "connection reset by peer") ||
+ strings.Contains(message, "broken pipe")
+}
+
+func classifyOpenAIWSIngressReadErrorClass(err error) string {
+ if err == nil {
+ return "unknown"
+ }
+ if errors.Is(err, context.Canceled) {
+ return "context_canceled"
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ return "deadline_exceeded"
+ }
+ switch coderws.CloseStatus(err) {
+ case coderws.StatusServiceRestart:
+ return "service_restart"
+ case coderws.StatusTryAgainLater:
+ return "try_again_later"
+ }
+ if isOpenAIWSClientDisconnectError(err) {
+ return "upstream_closed"
+ }
+ if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
+ return "eof"
+ }
+ return "unknown"
+}
+
+func isOpenAIWSStreamWriteDisconnectError(err error, reqCtx context.Context) bool {
+ if err == nil {
+ return false
+ }
+ if reqCtx != nil && reqCtx.Err() != nil {
+ return true
+ }
+ return isOpenAIWSClientDisconnectError(err)
+}
+
+func openAIWSIngressResolveDrainReadTimeout(
+ baseTimeout time.Duration,
+ disconnectDeadline time.Time,
+ now time.Time,
+) (time.Duration, bool) {
+ if disconnectDeadline.IsZero() {
+ return baseTimeout, false
+ }
+ remaining := disconnectDeadline.Sub(now)
+ if remaining <= 0 {
+ return 0, true
+ }
+ if baseTimeout <= 0 || remaining < baseTimeout {
+ return remaining, false
+ }
+ return baseTimeout, false
+}
+
+func openAIWSIngressClientDisconnectedDrainTimeoutError(timeout time.Duration) error {
+ if timeout <= 0 {
+ timeout = openAIWSIngressClientDisconnectDrainTimeout
+ }
+ return fmt.Errorf("client disconnected before upstream terminal event (drain timeout=%s): %w", timeout, context.Canceled)
+}
+
+func openAIWSIngressPumpClosedTurnError(
+ clientDisconnected bool,
+ wroteDownstream bool,
+ partialResult *OpenAIForwardResult,
+) error {
+ if clientDisconnected {
+ return wrapOpenAIWSIngressTurnErrorWithPartial(
+ "client_disconnected_drain_timeout",
+ openAIWSIngressClientDisconnectedDrainTimeoutError(openAIWSIngressClientDisconnectDrainTimeout),
+ wroteDownstream,
+ partialResult,
+ )
+ }
+ return wrapOpenAIWSIngressTurnErrorWithPartial(
+ "read_upstream",
+ errors.New("upstream event pump closed unexpectedly"),
+ wroteDownstream,
+ partialResult,
+ )
+}
+
+func shouldFlushOpenAIWSBufferedEventsOnError(reqStream bool, wroteDownstream bool, clientDisconnected bool) bool {
+ return reqStream && wroteDownstream && !clientDisconnected
+}
+
+// errOpenAIWSClientPreempted 表示客户端在当前 turn 尚未完成时发送了新的 response.create 请求。
+var errOpenAIWSClientPreempted = errors.New("client preempted current turn with new request")
+
+var errOpenAIWSAdvanceClientReadUnavailable = errors.New("client reader channels unavailable")
+
+func openAIWSAdvanceConsumePendingClientReadErr(pendingErr *error) error {
+ if pendingErr == nil || *pendingErr == nil {
+ return nil
+ }
+ readErr := *pendingErr
+ *pendingErr = nil
+ return fmt.Errorf("read client websocket request: %w", readErr)
+}
+
+func openAIWSAdvanceClientReadUnavailable(clientMsgCh <-chan []byte, clientReadErrCh <-chan error) bool {
+ return clientMsgCh == nil && clientReadErrCh == nil
+}
+
+// isOpenAIWSUpstreamRestartCloseError 检测上游是否因服务重启/维护关闭了连接。
+// 1012=ServiceRestart, 1013=TryAgainLater,都是临时性上游维护,proxy 应视为可恢复错误。
+func isOpenAIWSUpstreamRestartCloseError(err error) bool {
+ var turnErr *openAIWSIngressTurnError
+ if !errors.As(err, &turnErr) || turnErr == nil {
+ return false
+ }
+ if turnErr.stage != "read_upstream" {
+ return false
+ }
+ status := coderws.CloseStatus(turnErr.cause)
+ return status == 1012 || status == 1013 // ServiceRestart, TryAgainLater
+}
+
+func classifyOpenAIWSIngressTurnAbortReason(err error) (openAIWSIngressTurnAbortReason, bool) {
+ if err == nil {
+ return openAIWSIngressTurnAbortReasonUnknown, false
+ }
+ if isOpenAIWSIngressPreviousResponseNotFound(err) {
+ return openAIWSIngressTurnAbortReasonPreviousResponse, true
+ }
+ if isOpenAIWSIngressToolOutputNotFound(err) {
+ return openAIWSIngressTurnAbortReasonToolOutput, true
+ }
+ if isOpenAIWSIngressUpstreamErrorEvent(err) {
+ return openAIWSIngressTurnAbortReasonUpstreamError, true
+ }
+ if isOpenAIWSContinuationUnavailableCloseError(err) {
+ return openAIWSIngressTurnAbortReasonContinuationUnavailable, true
+ }
+ if errors.Is(err, errOpenAIWSClientPreempted) {
+ return openAIWSIngressTurnAbortReasonClientPreempted, true
+ }
+ if errors.Is(err, context.Canceled) {
+ return openAIWSIngressTurnAbortReasonContextCanceled, true
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ return openAIWSIngressTurnAbortReasonContextDeadline, false
+ }
+ if isOpenAIWSClientDisconnectError(err) {
+ return openAIWSIngressTurnAbortReasonClientClosed, true
+ }
+ // 上游 ServiceRestart/TryAgainLater:必须在 stage-based 分类之前检测,
+ // 否则会被 "read_upstream" 分支兜底为 FailRequest。
+ if isOpenAIWSUpstreamRestartCloseError(err) {
+ return openAIWSIngressTurnAbortReasonUpstreamRestart, true
+ }
+
+ var turnErr *openAIWSIngressTurnError
+ if errors.As(err, &turnErr) && turnErr != nil {
+ switch strings.TrimSpace(turnErr.stage) {
+ case "write_upstream":
+ return openAIWSIngressTurnAbortReasonWriteUpstream, false
+ case "read_upstream":
+ return openAIWSIngressTurnAbortReasonReadUpstream, false
+ case "write_client":
+ return openAIWSIngressTurnAbortReasonWriteClient, false
+ }
+ }
+ return openAIWSIngressTurnAbortReasonUnknown, false
+}
+
+func openAIWSIngressTurnAbortDispositionForReason(reason openAIWSIngressTurnAbortReason) openAIWSIngressTurnAbortDisposition {
+ switch reason {
+ case openAIWSIngressTurnAbortReasonPreviousResponse,
+ openAIWSIngressTurnAbortReasonToolOutput,
+ openAIWSIngressTurnAbortReasonUpstreamError,
+ openAIWSIngressTurnAbortReasonClientPreempted,
+ openAIWSIngressTurnAbortReasonUpstreamRestart:
+ return openAIWSIngressTurnAbortDispositionContinueTurn
+ case openAIWSIngressTurnAbortReasonContextCanceled,
+ openAIWSIngressTurnAbortReasonClientClosed:
+ return openAIWSIngressTurnAbortDispositionCloseGracefully
+ default:
+ return openAIWSIngressTurnAbortDispositionFailRequest
+ }
+}
+
+func classifyOpenAIWSReadFallbackReason(err error) string {
+ if err == nil {
+ return "read_event"
+ }
+ switch coderws.CloseStatus(err) {
+ case coderws.StatusServiceRestart:
+ return "service_restart"
+ case coderws.StatusTryAgainLater:
+ return "try_again_later"
+ case coderws.StatusPolicyViolation:
+ return "policy_violation"
+ case coderws.StatusMessageTooBig:
+ return "message_too_big"
+ default:
+ return "read_event"
+ }
+}
+
+func classifyOpenAIWSAcquireError(err error) string {
+ if err == nil {
+ return "acquire_conn"
+ }
+ var dialErr *openAIWSDialError
+ if errors.As(err, &dialErr) {
+ switch dialErr.StatusCode {
+ case 426:
+ return "upgrade_required"
+ case 401, 403:
+ return "auth_failed"
+ case 429:
+ return "upstream_rate_limited"
+ }
+ if dialErr.StatusCode >= 500 {
+ return "upstream_5xx"
+ }
+ return "dial_failed"
+ }
+ if errors.Is(err, errOpenAIWSConnQueueFull) {
+ return "conn_queue_full"
+ }
+ if errors.Is(err, errOpenAIWSPreferredConnUnavailable) {
+ return "preferred_conn_unavailable"
+ }
+ if errors.Is(err, context.DeadlineExceeded) {
+ return "acquire_timeout"
+ }
+ return "acquire_conn"
+}
+
+func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
+ code := strings.ToLower(strings.TrimSpace(codeRaw))
+ errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
+ msg := strings.ToLower(strings.TrimSpace(msgRaw))
+
+ switch code {
+ case "upgrade_required":
+ return "upgrade_required", true
+ case "websocket_not_supported", "websocket_unsupported":
+ return "ws_unsupported", true
+ case "websocket_connection_limit_reached":
+ return "ws_connection_limit_reached", true
+ case "previous_response_not_found":
+ return "previous_response_not_found", true
+ }
+ if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
+ return "upgrade_required", true
+ }
+ if strings.Contains(errType, "upgrade") {
+ return "upgrade_required", true
+ }
+ if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") {
+ return "ws_unsupported", true
+ }
+ if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") {
+ return "ws_connection_limit_reached", true
+ }
+ if strings.Contains(msg, "previous_response_not_found") ||
+ (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) {
+ return "previous_response_not_found", true
+ }
+ // "No tool output found for function call " / "No tool call found for function call output..."
+ // 表示 previous_response_id 指向的 response 包含未完成的 function_call(例如用户在 Codex CLI
+ // 按 ESC 取消 function_call 后重新发送消息)。此时 previous_response_id 本身就是问题,需要移除后重放。
+ if strings.Contains(msg, "no tool output found") ||
+ strings.Contains(msg, "no tool call found for function call output") ||
+ (strings.Contains(msg, "no tool call found") && strings.Contains(msg, "function call output")) {
+ return openAIWSIngressStageToolOutputNotFound, true
+ }
+ if strings.Contains(msg, "without its required following item") ||
+ strings.Contains(msg, "without its required preceding item") {
+ return openAIWSIngressStageToolOutputNotFound, true
+ }
+ if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") {
+ return "upstream_error_event", true
+ }
+ return "event_error", false
+}
+
+func classifyOpenAIWSErrorEvent(message []byte) (string, bool) {
+ if len(message) == 0 {
+ return "event_error", false
+ }
+ return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message))
+}
+
+func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
+ code := strings.ToLower(strings.TrimSpace(codeRaw))
+ errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
+ switch {
+ case strings.Contains(errType, "invalid_request"),
+ strings.Contains(code, "invalid_request"),
+ strings.Contains(code, "bad_request"),
+ code == "previous_response_not_found":
+ return http.StatusBadRequest
+ case strings.Contains(errType, "authentication"),
+ strings.Contains(code, "invalid_api_key"),
+ strings.Contains(code, "unauthorized"):
+ return http.StatusUnauthorized
+ case strings.Contains(errType, "permission"),
+ strings.Contains(code, "forbidden"):
+ return http.StatusForbidden
+ case strings.Contains(errType, "rate_limit"),
+ strings.Contains(code, "rate_limit"),
+ strings.Contains(code, "insufficient_quota"):
+ return http.StatusTooManyRequests
+ default:
+ return http.StatusBadGateway
+ }
+}
+
+func openAIWSErrorHTTPStatus(message []byte) int {
+ if len(message) == 0 {
+ return http.StatusBadGateway
+ }
+ codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message)
+ return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw)
+}
+
+func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration {
+ if s == nil || s.cfg == nil {
+ return 30 * time.Second
+ }
+ seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds
+ if seconds <= 0 {
+ return 0
+ }
+ return time.Duration(seconds) * time.Second
+}
+
+func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool {
+ if s == nil || accountID <= 0 {
+ return false
+ }
+ cooldown := s.openAIWSFallbackCooldown()
+ if cooldown <= 0 {
+ return false
+ }
+ rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID)
+ if !ok || rawUntil == nil {
+ return false
+ }
+ until, ok := rawUntil.(time.Time)
+ if !ok || until.IsZero() {
+ s.openaiWSFallbackUntil.Delete(accountID)
+ return false
+ }
+ if time.Now().Before(until) {
+ return true
+ }
+ s.openaiWSFallbackUntil.Delete(accountID)
+ return false
+}
+
+func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) {
+ if s == nil || accountID <= 0 {
+ return
+ }
+ cooldown := s.openAIWSFallbackCooldown()
+ if cooldown <= 0 {
+ return
+ }
+ s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown))
+}
+
+func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) {
+ if s == nil || accountID <= 0 {
+ return
+ }
+ s.openaiWSFallbackUntil.Delete(accountID)
+}
diff --git a/backend/internal/service/openai_ws_state_store.go b/backend/internal/service/openai_ws_state_store.go
index b606baa1a..cff5ab6cd 100644
--- a/backend/internal/service/openai_ws_state_store.go
+++ b/backend/internal/service/openai_ws_state_store.go
@@ -4,11 +4,13 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
- "fmt"
+ "strconv"
"strings"
"sync"
"sync/atomic"
"time"
+
+ "github.com/cespare/xxhash/v2"
)
const (
@@ -17,6 +19,7 @@ const (
openAIWSStateStoreCleanupMaxPerMap = 512
openAIWSStateStoreMaxEntriesPerMap = 65536
openAIWSStateStoreRedisTimeout = 3 * time.Second
+ openAIWSStateStoreHotCacheTTL = time.Minute
)
type openAIWSAccountBinding struct {
@@ -29,6 +32,11 @@ type openAIWSConnBinding struct {
expiresAt time.Time
}
+type openAIWSResponsePendingToolCallsBinding struct {
+ callIDs []string
+ expiresAt time.Time
+}
+
type openAIWSTurnStateBinding struct {
turnState string
expiresAt time.Time
@@ -39,6 +47,23 @@ type openAIWSSessionConnBinding struct {
expiresAt time.Time
}
+type openAIWSSessionLastResponseBinding struct {
+ responseID string
+ expiresAt time.Time
+}
+
+type openAIWSStateStoreSessionLastResponseCache interface {
+ SetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash, responseID string, ttl time.Duration) error
+ GetOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) (string, error)
+ DeleteOpenAIWSSessionLastResponseID(ctx context.Context, groupID int64, sessionHash string) error
+}
+
+type openAIWSStateStoreResponsePendingToolCallsCache interface {
+ SetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string, callIDs []string, ttl time.Duration) error
+ GetOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) ([]string, error)
+ DeleteOpenAIWSResponsePendingToolCalls(ctx context.Context, groupID int64, responseID string) error
+}
+
// OpenAIWSStateStore 管理 WSv2 的粘连状态。
// - response_id -> account_id 用于续链路由
// - response_id -> conn_id 用于连接内上下文复用
@@ -53,44 +78,111 @@ type OpenAIWSStateStore interface {
BindResponseConn(responseID, connID string, ttl time.Duration)
GetResponseConn(responseID string) (string, bool)
DeleteResponseConn(responseID string)
+ BindResponsePendingToolCalls(groupID int64, responseID string, callIDs []string, ttl time.Duration)
+ GetResponsePendingToolCalls(groupID int64, responseID string) ([]string, bool)
+ DeleteResponsePendingToolCalls(groupID int64, responseID string)
BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration)
GetSessionTurnState(groupID int64, sessionHash string) (string, bool)
DeleteSessionTurnState(groupID int64, sessionHash string)
+ BindSessionLastResponseID(groupID int64, sessionHash, responseID string, ttl time.Duration)
+ GetSessionLastResponseID(groupID int64, sessionHash string) (string, bool)
+ DeleteSessionLastResponseID(groupID int64, sessionHash string)
+
BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration)
GetSessionConn(groupID int64, sessionHash string) (string, bool)
DeleteSessionConn(groupID int64, sessionHash string)
}
+const openAIWSStateStoreConnShards = 16
+
+type openAIWSConnBindingShard struct {
+ mu sync.RWMutex
+ m map[string]openAIWSConnBinding
+}
+
type defaultOpenAIWSStateStore struct {
cache GatewayCache
- responseToAccountMu sync.RWMutex
- responseToAccount map[string]openAIWSAccountBinding
- responseToConnMu sync.RWMutex
- responseToConn map[string]openAIWSConnBinding
- sessionToTurnStateMu sync.RWMutex
- sessionToTurnState map[string]openAIWSTurnStateBinding
- sessionToConnMu sync.RWMutex
- sessionToConn map[string]openAIWSSessionConnBinding
+ responseToAccountMu sync.RWMutex
+ responseToAccount map[string]openAIWSAccountBinding
+ responseToConnShards [openAIWSStateStoreConnShards]openAIWSConnBindingShard
+ responsePendingToolMu sync.RWMutex
+ responsePendingTool map[string]openAIWSResponsePendingToolCallsBinding
+ sessionToTurnStateMu sync.RWMutex
+ sessionToTurnState map[string]openAIWSTurnStateBinding
+ sessionToLastRespMu sync.RWMutex
+ sessionToLastResp map[string]openAIWSSessionLastResponseBinding
+ sessionToConnMu sync.RWMutex
+ sessionToConn map[string]openAIWSSessionConnBinding
lastCleanupUnixNano atomic.Int64
+ stopCh chan struct{}
+ stopOnce sync.Once
+ workerWg sync.WaitGroup
+}
+
+func (s *defaultOpenAIWSStateStore) connShard(key string) *openAIWSConnBindingShard {
+ h := xxhash.Sum64String(key)
+ return &s.responseToConnShards[h%openAIWSStateStoreConnShards]
}
// NewOpenAIWSStateStore 创建默认 WS 状态存储。
func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore {
+ return newOpenAIWSStateStore(cache, openAIWSStateStoreCleanupInterval)
+}
+
+func newOpenAIWSStateStore(cache GatewayCache, cleanupInterval time.Duration) *defaultOpenAIWSStateStore {
store := &defaultOpenAIWSStateStore{
- cache: cache,
- responseToAccount: make(map[string]openAIWSAccountBinding, 256),
- responseToConn: make(map[string]openAIWSConnBinding, 256),
- sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
- sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
+ cache: cache,
+ responseToAccount: make(map[string]openAIWSAccountBinding, 256),
+ responsePendingTool: make(map[string]openAIWSResponsePendingToolCallsBinding, 256),
+ sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256),
+ sessionToLastResp: make(map[string]openAIWSSessionLastResponseBinding, 256),
+ sessionToConn: make(map[string]openAIWSSessionConnBinding, 256),
+ stopCh: make(chan struct{}),
+ }
+ for i := range store.responseToConnShards {
+ store.responseToConnShards[i].m = make(map[string]openAIWSConnBinding, 16)
}
store.lastCleanupUnixNano.Store(time.Now().UnixNano())
+ store.startCleanupWorker(cleanupInterval)
return store
}
+func (s *defaultOpenAIWSStateStore) startCleanupWorker(interval time.Duration) {
+ if s == nil || interval <= 0 {
+ return
+ }
+ s.workerWg.Add(1)
+ go func() {
+ defer s.workerWg.Done()
+ ticker := time.NewTicker(interval)
+ defer ticker.Stop()
+ for {
+ select {
+ case <-s.stopCh:
+ return
+ case <-ticker.C:
+ s.maybeCleanup()
+ }
+ }
+ }()
+}
+
+func (s *defaultOpenAIWSStateStore) Close() {
+ if s == nil {
+ return
+ }
+ s.stopOnce.Do(func() {
+ if s.stopCh != nil {
+ close(s.stopCh)
+ }
+ })
+ s.workerWg.Wait()
+}
+
func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error {
id := normalizeOpenAIWSResponseID(responseID)
if id == "" || accountID <= 0 {
@@ -99,19 +191,31 @@ func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, gro
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
- expiresAt := time.Now().Add(ttl)
+ var redisErr error
+ if s.cache != nil {
+ cacheKey := openAIWSResponseAccountCacheKey(id)
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
+ redisErr = s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
+ cancel()
+ if redisErr != nil {
+ logOpenAIWSModeInfo(
+ "state_store_bind_response_account_redis_fail group_id=%d response_id=%s account_id=%d cause=%s",
+ groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), accountID, truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+ }
+
+ // 无论 Redis 是否写成功,都写入本地缓存作为降级保障。
+ localTTL := openAIWSStateStoreLocalHotTTL(ttl)
s.responseToAccountMu.Lock()
ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap)
- s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt}
+ s.responseToAccount[id] = openAIWSAccountBinding{
+ accountID: accountID,
+ expiresAt: time.Now().Add(localTTL),
+ }
s.responseToAccountMu.Unlock()
- if s.cache == nil {
- return nil
- }
- cacheKey := openAIWSResponseAccountCacheKey(id)
- cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
- defer cancel()
- return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl)
+ return redisErr
}
func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) {
@@ -119,7 +223,6 @@ func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, grou
if id == "" {
return 0, nil
}
- s.maybeCleanup()
now := time.Now()
s.responseToAccountMu.RLock()
@@ -138,13 +241,33 @@ func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, grou
cacheKey := openAIWSResponseAccountCacheKey(id)
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
- defer cancel()
accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey)
- if err != nil || accountID <= 0 {
+ cancel()
+ if err == nil && accountID > 0 {
+ return accountID, nil
+ }
+
+ // Compatibility fallback for pre-v2 cache keys.
+ legacyCacheKey := openAIWSResponseAccountLegacyCacheKey(id)
+ legacyCtx, legacyCancel := withOpenAIWSStateStoreRedisTimeout(ctx)
+ legacyAccountID, legacyErr := s.cache.GetSessionAccountID(legacyCtx, groupID, legacyCacheKey)
+ legacyCancel()
+ if legacyErr != nil || legacyAccountID <= 0 {
// 缓存读取失败不阻断主流程,按未命中降级。
return 0, nil
}
- return accountID, nil
+
+ logOpenAIWSModeInfo(
+ "state_store_get_response_account_legacy_fallback group_id=%d response_id=%s account_id=%d",
+ groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), legacyAccountID,
+ )
+
+ // Best effort: backfill v2 key so subsequent reads avoid legacy fallback.
+ backfillCtx, backfillCancel := withOpenAIWSStateStoreRedisTimeout(ctx)
+ _ = s.cache.SetSessionAccountID(backfillCtx, groupID, cacheKey, legacyAccountID, openAIWSStateStoreHotCacheTTL)
+ backfillCancel()
+
+ return legacyAccountID, nil
}
func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error {
@@ -161,7 +284,15 @@ func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, g
}
cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx)
defer cancel()
- return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id))
+ primaryKey := openAIWSResponseAccountCacheKey(id)
+ if err := s.cache.DeleteSessionAccountID(cacheCtx, groupID, primaryKey); err != nil {
+ return err
+ }
+ legacyKey := openAIWSResponseAccountLegacyCacheKey(id)
+ if legacyKey == "" || legacyKey == primaryKey {
+ return nil
+ }
+ return s.cache.DeleteSessionAccountID(cacheCtx, groupID, legacyKey)
}
func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) {
@@ -173,13 +304,14 @@ func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string,
ttl = normalizeOpenAIWSTTL(ttl)
s.maybeCleanup()
- s.responseToConnMu.Lock()
- ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap)
- s.responseToConn[id] = openAIWSConnBinding{
+ shard := s.connShard(id)
+ shard.mu.Lock()
+ ensureBindingCapacity(shard.m, id, openAIWSStateStoreMaxEntriesPerMap/openAIWSStateStoreConnShards)
+ shard.m[id] = openAIWSConnBinding{
connID: conn,
expiresAt: time.Now().Add(ttl),
}
- s.responseToConnMu.Unlock()
+ shard.mu.Unlock()
}
func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) {
@@ -187,13 +319,13 @@ func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string,
if id == "" {
return "", false
}
- s.maybeCleanup()
now := time.Now()
- s.responseToConnMu.RLock()
- binding, ok := s.responseToConn[id]
- s.responseToConnMu.RUnlock()
- if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" {
+ shard := s.connShard(id)
+ shard.mu.RLock()
+ binding, ok := shard.m[id]
+ shard.mu.RUnlock()
+ if !ok || now.After(binding.expiresAt) || binding.connID == "" {
return "", false
}
return binding.connID, true
@@ -204,9 +336,115 @@ func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) {
if id == "" {
return
}
- s.responseToConnMu.Lock()
- delete(s.responseToConn, id)
- s.responseToConnMu.Unlock()
+ shard := s.connShard(id)
+ shard.mu.Lock()
+ delete(shard.m, id)
+ shard.mu.Unlock()
+}
+
+func (s *defaultOpenAIWSStateStore) BindResponsePendingToolCalls(groupID int64, responseID string, callIDs []string, ttl time.Duration) {
+ id := normalizeOpenAIWSResponseID(responseID)
+ normalizedCallIDs := normalizeOpenAIWSPendingToolCallIDs(callIDs)
+ if id == "" || len(normalizedCallIDs) == 0 {
+ return
+ }
+ key := openAIWSResponsePendingToolCallsBindingKey(groupID, id)
+ if key == "" {
+ return
+ }
+ ttl = normalizeOpenAIWSTTL(ttl)
+ s.maybeCleanup()
+
+ s.responsePendingToolMu.Lock()
+ ensureBindingCapacity(s.responsePendingTool, key, openAIWSStateStoreMaxEntriesPerMap)
+ s.responsePendingTool[key] = openAIWSResponsePendingToolCallsBinding{
+ callIDs: append([]string(nil), normalizedCallIDs...),
+ expiresAt: time.Now().Add(ttl),
+ }
+ s.responsePendingToolMu.Unlock()
+
+ if cache := s.responsePendingToolCallsCache(); cache != nil {
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ if redisErr := cache.SetOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id, normalizedCallIDs, ttl); redisErr != nil {
+ logOpenAIWSModeInfo(
+ "state_store_bind_response_pending_tool_calls_redis_fail group_id=%d response_id=%s call_count=%d cause=%s",
+ groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), len(normalizedCallIDs), truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+ }
+}
+
+func (s *defaultOpenAIWSStateStore) GetResponsePendingToolCalls(groupID int64, responseID string) ([]string, bool) {
+ id := normalizeOpenAIWSResponseID(responseID)
+ if id == "" {
+ return nil, false
+ }
+ key := openAIWSResponsePendingToolCallsBindingKey(groupID, id)
+ if key == "" {
+ return nil, false
+ }
+
+ now := time.Now()
+ s.responsePendingToolMu.RLock()
+ binding, ok := s.responsePendingTool[key]
+ s.responsePendingToolMu.RUnlock()
+ if !ok || now.After(binding.expiresAt) || len(binding.callIDs) == 0 {
+ cache := s.responsePendingToolCallsCache()
+ if cache == nil {
+ return nil, false
+ }
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ callIDs, err := cache.GetOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id)
+ normalizedCallIDs := normalizeOpenAIWSPendingToolCallIDs(callIDs)
+ if err != nil || len(normalizedCallIDs) == 0 {
+ if err != nil {
+ logOpenAIWSModeInfo(
+ "state_store_get_response_pending_tool_calls_redis_fail group_id=%d response_id=%s cause=%s",
+ groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+ return nil, false
+ }
+
+ logOpenAIWSModeInfo(
+ "state_store_get_response_pending_tool_calls_redis_hit group_id=%d response_id=%s call_count=%d",
+ groupID, truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), len(normalizedCallIDs),
+ )
+
+ // Redis 命中后回填本地热缓存,降低后续访问开销。
+ s.responsePendingToolMu.Lock()
+ ensureBindingCapacity(s.responsePendingTool, key, openAIWSStateStoreMaxEntriesPerMap)
+ s.responsePendingTool[key] = openAIWSResponsePendingToolCallsBinding{
+ callIDs: append([]string(nil), normalizedCallIDs...),
+ expiresAt: time.Now().Add(openAIWSStateStoreHotCacheTTL),
+ }
+ s.responsePendingToolMu.Unlock()
+ return normalizedCallIDs, true
+ }
+ // binding.callIDs was already copied at bind time; return directly (callers are read-only).
+ return binding.callIDs, true
+}
+
+func (s *defaultOpenAIWSStateStore) DeleteResponsePendingToolCalls(groupID int64, responseID string) {
+ id := normalizeOpenAIWSResponseID(responseID)
+ if id == "" {
+ return
+ }
+ key := openAIWSResponsePendingToolCallsBindingKey(groupID, id)
+ if key == "" {
+ return
+ }
+ s.responsePendingToolMu.Lock()
+ delete(s.responsePendingTool, key)
+ s.responsePendingToolMu.Unlock()
+
+ if cache := s.responsePendingToolCallsCache(); cache != nil {
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ _ = cache.DeleteOpenAIWSResponsePendingToolCalls(cacheCtx, groupID, id)
+ }
}
func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) {
@@ -232,7 +470,6 @@ func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHa
if key == "" {
return "", false
}
- s.maybeCleanup()
now := time.Now()
s.sessionToTurnStateMu.RLock()
@@ -254,6 +491,99 @@ func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessio
s.sessionToTurnStateMu.Unlock()
}
+func (s *defaultOpenAIWSStateStore) BindSessionLastResponseID(groupID int64, sessionHash, responseID string, ttl time.Duration) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ id := normalizeOpenAIWSResponseID(responseID)
+ if key == "" || id == "" {
+ return
+ }
+ ttl = normalizeOpenAIWSTTL(ttl)
+ s.maybeCleanup()
+
+ s.sessionToLastRespMu.Lock()
+ ensureBindingCapacity(s.sessionToLastResp, key, openAIWSStateStoreMaxEntriesPerMap)
+ s.sessionToLastResp[key] = openAIWSSessionLastResponseBinding{
+ responseID: id,
+ expiresAt: time.Now().Add(ttl),
+ }
+ s.sessionToLastRespMu.Unlock()
+
+ if cache := s.sessionLastResponseCache(); cache != nil {
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ if redisErr := cache.SetOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash), id, ttl); redisErr != nil {
+ logOpenAIWSModeInfo(
+ "state_store_bind_session_last_response_redis_fail group_id=%d session_hash=%s response_id=%s cause=%s",
+ groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(id, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(redisErr.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+ }
+}
+
+func (s *defaultOpenAIWSStateStore) GetSessionLastResponseID(groupID int64, sessionHash string) (string, bool) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ if key == "" {
+ return "", false
+ }
+
+ now := time.Now()
+ s.sessionToLastRespMu.RLock()
+ binding, ok := s.sessionToLastResp[key]
+ s.sessionToLastRespMu.RUnlock()
+ if ok && now.Before(binding.expiresAt) && strings.TrimSpace(binding.responseID) != "" {
+ return binding.responseID, true
+ }
+
+ cache := s.sessionLastResponseCache()
+ if cache == nil {
+ return "", false
+ }
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ responseID, err := cache.GetOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash))
+ responseID = normalizeOpenAIWSResponseID(responseID)
+ if err != nil || responseID == "" {
+ if err != nil {
+ logOpenAIWSModeInfo(
+ "state_store_get_session_last_response_redis_fail group_id=%d session_hash=%s cause=%s",
+ groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen),
+ )
+ }
+ return "", false
+ }
+
+ logOpenAIWSModeInfo(
+ "state_store_get_session_last_response_redis_hit group_id=%d session_hash=%s response_id=%s",
+ groupID, truncateOpenAIWSLogValue(sessionHash, openAIWSIDValueMaxLen), truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen),
+ )
+
+ // Redis 命中后回填本地热缓存,降低后续访问开销。
+ s.sessionToLastRespMu.Lock()
+ ensureBindingCapacity(s.sessionToLastResp, key, openAIWSStateStoreMaxEntriesPerMap)
+ s.sessionToLastResp[key] = openAIWSSessionLastResponseBinding{
+ responseID: responseID,
+ expiresAt: time.Now().Add(openAIWSStateStoreHotCacheTTL),
+ }
+ s.sessionToLastRespMu.Unlock()
+ return responseID, true
+}
+
+func (s *defaultOpenAIWSStateStore) DeleteSessionLastResponseID(groupID int64, sessionHash string) {
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ if key == "" {
+ return
+ }
+ s.sessionToLastRespMu.Lock()
+ delete(s.sessionToLastResp, key)
+ s.sessionToLastRespMu.Unlock()
+
+ if cache := s.sessionLastResponseCache(); cache != nil {
+ cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
+ defer cancel()
+ _ = cache.DeleteOpenAIWSSessionLastResponseID(cacheCtx, groupID, strings.TrimSpace(sessionHash))
+ }
+}
+
func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) {
key := openAIWSSessionTurnStateKey(groupID, sessionHash)
conn := strings.TrimSpace(connID)
@@ -277,7 +607,6 @@ func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash st
if key == "" {
return "", false
}
- s.maybeCleanup()
now := time.Now()
s.sessionToConnMu.RLock()
@@ -317,14 +646,29 @@ func (s *defaultOpenAIWSStateStore) maybeCleanup() {
cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap)
s.responseToAccountMu.Unlock()
- s.responseToConnMu.Lock()
- cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap)
- s.responseToConnMu.Unlock()
+ perShardLimit := openAIWSStateStoreCleanupMaxPerMap / openAIWSStateStoreConnShards
+ if perShardLimit < 32 {
+ perShardLimit = 32
+ }
+ for i := range s.responseToConnShards {
+ shard := &s.responseToConnShards[i]
+ shard.mu.Lock()
+ cleanupExpiredConnBindings(shard.m, now, perShardLimit)
+ shard.mu.Unlock()
+ }
+
+ s.responsePendingToolMu.Lock()
+ cleanupExpiredResponsePendingToolCallsBindings(s.responsePendingTool, now, openAIWSStateStoreCleanupMaxPerMap)
+ s.responsePendingToolMu.Unlock()
s.sessionToTurnStateMu.Lock()
cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToTurnStateMu.Unlock()
+ s.sessionToLastRespMu.Lock()
+ cleanupExpiredSessionLastResponseBindings(s.sessionToLastResp, now, openAIWSStateStoreCleanupMaxPerMap)
+ s.sessionToLastRespMu.Unlock()
+
s.sessionToConnMu.Lock()
cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap)
s.sessionToConnMu.Unlock()
@@ -362,6 +706,22 @@ func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now tim
}
}
+func cleanupExpiredResponsePendingToolCallsBindings(bindings map[string]openAIWSResponsePendingToolCallsBinding, now time.Time, maxScan int) {
+ if len(bindings) == 0 || maxScan <= 0 {
+ return
+ }
+ scanned := 0
+ for key, binding := range bindings {
+ if now.After(binding.expiresAt) {
+ delete(bindings, key)
+ }
+ scanned++
+ if scanned >= maxScan {
+ break
+ }
+ }
+}
+
func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
@@ -378,6 +738,22 @@ func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBindin
}
}
+func cleanupExpiredSessionLastResponseBindings(bindings map[string]openAIWSSessionLastResponseBinding, now time.Time, maxScan int) {
+ if len(bindings) == 0 || maxScan <= 0 {
+ return
+ }
+ scanned := 0
+ for key, binding := range bindings {
+ if now.After(binding.expiresAt) {
+ delete(bindings, key)
+ }
+ scanned++
+ if scanned >= maxScan {
+ break
+ }
+ }
+}
+
func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) {
if len(bindings) == 0 || maxScan <= 0 {
return
@@ -394,17 +770,56 @@ func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBi
}
}
-func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) {
+type expiringBinding interface {
+ getExpiresAt() time.Time
+}
+
+func (b openAIWSAccountBinding) getExpiresAt() time.Time { return b.expiresAt }
+func (b openAIWSConnBinding) getExpiresAt() time.Time { return b.expiresAt }
+func (b openAIWSResponsePendingToolCallsBinding) getExpiresAt() time.Time { return b.expiresAt }
+func (b openAIWSTurnStateBinding) getExpiresAt() time.Time { return b.expiresAt }
+func (b openAIWSSessionConnBinding) getExpiresAt() time.Time { return b.expiresAt }
+func (b openAIWSSessionLastResponseBinding) getExpiresAt() time.Time { return b.expiresAt }
+
+func ensureBindingCapacity[T expiringBinding](bindings map[string]T, incomingKey string, maxEntries int) {
if len(bindings) < maxEntries || maxEntries <= 0 {
return
}
if _, exists := bindings[incomingKey]; exists {
return
}
- // 固定上限保护:淘汰任意一项,优先保证内存有界。
- for key := range bindings {
- delete(bindings, key)
- return
+ // 优先驱逐已过期条目;若不存在过期项,则按 expiresAt 最早驱逐,避免随机删除活跃绑定。
+ now := time.Now()
+ for key, val := range bindings {
+ if !val.getExpiresAt().IsZero() && now.After(val.getExpiresAt()) {
+ delete(bindings, key)
+ return
+ }
+ }
+ var (
+ evictKey string
+ evictExpireAt time.Time
+ hasCandidate bool
+ )
+ for key, val := range bindings {
+ expiresAt := val.getExpiresAt()
+ if !hasCandidate {
+ evictKey = key
+ evictExpireAt = expiresAt
+ hasCandidate = true
+ continue
+ }
+ switch {
+ case expiresAt.IsZero() && !evictExpireAt.IsZero():
+ evictKey = key
+ evictExpireAt = expiresAt
+ case !expiresAt.IsZero() && !evictExpireAt.IsZero() && expiresAt.Before(evictExpireAt):
+ evictKey = key
+ evictExpireAt = expiresAt
+ }
+ }
+ if hasCandidate {
+ delete(bindings, evictKey)
}
}
@@ -412,7 +827,46 @@ func normalizeOpenAIWSResponseID(responseID string) string {
return strings.TrimSpace(responseID)
}
+func normalizeOpenAIWSPendingToolCallIDs(callIDs []string) []string {
+ if len(callIDs) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(callIDs))
+ normalized := make([]string, 0, len(callIDs))
+ for _, callID := range callIDs {
+ id := strings.TrimSpace(callID)
+ if id == "" {
+ continue
+ }
+ if _, ok := seen[id]; ok {
+ continue
+ }
+ seen[id] = struct{}{}
+ normalized = append(normalized, id)
+ }
+ return normalized
+}
+
+func openAIWSResponsePendingToolCallsBindingKey(groupID int64, responseID string) string {
+ id := normalizeOpenAIWSResponseID(responseID)
+ if id == "" {
+ return ""
+ }
+ return strconv.FormatInt(groupID, 10) + ":" + id
+}
+
func openAIWSResponseAccountCacheKey(responseID string) string {
+ h := xxhash.Sum64String(responseID)
+ // Pad to 16 hex chars for consistent key length.
+ hex := strconv.FormatUint(h, 16)
+ const pad = "0000000000000000"
+ if len(hex) < 16 {
+ hex = pad[:16-len(hex)] + hex
+ }
+ return openAIWSResponseAccountCachePrefix + "v2:" + hex
+}
+
+func openAIWSResponseAccountLegacyCacheKey(responseID string) string {
sum := sha256.Sum256([]byte(responseID))
return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:])
}
@@ -424,12 +878,42 @@ func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration {
return ttl
}
+func openAIWSStateStoreLocalHotTTL(ttl time.Duration) time.Duration {
+ ttl = normalizeOpenAIWSTTL(ttl)
+ if ttl > openAIWSStateStoreHotCacheTTL {
+ return openAIWSStateStoreHotCacheTTL
+ }
+ return ttl
+}
+
+func (s *defaultOpenAIWSStateStore) sessionLastResponseCache() openAIWSStateStoreSessionLastResponseCache {
+ if s == nil || s.cache == nil {
+ return nil
+ }
+ cache, ok := s.cache.(openAIWSStateStoreSessionLastResponseCache)
+ if !ok {
+ return nil
+ }
+ return cache
+}
+
+func (s *defaultOpenAIWSStateStore) responsePendingToolCallsCache() openAIWSStateStoreResponsePendingToolCallsCache {
+ if s == nil || s.cache == nil {
+ return nil
+ }
+ cache, ok := s.cache.(openAIWSStateStoreResponsePendingToolCallsCache)
+ if !ok {
+ return nil
+ }
+ return cache
+}
+
func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string {
hash := strings.TrimSpace(sessionHash)
if hash == "" {
return ""
}
- return fmt.Sprintf("%d:%s", groupID, hash)
+ return strconv.FormatInt(groupID, 10) + ":" + hash
}
func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) {
diff --git a/backend/internal/service/openai_ws_state_store_test.go b/backend/internal/service/openai_ws_state_store_test.go
index 235d42331..f11a05ec9 100644
--- a/backend/internal/service/openai_ws_state_store_test.go
+++ b/backend/internal/service/openai_ws_state_store_test.go
@@ -41,6 +41,27 @@ func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) {
require.False(t, ok)
}
+func TestOpenAIWSStateStore_ResponsePendingToolCallsTTL(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+ groupID := int64(9)
+ store.BindResponsePendingToolCalls(groupID, "resp_pending_tool_1", []string{"call_1", "call_2", "call_1", " "}, 30*time.Millisecond)
+
+ callIDs, ok := store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_1")
+ require.True(t, ok)
+ require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs)
+ _, ok = store.GetResponsePendingToolCalls(groupID+1, "resp_pending_tool_1")
+ require.False(t, ok, "pending tool calls should be group-isolated")
+
+ store.DeleteResponsePendingToolCalls(groupID, "resp_pending_tool_1")
+ _, ok = store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_1")
+ require.False(t, ok)
+
+ store.BindResponsePendingToolCalls(groupID, "resp_pending_tool_2", []string{"call_3"}, 30*time.Millisecond)
+ time.Sleep(60 * time.Millisecond)
+ _, ok = store.GetResponsePendingToolCalls(groupID, "resp_pending_tool_2")
+ require.False(t, ok)
+}
+
func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) {
store := NewOpenAIWSStateStore(nil)
store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond)
@@ -75,6 +96,178 @@ func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) {
require.False(t, ok)
}
+func TestOpenAIWSStateStore_SessionLastResponseIDTTL(t *testing.T) {
+ store := NewOpenAIWSStateStore(nil)
+ store.BindSessionLastResponseID(9, "session_hash_resp_1", "resp_1", 30*time.Millisecond)
+
+ responseID, ok := store.GetSessionLastResponseID(9, "session_hash_resp_1")
+ require.True(t, ok)
+ require.Equal(t, "resp_1", responseID)
+
+ // group 隔离
+ _, ok = store.GetSessionLastResponseID(10, "session_hash_resp_1")
+ require.False(t, ok)
+
+ time.Sleep(60 * time.Millisecond)
+ _, ok = store.GetSessionLastResponseID(9, "session_hash_resp_1")
+ require.False(t, ok)
+}
+
+type openAIWSSessionLastResponseProbeCache struct {
+ sessionData map[string]string
+ setCalled bool
+ getCalled bool
+ delCalled bool
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) GetSessionAccountID(context.Context, int64, string) (int64, error) {
+ return 0, nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) SetSessionAccountID(context.Context, int64, string, int64, time.Duration) error {
+ return nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) DeleteSessionAccountID(context.Context, int64, string) error {
+ return nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) SetOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash, responseID string, _ time.Duration) error {
+ if c.sessionData == nil {
+ c.sessionData = make(map[string]string)
+ }
+ c.setCalled = true
+ c.sessionData[fmt.Sprintf("%d:%s", groupID, sessionHash)] = responseID
+ return nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) GetOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash string) (string, error) {
+ c.getCalled = true
+ return c.sessionData[fmt.Sprintf("%d:%s", groupID, sessionHash)], nil
+}
+
+func (c *openAIWSSessionLastResponseProbeCache) DeleteOpenAIWSSessionLastResponseID(_ context.Context, groupID int64, sessionHash string) error {
+ c.delCalled = true
+ delete(c.sessionData, fmt.Sprintf("%d:%s", groupID, sessionHash))
+ return nil
+}
+
+func TestOpenAIWSStateStore_SessionLastResponseID_UsesOptionalCacheFallback(t *testing.T) {
+ probe := &openAIWSSessionLastResponseProbeCache{sessionData: make(map[string]string)}
+ raw := NewOpenAIWSStateStore(probe)
+ store, ok := raw.(*defaultOpenAIWSStateStore)
+ require.True(t, ok)
+
+ groupID := int64(9)
+ sessionHash := "session_hash_resp_cache_1"
+ responseID := "resp_cache_1"
+ store.BindSessionLastResponseID(groupID, sessionHash, responseID, time.Minute)
+ require.True(t, probe.setCalled, "绑定 session last_response_id 时应写入可选缓存")
+
+ key := openAIWSSessionTurnStateKey(groupID, sessionHash)
+ store.sessionToLastRespMu.Lock()
+ delete(store.sessionToLastResp, key)
+ store.sessionToLastRespMu.Unlock()
+
+ gotResponseID, found := store.GetSessionLastResponseID(groupID, sessionHash)
+ require.True(t, found, "本地缓存缺失时应降级读取可选缓存")
+ require.Equal(t, responseID, gotResponseID)
+ require.True(t, probe.getCalled)
+
+ store.DeleteSessionLastResponseID(groupID, sessionHash)
+ require.True(t, probe.delCalled, "删除 session last_response_id 时应同步删除可选缓存")
+ _, found = store.GetSessionLastResponseID(groupID, sessionHash)
+ require.False(t, found)
+}
+
+type openAIWSResponsePendingToolCallsProbeCache struct {
+ pendingData map[string][]string
+ setCalls int
+ getCalls int
+ delCalls int
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) GetSessionAccountID(context.Context, int64, string) (int64, error) {
+ return 0, nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) SetSessionAccountID(context.Context, int64, string, int64, time.Duration) error {
+ return nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error {
+ return nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) DeleteSessionAccountID(context.Context, int64, string) error {
+ return nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) SetOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string, callIDs []string, _ time.Duration) error {
+ if c.pendingData == nil {
+ c.pendingData = make(map[string][]string)
+ }
+ key := fmt.Sprintf("%d:%s", groupID, responseID)
+ normalized := normalizeOpenAIWSPendingToolCallIDs(callIDs)
+ if len(normalized) == 0 {
+ delete(c.pendingData, key)
+ } else {
+ c.pendingData[key] = append([]string(nil), normalized...)
+ }
+ c.setCalls++
+ return nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) GetOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string) ([]string, error) {
+ c.getCalls++
+ callIDs := c.pendingData[fmt.Sprintf("%d:%s", groupID, responseID)]
+ return append([]string(nil), callIDs...), nil
+}
+
+func (c *openAIWSResponsePendingToolCallsProbeCache) DeleteOpenAIWSResponsePendingToolCalls(_ context.Context, groupID int64, responseID string) error {
+ c.delCalls++
+ delete(c.pendingData, fmt.Sprintf("%d:%s", groupID, responseID))
+ return nil
+}
+
+func TestOpenAIWSStateStore_ResponsePendingToolCalls_UsesOptionalCacheFallback(t *testing.T) {
+ probe := &openAIWSResponsePendingToolCallsProbeCache{pendingData: make(map[string][]string)}
+ raw := NewOpenAIWSStateStore(probe)
+ store, ok := raw.(*defaultOpenAIWSStateStore)
+ require.True(t, ok)
+
+ groupID := int64(11)
+ responseID := "resp_pending_tool_cache_1"
+ store.BindResponsePendingToolCalls(groupID, responseID, []string{"call_1", "call_2", "call_1"}, time.Minute)
+ require.Equal(t, 1, probe.setCalls, "绑定 pending_tool_calls 时应写入可选缓存")
+
+ store.responsePendingToolMu.Lock()
+ delete(store.responsePendingTool, openAIWSResponsePendingToolCallsBindingKey(groupID, responseID))
+ store.responsePendingToolMu.Unlock()
+
+ callIDs, found := store.GetResponsePendingToolCalls(groupID, responseID)
+ require.True(t, found, "本地缓存缺失时应降级读取可选缓存")
+ require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs)
+ require.Equal(t, 1, probe.getCalls)
+
+ // 回填后再次读取应命中本地缓存,不再触发 Redis 回源。
+ callIDs, found = store.GetResponsePendingToolCalls(groupID, responseID)
+ require.True(t, found)
+ require.ElementsMatch(t, []string{"call_1", "call_2"}, callIDs)
+ require.Equal(t, 1, probe.getCalls)
+ _, found = store.GetResponsePendingToolCalls(groupID+1, responseID)
+ require.False(t, found, "optional cache fallback should remain group-isolated")
+
+ store.DeleteResponsePendingToolCalls(groupID, responseID)
+ require.Equal(t, 1, probe.delCalls, "删除 pending_tool_calls 时应同步删除可选缓存")
+ _, found = store.GetResponsePendingToolCalls(groupID, responseID)
+ require.False(t, found)
+}
+
func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) {
cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
store := NewOpenAIWSStateStore(cache)
@@ -94,6 +287,42 @@ func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.
require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射")
}
+func TestOpenAIWSStateStore_GetResponseAccount_LegacyKeyFallback(t *testing.T) {
+ cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
+ store := NewOpenAIWSStateStore(cache)
+ ctx := context.Background()
+ groupID := int64(18)
+ responseID := "resp_cache_legacy_fallback"
+
+ legacyKey := openAIWSResponseAccountLegacyCacheKey(responseID)
+ v2Key := openAIWSResponseAccountCacheKey(responseID)
+ cache.sessionBindings[legacyKey] = 601
+
+ accountID, err := store.GetResponseAccount(ctx, groupID, responseID)
+ require.NoError(t, err)
+ require.Equal(t, int64(601), accountID, "应支持 legacy cache key 回读")
+ require.Equal(t, int64(601), cache.sessionBindings[v2Key], "legacy 回读后应回填 v2 cache key")
+}
+
+func TestOpenAIWSStateStore_DeleteResponseAccount_DeletesLegacyAndV2Keys(t *testing.T) {
+ cache := &stubGatewayCache{sessionBindings: map[string]int64{}}
+ store := NewOpenAIWSStateStore(cache)
+ ctx := context.Background()
+ groupID := int64(19)
+ responseID := "resp_cache_delete_both_keys"
+
+ legacyKey := openAIWSResponseAccountLegacyCacheKey(responseID)
+ v2Key := openAIWSResponseAccountCacheKey(responseID)
+ cache.sessionBindings[legacyKey] = 701
+ cache.sessionBindings[v2Key] = 701
+
+ require.NoError(t, store.DeleteResponseAccount(ctx, groupID, responseID))
+ _, legacyExists := cache.sessionBindings[legacyKey]
+ _, v2Exists := cache.sessionBindings[v2Key]
+ require.False(t, legacyExists, "删除 response account 绑定时应清理 legacy key")
+ require.False(t, v2Exists, "删除 response account 绑定时应清理 v2 key")
+}
+
func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) {
raw := NewOpenAIWSStateStore(nil)
store, ok := raw.(*defaultOpenAIWSStateStore)
@@ -101,21 +330,27 @@ func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T
expiredAt := time.Now().Add(-time.Minute)
total := 2048
- store.responseToConnMu.Lock()
for i := 0; i < total; i++ {
- store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{
+ key := fmt.Sprintf("resp_%d", i)
+ shard := store.connShard(key)
+ shard.mu.Lock()
+ shard.m[key] = openAIWSConnBinding{
connID: "conn_incremental",
expiresAt: expiredAt,
}
+ shard.mu.Unlock()
}
- store.responseToConnMu.Unlock()
store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano())
store.maybeCleanup()
- store.responseToConnMu.RLock()
- remainingAfterFirst := len(store.responseToConn)
- store.responseToConnMu.RUnlock()
+ remainingAfterFirst := 0
+ for i := range store.responseToConnShards {
+ shard := &store.responseToConnShards[i]
+ shard.mu.RLock()
+ remainingAfterFirst += len(shard.m)
+ shard.mu.RUnlock()
+ }
require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展")
require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键")
@@ -124,36 +359,99 @@ func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T
store.maybeCleanup()
}
- store.responseToConnMu.RLock()
- remaining := len(store.responseToConn)
- store.responseToConnMu.RUnlock()
+ remaining := 0
+ for i := range store.responseToConnShards {
+ shard := &store.responseToConnShards[i]
+ shard.mu.RLock()
+ remaining += len(shard.m)
+ shard.mu.RUnlock()
+ }
require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键")
}
+func TestOpenAIWSStateStore_BackgroundCleanupRemovesExpiredWithoutNewWrites(t *testing.T) {
+ store := newOpenAIWSStateStore(nil, 20*time.Millisecond)
+ defer store.Close()
+
+ expiredAt := time.Now().Add(-time.Minute)
+ store.responseToAccountMu.Lock()
+ for i := 0; i < 64; i++ {
+ key := fmt.Sprintf("bg_cleanup_resp_%d", i)
+ store.responseToAccount[key] = openAIWSAccountBinding{
+ accountID: int64(i + 1),
+ expiresAt: expiredAt,
+ }
+ }
+ store.responseToAccountMu.Unlock()
+
+ // Backdate cleanup watermark so the worker can run immediately on next tick.
+ store.lastCleanupUnixNano.Store(time.Now().Add(-time.Minute).UnixNano())
+
+ require.Eventually(t, func() bool {
+ store.responseToAccountMu.RLock()
+ remaining := len(store.responseToAccount)
+ store.responseToAccountMu.RUnlock()
+ return remaining == 0
+ }, 600*time.Millisecond, 10*time.Millisecond, "后台 cleanup 应在无新写入时清理过期项")
+}
+
func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) {
- bindings := map[string]int{
- "a": 1,
- "b": 2,
+ bindings := map[string]openAIWSAccountBinding{
+ "a": {accountID: 1, expiresAt: time.Now().Add(time.Hour)},
+ "b": {accountID: 2, expiresAt: time.Now().Add(time.Hour)},
}
ensureBindingCapacity(bindings, "c", 2)
- bindings["c"] = 3
+ bindings["c"] = openAIWSAccountBinding{accountID: 3, expiresAt: time.Now().Add(time.Hour)}
require.Len(t, bindings, 2)
- require.Equal(t, 3, bindings["c"])
+ require.Equal(t, int64(3), bindings["c"].accountID)
}
func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) {
- bindings := map[string]int{
- "a": 1,
- "b": 2,
+ bindings := map[string]openAIWSAccountBinding{
+ "a": {accountID: 1, expiresAt: time.Now().Add(time.Hour)},
+ "b": {accountID: 2, expiresAt: time.Now().Add(time.Hour)},
}
ensureBindingCapacity(bindings, "a", 2)
- bindings["a"] = 9
+ bindings["a"] = openAIWSAccountBinding{accountID: 9, expiresAt: time.Now().Add(time.Hour)}
+
+ require.Len(t, bindings, 2)
+ require.Equal(t, int64(9), bindings["a"].accountID)
+}
+
+func TestEnsureBindingCapacity_PrefersExpiredEntry(t *testing.T) {
+ bindings := map[string]openAIWSAccountBinding{
+ "expired": {accountID: 1, expiresAt: time.Now().Add(-time.Hour)},
+ "active": {accountID: 2, expiresAt: time.Now().Add(time.Hour)},
+ }
+
+ ensureBindingCapacity(bindings, "c", 2)
+ bindings["c"] = openAIWSAccountBinding{accountID: 3, expiresAt: time.Now().Add(time.Hour)}
require.Len(t, bindings, 2)
- require.Equal(t, 9, bindings["a"])
+ _, hasExpired := bindings["expired"]
+ require.False(t, hasExpired, "expired entry should have been evicted")
+ require.Equal(t, int64(2), bindings["active"].accountID)
+ require.Equal(t, int64(3), bindings["c"].accountID)
+}
+
+func TestEnsureBindingCapacity_EvictsEarliestExpiryWhenNoExpired(t *testing.T) {
+ now := time.Now()
+ bindings := map[string]openAIWSAccountBinding{
+ "soon": {accountID: 1, expiresAt: now.Add(30 * time.Second)},
+ "later": {accountID: 2, expiresAt: now.Add(5 * time.Minute)},
+ }
+
+ ensureBindingCapacity(bindings, "new", 2)
+ bindings["new"] = openAIWSAccountBinding{accountID: 3, expiresAt: now.Add(10 * time.Minute)}
+
+ require.Len(t, bindings, 2)
+ _, hasSoon := bindings["soon"]
+ require.False(t, hasSoon, "entry with earliest expiresAt should be evicted")
+ require.Equal(t, int64(2), bindings["later"].accountID)
+ require.Equal(t, int64(3), bindings["new"].accountID)
}
type openAIWSStateStoreTimeoutProbeCache struct {
@@ -204,13 +502,14 @@ func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) {
accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe")
require.NoError(t, getErr)
- require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号")
+ require.Equal(t, int64(11), accountID, "Redis Set 失败时本地缓存仍应保留作为降级保障")
require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe"))
require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文")
require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文")
- require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取")
+ // 本地缓存作为降级保障保留,Get 直接命中本地缓存不会穿透到 Redis
+ require.False(t, probe.getHasDeadline, "本地缓存命中时不应穿透到 Redis 读取")
require.Greater(t, probe.setDeadlineDelta, 2*time.Second)
require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second)
require.Greater(t, probe.delDeadlineDelta, 2*time.Second)
@@ -226,6 +525,26 @@ func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) {
require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second)
}
+func TestOpenAIWSStateStore_BindResponseAccount_UsesShortLocalHotTTL(t *testing.T) {
+ cache := &stubGatewayCache{}
+ raw := NewOpenAIWSStateStore(cache)
+ store, ok := raw.(*defaultOpenAIWSStateStore)
+ require.True(t, ok)
+
+ groupID := int64(23)
+ responseID := "resp_local_hot_ttl"
+ require.NoError(t, store.BindResponseAccount(context.Background(), groupID, responseID, 902, 24*time.Hour))
+
+ id := normalizeOpenAIWSResponseID(responseID)
+ require.NotEmpty(t, id)
+ store.responseToAccountMu.RLock()
+ binding, exists := store.responseToAccount[id]
+ store.responseToAccountMu.RUnlock()
+ require.True(t, exists)
+ require.Equal(t, int64(902), binding.accountID)
+ require.WithinDuration(t, time.Now().Add(openAIWSStateStoreHotCacheTTL), binding.expiresAt, 1500*time.Millisecond)
+}
+
func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) {
ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background())
defer cancel()
diff --git a/backend/internal/service/openai_ws_test_helpers_test.go b/backend/internal/service/openai_ws_test_helpers_test.go
new file mode 100644
index 000000000..d719c0eaf
--- /dev/null
+++ b/backend/internal/service/openai_ws_test_helpers_test.go
@@ -0,0 +1,277 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "io"
+ "net/http"
+ "sync"
+ "time"
+)
+
+type openAIWSQueueDialer struct {
+ mu sync.Mutex
+ conns []openAIWSClientConn
+ dialCount int
+}
+
+func (d *openAIWSQueueDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = ctx
+ _ = wsURL
+ _ = headers
+ _ = proxyURL
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ d.dialCount++
+ if len(d.conns) == 0 {
+ return nil, 503, nil, errors.New("no test conn")
+ }
+ conn := d.conns[0]
+ if len(d.conns) > 1 {
+ d.conns = d.conns[1:]
+ }
+ return conn, 0, nil, nil
+}
+
+func (d *openAIWSQueueDialer) DialCount() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dialCount
+}
+
+type openAIWSCaptureConn struct {
+ mu sync.Mutex
+ readDelays []time.Duration
+ events [][]byte
+ writes []map[string]any
+ closed bool
+}
+
+func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error {
+ _ = ctx
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return errOpenAIWSConnClosed
+ }
+ switch payload := value.(type) {
+ case map[string]any:
+ c.writes = append(c.writes, cloneMapStringAny(payload))
+ case json.RawMessage:
+ var parsed map[string]any
+ if err := json.Unmarshal(payload, &parsed); err == nil {
+ c.writes = append(c.writes, cloneMapStringAny(parsed))
+ }
+ case []byte:
+ var parsed map[string]any
+ if err := json.Unmarshal(payload, &parsed); err == nil {
+ c.writes = append(c.writes, cloneMapStringAny(parsed))
+ }
+ }
+ return nil
+}
+
+func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ c.mu.Lock()
+ if c.closed {
+ c.mu.Unlock()
+ return nil, errOpenAIWSConnClosed
+ }
+ if len(c.events) == 0 {
+ c.mu.Unlock()
+ return nil, io.EOF
+ }
+ delay := time.Duration(0)
+ if len(c.readDelays) > 0 {
+ delay = c.readDelays[0]
+ c.readDelays = c.readDelays[1:]
+ }
+ event := c.events[0]
+ c.events = c.events[1:]
+ c.mu.Unlock()
+ if delay > 0 {
+ timer := time.NewTimer(delay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-timer.C:
+ }
+ }
+ return event, nil
+}
+
+func (c *openAIWSCaptureConn) Ping(ctx context.Context) error {
+ _ = ctx
+ return nil
+}
+
+func (c *openAIWSCaptureConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.closed = true
+ return nil
+}
+
+func (c *openAIWSCaptureConn) Closed() bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.closed
+}
+
+func cloneMapStringAny(src map[string]any) map[string]any {
+ if src == nil {
+ return nil
+ }
+ dst := make(map[string]any, len(src))
+ for k, v := range src {
+ dst[k] = v
+ }
+ return dst
+}
+
+type openAIWSAlwaysFailDialer struct {
+ mu sync.Mutex
+ dialCount int
+}
+
+func (d *openAIWSAlwaysFailDialer) Dial(
+ ctx context.Context,
+ wsURL string,
+ headers http.Header,
+ proxyURL string,
+) (openAIWSClientConn, int, http.Header, error) {
+ _ = ctx
+ _ = wsURL
+ _ = headers
+ _ = proxyURL
+ d.mu.Lock()
+ d.dialCount++
+ d.mu.Unlock()
+ return nil, 503, nil, errors.New("dial failed")
+}
+
+func (d *openAIWSAlwaysFailDialer) DialCount() int {
+ d.mu.Lock()
+ defer d.mu.Unlock()
+ return d.dialCount
+}
+
+type openAIWSFakeConn struct {
+ mu sync.Mutex
+ closed bool
+ payload [][]byte
+}
+
+func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error {
+ _ = ctx
+ _ = value
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return errors.New("closed")
+ }
+ c.payload = append(c.payload, []byte("ok"))
+ return nil
+}
+
+func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) {
+ _ = ctx
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if c.closed {
+ return nil, errors.New("closed")
+ }
+ return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil
+}
+
+func (c *openAIWSFakeConn) Ping(ctx context.Context) error {
+ _ = ctx
+ return nil
+}
+
+func (c *openAIWSFakeConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.closed = true
+ return nil
+}
+
+type openAIWSPingFailConn struct{}
+
+func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error {
+ return nil
+}
+
+func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) {
+ return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil
+}
+
+func (c *openAIWSPingFailConn) Ping(context.Context) error {
+ return errors.New("ping failed")
+}
+
+func (c *openAIWSPingFailConn) Close() error {
+ return nil
+}
+
+// openAIWSDelayedPingFailConn 是带可控延迟的 Ping 失败连接,
+// 用于模拟"Ping 执行期间连接被重建"的竞态场景。
+type openAIWSDelayedPingFailConn struct {
+ delay time.Duration
+ pingDone chan struct{} // Ping 开始执行时关闭,通知测试可以进行下一步
+ mu sync.Mutex
+ closed bool
+}
+
+func newOpenAIWSDelayedPingFailConn(delay time.Duration) *openAIWSDelayedPingFailConn {
+ return &openAIWSDelayedPingFailConn{
+ delay: delay,
+ pingDone: make(chan struct{}),
+ }
+}
+
+func (c *openAIWSDelayedPingFailConn) WriteJSON(context.Context, any) error { return nil }
+func (c *openAIWSDelayedPingFailConn) ReadMessage(context.Context) ([]byte, error) {
+ return []byte(`{"type":"response.completed","response":{"id":"resp_delayed_ping"}}`), nil
+}
+
+func (c *openAIWSDelayedPingFailConn) Ping(ctx context.Context) error {
+ // 通知测试 Ping 已开始
+ select {
+ case <-c.pingDone:
+ default:
+ close(c.pingDone)
+ }
+ // 等待延迟或上下文取消
+ timer := time.NewTimer(c.delay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return ctx.Err()
+ case <-timer.C:
+ }
+ return errors.New("ping failed after delay")
+}
+
+func (c *openAIWSDelayedPingFailConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.closed = true
+ return nil
+}
+
+func (c *openAIWSDelayedPingFailConn) Closed() bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ return c.closed
+}
diff --git a/backend/internal/service/openai_ws_turn.go b/backend/internal/service/openai_ws_turn.go
new file mode 100644
index 000000000..1ce3a14a5
--- /dev/null
+++ b/backend/internal/service/openai_ws_turn.go
@@ -0,0 +1,1430 @@
+package service
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "net/http"
+ "sort"
+ "strings"
+
+ "github.com/Wei-Shaw/sub2api/internal/pkg/logger"
+ "github.com/gin-gonic/gin"
+ "github.com/tidwall/gjson"
+ "github.com/tidwall/sjson"
+ "go.uber.org/zap"
+)
+
+// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。
+type OpenAIWSIngressHooks struct {
+ BeforeTurn func(turn int) error
+ AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error)
+}
+
+func normalizeOpenAIWSLogValue(value string) string {
+ trimmed := strings.TrimSpace(value)
+ if trimmed == "" {
+ return "-"
+ }
+ return openAIWSLogValueReplacer.Replace(trimmed)
+}
+
+func truncateOpenAIWSLogValue(value string, maxLen int) string {
+ normalized := normalizeOpenAIWSLogValue(value)
+ if normalized == "-" || maxLen <= 0 {
+ return normalized
+ }
+ if len(normalized) <= maxLen {
+ return normalized
+ }
+ return normalized[:maxLen] + "..."
+}
+
+func openAIWSHeaderValueForLog(headers http.Header, key string) string {
+ if headers == nil {
+ return "-"
+ }
+ return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen)
+}
+
+func hasOpenAIWSHeader(headers http.Header, key string) bool {
+ if headers == nil {
+ return false
+ }
+ return strings.TrimSpace(headers.Get(key)) != ""
+}
+
+type openAIWSSessionHeaderResolution struct {
+ SessionID string
+ ConversationID string
+ SessionSource string
+ ConversationSource string
+}
+
+func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution {
+ resolution := openAIWSSessionHeaderResolution{
+ SessionSource: "none",
+ ConversationSource: "none",
+ }
+ if c != nil && c.Request != nil {
+ if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" {
+ resolution.SessionID = sessionID
+ resolution.SessionSource = "header_session_id"
+ }
+ if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" {
+ resolution.ConversationID = conversationID
+ resolution.ConversationSource = "header_conversation_id"
+ if resolution.SessionID == "" {
+ resolution.SessionID = conversationID
+ resolution.SessionSource = "header_conversation_id"
+ }
+ }
+ }
+
+ cacheKey := strings.TrimSpace(promptCacheKey)
+ if cacheKey != "" {
+ if resolution.SessionID == "" {
+ resolution.SessionID = cacheKey
+ resolution.SessionSource = "prompt_cache_key"
+ }
+ }
+ return resolution
+}
+
+func openAIWSIngressSessionScopeFromContext(c *gin.Context) string {
+ if c == nil {
+ return ""
+ }
+ value, exists := c.Get("api_key")
+ if !exists || value == nil {
+ return ""
+ }
+ apiKey, ok := value.(*APIKey)
+ if !ok || apiKey == nil {
+ return ""
+ }
+ userID := apiKey.UserID
+ if userID <= 0 && apiKey.User != nil {
+ userID = apiKey.User.ID
+ }
+ apiKeyID := apiKey.ID
+ if userID <= 0 && apiKeyID <= 0 {
+ return ""
+ }
+ return fmt.Sprintf("u%d:k%d", userID, apiKeyID)
+}
+
+func openAIWSApplySessionScope(sessionHash, scope string) string {
+ hash := strings.TrimSpace(sessionHash)
+ if hash == "" {
+ return ""
+ }
+ scope = strings.TrimSpace(scope)
+ if scope == "" {
+ return hash
+ }
+ return scope + "|" + hash
+}
+
+func shouldLogOpenAIWSEvent(idx int, eventType string) bool {
+ if idx <= openAIWSEventLogHeadLimit {
+ return true
+ }
+ if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 {
+ return true
+ }
+ if eventType == "error" || isOpenAIWSTerminalEvent(eventType) {
+ return true
+ }
+ return false
+}
+
+func shouldLogOpenAIWSBufferedEvent(idx int) bool {
+ if idx <= openAIWSBufferLogHeadLimit {
+ return true
+ }
+ if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 {
+ return true
+ }
+ return false
+}
+
+func openAIWSEventMayContainModel(eventType string) bool {
+ switch eventType {
+ case "response.created",
+ "response.in_progress",
+ "response.completed",
+ "response.done",
+ "response.failed",
+ "response.incomplete",
+ "response.cancelled",
+ "response.canceled":
+ return true
+ default:
+ trimmed := strings.TrimSpace(eventType)
+ if trimmed == eventType {
+ return false
+ }
+ switch trimmed {
+ case "response.created",
+ "response.in_progress",
+ "response.completed",
+ "response.done",
+ "response.failed",
+ "response.incomplete",
+ "response.cancelled",
+ "response.canceled":
+ return true
+ default:
+ return false
+ }
+ }
+}
+
+func openAIWSEventMayContainToolCalls(eventType string) bool {
+ if eventType == "" {
+ return false
+ }
+ if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") {
+ return true
+ }
+ switch eventType {
+ case "response.output_item.added", "response.output_item.done", "response.completed", "response.done":
+ return true
+ default:
+ return false
+ }
+}
+
+// openAIWSEventShouldParseUsage 判断是否应解析 usage。
+// 调用方需确保 eventType 已经过 TrimSpace(如 parseOpenAIWSEventType 的返回值)。
+func openAIWSEventShouldParseUsage(eventType string) bool {
+ switch eventType {
+ case "response.completed", "response.done", "response.failed":
+ return true
+ default:
+ return false
+ }
+}
+
+// parseOpenAIWSEventType extracts only the event type and response ID from a WS message.
+// Use this lightweight version on hot paths where the full response body is not needed.
+func parseOpenAIWSEventType(message []byte) (eventType string, responseID string) {
+ if len(message) == 0 {
+ return "", ""
+ }
+ values := gjson.GetManyBytes(message, "type", "response.id", "id")
+ eventType = strings.TrimSpace(values[0].String())
+ if id := strings.TrimSpace(values[1].String()); id != "" {
+ responseID = id
+ } else {
+ responseID = strings.TrimSpace(values[2].String())
+ }
+ return eventType, responseID
+}
+
+func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) {
+ if len(message) == 0 {
+ return "", "", gjson.Result{}
+ }
+ values := gjson.GetManyBytes(message, "type", "response.id", "id", "response")
+ eventType = strings.TrimSpace(values[0].String())
+ if id := strings.TrimSpace(values[1].String()); id != "" {
+ responseID = id
+ } else {
+ responseID = strings.TrimSpace(values[2].String())
+ }
+ return eventType, responseID, values[3]
+}
+
+func openAIWSMessageLikelyContainsToolCalls(message []byte) bool {
+ if len(message) == 0 {
+ return false
+ }
+ return bytes.Contains(message, []byte(`"tool_calls"`)) ||
+ bytes.Contains(message, []byte(`"tool_call"`)) ||
+ bytes.Contains(message, []byte(`"function_call"`))
+}
+
+func openAIWSCollectPendingFunctionCallIDsFromJSONResult(result gjson.Result, callIDSet map[string]struct{}, depth int) {
+ if !result.Exists() || callIDSet == nil || depth > 8 || result.Type != gjson.JSON {
+ return
+ }
+ itemType := strings.TrimSpace(result.Get("type").String())
+ if itemType == "function_call" || itemType == "tool_call" {
+ callID := strings.TrimSpace(result.Get("call_id").String())
+ if callID == "" {
+ fallbackID := strings.TrimSpace(result.Get("id").String())
+ if strings.HasPrefix(fallbackID, "call_") {
+ callID = fallbackID
+ }
+ }
+ if callID != "" {
+ callIDSet[callID] = struct{}{}
+ }
+ }
+ result.ForEach(func(_, child gjson.Result) bool {
+ openAIWSCollectPendingFunctionCallIDsFromJSONResult(child, callIDSet, depth+1)
+ return true
+ })
+}
+
+func openAIWSExtractPendingFunctionCallIDsFromEvent(message []byte) []string {
+ if len(message) == 0 {
+ return nil
+ }
+ callIDSet := make(map[string]struct{}, 4)
+ openAIWSCollectPendingFunctionCallIDsFromJSONResult(gjson.ParseBytes(message), callIDSet, 0)
+ if len(callIDSet) == 0 {
+ return nil
+ }
+ callIDs := make([]string, 0, len(callIDSet))
+ for callID := range callIDSet {
+ callIDs = append(callIDs, callID)
+ }
+ sort.Strings(callIDs)
+ return callIDs
+}
+
+func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) {
+ if usage == nil || len(message) == 0 {
+ return
+ }
+ values := gjson.GetManyBytes(
+ message,
+ "response.usage.input_tokens",
+ "response.usage.output_tokens",
+ "response.usage.input_tokens_details.cached_tokens",
+ )
+ usage.InputTokens = int(values[0].Int())
+ usage.OutputTokens = int(values[1].Int())
+ usage.CacheReadInputTokens = int(values[2].Int())
+}
+
+func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
+ if len(message) == 0 {
+ return "", "", ""
+ }
+ values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message")
+ return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String())
+}
+
+func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) {
+ code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen)
+ errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen)
+ errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen)
+ return code, errType, errMessage
+}
+
+func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) {
+ if len(message) == 0 {
+ return "-", "-", "-"
+ }
+ return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message))
+}
+
+func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string {
+ if len(payload) == 0 {
+ return "-"
+ }
+ type keySize struct {
+ Key string
+ Size int
+ }
+ sizes := make([]keySize, 0, len(payload))
+ for key, value := range payload {
+ size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth)
+ sizes = append(sizes, keySize{Key: key, Size: size})
+ }
+ sort.Slice(sizes, func(i, j int) bool {
+ if sizes[i].Size == sizes[j].Size {
+ return sizes[i].Key < sizes[j].Key
+ }
+ return sizes[i].Size > sizes[j].Size
+ })
+
+ if topN <= 0 || topN > len(sizes) {
+ topN = len(sizes)
+ }
+ parts := make([]string, 0, topN)
+ for idx := 0; idx < topN; idx++ {
+ item := sizes[idx]
+ parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size))
+ }
+ return strings.Join(parts, ",")
+}
+
+func estimateOpenAIWSPayloadValueSize(value any, depth int) int {
+ if depth <= 0 {
+ return -1
+ }
+ switch v := value.(type) {
+ case nil:
+ return 0
+ case string:
+ return len(v)
+ case []byte:
+ return len(v)
+ case bool:
+ return 1
+ case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
+ return 8
+ case float32, float64:
+ return 8
+ case map[string]any:
+ if len(v) == 0 {
+ return 2
+ }
+ total := 2
+ count := 0
+ for key, item := range v {
+ count++
+ if count > openAIWSPayloadSizeEstimateMaxItems {
+ return -1
+ }
+ itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1)
+ if itemSize < 0 {
+ return -1
+ }
+ total += len(key) + itemSize + 3
+ if total > openAIWSPayloadSizeEstimateMaxBytes {
+ return -1
+ }
+ }
+ return total
+ case []any:
+ if len(v) == 0 {
+ return 2
+ }
+ total := 2
+ limit := len(v)
+ if limit > openAIWSPayloadSizeEstimateMaxItems {
+ return -1
+ }
+ for i := 0; i < limit; i++ {
+ itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1)
+ if itemSize < 0 {
+ return -1
+ }
+ total += itemSize + 1
+ if total > openAIWSPayloadSizeEstimateMaxBytes {
+ return -1
+ }
+ }
+ return total
+ default:
+ raw, err := json.Marshal(v)
+ if err != nil {
+ return -1
+ }
+ if len(raw) > openAIWSPayloadSizeEstimateMaxBytes {
+ return -1
+ }
+ return len(raw)
+ }
+}
+
+func openAIWSPayloadString(payload map[string]any, key string) string {
+ if len(payload) == 0 {
+ return ""
+ }
+ raw, ok := payload[key]
+ if !ok {
+ return ""
+ }
+ switch v := raw.(type) {
+ case nil:
+ return ""
+ case string:
+ return strings.TrimSpace(v)
+ case []byte:
+ return strings.TrimSpace(string(v))
+ default:
+ return ""
+ }
+}
+
+func openAIWSPayloadStringFromRaw(payload []byte, key string) string {
+ if len(payload) == 0 || strings.TrimSpace(key) == "" {
+ return ""
+ }
+ return strings.TrimSpace(gjson.GetBytes(payload, key).String())
+}
+
+func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool {
+ if len(payload) == 0 || strings.TrimSpace(key) == "" {
+ return defaultValue
+ }
+ value := gjson.GetBytes(payload, key)
+ if !value.Exists() {
+ return defaultValue
+ }
+ if value.Type != gjson.True && value.Type != gjson.False {
+ return defaultValue
+ }
+ return value.Bool()
+}
+
+func openAIWSSessionHashesFromID(sessionID string) (string, string) {
+ return deriveOpenAISessionHashes(sessionID)
+}
+
+func extractOpenAIWSImageURL(value any) string {
+ switch v := value.(type) {
+ case string:
+ return strings.TrimSpace(v)
+ case map[string]any:
+ if raw, ok := v["url"].(string); ok {
+ return strings.TrimSpace(raw)
+ }
+ }
+ return ""
+}
+
+func summarizeOpenAIWSInput(input any) string {
+ items, ok := input.([]any)
+ if !ok || len(items) == 0 {
+ return "-"
+ }
+
+ itemCount := len(items)
+ textChars := 0
+ imageDataURLs := 0
+ imageDataURLChars := 0
+ imageRemoteURLs := 0
+
+ handleContentItem := func(contentItem map[string]any) {
+ contentType, _ := contentItem["type"].(string)
+ switch strings.TrimSpace(contentType) {
+ case "input_text", "output_text", "text":
+ if text, ok := contentItem["text"].(string); ok {
+ textChars += len(text)
+ }
+ case "input_image":
+ imageURL := extractOpenAIWSImageURL(contentItem["image_url"])
+ if imageURL == "" {
+ return
+ }
+ if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
+ imageDataURLs++
+ imageDataURLChars += len(imageURL)
+ return
+ }
+ imageRemoteURLs++
+ }
+ }
+
+ handleInputItem := func(inputItem map[string]any) {
+ if content, ok := inputItem["content"].([]any); ok {
+ for _, rawContent := range content {
+ contentItem, ok := rawContent.(map[string]any)
+ if !ok {
+ continue
+ }
+ handleContentItem(contentItem)
+ }
+ return
+ }
+
+ itemType, _ := inputItem["type"].(string)
+ switch strings.TrimSpace(itemType) {
+ case "input_text", "output_text", "text":
+ if text, ok := inputItem["text"].(string); ok {
+ textChars += len(text)
+ }
+ case "input_image":
+ imageURL := extractOpenAIWSImageURL(inputItem["image_url"])
+ if imageURL == "" {
+ return
+ }
+ if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") {
+ imageDataURLs++
+ imageDataURLChars += len(imageURL)
+ return
+ }
+ imageRemoteURLs++
+ }
+ }
+
+ for _, rawItem := range items {
+ inputItem, ok := rawItem.(map[string]any)
+ if !ok {
+ continue
+ }
+ handleInputItem(inputItem)
+ }
+
+ return fmt.Sprintf(
+ "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d",
+ itemCount,
+ textChars,
+ imageDataURLs,
+ imageDataURLChars,
+ imageRemoteURLs,
+ )
+}
+
+func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) {
+ if len(payload) == 0 || strings.TrimSpace(key) == "" {
+ return
+ }
+ if _, exists := payload[key]; !exists {
+ return
+ }
+ delete(payload, key)
+ *removed = append(*removed, key)
+}
+
+// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段,
+// 避免重试成功却改变原始请求语义。
+// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。
+func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) {
+ if len(payload) == 0 {
+ return "empty", nil
+ }
+ if attempt <= 1 {
+ return "full", nil
+ }
+
+ removed := make([]string, 0, 2)
+ if attempt >= 2 {
+ dropOpenAIWSPayloadKey(payload, "include", &removed)
+ }
+
+ if len(removed) == 0 {
+ return "full", nil
+ }
+ sort.Strings(removed)
+ return "trim_optional_fields", removed
+}
+
+func logOpenAIWSModeInfo(format string, args ...any) {
+ logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
+}
+
+func isOpenAIWSModeDebugEnabled() bool {
+ return logger.L().Core().Enabled(zap.DebugLevel)
+}
+
+func logOpenAIWSModeDebug(format string, args ...any) {
+ if !isOpenAIWSModeDebugEnabled() {
+ return
+ }
+ logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...)
+}
+
+func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) {
+ if err == nil {
+ return
+ }
+ logger.L().Warn(
+ "openai.ws_bind_response_account_failed",
+ zap.Int64("group_id", groupID),
+ zap.Int64("account_id", accountID),
+ zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)),
+ zap.Error(err),
+ )
+}
+
+func logOpenAIWSIngressTurnAbort(
+ accountID int64,
+ turn int,
+ connID string,
+ reason openAIWSIngressTurnAbortReason,
+ expected bool,
+ cause error,
+) {
+ causeValue := "-"
+ if cause != nil {
+ causeValue = truncateOpenAIWSLogValue(cause.Error(), openAIWSLogValueMaxLen)
+ }
+ logOpenAIWSModeInfo(
+ "ingress_ws_turn_aborted account_id=%d turn=%d conn_id=%s reason=%s expected=%v cause=%s",
+ accountID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ normalizeOpenAIWSLogValue(string(reason)),
+ expected,
+ causeValue,
+ )
+}
+
+func sortedKeys(m map[string]any) []string {
+ if len(m) == 0 {
+ return nil
+ }
+ keys := make([]string, 0, len(m))
+ for k := range m {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ return keys
+}
+
+func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) {
+ return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes)
+}
+
+func dropPreviousResponseIDFromRawPayloadWithDeleteFn(
+ payload []byte,
+ deleteFn func([]byte, string) ([]byte, error),
+) ([]byte, bool, error) {
+ if len(payload) == 0 {
+ return payload, false, nil
+ }
+ if !gjson.GetBytes(payload, "previous_response_id").Exists() {
+ return payload, false, nil
+ }
+ if deleteFn == nil {
+ deleteFn = sjson.DeleteBytes
+ }
+
+ updated := payload
+ for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses &&
+ gjson.GetBytes(updated, "previous_response_id").Exists(); i++ {
+ next, err := deleteFn(updated, "previous_response_id")
+ if err != nil {
+ return payload, false, err
+ }
+ updated = next
+ }
+ return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil
+}
+
+func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) {
+ normalizedPrevID := strings.TrimSpace(previousResponseID)
+ if len(payload) == 0 || normalizedPrevID == "" {
+ return payload, nil
+ }
+ if current := openAIWSPayloadStringFromRaw(payload, "previous_response_id"); current == normalizedPrevID {
+ return payload, nil
+ }
+ updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID)
+ if err == nil {
+ return updated, nil
+ }
+
+ var reqBody map[string]any
+ if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil {
+ return nil, err
+ }
+ reqBody["previous_response_id"] = normalizedPrevID
+ rebuilt, marshalErr := json.Marshal(reqBody)
+ if marshalErr != nil {
+ return nil, marshalErr
+ }
+ return rebuilt, nil
+}
+
+func shouldInferIngressFunctionCallOutputPreviousResponseID(
+ storeDisabled bool,
+ turn int,
+ hasFunctionCallOutput bool,
+ currentPreviousResponseID string,
+ expectedPreviousResponseID string,
+) bool {
+ if !storeDisabled || turn <= 0 || !hasFunctionCallOutput {
+ return false
+ }
+ if strings.TrimSpace(currentPreviousResponseID) != "" {
+ return false
+ }
+ return strings.TrimSpace(expectedPreviousResponseID) != ""
+}
+
+func alignStoreDisabledPreviousResponseID(
+ payload []byte,
+ expectedPreviousResponseID string,
+) ([]byte, bool, error) {
+ if len(payload) == 0 {
+ return payload, false, nil
+ }
+ expected := strings.TrimSpace(expectedPreviousResponseID)
+ if expected == "" {
+ return payload, false, nil
+ }
+ current := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ if current == "" || current == expected {
+ return payload, false, nil
+ }
+
+ // 常见路径(无重复 key)直接 set,避免先 delete 再 set 的双遍处理。
+ // 仅在检测到重复 key 时走 drop+set 慢路径,确保最终语义一致。
+ if bytes.Count(payload, []byte(`"previous_response_id"`)) <= 1 {
+ updated, setErr := setPreviousResponseIDToRawPayload(payload, expected)
+ if setErr != nil {
+ return payload, false, setErr
+ }
+ return updated, !bytes.Equal(updated, payload), nil
+ }
+
+ withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ if dropErr != nil {
+ return payload, false, dropErr
+ }
+ if !removed {
+ return payload, false, nil
+ }
+ updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected)
+ if setErr != nil {
+ return payload, false, setErr
+ }
+ return updated, true, nil
+}
+
+func cloneOpenAIWSPayloadBytes(payload []byte) []byte {
+ if len(payload) == 0 {
+ return nil
+ }
+ cloned := make([]byte, len(payload))
+ copy(cloned, payload)
+ return cloned
+}
+
+func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage {
+ if items == nil {
+ return nil
+ }
+ cloned := make([]json.RawMessage, 0, len(items))
+ for idx := range items {
+ cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx])))
+ }
+ return cloned
+}
+
+func cloneOpenAIWSJSONRawString(raw string) []byte {
+ if strings.TrimSpace(raw) == "" {
+ return nil
+ }
+ cloned := make([]byte, len(raw))
+ copy(cloned, raw)
+ return cloned
+}
+
+func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) {
+ trimmed := bytes.TrimSpace(raw)
+ if len(trimmed) == 0 {
+ return nil, errors.New("json is empty")
+ }
+ var decoded any
+ if err := json.Unmarshal(trimmed, &decoded); err != nil {
+ return nil, err
+ }
+ return json.Marshal(decoded)
+}
+
+func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte {
+ normalized, err := normalizeOpenAIWSJSONForCompare(raw)
+ if err != nil {
+ return bytes.TrimSpace(raw)
+ }
+ return normalized
+}
+
+func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) {
+ if len(payload) == 0 {
+ return nil, errors.New("payload is empty")
+ }
+ var decoded map[string]any
+ if err := json.Unmarshal(payload, &decoded); err != nil {
+ return nil, err
+ }
+ delete(decoded, "input")
+ delete(decoded, "previous_response_id")
+ return json.Marshal(decoded)
+}
+
+func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) {
+ if len(payload) == 0 {
+ return nil, false, nil
+ }
+ inputValue := gjson.GetBytes(payload, "input")
+ if !inputValue.Exists() {
+ return nil, false, nil
+ }
+ if inputValue.Type == gjson.JSON {
+ raw := strings.TrimSpace(inputValue.Raw)
+ if strings.HasPrefix(raw, "[") {
+ var items []json.RawMessage
+ if err := json.Unmarshal([]byte(raw), &items); err != nil {
+ return nil, true, err
+ }
+ return items, true, nil
+ }
+ return []json.RawMessage{json.RawMessage(raw)}, true, nil
+ }
+ if inputValue.Type == gjson.String {
+ encoded, _ := json.Marshal(inputValue.String())
+ return []json.RawMessage{encoded}, true, nil
+ }
+ return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil
+}
+
+func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) {
+ previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload)
+ if prevErr != nil {
+ return false, prevErr
+ }
+ currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
+ if currentErr != nil {
+ return false, currentErr
+ }
+ if !previousExists && !currentExists {
+ return true, nil
+ }
+ if !previousExists {
+ return len(currentItems) == 0, nil
+ }
+ if !currentExists {
+ return len(previousItems) == 0, nil
+ }
+ if len(currentItems) < len(previousItems) {
+ return false, nil
+ }
+
+ for idx := range previousItems {
+ previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx])
+ currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx])
+ if !bytes.Equal(previousNormalized, currentNormalized) {
+ return false, nil
+ }
+ }
+ return true, nil
+}
+
+func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool {
+ if len(prefix) == 0 {
+ return true
+ }
+ if len(items) < len(prefix) {
+ return false
+ }
+ for idx := range prefix {
+ previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx])
+ currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx])
+ if !bytes.Equal(previousNormalized, currentNormalized) {
+ return false
+ }
+ }
+ return true
+}
+
+func limitOpenAIWSReplayInputSequenceByBytes(items []json.RawMessage, maxBytes int) []json.RawMessage {
+ if len(items) == 0 {
+ return nil
+ }
+ if maxBytes <= 0 {
+ return cloneOpenAIWSRawMessages(items)
+ }
+
+ start := len(items)
+ total := 2 // "[]"
+ for idx := len(items) - 1; idx >= 0; idx-- {
+ itemBytes := len(items[idx])
+ if start != len(items) {
+ itemBytes++ // comma
+ }
+ if total+itemBytes > maxBytes {
+ // Keep at least the newest item to avoid creating an empty replay input.
+ if start == len(items) {
+ start = idx
+ }
+ break
+ }
+ total += itemBytes
+ start = idx
+ }
+ if start < 0 || start > len(items) {
+ start = len(items) - 1
+ }
+ return cloneOpenAIWSRawMessages(items[start:])
+}
+
+func buildOpenAIWSReplayInputSequence(
+ previousFullInput []json.RawMessage,
+ previousFullInputExists bool,
+ currentPayload []byte,
+ hasPreviousResponseID bool,
+) ([]json.RawMessage, bool, error) {
+ currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
+ if currentErr != nil {
+ return nil, false, currentErr
+ }
+ candidate := []json.RawMessage(nil)
+ exists := false
+ if !hasPreviousResponseID {
+ candidate = cloneOpenAIWSRawMessages(currentItems)
+ exists = currentExists
+ if !exists {
+ return candidate, false, nil
+ }
+ return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), true, nil
+ }
+ if !previousFullInputExists {
+ candidate = cloneOpenAIWSRawMessages(currentItems)
+ exists = currentExists
+ if !exists {
+ return candidate, false, nil
+ }
+ return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), true, nil
+ }
+ if !currentExists || len(currentItems) == 0 {
+ candidate = cloneOpenAIWSRawMessages(previousFullInput)
+ exists = true
+ return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil
+ }
+ if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) {
+ candidate = cloneOpenAIWSRawMessages(currentItems)
+ exists = true
+ return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil
+ }
+ merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems))
+ merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...)
+ merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...)
+ candidate = merged
+ exists = true
+ return limitOpenAIWSReplayInputSequenceByBytes(candidate, openAIWSIngressReplayInputMaxBytes), exists, nil
+}
+
+func openAIWSInputAppearsEditedFromPreviousFullInput(
+ previousFullInput []json.RawMessage,
+ previousFullInputExists bool,
+ currentPayload []byte,
+ hasPreviousResponseID bool,
+) (bool, error) {
+ if !hasPreviousResponseID || !previousFullInputExists {
+ return false, nil
+ }
+ currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload)
+ if currentErr != nil {
+ return false, currentErr
+ }
+ if !currentExists || len(currentItems) == 0 {
+ return false, nil
+ }
+ if len(previousFullInput) < 2 {
+ // Single-item turns are ambiguous (could be a normal incremental replace), avoid false positives.
+ return false, nil
+ }
+ if len(currentItems) < len(previousFullInput) {
+ // Most delta appends only send the latest one/few items.
+ return false, nil
+ }
+ if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) {
+ // Full snapshot append or unchanged snapshot.
+ return false, nil
+ }
+ return true, nil
+}
+
+func setOpenAIWSPayloadInputSequence(
+ payload []byte,
+ fullInput []json.RawMessage,
+ fullInputExists bool,
+) ([]byte, error) {
+ if !fullInputExists {
+ return payload, nil
+ }
+ // Preserve [] vs null semantics when input exists but is empty.
+ inputForMarshal := fullInput
+ if inputForMarshal == nil {
+ inputForMarshal = []json.RawMessage{}
+ }
+ inputRaw, marshalErr := json.Marshal(inputForMarshal)
+ if marshalErr != nil {
+ return nil, marshalErr
+ }
+ return sjson.SetRawBytes(payload, "input", inputRaw)
+}
+
+func openAIWSNormalizeCallIDs(callIDs []string) []string {
+ if len(callIDs) == 0 {
+ return nil
+ }
+ seen := make(map[string]struct{}, len(callIDs))
+ normalized := make([]string, 0, len(callIDs))
+ for _, callID := range callIDs {
+ id := strings.TrimSpace(callID)
+ if id == "" {
+ continue
+ }
+ if _, exists := seen[id]; exists {
+ continue
+ }
+ seen[id] = struct{}{}
+ normalized = append(normalized, id)
+ }
+ sort.Strings(normalized)
+ return normalized
+}
+
+func openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload []byte) []string {
+ if len(payload) == 0 {
+ return nil
+ }
+ input := gjson.GetBytes(payload, "input")
+ if !input.Exists() {
+ return nil
+ }
+ callIDSet := make(map[string]struct{}, 4)
+ collect := func(item gjson.Result) {
+ if item.Type != gjson.JSON {
+ return
+ }
+ if strings.TrimSpace(item.Get("type").String()) != "function_call_output" {
+ return
+ }
+ callID := strings.TrimSpace(item.Get("call_id").String())
+ if callID == "" {
+ return
+ }
+ callIDSet[callID] = struct{}{}
+ }
+ if input.IsArray() {
+ input.ForEach(func(_, item gjson.Result) bool {
+ collect(item)
+ return true
+ })
+ } else {
+ collect(input)
+ }
+ if len(callIDSet) == 0 {
+ return nil
+ }
+ callIDs := make([]string, 0, len(callIDSet))
+ for callID := range callIDSet {
+ callIDs = append(callIDs, callID)
+ }
+ sort.Strings(callIDs)
+ return callIDs
+}
+
+func openAIWSFindMissingCallIDs(requiredCallIDs []string, actualCallIDs []string) []string {
+ required := openAIWSNormalizeCallIDs(requiredCallIDs)
+ if len(required) == 0 {
+ return nil
+ }
+ actualSet := make(map[string]struct{}, len(actualCallIDs))
+ for _, callID := range actualCallIDs {
+ id := strings.TrimSpace(callID)
+ if id == "" {
+ continue
+ }
+ actualSet[id] = struct{}{}
+ }
+ missing := make([]string, 0, len(required))
+ for _, callID := range required {
+ if _, ok := actualSet[callID]; ok {
+ continue
+ }
+ missing = append(missing, callID)
+ }
+ return missing
+}
+
+func openAIWSInjectFunctionCallOutputItems(payload []byte, callIDs []string, outputValue string) ([]byte, int, error) {
+ normalizedCallIDs := openAIWSNormalizeCallIDs(callIDs)
+ if len(normalizedCallIDs) == 0 {
+ return payload, 0, nil
+ }
+ inputItems, inputExists, inputErr := openAIWSExtractNormalizedInputSequence(payload)
+ if inputErr != nil {
+ return nil, 0, inputErr
+ }
+ if !inputExists {
+ inputItems = []json.RawMessage{}
+ }
+ updatedInput := make([]json.RawMessage, 0, len(inputItems)+len(normalizedCallIDs))
+ updatedInput = append(updatedInput, cloneOpenAIWSRawMessages(inputItems)...)
+ for _, callID := range normalizedCallIDs {
+ rawItem, marshalErr := json.Marshal(map[string]any{
+ "type": "function_call_output",
+ "call_id": callID,
+ "output": outputValue,
+ })
+ if marshalErr != nil {
+ return nil, 0, marshalErr
+ }
+ updatedInput = append(updatedInput, json.RawMessage(rawItem))
+ }
+ updatedPayload, setErr := setOpenAIWSPayloadInputSequence(payload, updatedInput, true)
+ if setErr != nil {
+ return nil, 0, setErr
+ }
+ return updatedPayload, len(normalizedCallIDs), nil
+}
+
+func shouldKeepIngressPreviousResponseID(
+ previousPayload []byte,
+ currentPayload []byte,
+ lastTurnResponseID string,
+ hasFunctionCallOutput bool,
+ expectedPendingCallIDs []string,
+ functionCallOutputCallIDs []string,
+) (bool, string, error) {
+ if hasFunctionCallOutput {
+ if len(expectedPendingCallIDs) == 0 {
+ return true, "has_function_call_output", nil
+ }
+ if len(openAIWSFindMissingCallIDs(expectedPendingCallIDs, functionCallOutputCallIDs)) > 0 {
+ return false, "function_call_output_call_id_mismatch", nil
+ }
+ return true, "function_call_output_call_id_match", nil
+ }
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
+ if currentPreviousResponseID == "" {
+ return false, "missing_previous_response_id", nil
+ }
+ expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
+ if expectedPreviousResponseID == "" {
+ return false, "missing_last_turn_response_id", nil
+ }
+ if currentPreviousResponseID != expectedPreviousResponseID {
+ return false, "previous_response_id_mismatch", nil
+ }
+ if len(previousPayload) == 0 {
+ return false, "missing_previous_turn_payload", nil
+ }
+
+ previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload)
+ if previousComparableErr != nil {
+ return false, "non_input_compare_error", previousComparableErr
+ }
+ currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
+ if currentComparableErr != nil {
+ return false, "non_input_compare_error", currentComparableErr
+ }
+ if !bytes.Equal(previousComparable, currentComparable) {
+ return false, "non_input_changed", nil
+ }
+ return true, "strict_incremental_ok", nil
+}
+
+type openAIWSIngressPreviousTurnStrictState struct {
+ nonInputComparable []byte
+}
+
+func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) {
+ if len(payload) == 0 {
+ return nil, nil
+ }
+ nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload)
+ if nonInputErr != nil {
+ return nil, nonInputErr
+ }
+ return &openAIWSIngressPreviousTurnStrictState{
+ nonInputComparable: nonInputComparable,
+ }, nil
+}
+
+func shouldKeepIngressPreviousResponseIDWithStrictState(
+ previousState *openAIWSIngressPreviousTurnStrictState,
+ currentPayload []byte,
+ lastTurnResponseID string,
+ hasFunctionCallOutput bool,
+ expectedPendingCallIDs []string,
+ functionCallOutputCallIDs []string,
+) (bool, string, error) {
+ if hasFunctionCallOutput {
+ if len(expectedPendingCallIDs) == 0 {
+ return true, "has_function_call_output", nil
+ }
+ if len(openAIWSFindMissingCallIDs(expectedPendingCallIDs, functionCallOutputCallIDs)) > 0 {
+ return false, "function_call_output_call_id_mismatch", nil
+ }
+ return true, "function_call_output_call_id_match", nil
+ }
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id"))
+ if currentPreviousResponseID == "" {
+ return false, "missing_previous_response_id", nil
+ }
+ expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID)
+ if expectedPreviousResponseID == "" {
+ return false, "missing_last_turn_response_id", nil
+ }
+ if currentPreviousResponseID != expectedPreviousResponseID {
+ return false, "previous_response_id_mismatch", nil
+ }
+ if previousState == nil {
+ return false, "missing_previous_turn_payload", nil
+ }
+
+ currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload)
+ if currentComparableErr != nil {
+ return false, "non_input_compare_error", currentComparableErr
+ }
+ if !bytes.Equal(previousState.nonInputComparable, currentComparable) {
+ return false, "non_input_changed", nil
+ }
+ return true, "strict_incremental_ok", nil
+}
+
+func payloadAsJSON(payload map[string]any) string {
+ return string(payloadAsJSONBytes(payload))
+}
+
+func normalizeOpenAIWSPreferredConnID(connID string) (string, bool) {
+ trimmed := strings.TrimSpace(connID)
+ if trimmed == "" {
+ return "", false
+ }
+ if strings.HasPrefix(trimmed, openAIWSConnIDPrefixCtx) {
+ return trimmed, true
+ }
+ if strings.HasPrefix(trimmed, openAIWSConnIDPrefixLegacy) {
+ return trimmed, true
+ }
+ return "", false
+}
+
+func openAIWSPreferredConnIDFromResponse(stateStore OpenAIWSStateStore, responseID string) string {
+ if stateStore == nil {
+ return ""
+ }
+ normalizedResponseID := strings.TrimSpace(responseID)
+ if normalizedResponseID == "" {
+ return ""
+ }
+ connID, ok := stateStore.GetResponseConn(normalizedResponseID)
+ if !ok {
+ return ""
+ }
+ normalizedConnID, ok := normalizeOpenAIWSPreferredConnID(connID)
+ if !ok {
+ return ""
+ }
+ return normalizedConnID
+}
+
+func payloadAsJSONBytes(payload map[string]any) []byte {
+ if len(payload) == 0 {
+ return []byte("{}")
+ }
+ body, err := json.Marshal(payload)
+ if err != nil {
+ return []byte("{}")
+ }
+ return body
+}
+
+func isOpenAIWSTerminalEvent(eventType string) bool {
+ switch eventType {
+ case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled":
+ return true
+ default:
+ return false
+ }
+}
+
+func shouldPersistOpenAIWSLastResponseID(terminalEventType string) bool {
+ switch terminalEventType {
+ case "response.completed", "response.done":
+ return true
+ default:
+ return false
+ }
+}
+
+func isOpenAIWSTokenEvent(eventType string) bool {
+ if eventType == "" {
+ return false
+ }
+ switch eventType {
+ case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done":
+ return false
+ }
+ if strings.Contains(eventType, ".delta") {
+ return true
+ }
+ if strings.HasPrefix(eventType, "response.output_text") {
+ return true
+ }
+ if strings.HasPrefix(eventType, "response.output") {
+ return true
+ }
+ return eventType == "response.completed" || eventType == "response.done"
+}
+
+func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte {
+ if len(message) == 0 {
+ return message
+ }
+ if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel {
+ return message
+ }
+ if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) {
+ return message
+ }
+ modelValues := gjson.GetManyBytes(message, "model", "response.model")
+ replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel
+ replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel
+ if !replaceModel && !replaceResponseModel {
+ return message
+ }
+ updated := message
+ if replaceModel {
+ if next, err := sjson.SetBytes(updated, "model", toModel); err == nil {
+ updated = next
+ }
+ }
+ if replaceResponseModel {
+ if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil {
+ updated = next
+ }
+ }
+ return updated
+}
+
+func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) {
+ if usage == nil || len(body) == 0 {
+ return
+ }
+ values := gjson.GetManyBytes(
+ body,
+ "usage.input_tokens",
+ "usage.output_tokens",
+ "usage.input_tokens_details.cached_tokens",
+ )
+ usage.InputTokens = int(values[0].Int())
+ usage.OutputTokens = int(values[1].Int())
+ usage.CacheReadInputTokens = int(values[2].Int())
+}
+
+func getOpenAIGroupIDFromContext(c *gin.Context) int64 {
+ if c == nil {
+ return 0
+ }
+ value, exists := c.Get("api_key")
+ if !exists {
+ return 0
+ }
+ apiKey, ok := value.(*APIKey)
+ if !ok || apiKey == nil || apiKey.GroupID == nil {
+ return 0
+ }
+ return *apiKey.GroupID
+}
+
+func openAIWSIngressFallbackSessionSeedFromContext(c *gin.Context) string {
+ if c == nil {
+ return ""
+ }
+ value, exists := c.Get("api_key")
+ if !exists {
+ return ""
+ }
+ apiKey, ok := value.(*APIKey)
+ if !ok || apiKey == nil {
+ return ""
+ }
+ gid := int64(0)
+ if apiKey.GroupID != nil {
+ gid = *apiKey.GroupID
+ }
+ userID := int64(0)
+ if apiKey.User != nil {
+ userID = apiKey.User.ID
+ }
+ return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKey.ID)
+}
diff --git a/backend/internal/service/openai_ws_upstream_pump_test.go b/backend/internal/service/openai_ws_upstream_pump_test.go
new file mode 100644
index 000000000..26a7439dc
--- /dev/null
+++ b/backend/internal/service/openai_ws_upstream_pump_test.go
@@ -0,0 +1,1894 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "io"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// 辅助:构造测试用上游事件 JSON
+// ---------------------------------------------------------------------------
+
+func pumpTestEvent(eventType string) []byte {
+ m := map[string]any{"type": eventType}
+ b, _ := json.Marshal(m)
+ return b
+}
+
+func pumpTestEventWithResponseID(eventType, responseID string) []byte {
+ m := map[string]any{"type": eventType, "response": map[string]any{"id": responseID}}
+ b, _ := json.Marshal(m)
+ return b
+}
+
+// ---------------------------------------------------------------------------
+// 辅助:模拟上游连接(支持按序返回事件、延迟、错误注入)
+// ---------------------------------------------------------------------------
+
+type pumpTestConn struct {
+ mu sync.Mutex
+ events []pumpTestConnEvent
+ readCount int
+ closed bool
+ closedCh chan struct{}
+ ignoreCtx bool
+ pingErr error
+ writeErr error
+ writeCount int
+}
+
+type pumpTestConnEvent struct {
+ data []byte
+ err error
+ delay time.Duration
+}
+
+func newPumpTestConn(events ...pumpTestConnEvent) *pumpTestConn {
+ return &pumpTestConn{
+ events: events,
+ closedCh: make(chan struct{}),
+ }
+}
+
+func (c *pumpTestConn) WriteJSON(_ context.Context, _ any) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ c.writeCount++
+ return c.writeErr
+}
+
+func (c *pumpTestConn) ReadMessage(ctx context.Context) ([]byte, error) {
+ if ctx == nil {
+ ctx = context.Background()
+ }
+ c.mu.Lock()
+ if c.closed {
+ c.mu.Unlock()
+ return nil, errOpenAIWSConnClosed
+ }
+ if len(c.events) == 0 {
+ c.mu.Unlock()
+ if c.ignoreCtx {
+ <-c.closedCh
+ return nil, io.EOF
+ }
+ // 阻塞直到上下文取消,模拟上游无更多事件
+ <-ctx.Done()
+ return nil, ctx.Err()
+ }
+ evt := c.events[0]
+ c.events = c.events[1:]
+ c.readCount++
+ c.mu.Unlock()
+
+ if evt.delay > 0 {
+ timer := time.NewTimer(evt.delay)
+ defer timer.Stop()
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case <-timer.C:
+ }
+ }
+ return evt.data, evt.err
+}
+
+func (c *pumpTestConn) Ping(_ context.Context) error { return c.pingErr }
+
+func (c *pumpTestConn) Close() error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ if !c.closed {
+ c.closed = true
+ close(c.closedCh)
+ }
+ return nil
+}
+
+// ---------------------------------------------------------------------------
+// 辅助:模拟 lease 接口(仅泵测试所需的读写方法)
+// ---------------------------------------------------------------------------
+
+type pumpTestLease struct {
+ conn *pumpTestConn
+ broken atomic.Bool
+}
+
+func (l *pumpTestLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) {
+ readCtx, cancel := context.WithTimeout(ctx, timeout)
+ defer cancel()
+ return l.conn.ReadMessage(readCtx)
+}
+
+func (l *pumpTestLease) MarkBroken() {
+ l.broken.Store(true)
+ if l.conn != nil {
+ _ = l.conn.Close()
+ }
+}
+
+func (l *pumpTestLease) IsBroken() bool { return l.broken.Load() }
+
+// ---------------------------------------------------------------------------
+// 辅助:运行泵 goroutine 并收集所有产出的事件
+// ---------------------------------------------------------------------------
+
+// startPump 模拟 sendAndRelay 中的泵 goroutine,返回事件 channel 和取消函数。
+func startPump(ctx context.Context, lease *pumpTestLease, readTimeout time.Duration) (chan openAIWSUpstreamPumpEvent, context.CancelFunc) {
+ pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize)
+ pumpCtx, pumpCancel := context.WithCancel(ctx)
+ go func() {
+ defer close(pumpEventCh)
+ for {
+ msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, readTimeout)
+ select {
+ case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}:
+ case <-pumpCtx.Done():
+ return
+ }
+ if readErr != nil {
+ return
+ }
+ evtType, _ := parseOpenAIWSEventType(msg)
+ if isOpenAIWSTerminalEvent(evtType) || evtType == "error" {
+ return
+ }
+ }
+ }()
+ return pumpEventCh, pumpCancel
+}
+
+// collectAll 从 channel 读取所有事件直到关闭。
+func collectAll(ch chan openAIWSUpstreamPumpEvent) []openAIWSUpstreamPumpEvent {
+ var result []openAIWSUpstreamPumpEvent
+ for evt := range ch {
+ result = append(result, evt)
+ }
+ return result
+}
+
+// ---------------------------------------------------------------------------
+// 测试:openAIWSUpstreamPumpEvent 结构体
+// ---------------------------------------------------------------------------
+
+func TestOpenAIWSUpstreamPumpEvent_Fields(t *testing.T) {
+ t.Parallel()
+
+ t.Run("message_only", func(t *testing.T) {
+ evt := openAIWSUpstreamPumpEvent{message: []byte("hello")}
+ assert.Equal(t, []byte("hello"), evt.message)
+ assert.NoError(t, evt.err)
+ })
+
+ t.Run("error_only", func(t *testing.T) {
+ evt := openAIWSUpstreamPumpEvent{err: io.EOF}
+ assert.Nil(t, evt.message)
+ assert.ErrorIs(t, evt.err, io.EOF)
+ })
+
+ t.Run("both_fields", func(t *testing.T) {
+ evt := openAIWSUpstreamPumpEvent{message: []byte("partial"), err: io.ErrUnexpectedEOF}
+ assert.Equal(t, []byte("partial"), evt.message)
+ assert.ErrorIs(t, evt.err, io.ErrUnexpectedEOF)
+ })
+}
+
+func TestOpenAIWSUpstreamPumpBufferSize(t *testing.T) {
+ t.Parallel()
+ assert.Equal(t, 16, openAIWSUpstreamPumpBufferSize, "缓冲大小应为 16")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵 goroutine 正常事件流
+// ---------------------------------------------------------------------------
+
+func TestPump_NormalEventFlow(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+
+ require.Len(t, events, 4)
+ for _, evt := range events {
+ assert.NoError(t, evt.err)
+ assert.NotEmpty(t, evt.message)
+ }
+ // 验证最后一个是终端事件
+ lastType, _ := parseOpenAIWSEventType(events[3].message)
+ assert.True(t, isOpenAIWSTerminalEvent(lastType))
+}
+
+func TestPump_TerminalEventStopsPump(t *testing.T) {
+ t.Parallel()
+ terminalTypes := []string{
+ "response.completed",
+ "response.done",
+ "response.failed",
+ "response.incomplete",
+ "response.cancelled",
+ "response.canceled",
+ }
+ for _, tt := range terminalTypes {
+ tt := tt
+ t.Run(tt, func(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent(tt)},
+ // 以下事件不应该被读取
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2, "终端事件 %s 后泵应停止", tt)
+ assert.NoError(t, events[0].err)
+ assert.NoError(t, events[1].err)
+ })
+ }
+}
+
+func TestPump_ErrorEventStopsPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("error")},
+ // 不应被读取
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2, "error 事件后泵应停止")
+ evtType, _ := parseOpenAIWSEventType(events[1].message)
+ assert.Equal(t, "error", evtType)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵 goroutine 读取错误传播
+// ---------------------------------------------------------------------------
+
+func TestPump_ReadErrorPropagated(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{err: io.ErrUnexpectedEOF},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2)
+ assert.NoError(t, events[0].err)
+ assert.ErrorIs(t, events[1].err, io.ErrUnexpectedEOF)
+}
+
+func TestPump_ReadErrorOnFirstEvent(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{err: errors.New("connection refused")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 1)
+ assert.Error(t, events[0].err)
+ assert.Contains(t, events[0].err.Error(), "connection refused")
+}
+
+func TestPump_EOFError(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{err: io.EOF},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 3)
+ assert.ErrorIs(t, events[2].err, io.EOF)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:上下文取消终止泵
+// ---------------------------------------------------------------------------
+
+func TestPump_ContextCancellationStopsPump(t *testing.T) {
+ t.Parallel()
+ // 连接永远阻塞在第二次读取
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 无更多事件,ReadMessage 将阻塞直到 ctx 取消
+ )
+ lease := &pumpTestLease{conn: conn}
+ ctx, ctxCancel := context.WithCancel(context.Background())
+ ch, pumpCancel := startPump(ctx, lease, 30*time.Second)
+ defer pumpCancel()
+
+ // 读取第一个事件
+ evt := <-ch
+ assert.NoError(t, evt.err)
+
+ // 取消上下文
+ ctxCancel()
+
+ // 泵应该退出,channel 应该关闭
+ events := collectAll(ch)
+ // 可能收到一个 context.Canceled 错误事件
+ for _, e := range events {
+ assert.Error(t, e.err)
+ }
+}
+
+func TestPump_PumpCancelStopsPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second)
+
+ evt := <-ch
+ assert.NoError(t, evt.err)
+
+ // 调用 pumpCancel 应终止泵
+ pumpCancel()
+
+ // channel 应被关闭
+ events := collectAll(ch)
+ for _, e := range events {
+ assert.Error(t, e.err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:缓冲行为
+// ---------------------------------------------------------------------------
+
+func TestPump_BufferAllowsConcurrentReadWrite(t *testing.T) {
+ t.Parallel()
+ // 生成超过缓冲大小的事件,验证不会死锁
+ numEvents := openAIWSUpstreamPumpBufferSize + 5
+ connEvents := make([]pumpTestConnEvent, 0, numEvents)
+ for i := 0; i < numEvents-1; i++ {
+ connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")})
+ }
+ connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.completed")})
+
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, numEvents)
+ for _, evt := range events {
+ assert.NoError(t, evt.err)
+ }
+}
+
+func TestPump_SlowConsumerDoesNotBlock(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ // 模拟慢消费者
+ var events []openAIWSUpstreamPumpEvent
+ for evt := range ch {
+ events = append(events, evt)
+ time.Sleep(10 * time.Millisecond) // 慢消费
+ }
+ require.Len(t, events, 4)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水定时器机制
+// ---------------------------------------------------------------------------
+
+func TestPump_DrainTimerCancelsPump(t *testing.T) {
+ t.Parallel()
+ // 模拟:客户端断连后,排水定时器到期取消泵
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 第二次读取会阻塞(模拟上游仍在生成但还没发出事件)
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second)
+
+ // 读取第一个事件
+ evt := <-ch
+ assert.NoError(t, evt.err)
+
+ // 模拟排水定时器:50ms 后取消泵(正式代码中是 5 秒)
+ drainTimer := time.AfterFunc(50*time.Millisecond, pumpCancel)
+ defer drainTimer.Stop()
+
+ // 等待 channel 关闭
+ start := time.Now()
+ remaining := collectAll(ch)
+ elapsed := time.Since(start)
+
+ // 应在 50ms 附近退出,而非 30 秒
+ assert.Less(t, elapsed, 2*time.Second, "排水定时器应在约 50ms 后终止泵")
+
+ // 可能收到 context.Canceled 错误事件
+ for _, e := range remaining {
+ assert.Error(t, e.err)
+ }
+}
+
+func TestPump_DrainDeadlineCheckInMainLoop(t *testing.T) {
+ t.Parallel()
+ // 模拟主循环中的排水超时检查逻辑
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 加延迟模拟上游慢响应
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 80 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ var eventsBeforeDrain []openAIWSUpstreamPumpEvent
+ drainTriggered := false
+
+ for evt := range ch {
+ // 检查排水超时
+ if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ pumpCancel()
+ drainTriggered = true
+ break
+ }
+ if evt.err != nil {
+ break
+ }
+ eventsBeforeDrain = append(eventsBeforeDrain, evt)
+
+ // 模拟:第一个事件后客户端断连,设置极短的排水截止时间
+ if !clientDisconnected && len(eventsBeforeDrain) == 1 {
+ clientDisconnected = true
+ drainDeadline = time.Now().Add(30 * time.Millisecond)
+ }
+ }
+
+ // 排水截止时间为 30ms,第二个事件延迟 80ms,所以应该触发排水超时
+ assert.True(t, drainTriggered, "排水超时应被触发")
+ assert.Len(t, eventsBeforeDrain, 1, "排水前应只有 1 个事件")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:与上游事件延迟的并发行为
+// ---------------------------------------------------------------------------
+
+func TestPump_ReadDelayDoesNotBlockPreviousEvents(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 100 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ // 第一个事件应该立即可用
+ start := time.Now()
+ evt := <-ch
+ assert.NoError(t, evt.err)
+ assert.Less(t, time.Since(start), 50*time.Millisecond, "第一个事件应立即到达")
+
+ events := collectAll(ch)
+ require.Len(t, events, 2)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:空事件流
+// ---------------------------------------------------------------------------
+
+func TestPump_EmptyStreamContextCancel(t *testing.T) {
+ t.Parallel()
+ // 没有任何事件,连接阻塞,靠 context 取消
+ conn := newPumpTestConn() // 无事件
+ lease := &pumpTestLease{conn: conn}
+ ctx, ctxCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
+ defer ctxCancel()
+ ch, pumpCancel := startPump(ctx, lease, 30*time.Second)
+ defer pumpCancel()
+
+ events := collectAll(ch)
+ // context 取消后,泵的 select 可能选择 pumpCtx.Done() 分支直接退出(0 个事件),
+ // 也可能先将错误事件发送到 channel 后退出(1 个事件),两种行为都正确。
+ assert.LessOrEqual(t, len(events), 1, "最多应收到 1 个事件")
+ for _, evt := range events {
+ assert.Error(t, evt.err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:非终端/非错误事件不终止泵
+// ---------------------------------------------------------------------------
+
+func TestPump_NonTerminalEventsDoNotStopPump(t *testing.T) {
+ t.Parallel()
+ nonTerminalTypes := []string{
+ "response.created",
+ "response.in_progress",
+ "response.output_text.delta",
+ "response.content_part.added",
+ "response.output_item.added",
+ "response.reasoning_summary_text.delta",
+ }
+ connEvents := make([]pumpTestConnEvent, 0, len(nonTerminalTypes)+1)
+ for _, et := range nonTerminalTypes {
+ connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent(et)})
+ }
+ connEvents = append(connEvents, pumpTestConnEvent{data: pumpTestEvent("response.completed")})
+
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, len(nonTerminalTypes)+1, "所有非终端事件 + 终端事件都应被传递")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:多次 pumpCancel 调用安全(幂等)
+// ---------------------------------------------------------------------------
+
+func TestPump_MultipleCancelSafe(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ events := collectAll(ch)
+ require.Len(t, events, 1)
+
+ // 多次调用 pumpCancel 不应 panic
+ assert.NotPanics(t, func() {
+ pumpCancel()
+ pumpCancel()
+ pumpCancel()
+ })
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵与主循环集成——模拟完整的 relay 消费模式
+// ---------------------------------------------------------------------------
+
+func TestPump_IntegrationRelayPattern(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_abc123")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEventWithResponseID("response.completed", "resp_abc123")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ // 模拟主循环处理
+ var responseID string
+ eventCount := 0
+ tokenEventCount := 0
+ var terminalEventType string
+ clientWriteCount := 0
+
+ for evt := range ch {
+ if evt.err != nil {
+ t.Fatalf("unexpected error: %v", evt.err)
+ }
+ eventType, evtRespID := parseOpenAIWSEventType(evt.message)
+ if responseID == "" && evtRespID != "" {
+ responseID = evtRespID
+ }
+ eventCount++
+ if isOpenAIWSTokenEvent(eventType) {
+ tokenEventCount++
+ }
+ // 模拟写客户端
+ clientWriteCount++
+
+ if isOpenAIWSTerminalEvent(eventType) {
+ terminalEventType = eventType
+ break
+ }
+ }
+
+ assert.Equal(t, "resp_abc123", responseID)
+ assert.Equal(t, 5, eventCount)
+ assert.GreaterOrEqual(t, tokenEventCount, 3, "至少 3 个 delta 事件应被计为 token 事件")
+ assert.Equal(t, 5, clientWriteCount)
+ assert.Equal(t, "response.completed", terminalEventType)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵 goroutine 在 channel 满时 + context 取消的行为
+// ---------------------------------------------------------------------------
+
+func TestPump_ChannelFullThenCancel(t *testing.T) {
+ t.Parallel()
+ // 生成大量事件但不消费,验证 pumpCancel 仍然能终止泵
+ numEvents := openAIWSUpstreamPumpBufferSize * 3
+ connEvents := make([]pumpTestConnEvent, numEvents)
+ for i := range connEvents {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}
+ }
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ // 等待缓冲区被填满
+ time.Sleep(50 * time.Millisecond)
+
+ // 取消泵
+ pumpCancel()
+
+ // 清空 channel
+ events := collectAll(ch)
+ // 应收到 bufferSize 到 bufferSize+1 个事件(泵在 channel 满时可能阻塞在 select)
+ assert.LessOrEqual(t, len(events), numEvents, "不应收到超过总事件数的事件")
+ assert.GreaterOrEqual(t, len(events), 1, "至少应收到一些事件")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:读取超时机制
+// ---------------------------------------------------------------------------
+
+func TestPump_ReadTimeoutTriggersError(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 第二次读取延迟超过超时
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 500 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ // 读取超时设为 50ms,远小于 500ms 延迟
+ ch, cancel := startPump(context.Background(), lease, 50*time.Millisecond)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2)
+ assert.NoError(t, events[0].err)
+ assert.Error(t, events[1].err, "第二次读取应超时")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵在 response.done 事件后停止(另一种终端事件)
+// ---------------------------------------------------------------------------
+
+func TestPump_ResponseDoneStopsPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.done")},
+ pumpTestConnEvent{data: pumpTestEvent("should_not_reach")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 3)
+ lastType, _ := parseOpenAIWSEventType(events[2].message)
+ assert.Equal(t, "response.done", lastType)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵在读取到 error event 后不继续读取更多事件
+// ---------------------------------------------------------------------------
+
+func TestPump_ErrorEventStopsReading(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("error")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}, // 不应被读取
+ )
+ // 重写以追踪读取次数
+ origEvents := conn.events
+ conn.events = nil
+ var wrappedConn pumpTestConn
+ wrappedConn.closedCh = make(chan struct{})
+ wrappedConn.events = origEvents
+ wrappedLease := &pumpTestLease{conn: &wrappedConn}
+
+ ch, cancel := startPump(context.Background(), wrappedLease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 1, "error 事件后不应再读取更多事件")
+ evtType, _ := parseOpenAIWSEventType(events[0].message)
+ assert.Equal(t, "error", evtType)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:验证事件顺序保持不变
+// ---------------------------------------------------------------------------
+
+func TestPump_EventOrderPreserved(t *testing.T) {
+ t.Parallel()
+ expectedTypes := []string{
+ "response.created",
+ "response.in_progress",
+ "response.output_item.added",
+ "response.content_part.added",
+ "response.output_text.delta",
+ "response.output_text.delta",
+ "response.output_text.delta",
+ "response.output_text.done",
+ "response.content_part.done",
+ "response.output_item.done",
+ "response.completed",
+ }
+ connEvents := make([]pumpTestConnEvent, len(expectedTypes))
+ for i, et := range expectedTypes {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent(et)}
+ }
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, len(expectedTypes))
+ for i, evt := range events {
+ evtType, _ := parseOpenAIWSEventType(evt.message)
+ assert.Equal(t, expectedTypes[i], evtType, "事件 %d 类型不匹配", i)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:无效 JSON 消息不影响泵运行
+// ---------------------------------------------------------------------------
+
+func TestPump_InvalidJSONDoesNotStopPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: []byte("not json")},
+ pumpTestConnEvent{data: []byte("{invalid")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 3, "无效 JSON 不应终止泵")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:并发安全——多个消费者不会 panic
+// ---------------------------------------------------------------------------
+
+func TestPump_ConcurrentConsumeAndCancel(t *testing.T) {
+ t.Parallel()
+ connEvents := make([]pumpTestConnEvent, 100)
+ for i := range connEvents {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: time.Millisecond}
+ }
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ // 同时消费和取消,不应 panic
+ var wg sync.WaitGroup
+ wg.Add(2)
+ go func() {
+ defer wg.Done()
+ for range ch {
+ // 消费
+ }
+ }()
+ go func() {
+ defer wg.Done()
+ time.Sleep(20 * time.Millisecond)
+ pumpCancel()
+ }()
+
+ done := make(chan struct{})
+ go func() {
+ wg.Wait()
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ // 成功
+ case <-time.After(5 * time.Second):
+ t.Fatal("超时:并发消费和取消场景死锁")
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水定时器与正常终端事件的竞争
+// ---------------------------------------------------------------------------
+
+func TestPump_DrainTimerRaceWithTerminalEvent(t *testing.T) {
+ t.Parallel()
+ // 终端事件在排水定时器到期前到达
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 10 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ // 设置较长的排水定时器(200ms),终端事件应在 10ms 后到达
+ drainTimer := time.AfterFunc(200*time.Millisecond, pumpCancel)
+ defer drainTimer.Stop()
+
+ events := collectAll(ch)
+ // 终端事件应先到达
+ require.Len(t, events, 2)
+ lastType, _ := parseOpenAIWSEventType(events[1].message)
+ assert.Equal(t, "response.completed", lastType)
+ assert.NoError(t, events[1].err)
+
+ pumpCancel() // 清理
+}
+
+// ---------------------------------------------------------------------------
+// 测试:大量事件的吞吐量(确保泵不引入异常开销)
+// ---------------------------------------------------------------------------
+
+func TestPump_HighThroughput(t *testing.T) {
+ t.Parallel()
+ numEvents := 1000
+ connEvents := make([]pumpTestConnEvent, numEvents)
+ for i := range connEvents {
+ if i == numEvents-1 {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.completed")}
+ } else {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}
+ }
+ }
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ start := time.Now()
+ events := collectAll(ch)
+ elapsed := time.Since(start)
+
+ require.Len(t, events, numEvents)
+ assert.Less(t, elapsed, 2*time.Second, "1000 个事件应在 2 秒内完成")
+ for _, evt := range events {
+ assert.NoError(t, evt.err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:空消息(零字节)不终止泵
+// ---------------------------------------------------------------------------
+
+func TestPump_EmptyMessageDoesNotStopPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: []byte{}},
+ pumpTestConnEvent{data: nil},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 3, "空消息不应终止泵")
+}
+
+// ===========================================================================
+// 以下为消息泵模式新增代码路径的补充测试
+// ===========================================================================
+
+// ---------------------------------------------------------------------------
+// 测试:泵 channel 关闭但无终端事件(上游异常断连)
+// ---------------------------------------------------------------------------
+
+func TestPump_UnexpectedCloseDetectedByConsumer(t *testing.T) {
+ t.Parallel()
+ // 模拟:上游只发了非终端事件就断连(ReadMessage 返回 EOF)
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{err: io.EOF}, // 上游断连
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ // 模拟主循环消费:检查是否收到了终端事件
+ receivedTerminal := false
+ var lastErr error
+ for evt := range ch {
+ if evt.err != nil {
+ lastErr = evt.err
+ break
+ }
+ evtType, _ := parseOpenAIWSEventType(evt.message)
+ if isOpenAIWSTerminalEvent(evtType) {
+ receivedTerminal = true
+ break
+ }
+ }
+ // 未收到终端事件,但收到了 EOF 错误——消费者应识别为上游异常断连
+ assert.False(t, receivedTerminal, "不应收到终端事件")
+ assert.ErrorIs(t, lastErr, io.EOF, "应收到 EOF 错误标识上游断连")
+}
+
+func TestPump_ChannelCloseWithoutTerminalOrError(t *testing.T) {
+ t.Parallel()
+ // 极端情况:泵被外部取消(pumpCancel),channel 关闭但既无终端事件也无错误事件。
+ // 模拟中间事件后泵被取消。
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ // 无更多事件,ReadMessage 将阻塞
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second)
+
+ // 消费前两个事件
+ evt1 := <-ch
+ assert.NoError(t, evt1.err)
+ evt2 := <-ch
+ assert.NoError(t, evt2.err)
+
+ // 外部取消泵
+ pumpCancel()
+
+ // for-range 应退出,模拟 "泵 channel 关闭但未收到终端事件" 场景
+ receivedTerminal := false
+ for evt := range ch {
+ if evt.err == nil {
+ evtType, _ := parseOpenAIWSEventType(evt.message)
+ if isOpenAIWSTerminalEvent(evtType) {
+ receivedTerminal = true
+ }
+ }
+ }
+ assert.False(t, receivedTerminal, "泵被取消后不应再收到终端事件")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:lease.MarkBroken 场景验证
+// ---------------------------------------------------------------------------
+
+func TestPump_LeaseMarkedBrokenOnUnexpectedClose(t *testing.T) {
+ t.Parallel()
+ // 模拟主循环:泵关闭但无终端事件时应标记 lease 为 broken
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{err: io.ErrUnexpectedEOF},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ receivedTerminal := false
+ for evt := range ch {
+ if evt.err != nil {
+ // 模拟正式代码中的错误处理路径
+ lease.MarkBroken()
+ break
+ }
+ evtType, _ := parseOpenAIWSEventType(evt.message)
+ if isOpenAIWSTerminalEvent(evtType) {
+ receivedTerminal = true
+ break
+ }
+ }
+ // 如果 for-range 正常退出且未收到终端事件,也标记 broken
+ if !receivedTerminal {
+ lease.MarkBroken()
+ }
+
+ assert.True(t, lease.IsBroken(), "上游异常断连应标记 lease 为 broken")
+}
+
+func TestPump_LeaseNotBrokenOnNormalTerminal(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ for evt := range ch {
+ if evt.err != nil {
+ lease.MarkBroken()
+ break
+ }
+ }
+
+ assert.False(t, lease.IsBroken(), "正常终端事件不应标记 lease 为 broken")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水定时器只创建一次
+// ---------------------------------------------------------------------------
+
+func TestPump_DrainTimerCreatedOnlyOnce(t *testing.T) {
+ t.Parallel()
+ // 模拟多次"客户端断连"信号,验证排水定时器只创建一次
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 10 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ drainTimerCount := 0
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ var drainTimer *time.Timer
+
+ for evt := range ch {
+ if evt.err != nil {
+ break
+ }
+ // 每个事件后都"检测到客户端断连"
+ if !clientDisconnected {
+ clientDisconnected = true
+ }
+ // 排水定时器只在第一次断连时创建
+ if clientDisconnected && drainDeadline.IsZero() {
+ drainDeadline = time.Now().Add(500 * time.Millisecond)
+ drainTimer = time.AfterFunc(500*time.Millisecond, pumpCancel)
+ drainTimerCount++
+ }
+ }
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+
+ assert.Equal(t, 1, drainTimerCount, "排水定时器应只创建一次")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水定时器在正常完成前被 Stop
+// ---------------------------------------------------------------------------
+
+func TestPump_DrainTimerStoppedOnNormalCompletion(t *testing.T) {
+ t.Parallel()
+ // 终端事件在排水定时器到期前到达,验证定时器被正确停止
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ // 创建长时间排水定时器
+ drainTimer := time.AfterFunc(10*time.Second, pumpCancel)
+
+ var events []openAIWSUpstreamPumpEvent
+ for evt := range ch {
+ events = append(events, evt)
+ }
+
+ // 正常完成后停止排水定时器(模拟 defer drainTimer.Stop())
+ stopped := drainTimer.Stop()
+ pumpCancel() // 清理
+
+ assert.True(t, stopped, "定时器应尚未触发,Stop() 返回 true")
+ require.Len(t, events, 2)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水期间读取错误处理
+// ---------------------------------------------------------------------------
+
+func TestPump_ReadErrorDuringDrainTreatedAsDrainTimeout(t *testing.T) {
+ t.Parallel()
+ // 新代码:客户端已断连时任何读取错误都按排水超时处理(不仅限 DeadlineExceeded)
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{err: io.ErrUnexpectedEOF, delay: 20 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ clientDisconnected := false
+ var drainError error
+
+ for evt := range ch {
+ if !clientDisconnected {
+ // 第一个事件后模拟客户端断连
+ clientDisconnected = true
+ continue
+ }
+ if evt.err != nil && clientDisconnected {
+ // 新代码路径:排水期间收到读取错误
+ drainError = evt.err
+ break
+ }
+ }
+
+ assert.Error(t, drainError, "排水期间应收到读取错误")
+ assert.ErrorIs(t, drainError, io.ErrUnexpectedEOF)
+}
+
+func TestPump_ReadErrorDuringDrain_EOF(t *testing.T) {
+ t.Parallel()
+ // EOF 在排水期间等同于上游关闭
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{err: io.EOF},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ clientDisconnected := false
+ drainErrorCount := 0
+
+ for evt := range ch {
+ if evt.err != nil {
+ if clientDisconnected {
+ drainErrorCount++
+ }
+ break
+ }
+ // 第一个事件后模拟客户端断连
+ if !clientDisconnected {
+ clientDisconnected = true
+ }
+ }
+
+ assert.Equal(t, 1, drainErrorCount, "排水期间 EOF 应被计为一次排水错误")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:排水截止时间检查——在事件间隙中过期
+// ---------------------------------------------------------------------------
+
+func TestPump_DrainDeadlineExpiresBetweenEvents(t *testing.T) {
+ t.Parallel()
+ // 排水截止时间在两个上游事件之间到期
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 60 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 60 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ drainExpired := false
+ eventsProcessed := 0
+
+ for evt := range ch {
+ // 排水超时检查(在处理事件前,模拟正式代码)
+ if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ pumpCancel()
+ drainExpired = true
+ break
+ }
+ if evt.err != nil {
+ break
+ }
+ eventsProcessed++
+
+ // 第一个事件后断连,排水截止时间设为 30ms
+ if !clientDisconnected {
+ clientDisconnected = true
+ drainDeadline = time.Now().Add(30 * time.Millisecond)
+ }
+ }
+
+ // 第二个事件延迟 60ms > 排水截止 30ms,应触发排水超时
+ assert.True(t, drainExpired, "排水截止时间应在事件间隙中过期")
+ assert.Equal(t, 1, eventsProcessed, "过期前应只处理了 1 个事件")
+}
+
+func TestPump_DrainDeadlineNotYetExpiredAllowsProcessing(t *testing.T) {
+ t.Parallel()
+ // 排水截止时间足够长,允许处理所有事件
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 5 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+ defer pumpCancel()
+
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ drainExpired := false
+ eventsProcessed := 0
+
+ for evt := range ch {
+ if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ pumpCancel()
+ drainExpired = true
+ break
+ }
+ if evt.err != nil {
+ break
+ }
+ eventsProcessed++
+ if !clientDisconnected {
+ clientDisconnected = true
+ drainDeadline = time.Now().Add(500 * time.Millisecond) // 足够长
+ }
+ }
+
+ assert.False(t, drainExpired, "排水截止时间未过期,不应触发排水超时")
+ assert.Equal(t, 3, eventsProcessed, "所有事件都应被处理")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:goroutine 清理和资源释放
+// ---------------------------------------------------------------------------
+
+func TestPump_DeferPumpCancelAndDrainTimerCleanup(t *testing.T) {
+ t.Parallel()
+ // 模拟正式代码的完整 defer 清理路径
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ pumpEventCh := make(chan openAIWSUpstreamPumpEvent, openAIWSUpstreamPumpBufferSize)
+ pumpCtx, pumpCancel := context.WithCancel(context.Background())
+ // 模拟 defer pumpCancel()
+ defer pumpCancel()
+
+ go func() {
+ defer close(pumpEventCh)
+ for {
+ msg, readErr := lease.ReadMessageWithContextTimeout(pumpCtx, 5*time.Second)
+ select {
+ case pumpEventCh <- openAIWSUpstreamPumpEvent{message: msg, err: readErr}:
+ case <-pumpCtx.Done():
+ return
+ }
+ if readErr != nil {
+ return
+ }
+ evtType, _ := parseOpenAIWSEventType(msg)
+ if isOpenAIWSTerminalEvent(evtType) || evtType == "error" {
+ return
+ }
+ }
+ }()
+
+ // 模拟排水定时器
+ var drainTimer *time.Timer
+ defer func() {
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+ }()
+ drainTimer = time.AfterFunc(10*time.Second, pumpCancel)
+
+ events := collectAll(pumpEventCh)
+ require.Len(t, events, 2)
+
+ // defer 清理后不应 panic
+ assert.NotPanics(t, func() {
+ pumpCancel()
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+ })
+}
+
+// ---------------------------------------------------------------------------
+// 测试:连接在泵运行期间被关闭
+// ---------------------------------------------------------------------------
+
+func TestPump_ConnectionClosedDuringPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 后续事件阻塞
+ )
+ lease := &pumpTestLease{conn: conn}
+ // 使用较短的读超时,因为 conn.Close() 不会解除阻塞的 ReadMessage(它等待 <-ctx.Done())
+ ch, pumpCancel := startPump(context.Background(), lease, 100*time.Millisecond)
+ defer pumpCancel()
+
+ // 读取第一个事件
+ evt := <-ch
+ assert.NoError(t, evt.err)
+
+ // 关闭连接——注意:ReadMessage 仍在等待 ctx.Done(),
+ // 但读超时为 100ms 会触发 context.DeadlineExceeded。
+ // 下次 ReadMessage 调用时会检测到 closed 状态。
+ _ = conn.Close()
+
+ // 泵应在读超时后检测到连接关闭
+ events := collectAll(ch)
+ require.GreaterOrEqual(t, len(events), 1, "应收到错误")
+ // 至少有一个事件包含错误
+ hasError := false
+ for _, e := range events {
+ if e.err != nil {
+ hasError = true
+ }
+ }
+ assert.True(t, hasError, "应收到连接关闭或超时错误")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:大消息(KB 级别 JSON)不影响泵传递
+// ---------------------------------------------------------------------------
+
+func TestPump_LargeMessages(t *testing.T) {
+ t.Parallel()
+ // 构造 ~10KB 的消息
+ largeContent := make([]byte, 10*1024)
+ for i := range largeContent {
+ largeContent[i] = 'x'
+ }
+ largeMsg := []byte(`{"type":"response.output_text.delta","delta":"` + string(largeContent) + `"}`)
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: largeMsg},
+ pumpTestConnEvent{data: largeMsg},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 4)
+ // 验证大消息完整传递
+ assert.Len(t, events[1].message, len(largeMsg))
+ assert.Len(t, events[2].message, len(largeMsg))
+}
+
+// ---------------------------------------------------------------------------
+// 测试:多轮泵会话(同一 lease 上依次创建多个泵)
+// ---------------------------------------------------------------------------
+
+func TestPump_SequentialSessions(t *testing.T) {
+ t.Parallel()
+ // 第一轮
+ conn1 := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease1 := &pumpTestLease{conn: conn1}
+ ch1, cancel1 := startPump(context.Background(), lease1, 5*time.Second)
+ events1 := collectAll(ch1)
+ cancel1()
+ require.Len(t, events1, 2)
+
+ // 第二轮(新连接、新泵)
+ conn2 := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease2 := &pumpTestLease{conn: conn2}
+ ch2, cancel2 := startPump(context.Background(), lease2, 5*time.Second)
+ events2 := collectAll(ch2)
+ cancel2()
+ require.Len(t, events2, 3)
+
+ // 两轮之间互不影响
+ assert.False(t, lease1.IsBroken())
+ assert.False(t, lease2.IsBroken())
+}
+
+// ---------------------------------------------------------------------------
+// 测试:完整 relay 模式集成——包含客户端断连和排水
+// ---------------------------------------------------------------------------
+
+func TestPump_IntegrationRelayWithClientDisconnectAndDrain(t *testing.T) {
+ t.Parallel()
+ // 模拟完整场景:上游慢速响应,客户端在中途断连,排水定时器到期终止
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_drain1")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond},
+ // 上游后续事件延迟大于排水超时
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 200 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 200 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ var drainTimer *time.Timer
+ defer func() {
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+ }()
+
+ eventsProcessed := 0
+ drainTriggered := false
+
+ for evt := range ch {
+ // 排水超时检查
+ if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ pumpCancel()
+ drainTriggered = true
+ lease.MarkBroken()
+ break
+ }
+ if evt.err != nil {
+ if clientDisconnected {
+ // 排水期间读取错误(pumpCancel 导致 context.Canceled)
+ drainTriggered = true
+ lease.MarkBroken()
+ }
+ break
+ }
+ eventsProcessed++
+
+ // 模拟:第 2 个事件后客户端断连
+ if eventsProcessed == 2 && !clientDisconnected {
+ clientDisconnected = true
+ drainDeadline = time.Now().Add(50 * time.Millisecond) // 50ms 排水超时
+ drainTimer = time.AfterFunc(50*time.Millisecond, pumpCancel)
+ }
+ }
+ // for-range 退出后,如果 channel 因 pumpCancel 关闭且排水截止已过期,
+ // 也视为排水超时触发(泵的 select 可能选择 pumpCtx.Done() 而不发送错误事件)。
+ if !drainTriggered && clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ drainTriggered = true
+ lease.MarkBroken()
+ }
+
+ pumpCancel() // 最终清理
+
+ // 排水超时应触发(50ms 排水 vs 200ms 后续事件延迟)
+ assert.True(t, drainTriggered, "排水超时应被触发")
+ assert.GreaterOrEqual(t, eventsProcessed, 2, "至少应处理 2 个事件")
+ assert.LessOrEqual(t, eventsProcessed, 4, "不应处理所有 5 个事件")
+}
+
+func TestPump_IntegrationRelayWithSuccessfulDrain(t *testing.T) {
+ t.Parallel()
+ // 客户端断连后上游快速完成,在排水超时前正常结束
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEventWithResponseID("response.created", "resp_drain2")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 5 * time.Millisecond},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed"), delay: 5 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ clientDisconnected := false
+ drainDeadline := time.Time{}
+ var drainTimer *time.Timer
+ defer func() {
+ if drainTimer != nil {
+ drainTimer.Stop()
+ }
+ }()
+
+ eventsProcessed := 0
+ receivedTerminal := false
+ drainTriggered := false
+
+ for evt := range ch {
+ if clientDisconnected && !drainDeadline.IsZero() && time.Now().After(drainDeadline) {
+ pumpCancel()
+ drainTriggered = true
+ break
+ }
+ if evt.err != nil {
+ break
+ }
+ eventsProcessed++
+
+ evtType, _ := parseOpenAIWSEventType(evt.message)
+ if isOpenAIWSTerminalEvent(evtType) {
+ receivedTerminal = true
+ break
+ }
+
+ // 第一个事件后客户端断连
+ if eventsProcessed == 1 && !clientDisconnected {
+ clientDisconnected = true
+ drainDeadline = time.Now().Add(500 * time.Millisecond) // 足够长的排水超时
+ drainTimer = time.AfterFunc(500*time.Millisecond, pumpCancel)
+ }
+ }
+
+ pumpCancel()
+
+ assert.False(t, drainTriggered, "排水超时不应触发")
+ assert.True(t, receivedTerminal, "应正常收到终端事件")
+ assert.Equal(t, 4, eventsProcessed, "所有 4 个事件都应被处理")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵事件错误携带部分消息数据
+// ---------------------------------------------------------------------------
+
+func TestPump_ErrorEventWithPartialMessage(t *testing.T) {
+ t.Parallel()
+ // 模拟上游返回部分数据和错误
+ partialData := []byte(`{"type":"response.output_text`)
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: partialData, err: io.ErrUnexpectedEOF},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2)
+ // 第二个事件同时携带 message 和 error
+ assert.Equal(t, partialData, events[1].message)
+ assert.ErrorIs(t, events[1].err, io.ErrUnexpectedEOF)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:零超时读取
+// ---------------------------------------------------------------------------
+
+func TestPump_ZeroReadTimeout(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 第二次读取需要时间,但超时为 0
+ pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta"), delay: 10 * time.Millisecond},
+ )
+ lease := &pumpTestLease{conn: conn}
+ // 使用极短超时(1 纳秒 ~ 立即超时)
+ ch, cancel := startPump(context.Background(), lease, time.Nanosecond)
+ defer cancel()
+
+ events := collectAll(ch)
+ // 至少第一个事件成功读取(无延迟),第二个大概率超时
+ require.GreaterOrEqual(t, len(events), 1)
+ // 查找是否有超时错误
+ hasTimeout := false
+ for _, evt := range events {
+ if evt.err != nil && errors.Is(evt.err, context.DeadlineExceeded) {
+ hasTimeout = true
+ }
+ }
+ assert.True(t, hasTimeout, "极短超时应产生 DeadlineExceeded 错误")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:并发多个排水定时器取消(防止重复调用 pumpCancel)
+// ---------------------------------------------------------------------------
+
+func TestPump_ConcurrentDrainTimerAndExternalCancel(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ // 阻塞
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second)
+
+ // 读取第一个事件
+ evt := <-ch
+ assert.NoError(t, evt.err)
+
+ // 同时设置排水定时器和外部取消
+ done := make(chan struct{})
+ drainTimer := time.AfterFunc(30*time.Millisecond, pumpCancel)
+ defer drainTimer.Stop()
+
+ go func() {
+ time.Sleep(20 * time.Millisecond)
+ pumpCancel() // 外部取消稍早于定时器
+ close(done)
+ }()
+
+ // 不应死锁或 panic
+ events := collectAll(ch)
+ <-done
+
+ // 验证泵已终止
+ for _, e := range events {
+ if e.err != nil {
+ assert.Error(t, e.err)
+ }
+ }
+}
+
+func TestPump_DrainTimerMarkBrokenUnblocksIgnoreContextRead(t *testing.T) {
+ t.Parallel()
+
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ )
+ conn.ignoreCtx = true
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 30*time.Second)
+ defer pumpCancel()
+
+ first := <-ch
+ require.NoError(t, first.err)
+ require.NotEmpty(t, first.message)
+
+ done := make(chan struct{})
+ drainTimer := time.AfterFunc(30*time.Millisecond, func() {
+ lease.MarkBroken()
+ pumpCancel()
+ })
+ defer drainTimer.Stop()
+
+ go func() {
+ _ = collectAll(ch)
+ close(done)
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("pump should stop quickly when drain timer marks lease broken")
+ }
+
+ assert.True(t, lease.IsBroken(), "lease should be marked broken by drain timer")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:快速连续事件(突发模式)
+// ---------------------------------------------------------------------------
+
+func TestPump_BurstEvents(t *testing.T) {
+ t.Parallel()
+ // 50 个事件无延迟突发
+ numBurst := 50
+ connEvents := make([]pumpTestConnEvent, numBurst+1)
+ for i := 0; i < numBurst; i++ {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}
+ }
+ connEvents[numBurst] = pumpTestConnEvent{data: pumpTestEvent("response.completed")}
+
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, numBurst+1, "突发事件应全部被传递")
+
+ // 验证所有事件无错误
+ for i, evt := range events {
+ assert.NoError(t, evt.err, "事件 %d 不应有错误", i)
+ }
+ lastType, _ := parseOpenAIWSEventType(events[numBurst].message)
+ assert.True(t, isOpenAIWSTerminalEvent(lastType))
+}
+
+// ---------------------------------------------------------------------------
+// 测试:事件类型解析边界情况
+// ---------------------------------------------------------------------------
+
+func TestPump_EventTypeParsingEdgeCases(t *testing.T) {
+ t.Parallel()
+ // 各种边缘 JSON 格式
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: []byte(`{"type": " response.created "}`)}, // 带空格
+ pumpTestConnEvent{data: []byte(`{"type":"response.output_text.delta"}`)}, // 无空格
+ pumpTestConnEvent{data: []byte(`{"type":"","other":"field"}`)}, // 空类型
+ pumpTestConnEvent{data: []byte(`{"no_type_field": true}`)}, // 无 type 字段
+ pumpTestConnEvent{data: []byte(`{"type":"response.completed"}`)}, // 终端
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 5, "所有格式的事件都应被传递")
+ for _, evt := range events {
+ assert.NoError(t, evt.err)
+ }
+}
+
+// ---------------------------------------------------------------------------
+// 测试:function_call_output 等非标准事件类型不终止泵
+// ---------------------------------------------------------------------------
+
+func TestPump_FunctionCallOutputDoesNotStopPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: pumpTestEvent("response.function_call_arguments.delta")},
+ pumpTestConnEvent{data: pumpTestEvent("response.function_call_arguments.done")},
+ pumpTestConnEvent{data: pumpTestEvent("response.output_item.done")},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 5, "function_call 相关事件不应终止泵")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:pumpCancel 在 channel 已关闭后调用不 panic
+// ---------------------------------------------------------------------------
+
+func TestPump_CancelAfterChannelClosed(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, pumpCancel := startPump(context.Background(), lease, 5*time.Second)
+
+ // 等待 channel 关闭
+ events := collectAll(ch)
+ require.Len(t, events, 1)
+
+ // channel 已关闭后再取消不应 panic
+ assert.NotPanics(t, func() {
+ pumpCancel()
+ })
+
+ // 再次从已关闭 channel 读取应返回零值
+ evt, ok := <-ch
+ assert.False(t, ok, "channel 应已关闭")
+ assert.Nil(t, evt.message)
+ assert.NoError(t, evt.err)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:混合事件大小(小消息和大消息交替)
+// ---------------------------------------------------------------------------
+
+func TestPump_MixedMessageSizes(t *testing.T) {
+ t.Parallel()
+ smallMsg := pumpTestEvent("response.output_text.delta")
+ largeContent := make([]byte, 64*1024) // 64KB
+ for i := range largeContent {
+ largeContent[i] = 'A'
+ }
+ largeMsg := []byte(`{"type":"response.output_text.delta","delta":"` + string(largeContent) + `"}`)
+
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{data: smallMsg},
+ pumpTestConnEvent{data: largeMsg},
+ pumpTestConnEvent{data: smallMsg},
+ pumpTestConnEvent{data: largeMsg},
+ pumpTestConnEvent{data: pumpTestEvent("response.completed")},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 6)
+ assert.Len(t, events[2].message, len(largeMsg), "大消息应完整传递")
+ assert.Len(t, events[4].message, len(largeMsg), "大消息应完整传递")
+}
+
+// ---------------------------------------------------------------------------
+// 测试:泵在 errOpenAIWSConnClosed 错误后停止
+// ---------------------------------------------------------------------------
+
+func TestPump_ConnClosedErrorStopsPump(t *testing.T) {
+ t.Parallel()
+ conn := newPumpTestConn(
+ pumpTestConnEvent{data: pumpTestEvent("response.created")},
+ pumpTestConnEvent{err: errOpenAIWSConnClosed},
+ )
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ events := collectAll(ch)
+ require.Len(t, events, 2)
+ assert.ErrorIs(t, events[1].err, errOpenAIWSConnClosed)
+}
+
+// ---------------------------------------------------------------------------
+// 测试:同时读取和写入——验证读写解耦
+// ---------------------------------------------------------------------------
+
+func TestPump_ReadWriteDecoupling(t *testing.T) {
+ t.Parallel()
+ // 模拟:上游事件到达时,客户端写入有延迟(通过 channel 消费延迟模拟)
+ numEvents := 10
+ connEvents := make([]pumpTestConnEvent, numEvents)
+ for i := 0; i < numEvents-1; i++ {
+ connEvents[i] = pumpTestConnEvent{data: pumpTestEvent("response.output_text.delta")}
+ }
+ connEvents[numEvents-1] = pumpTestConnEvent{data: pumpTestEvent("response.completed")}
+
+ conn := newPumpTestConn(connEvents...)
+ lease := &pumpTestLease{conn: conn}
+ ch, cancel := startPump(context.Background(), lease, 5*time.Second)
+ defer cancel()
+
+ // 模拟慢写入:每个事件处理需要 5ms
+ start := time.Now()
+ var events []openAIWSUpstreamPumpEvent
+ for evt := range ch {
+ events = append(events, evt)
+ time.Sleep(5 * time.Millisecond) // 模拟写入延迟
+ }
+ elapsed := time.Since(start)
+
+ require.Len(t, events, numEvents)
+ // 如果没有并发(串行读写),总时间 >= numEvents * 5ms = 50ms
+ // 有缓冲并发时,上游读取可以提前完成,总时间 < 串行预估
+ // 此处验证所有事件都被传递即可
+ t.Logf("处理 %d 个事件耗时: %v (慢消费模式)", numEvents, elapsed)
+}
diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go
index b22da7522..590793c11 100644
--- a/backend/internal/service/redeem_service.go
+++ b/backend/internal/service/redeem_service.go
@@ -15,11 +15,12 @@ import (
)
var (
- ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found")
- ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used")
- ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance")
- ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later")
- ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again")
+ ErrRedeemCodeNotFound = infraerrors.NotFound("REDEEM_CODE_NOT_FOUND", "redeem code not found")
+ ErrRedeemCodeUsed = infraerrors.Conflict("REDEEM_CODE_USED", "redeem code already used")
+ ErrInsufficientBalance = infraerrors.BadRequest("INSUFFICIENT_BALANCE", "insufficient balance")
+ ErrRedeemRateLimited = infraerrors.TooManyRequests("REDEEM_RATE_LIMITED", "too many failed attempts, please try again later")
+ ErrRedeemCodeLocked = infraerrors.Conflict("REDEEM_CODE_LOCKED", "redeem code is being processed, please try again")
+ ErrBalanceCacheNotFound = errors.New("balance cache key not found")
)
const (
diff --git a/backend/internal/service/setting_bulk_edit_template.go b/backend/internal/service/setting_bulk_edit_template.go
new file mode 100644
index 000000000..dd28e1633
--- /dev/null
+++ b/backend/internal/service/setting_bulk_edit_template.go
@@ -0,0 +1,770 @@
+package service
+
+import (
+ "context"
+ "crypto/rand"
+ "encoding/hex"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "sort"
+ "strings"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+)
+
+const (
+ BulkEditTemplateShareScopePrivate = "private"
+ BulkEditTemplateShareScopeTeam = "team"
+ BulkEditTemplateShareScopeGroups = "groups"
+)
+
+var (
+ ErrBulkEditTemplateNotFound = infraerrors.NotFound("BULK_EDIT_TEMPLATE_NOT_FOUND", "bulk edit template not found")
+ ErrBulkEditTemplateVersionNotFound = infraerrors.NotFound(
+ "BULK_EDIT_TEMPLATE_VERSION_NOT_FOUND",
+ "bulk edit template version not found",
+ )
+ ErrBulkEditTemplateForbidden = infraerrors.Forbidden(
+ "BULK_EDIT_TEMPLATE_FORBIDDEN",
+ "no permission to modify this bulk edit template",
+ )
+ bulkEditTemplateRandRead = rand.Read
+)
+
+type BulkEditTemplate struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ ScopePlatform string `json:"scope_platform"`
+ ScopeType string `json:"scope_type"`
+ ShareScope string `json:"share_scope"`
+ GroupIDs []int64 `json:"group_ids"`
+ State map[string]any `json:"state"`
+ CreatedBy int64 `json:"created_by"`
+ UpdatedBy int64 `json:"updated_by"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+}
+
+type BulkEditTemplateQuery struct {
+ ScopePlatform string
+ ScopeType string
+ ScopeGroupIDs []int64
+ RequesterUserID int64
+}
+
+type BulkEditTemplateVersion struct {
+ VersionID string `json:"version_id"`
+ ShareScope string `json:"share_scope"`
+ GroupIDs []int64 `json:"group_ids"`
+ State map[string]any `json:"state"`
+ UpdatedBy int64 `json:"updated_by"`
+ UpdatedAt int64 `json:"updated_at"`
+}
+
+type BulkEditTemplateVersionQuery struct {
+ TemplateID string
+ ScopeGroupIDs []int64
+ RequesterUserID int64
+}
+
+type BulkEditTemplateUpsertInput struct {
+ ID string
+ Name string
+ ScopePlatform string
+ ScopeType string
+ ShareScope string
+ GroupIDs []int64
+ State map[string]any
+ RequesterUserID int64
+}
+
+type BulkEditTemplateRollbackInput struct {
+ TemplateID string
+ VersionID string
+ ScopeGroupIDs []int64
+ RequesterUserID int64
+}
+
+type bulkEditTemplateLibraryStore struct {
+ Items []bulkEditTemplateStoreItem `json:"items"`
+}
+
+type bulkEditTemplateVersionStoreItem struct {
+ VersionID string `json:"version_id"`
+ ShareScope string `json:"share_scope"`
+ GroupIDs []int64 `json:"group_ids"`
+ State json.RawMessage `json:"state"`
+ UpdatedBy int64 `json:"updated_by"`
+ UpdatedAt int64 `json:"updated_at"`
+}
+
+type bulkEditTemplateStoreItem struct {
+ ID string `json:"id"`
+ Name string `json:"name"`
+ ScopePlatform string `json:"scope_platform"`
+ ScopeType string `json:"scope_type"`
+ ShareScope string `json:"share_scope"`
+ GroupIDs []int64 `json:"group_ids"`
+ State json.RawMessage `json:"state"`
+ Versions []bulkEditTemplateVersionStoreItem `json:"versions"`
+ CreatedBy int64 `json:"created_by"`
+ UpdatedBy int64 `json:"updated_by"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+}
+
+func (s *SettingService) ListBulkEditTemplates(ctx context.Context, query BulkEditTemplateQuery) ([]BulkEditTemplate, error) {
+ store, err := s.loadBulkEditTemplateLibrary(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ scopePlatform := strings.TrimSpace(strings.ToLower(query.ScopePlatform))
+ scopeType := strings.TrimSpace(strings.ToLower(query.ScopeType))
+ scopeGroupIDs := normalizeBulkEditTemplateGroupIDs(query.ScopeGroupIDs)
+ scopeGroupSet := make(map[int64]struct{}, len(scopeGroupIDs))
+ for _, groupID := range scopeGroupIDs {
+ scopeGroupSet[groupID] = struct{}{}
+ }
+
+ out := make([]BulkEditTemplate, 0, len(store.Items))
+ for idx := range store.Items {
+ item := store.Items[idx]
+ if scopePlatform != "" && item.ScopePlatform != scopePlatform {
+ continue
+ }
+ if scopeType != "" && item.ScopeType != scopeType {
+ continue
+ }
+ if !isBulkEditTemplateVisible(item, query.RequesterUserID, scopeGroupSet) {
+ continue
+ }
+ out = append(out, toBulkEditTemplate(item))
+ }
+
+ sort.Slice(out, func(i, j int) bool {
+ if out[i].UpdatedAt == out[j].UpdatedAt {
+ return out[i].ID < out[j].ID
+ }
+ return out[i].UpdatedAt > out[j].UpdatedAt
+ })
+
+ return out, nil
+}
+
+func (s *SettingService) UpsertBulkEditTemplate(ctx context.Context, input BulkEditTemplateUpsertInput) (*BulkEditTemplate, error) {
+ name := strings.TrimSpace(input.Name)
+ if name == "" {
+ return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template name is required")
+ }
+ if input.RequesterUserID <= 0 {
+ return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized")
+ }
+
+ scopePlatform := strings.TrimSpace(strings.ToLower(input.ScopePlatform))
+ scopeType := strings.TrimSpace(strings.ToLower(input.ScopeType))
+ if scopePlatform == "" || scopeType == "" {
+ return nil, infraerrors.BadRequest(
+ "BULK_EDIT_TEMPLATE_INVALID_INPUT",
+ "scope_platform and scope_type are required",
+ )
+ }
+
+ shareScope, shareScopeErr := validateBulkEditTemplateShareScope(input.ShareScope)
+ if shareScopeErr != nil {
+ return nil, shareScopeErr
+ }
+
+ groupIDs := normalizeBulkEditTemplateGroupIDs(input.GroupIDs)
+ if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 {
+ return nil, infraerrors.BadRequest(
+ "BULK_EDIT_TEMPLATE_INVALID_INPUT",
+ "group_ids is required when share_scope=groups",
+ )
+ }
+
+ stateRaw, err := json.Marshal(input.State)
+ if err != nil {
+ return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "invalid template state")
+ }
+ if len(stateRaw) == 0 || string(stateRaw) == "null" {
+ stateRaw = json.RawMessage("{}")
+ }
+
+ store, err := s.loadBulkEditTemplateLibrary(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ templateID := strings.TrimSpace(input.ID)
+ matchIndex := -1
+ if templateID != "" {
+ for idx := range store.Items {
+ if store.Items[idx].ID == templateID {
+ matchIndex = idx
+ break
+ }
+ }
+ if matchIndex < 0 {
+ return nil, ErrBulkEditTemplateNotFound
+ }
+ }
+
+ if matchIndex < 0 && templateID == "" {
+ for idx := range store.Items {
+ item := store.Items[idx]
+ if item.ScopePlatform != scopePlatform || item.ScopeType != scopeType {
+ continue
+ }
+ if !strings.EqualFold(strings.TrimSpace(item.Name), name) {
+ continue
+ }
+ if !canModifyBulkEditTemplate(item, input.RequesterUserID) {
+ continue
+ }
+ matchIndex = idx
+ break
+ }
+ }
+
+ nowMS := time.Now().UnixMilli()
+ if matchIndex >= 0 {
+ item := store.Items[matchIndex]
+ if !canModifyBulkEditTemplate(item, input.RequesterUserID) {
+ return nil, ErrBulkEditTemplateForbidden
+ }
+
+ previousVersion := snapshotBulkEditTemplateVersion(item)
+ item.Versions = append(item.Versions, previousVersion)
+ item.Name = name
+ item.ScopePlatform = scopePlatform
+ item.ScopeType = scopeType
+ item.ShareScope = shareScope
+ item.GroupIDs = groupIDs
+ item.State = cloneBulkEditTemplateStateRaw(stateRaw)
+ if item.CreatedBy <= 0 {
+ item.CreatedBy = input.RequesterUserID
+ }
+ if item.CreatedAt <= 0 {
+ item.CreatedAt = nowMS
+ }
+ item.UpdatedBy = input.RequesterUserID
+ item.UpdatedAt = nowMS
+ store.Items[matchIndex] = item
+
+ if err := s.persistBulkEditTemplateLibrary(ctx, store); err != nil {
+ return nil, err
+ }
+ output := toBulkEditTemplate(item)
+ return &output, nil
+ }
+
+ if templateID == "" {
+ templateID = generateBulkEditTemplateID()
+ }
+
+ created := bulkEditTemplateStoreItem{
+ ID: templateID,
+ Name: name,
+ ScopePlatform: scopePlatform,
+ ScopeType: scopeType,
+ ShareScope: shareScope,
+ GroupIDs: groupIDs,
+ State: cloneBulkEditTemplateStateRaw(stateRaw),
+ Versions: []bulkEditTemplateVersionStoreItem{},
+ CreatedBy: input.RequesterUserID,
+ UpdatedBy: input.RequesterUserID,
+ CreatedAt: nowMS,
+ UpdatedAt: nowMS,
+ }
+ store.Items = append(store.Items, created)
+
+ if err := s.persistBulkEditTemplateLibrary(ctx, store); err != nil {
+ return nil, err
+ }
+
+ output := toBulkEditTemplate(created)
+ return &output, nil
+}
+
+func (s *SettingService) DeleteBulkEditTemplate(ctx context.Context, templateID string, requesterUserID int64) error {
+ id := strings.TrimSpace(templateID)
+ if id == "" {
+ return infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template id is required")
+ }
+ if requesterUserID <= 0 {
+ return infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized")
+ }
+
+ store, err := s.loadBulkEditTemplateLibrary(ctx)
+ if err != nil {
+ return err
+ }
+
+ idx := -1
+ for index := range store.Items {
+ if store.Items[index].ID == id {
+ idx = index
+ break
+ }
+ }
+ if idx < 0 {
+ return ErrBulkEditTemplateNotFound
+ }
+
+ target := store.Items[idx]
+ if target.ShareScope == BulkEditTemplateShareScopePrivate && target.CreatedBy > 0 && target.CreatedBy != requesterUserID {
+ return ErrBulkEditTemplateForbidden
+ }
+
+ store.Items = append(store.Items[:idx], store.Items[idx+1:]...)
+ return s.persistBulkEditTemplateLibrary(ctx, store)
+}
+
+func (s *SettingService) ListBulkEditTemplateVersions(
+ ctx context.Context,
+ query BulkEditTemplateVersionQuery,
+) ([]BulkEditTemplateVersion, error) {
+ templateID := strings.TrimSpace(query.TemplateID)
+ if templateID == "" {
+ return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template id is required")
+ }
+ if query.RequesterUserID <= 0 {
+ return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized")
+ }
+
+ store, err := s.loadBulkEditTemplateLibrary(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ scopeGroupSet := toBulkEditTemplateScopeGroupSet(query.ScopeGroupIDs)
+ target := findBulkEditTemplateStoreItemByID(store.Items, templateID)
+ if target == nil {
+ return nil, ErrBulkEditTemplateNotFound
+ }
+ if !isBulkEditTemplateVisible(*target, query.RequesterUserID, scopeGroupSet) {
+ return nil, ErrBulkEditTemplateForbidden
+ }
+
+ versions := make([]BulkEditTemplateVersion, 0, len(target.Versions))
+ for idx := range target.Versions {
+ versions = append(versions, toBulkEditTemplateVersion(target.Versions[idx]))
+ }
+
+ sort.Slice(versions, func(i, j int) bool {
+ if versions[i].UpdatedAt == versions[j].UpdatedAt {
+ return versions[i].VersionID < versions[j].VersionID
+ }
+ return versions[i].UpdatedAt > versions[j].UpdatedAt
+ })
+
+ return versions, nil
+}
+
+func (s *SettingService) RollbackBulkEditTemplate(
+ ctx context.Context,
+ input BulkEditTemplateRollbackInput,
+) (*BulkEditTemplate, error) {
+ templateID := strings.TrimSpace(input.TemplateID)
+ versionID := strings.TrimSpace(input.VersionID)
+ if templateID == "" || versionID == "" {
+ return nil, infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template_id and version_id are required")
+ }
+ if input.RequesterUserID <= 0 {
+ return nil, infraerrors.Unauthorized("UNAUTHORIZED", "unauthorized")
+ }
+
+ store, err := s.loadBulkEditTemplateLibrary(ctx)
+ if err != nil {
+ return nil, err
+ }
+
+ scopeGroupSet := toBulkEditTemplateScopeGroupSet(input.ScopeGroupIDs)
+ templateIndex := findBulkEditTemplateStoreItemIndexByID(store.Items, templateID)
+ if templateIndex < 0 {
+ return nil, ErrBulkEditTemplateNotFound
+ }
+
+ item := store.Items[templateIndex]
+ if !isBulkEditTemplateVisible(item, input.RequesterUserID, scopeGroupSet) {
+ return nil, ErrBulkEditTemplateForbidden
+ }
+
+ versionIndex := findBulkEditTemplateVersionIndexByID(item.Versions, versionID)
+ if versionIndex < 0 {
+ return nil, ErrBulkEditTemplateVersionNotFound
+ }
+
+ targetVersion := item.Versions[versionIndex]
+ previousVersion := snapshotBulkEditTemplateVersion(item)
+ item.Versions = append(item.Versions, previousVersion)
+ item.ShareScope = targetVersion.ShareScope
+ item.GroupIDs = append([]int64(nil), targetVersion.GroupIDs...)
+ item.State = cloneBulkEditTemplateStateRaw(targetVersion.State)
+ item.UpdatedBy = input.RequesterUserID
+ item.UpdatedAt = time.Now().UnixMilli()
+
+ store.Items[templateIndex] = item
+ if persistErr := s.persistBulkEditTemplateLibrary(ctx, store); persistErr != nil {
+ return nil, persistErr
+ }
+
+ output := toBulkEditTemplate(item)
+ return &output, nil
+}
+
+func (s *SettingService) loadBulkEditTemplateLibrary(ctx context.Context) (*bulkEditTemplateLibraryStore, error) {
+ raw, err := s.settingRepo.GetValue(ctx, SettingKeyBulkEditTemplateLibrary)
+ if err != nil {
+ if errors.Is(err, ErrSettingNotFound) {
+ return &bulkEditTemplateLibraryStore{}, nil
+ }
+ return nil, fmt.Errorf("get bulk edit template library: %w", err)
+ }
+
+ raw = strings.TrimSpace(raw)
+ if raw == "" {
+ return &bulkEditTemplateLibraryStore{}, nil
+ }
+
+ store := bulkEditTemplateLibraryStore{}
+ if err := json.Unmarshal([]byte(raw), &store); err != nil {
+ return nil, fmt.Errorf("parse bulk edit template library: %w", err)
+ }
+
+ normalized := normalizeBulkEditTemplateLibraryStore(store)
+ return &normalized, nil
+}
+
+func (s *SettingService) persistBulkEditTemplateLibrary(ctx context.Context, store *bulkEditTemplateLibraryStore) error {
+ if store == nil {
+ return infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "template library cannot be nil")
+ }
+
+ normalized := normalizeBulkEditTemplateLibraryStore(*store)
+ data, err := json.Marshal(normalized)
+ if err != nil {
+ return fmt.Errorf("marshal bulk edit template library: %w", err)
+ }
+
+ return s.settingRepo.Set(ctx, SettingKeyBulkEditTemplateLibrary, string(data))
+}
+
+func validateBulkEditTemplateShareScope(scope string) (string, error) {
+ normalized := strings.TrimSpace(strings.ToLower(scope))
+ if normalized == "" {
+ return BulkEditTemplateShareScopePrivate, nil
+ }
+ switch normalized {
+ case BulkEditTemplateShareScopePrivate,
+ BulkEditTemplateShareScopeTeam,
+ BulkEditTemplateShareScopeGroups:
+ return normalized, nil
+ default:
+ return "", infraerrors.BadRequest("BULK_EDIT_TEMPLATE_INVALID_INPUT", "invalid share_scope")
+ }
+}
+
+func normalizeBulkEditTemplateLibraryStore(store bulkEditTemplateLibraryStore) bulkEditTemplateLibraryStore {
+ if len(store.Items) == 0 {
+ return bulkEditTemplateLibraryStore{Items: []bulkEditTemplateStoreItem{}}
+ }
+
+ nowMS := time.Now().UnixMilli()
+ items := make([]bulkEditTemplateStoreItem, 0, len(store.Items))
+ seenID := make(map[string]struct{}, len(store.Items))
+
+ for _, raw := range store.Items {
+ name := strings.TrimSpace(raw.Name)
+ scopePlatform := strings.TrimSpace(strings.ToLower(raw.ScopePlatform))
+ scopeType := strings.TrimSpace(strings.ToLower(raw.ScopeType))
+ if name == "" || scopePlatform == "" || scopeType == "" {
+ continue
+ }
+
+ shareScope := normalizeBulkEditTemplateShareScopeOrDefault(raw.ShareScope)
+ groupIDs := normalizeBulkEditTemplateGroupIDs(raw.GroupIDs)
+ if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 {
+ shareScope = BulkEditTemplateShareScopePrivate
+ }
+
+ templateID := strings.TrimSpace(raw.ID)
+ if templateID == "" {
+ templateID = generateBulkEditTemplateID()
+ }
+ if _, exists := seenID[templateID]; exists {
+ continue
+ }
+ seenID[templateID] = struct{}{}
+
+ state := raw.State
+ if len(state) == 0 || string(state) == "null" {
+ state = json.RawMessage("{}")
+ }
+
+ createdAt := raw.CreatedAt
+ if createdAt <= 0 {
+ createdAt = nowMS
+ }
+ updatedAt := raw.UpdatedAt
+ if updatedAt <= 0 {
+ updatedAt = createdAt
+ }
+
+ items = append(items, bulkEditTemplateStoreItem{
+ ID: templateID,
+ Name: name,
+ ScopePlatform: scopePlatform,
+ ScopeType: scopeType,
+ ShareScope: shareScope,
+ GroupIDs: groupIDs,
+ State: cloneBulkEditTemplateStateRaw(state),
+ Versions: normalizeBulkEditTemplateVersionStoreItems(raw.Versions),
+ CreatedBy: raw.CreatedBy,
+ UpdatedBy: raw.UpdatedBy,
+ CreatedAt: createdAt,
+ UpdatedAt: updatedAt,
+ })
+ }
+
+ return bulkEditTemplateLibraryStore{Items: items}
+}
+
+func toBulkEditTemplate(item bulkEditTemplateStoreItem) BulkEditTemplate {
+ state := map[string]any{}
+ if err := json.Unmarshal(item.State, &state); err != nil || state == nil {
+ state = map[string]any{}
+ }
+
+ return BulkEditTemplate{
+ ID: item.ID,
+ Name: item.Name,
+ ScopePlatform: item.ScopePlatform,
+ ScopeType: item.ScopeType,
+ ShareScope: item.ShareScope,
+ GroupIDs: append([]int64(nil), item.GroupIDs...),
+ State: state,
+ CreatedBy: item.CreatedBy,
+ UpdatedBy: item.UpdatedBy,
+ CreatedAt: item.CreatedAt,
+ UpdatedAt: item.UpdatedAt,
+ }
+}
+
+func toBulkEditTemplateVersion(item bulkEditTemplateVersionStoreItem) BulkEditTemplateVersion {
+ state := map[string]any{}
+ if err := json.Unmarshal(item.State, &state); err != nil || state == nil {
+ state = map[string]any{}
+ }
+
+ return BulkEditTemplateVersion{
+ VersionID: item.VersionID,
+ ShareScope: item.ShareScope,
+ GroupIDs: append([]int64(nil), item.GroupIDs...),
+ State: state,
+ UpdatedBy: item.UpdatedBy,
+ UpdatedAt: item.UpdatedAt,
+ }
+}
+
+func normalizeBulkEditTemplateVersionStoreItems(
+ rawVersions []bulkEditTemplateVersionStoreItem,
+) []bulkEditTemplateVersionStoreItem {
+ if len(rawVersions) == 0 {
+ return []bulkEditTemplateVersionStoreItem{}
+ }
+
+ nowMS := time.Now().UnixMilli()
+ seen := make(map[string]struct{}, len(rawVersions))
+ out := make([]bulkEditTemplateVersionStoreItem, 0, len(rawVersions))
+ for _, raw := range rawVersions {
+ versionID := strings.TrimSpace(raw.VersionID)
+ if versionID == "" {
+ versionID = generateBulkEditTemplateVersionID()
+ }
+ if _, exists := seen[versionID]; exists {
+ continue
+ }
+ seen[versionID] = struct{}{}
+
+ shareScope := normalizeBulkEditTemplateShareScopeOrDefault(raw.ShareScope)
+ groupIDs := normalizeBulkEditTemplateGroupIDs(raw.GroupIDs)
+ if shareScope == BulkEditTemplateShareScopeGroups && len(groupIDs) == 0 {
+ shareScope = BulkEditTemplateShareScopePrivate
+ }
+
+ updatedAt := raw.UpdatedAt
+ if updatedAt <= 0 {
+ updatedAt = nowMS
+ }
+
+ out = append(out, bulkEditTemplateVersionStoreItem{
+ VersionID: versionID,
+ ShareScope: shareScope,
+ GroupIDs: groupIDs,
+ State: cloneBulkEditTemplateStateRaw(raw.State),
+ UpdatedBy: raw.UpdatedBy,
+ UpdatedAt: updatedAt,
+ })
+ }
+
+ sort.Slice(out, func(i, j int) bool {
+ if out[i].UpdatedAt == out[j].UpdatedAt {
+ return out[i].VersionID < out[j].VersionID
+ }
+ return out[i].UpdatedAt > out[j].UpdatedAt
+ })
+ return out
+}
+
+func snapshotBulkEditTemplateVersion(item bulkEditTemplateStoreItem) bulkEditTemplateVersionStoreItem {
+ updatedAt := item.UpdatedAt
+ if updatedAt <= 0 {
+ updatedAt = time.Now().UnixMilli()
+ }
+ return bulkEditTemplateVersionStoreItem{
+ VersionID: generateBulkEditTemplateVersionID(),
+ ShareScope: normalizeBulkEditTemplateShareScopeOrDefault(item.ShareScope),
+ GroupIDs: normalizeBulkEditTemplateGroupIDs(item.GroupIDs),
+ State: cloneBulkEditTemplateStateRaw(item.State),
+ UpdatedBy: item.UpdatedBy,
+ UpdatedAt: updatedAt,
+ }
+}
+
+func cloneBulkEditTemplateStateRaw(raw json.RawMessage) json.RawMessage {
+ if len(raw) == 0 || string(raw) == "null" {
+ return json.RawMessage("{}")
+ }
+ cloned := make(json.RawMessage, len(raw))
+ copy(cloned, raw)
+ return cloned
+}
+
+func toBulkEditTemplateScopeGroupSet(raw []int64) map[int64]struct{} {
+ groupIDs := normalizeBulkEditTemplateGroupIDs(raw)
+ scopeGroupSet := make(map[int64]struct{}, len(groupIDs))
+ for _, groupID := range groupIDs {
+ scopeGroupSet[groupID] = struct{}{}
+ }
+ return scopeGroupSet
+}
+
+func findBulkEditTemplateStoreItemByID(
+ items []bulkEditTemplateStoreItem,
+ templateID string,
+) *bulkEditTemplateStoreItem {
+ for idx := range items {
+ if items[idx].ID == templateID {
+ return &items[idx]
+ }
+ }
+ return nil
+}
+
+func findBulkEditTemplateStoreItemIndexByID(items []bulkEditTemplateStoreItem, templateID string) int {
+ for idx := range items {
+ if items[idx].ID == templateID {
+ return idx
+ }
+ }
+ return -1
+}
+
+func findBulkEditTemplateVersionIndexByID(
+ versions []bulkEditTemplateVersionStoreItem,
+ versionID string,
+) int {
+ for idx := range versions {
+ if versions[idx].VersionID == versionID {
+ return idx
+ }
+ }
+ return -1
+}
+
+func canModifyBulkEditTemplate(item bulkEditTemplateStoreItem, requesterUserID int64) bool {
+ if requesterUserID <= 0 {
+ return false
+ }
+ if item.ShareScope != BulkEditTemplateShareScopePrivate {
+ return true
+ }
+ if item.CreatedBy <= 0 {
+ return true
+ }
+ return item.CreatedBy == requesterUserID
+}
+
+func isBulkEditTemplateVisible(
+ item bulkEditTemplateStoreItem,
+ requesterUserID int64,
+ scopeGroupSet map[int64]struct{},
+) bool {
+ switch item.ShareScope {
+ case BulkEditTemplateShareScopeTeam:
+ return true
+ case BulkEditTemplateShareScopeGroups:
+ if len(scopeGroupSet) == 0 || len(item.GroupIDs) == 0 {
+ return false
+ }
+ for _, groupID := range item.GroupIDs {
+ if _, ok := scopeGroupSet[groupID]; ok {
+ return true
+ }
+ }
+ return false
+ default:
+ return requesterUserID > 0 && item.CreatedBy == requesterUserID
+ }
+}
+
+func normalizeBulkEditTemplateShareScopeOrDefault(scope string) string {
+ normalized, err := validateBulkEditTemplateShareScope(scope)
+ if err != nil {
+ return BulkEditTemplateShareScopePrivate
+ }
+ return normalized
+}
+
+func normalizeBulkEditTemplateGroupIDs(raw []int64) []int64 {
+ if len(raw) == 0 {
+ return []int64{}
+ }
+
+ seen := make(map[int64]struct{}, len(raw))
+ groupIDs := make([]int64, 0, len(raw))
+ for _, groupID := range raw {
+ if groupID <= 0 {
+ continue
+ }
+ if _, exists := seen[groupID]; exists {
+ continue
+ }
+ seen[groupID] = struct{}{}
+ groupIDs = append(groupIDs, groupID)
+ }
+ sort.Slice(groupIDs, func(i, j int) bool {
+ return groupIDs[i] < groupIDs[j]
+ })
+ return groupIDs
+}
+
+func generateBulkEditTemplateID() string {
+ buf := make([]byte, 12)
+ if _, err := bulkEditTemplateRandRead(buf); err == nil {
+ return "btpl-" + hex.EncodeToString(buf)
+ }
+ return fmt.Sprintf("btpl-%d", time.Now().UnixNano())
+}
+
+func generateBulkEditTemplateVersionID() string {
+ buf := make([]byte, 12)
+ if _, err := bulkEditTemplateRandRead(buf); err == nil {
+ return "btplv-" + hex.EncodeToString(buf)
+ }
+ return fmt.Sprintf("btplv-%d", time.Now().UnixNano())
+}
diff --git a/backend/internal/service/setting_bulk_edit_template_test.go b/backend/internal/service/setting_bulk_edit_template_test.go
new file mode 100644
index 000000000..c3a0351fa
--- /dev/null
+++ b/backend/internal/service/setting_bulk_edit_template_test.go
@@ -0,0 +1,860 @@
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "testing"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+type bulkTemplateSettingRepoStub struct {
+ values map[string]string
+}
+
+func newBulkTemplateSettingRepoStub() *bulkTemplateSettingRepoStub {
+ return &bulkTemplateSettingRepoStub{values: map[string]string{}}
+}
+
+func (s *bulkTemplateSettingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ value, err := s.GetValue(ctx, key)
+ if err != nil {
+ return nil, err
+ }
+ return &Setting{Key: key, Value: value}, nil
+}
+
+func (s *bulkTemplateSettingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ value, ok := s.values[key]
+ if !ok {
+ return "", ErrSettingNotFound
+ }
+ return value, nil
+}
+
+func (s *bulkTemplateSettingRepoStub) Set(ctx context.Context, key, value string) error {
+ s.values[key] = value
+ return nil
+}
+
+func (s *bulkTemplateSettingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ out := make(map[string]string, len(keys))
+ for _, key := range keys {
+ if value, ok := s.values[key]; ok {
+ out[key] = value
+ }
+ }
+ return out, nil
+}
+
+func (s *bulkTemplateSettingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ for key, value := range settings {
+ s.values[key] = value
+ }
+ return nil
+}
+
+func (s *bulkTemplateSettingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ out := make(map[string]string, len(s.values))
+ for key, value := range s.values {
+ out[key] = value
+ }
+ return out, nil
+}
+
+func (s *bulkTemplateSettingRepoStub) Delete(ctx context.Context, key string) error {
+ delete(s.values, key)
+ return nil
+}
+
+type bulkTemplateFailingRepoStub struct{}
+
+func (s *bulkTemplateFailingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ return nil, errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ return "", errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) Set(ctx context.Context, key, value string) error {
+ return errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ return nil, errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ return errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ return nil, errors.New("boom")
+}
+func (s *bulkTemplateFailingRepoStub) Delete(ctx context.Context, key string) error {
+ return errors.New("boom")
+}
+
+func TestSettingServiceBulkEditTemplate_UpsertAndPrivateVisibility(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "OpenAI OAuth Baseline",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"enableOpenAIWSMode": true},
+ RequesterUserID: 11,
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, created.ID)
+ require.Equal(t, BulkEditTemplateShareScopePrivate, created.ShareScope)
+
+ listByOwner, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 11,
+ })
+ require.NoError(t, err)
+ require.Len(t, listByOwner, 1)
+ require.Equal(t, created.ID, listByOwner[0].ID)
+ require.Equal(t, true, listByOwner[0].State["enableOpenAIWSMode"])
+
+ listByOther, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 22,
+ })
+ require.NoError(t, err)
+ require.Len(t, listByOther, 0)
+}
+
+func TestSettingServiceBulkEditTemplate_GroupsVisibilityByScopeGroupIDs(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Shared By Group",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeGroups,
+ GroupIDs: []int64{10, 20},
+ State: map[string]any{"enableBaseUrl": true},
+ RequesterUserID: 9,
+ })
+ require.NoError(t, err)
+
+ invisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ScopeGroupIDs: []int64{99},
+ RequesterUserID: 8,
+ })
+ require.NoError(t, err)
+ require.Len(t, invisible, 0)
+
+ visible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ScopeGroupIDs: []int64{20, 100},
+ RequesterUserID: 8,
+ })
+ require.NoError(t, err)
+ require.Len(t, visible, 1)
+ require.Equal(t, BulkEditTemplateShareScopeGroups, visible[0].ShareScope)
+ require.Equal(t, []int64{10, 20}, visible[0].GroupIDs)
+}
+
+func TestSettingServiceBulkEditTemplate_UpsertByNameReplacesSameScopeRecord(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ first, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Team Baseline",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: map[string]any{"priority": 1},
+ RequesterUserID: 7,
+ })
+ require.NoError(t, err)
+
+ second, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "team baseline",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: map[string]any{"priority": 9},
+ RequesterUserID: 7,
+ })
+ require.NoError(t, err)
+ require.Equal(t, first.ID, second.ID)
+
+ items, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 3,
+ })
+ require.NoError(t, err)
+ require.Len(t, items, 1)
+ require.EqualValues(t, 9, items[0].State["priority"])
+}
+
+func TestSettingServiceBulkEditTemplate_DeletePermissionAndNotFound(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Private Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"status": "active"},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+
+ err = svc.DeleteBulkEditTemplate(context.Background(), created.ID, 2)
+ require.Error(t, err)
+ require.True(t, infraerrors.IsForbidden(err))
+
+ err = svc.DeleteBulkEditTemplate(context.Background(), created.ID, 1)
+ require.NoError(t, err)
+
+ ownerList, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+ require.Len(t, ownerList, 0)
+
+ err = svc.DeleteBulkEditTemplate(context.Background(), "missing-id", 1)
+ require.Error(t, err)
+ require.True(t, infraerrors.IsNotFound(err))
+}
+
+func TestSettingServiceBulkEditTemplate_ValidatesInput(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Groups No IDs",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeGroups,
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Invalid Scope",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: "invalid",
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+}
+
+func TestSettingServiceBulkEditTemplate_CoversInternalHelpers(t *testing.T) {
+ store := normalizeBulkEditTemplateLibraryStore(bulkEditTemplateLibraryStore{
+ Items: []bulkEditTemplateStoreItem{
+ {
+ ID: "same-id",
+ Name: " One ",
+ ScopePlatform: "OPENAI",
+ ScopeType: "OAUTH",
+ ShareScope: BulkEditTemplateShareScopeGroups,
+ GroupIDs: []int64{},
+ State: nil,
+ },
+ {
+ ID: "same-id",
+ Name: "Duplicate ID",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ },
+ {
+ ID: "",
+ Name: "Two",
+ ScopePlatform: "openai",
+ ScopeType: "apikey",
+ ShareScope: "invalid",
+ GroupIDs: []int64{5, 5, 1},
+ State: []byte(`{"ok":true}`),
+ },
+ {
+ ID: "invalid-entry",
+ Name: "",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ },
+ },
+ })
+ require.Len(t, store.Items, 2)
+ require.Equal(t, "same-id", store.Items[0].ID)
+ require.Equal(t, BulkEditTemplateShareScopePrivate, store.Items[0].ShareScope)
+ require.Equal(t, []int64{1, 5}, store.Items[1].GroupIDs)
+ require.NotEmpty(t, store.Items[1].ID)
+
+ require.Equal(t, BulkEditTemplateShareScopeTeam, normalizeBulkEditTemplateShareScopeOrDefault("team"))
+ require.Equal(t, BulkEditTemplateShareScopePrivate, normalizeBulkEditTemplateShareScopeOrDefault("bad"))
+ require.Equal(t, []int64{}, normalizeBulkEditTemplateGroupIDs(nil))
+
+ scope, err := validateBulkEditTemplateShareScope("")
+ require.NoError(t, err)
+ require.Equal(t, BulkEditTemplateShareScopePrivate, scope)
+ _, err = validateBulkEditTemplateShareScope("bad")
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ require.True(t, isBulkEditTemplateVisible(
+ bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeTeam},
+ 1,
+ map[int64]struct{}{},
+ ))
+ require.False(t, isBulkEditTemplateVisible(
+ bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeGroups, GroupIDs: []int64{2}},
+ 1,
+ map[int64]struct{}{},
+ ))
+ require.True(t, isBulkEditTemplateVisible(
+ bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopeGroups, GroupIDs: []int64{2}},
+ 1,
+ map[int64]struct{}{2: {}},
+ ))
+ require.True(t, isBulkEditTemplateVisible(
+ bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopePrivate, CreatedBy: 9},
+ 9,
+ nil,
+ ))
+ require.False(t, isBulkEditTemplateVisible(
+ bulkEditTemplateStoreItem{ShareScope: BulkEditTemplateShareScopePrivate, CreatedBy: 9},
+ 1,
+ nil,
+ ))
+
+ converted := toBulkEditTemplate(bulkEditTemplateStoreItem{
+ ID: "id-1",
+ Name: "Demo",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ GroupIDs: []int64{3},
+ State: []byte(`invalid-json`),
+ })
+ require.Equal(t, map[string]any{}, converted.State)
+}
+
+func TestSettingServiceBulkEditTemplate_LoadPersistBranches(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ err := svc.persistBulkEditTemplateLibrary(context.Background(), nil)
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ repo.values[SettingKeyBulkEditTemplateLibrary] = "{bad-json"
+ store, err := svc.loadBulkEditTemplateLibrary(context.Background())
+ require.Error(t, err)
+ require.Nil(t, store)
+
+ delete(repo.values, SettingKeyBulkEditTemplateLibrary)
+ store, err = svc.loadBulkEditTemplateLibrary(context.Background())
+ require.NoError(t, err)
+ require.NotNil(t, store)
+ require.Empty(t, store.Items)
+}
+
+func TestSettingServiceBulkEditTemplate_UpsertByMismatchedID(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Scoped Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+ require.NotEmpty(t, created.ID)
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ ID: "another-id",
+ Name: "Scoped Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{},
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsNotFound(err))
+
+ err = svc.DeleteBulkEditTemplate(context.Background(), "", 1)
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+}
+
+func TestSettingServiceBulkEditTemplate_PrivateTemplateIsolationAcrossUsers(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ ownerTemplate, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Private Scoped Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"priority": 1},
+ RequesterUserID: 101,
+ })
+ require.NoError(t, err)
+
+ otherTemplate, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Private Scoped Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"priority": 9},
+ RequesterUserID: 202,
+ })
+ require.NoError(t, err)
+ require.NotEqual(t, ownerTemplate.ID, otherTemplate.ID, "不同用户的私有同名模板不应互相覆盖")
+
+ ownerVisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 101,
+ })
+ require.NoError(t, err)
+ require.Len(t, ownerVisible, 1)
+ require.Equal(t, ownerTemplate.ID, ownerVisible[0].ID)
+ require.EqualValues(t, 1, ownerVisible[0].State["priority"])
+
+ otherVisible, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ RequesterUserID: 202,
+ })
+ require.NoError(t, err)
+ require.Len(t, otherVisible, 1)
+ require.Equal(t, otherTemplate.ID, otherVisible[0].ID)
+ require.EqualValues(t, 9, otherVisible[0].State["priority"])
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ ID: ownerTemplate.ID,
+ Name: ownerTemplate.Name,
+ ScopePlatform: ownerTemplate.ScopePlatform,
+ ScopeType: ownerTemplate.ScopeType,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"priority": 99},
+ RequesterUserID: 202,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsForbidden(err), "非 owner 不允许通过 template ID 修改私有模板")
+}
+
+func TestSettingServiceBulkEditTemplate_UpsertFailsWhenStoredLibraryCorrupted(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ repo.values[SettingKeyBulkEditTemplateLibrary] = "{bad-json"
+ svc := NewSettingService(repo, nil)
+
+ _, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Should Fail",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"ok": true},
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+}
+
+func TestSettingServiceBulkEditTemplate_ListFilteringAndSorting(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ store := bulkEditTemplateLibraryStore{
+ Items: []bulkEditTemplateStoreItem{
+ {
+ ID: "b",
+ Name: "Second",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: []byte(`{}`),
+ CreatedBy: 1,
+ UpdatedBy: 1,
+ CreatedAt: 1,
+ UpdatedAt: 100,
+ },
+ {
+ ID: "a",
+ Name: "First",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: []byte(`{}`),
+ CreatedBy: 1,
+ UpdatedBy: 1,
+ CreatedAt: 1,
+ UpdatedAt: 100,
+ },
+ {
+ ID: "skip-type",
+ Name: "Skip Type",
+ ScopePlatform: "openai",
+ ScopeType: "apikey",
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: []byte(`{}`),
+ CreatedBy: 1,
+ UpdatedBy: 1,
+ CreatedAt: 1,
+ UpdatedAt: 999,
+ },
+ {
+ ID: "skip-private",
+ Name: "Skip Private",
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: []byte(`{}`),
+ CreatedBy: 99,
+ UpdatedBy: 99,
+ CreatedAt: 1,
+ UpdatedAt: 1000,
+ },
+ },
+ }
+ require.NoError(t, svc.persistBulkEditTemplateLibrary(context.Background(), &store))
+
+ items, err := svc.ListBulkEditTemplates(context.Background(), BulkEditTemplateQuery{
+ ScopePlatform: "openai",
+ ScopeType: "oauth",
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+ require.Len(t, items, 2)
+ require.Equal(t, []string{"a", "b"}, []string{items[0].ID, items[1].ID})
+}
+
+func TestSettingServiceBulkEditTemplate_UpsertByIDAndMarshalError(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "By ID",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: map[string]any{"priority": 1},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+
+ updated, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ ID: created.ID,
+ Name: "By ID",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: map[string]any{"priority": 2},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+ require.Equal(t, created.ID, updated.ID)
+ require.EqualValues(t, 2, updated.State["priority"])
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Marshal Error",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"bad": make(chan int)},
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Unauthorized",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{},
+ RequesterUserID: 0,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsUnauthorized(err))
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Missing Scope",
+ ScopePlatform: "",
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{},
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+}
+
+func TestSettingServiceBulkEditTemplate_LoadErrorFromRepository(t *testing.T) {
+ svc := NewSettingService(&bulkTemplateFailingRepoStub{}, nil)
+ store, err := svc.loadBulkEditTemplateLibrary(context.Background())
+ require.Error(t, err)
+ require.Nil(t, store)
+}
+
+func TestGenerateBulkEditTemplateID_Fallback(t *testing.T) {
+ original := bulkEditTemplateRandRead
+ bulkEditTemplateRandRead = func(_ []byte) (int, error) {
+ return 0, errors.New("rand fail")
+ }
+ defer func() {
+ bulkEditTemplateRandRead = original
+ }()
+
+ id := generateBulkEditTemplateID()
+ require.NotEmpty(t, id)
+ require.Contains(t, id, "btpl-")
+
+ versionID := generateBulkEditTemplateVersionID()
+ require.NotEmpty(t, versionID)
+ require.Contains(t, versionID, "btplv-")
+}
+
+func TestSettingServiceBulkEditTemplate_VersionLifecycleAndRollback(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Versioned Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ State: map[string]any{"priority": 1},
+ RequesterUserID: 88,
+ })
+ require.NoError(t, err)
+
+ updated, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ ID: created.ID,
+ Name: "Versioned Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ State: map[string]any{"priority": 9},
+ RequesterUserID: 88,
+ })
+ require.NoError(t, err)
+ require.Equal(t, BulkEditTemplateShareScopeTeam, updated.ShareScope)
+
+ versions, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: created.ID,
+ RequesterUserID: 88,
+ })
+ require.NoError(t, err)
+ require.Len(t, versions, 1)
+ require.Equal(t, BulkEditTemplateShareScopePrivate, versions[0].ShareScope)
+ require.EqualValues(t, 1, versions[0].State["priority"])
+
+ rollbacked, err := svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{
+ TemplateID: created.ID,
+ VersionID: versions[0].VersionID,
+ RequesterUserID: 88,
+ })
+ require.NoError(t, err)
+ require.Equal(t, BulkEditTemplateShareScopePrivate, rollbacked.ShareScope)
+ require.EqualValues(t, 1, rollbacked.State["priority"])
+
+ versionsAfterRollback, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: created.ID,
+ RequesterUserID: 88,
+ })
+ require.NoError(t, err)
+ require.Len(t, versionsAfterRollback, 2)
+}
+
+func TestSettingServiceBulkEditTemplate_VersionVisibilityAndErrors(t *testing.T) {
+ repo := newBulkTemplateSettingRepoStub()
+ svc := NewSettingService(repo, nil)
+
+ created, err := svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ Name: "Group Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeGroups,
+ GroupIDs: []int64{7},
+ State: map[string]any{"enableBaseUrl": true},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+
+ _, err = svc.UpsertBulkEditTemplate(context.Background(), BulkEditTemplateUpsertInput{
+ ID: created.ID,
+ Name: "Group Template",
+ ScopePlatform: PlatformOpenAI,
+ ScopeType: AccountTypeOAuth,
+ ShareScope: BulkEditTemplateShareScopeGroups,
+ GroupIDs: []int64{7, 9},
+ State: map[string]any{"enableBaseUrl": false},
+ RequesterUserID: 1,
+ })
+ require.NoError(t, err)
+
+ _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: created.ID,
+ ScopeGroupIDs: []int64{8},
+ RequesterUserID: 2,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsForbidden(err))
+
+ visibleVersions, err := svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: created.ID,
+ ScopeGroupIDs: []int64{7},
+ RequesterUserID: 2,
+ })
+ require.NoError(t, err)
+ require.Len(t, visibleVersions, 1)
+
+ _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: "missing",
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsNotFound(err))
+
+ _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: created.ID,
+ RequesterUserID: 0,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsUnauthorized(err))
+
+ _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{
+ TemplateID: created.ID,
+ VersionID: "missing-version",
+ ScopeGroupIDs: []int64{7},
+ RequesterUserID: 2,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsNotFound(err))
+
+ _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{
+ TemplateID: created.ID,
+ VersionID: visibleVersions[0].VersionID,
+ ScopeGroupIDs: []int64{8},
+ RequesterUserID: 2,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsForbidden(err))
+
+ _, err = svc.ListBulkEditTemplateVersions(context.Background(), BulkEditTemplateVersionQuery{
+ TemplateID: " ",
+ RequesterUserID: 1,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsBadRequest(err))
+
+ _, err = svc.RollbackBulkEditTemplate(context.Background(), BulkEditTemplateRollbackInput{
+ TemplateID: created.ID,
+ VersionID: visibleVersions[0].VersionID,
+ RequesterUserID: 0,
+ })
+ require.Error(t, err)
+ require.True(t, infraerrors.IsUnauthorized(err))
+}
+
+func TestSettingServiceBulkEditTemplate_VersionHelpers(t *testing.T) {
+ normalized := normalizeBulkEditTemplateVersionStoreItems([]bulkEditTemplateVersionStoreItem{
+ {
+ VersionID: "",
+ ShareScope: "groups",
+ GroupIDs: []int64{},
+ State: nil,
+ UpdatedBy: 1,
+ UpdatedAt: 0,
+ },
+ {
+ VersionID: "v-1",
+ ShareScope: "team",
+ GroupIDs: []int64{4, 4, 2},
+ State: []byte(`{"ok":true}`),
+ UpdatedBy: 2,
+ UpdatedAt: 20,
+ },
+ {
+ VersionID: "v-1",
+ ShareScope: "team",
+ State: []byte(`{}`),
+ UpdatedAt: 30,
+ },
+ })
+ require.Len(t, normalized, 2)
+ privateCount := 0
+ teamCount := 0
+ for _, item := range normalized {
+ if item.ShareScope == BulkEditTemplateShareScopePrivate {
+ privateCount++
+ }
+ if item.ShareScope == BulkEditTemplateShareScopeTeam {
+ teamCount++
+ require.Equal(t, []int64{2, 4}, item.GroupIDs)
+ }
+ }
+ require.Equal(t, 1, privateCount)
+ require.Equal(t, 1, teamCount)
+
+ item := bulkEditTemplateStoreItem{
+ ID: "tpl-1",
+ ShareScope: BulkEditTemplateShareScopeTeam,
+ GroupIDs: []int64{3},
+ State: []byte(`{"priority":3}`),
+ UpdatedBy: 10,
+ UpdatedAt: 123,
+ }
+ version := snapshotBulkEditTemplateVersion(item)
+ require.NotEmpty(t, version.VersionID)
+ require.Equal(t, BulkEditTemplateShareScopeTeam, version.ShareScope)
+ require.EqualValues(t, 123, version.UpdatedAt)
+
+ versionDTO := toBulkEditTemplateVersion(bulkEditTemplateVersionStoreItem{
+ VersionID: "ver-1",
+ ShareScope: BulkEditTemplateShareScopePrivate,
+ GroupIDs: []int64{9},
+ State: []byte(`invalid`),
+ UpdatedBy: 1,
+ UpdatedAt: 2,
+ })
+ require.Equal(t, map[string]any{}, versionDTO.State)
+
+ require.Equal(t, -1, findBulkEditTemplateVersionIndexByID(nil, "x"))
+ require.Equal(t, -1, findBulkEditTemplateStoreItemIndexByID(nil, "x"))
+ require.Nil(t, findBulkEditTemplateStoreItemByID(nil, "x"))
+
+ scopeSet := toBulkEditTemplateScopeGroupSet([]int64{4, 4, 2, -1})
+ _, has2 := scopeSet[2]
+ _, has4 := scopeSet[4]
+ require.True(t, has2)
+ require.True(t, has4)
+
+ cloned := cloneBulkEditTemplateStateRaw(json.RawMessage(`{"x":1}`))
+ require.Equal(t, `{"x":1}`, string(cloned))
+ require.Equal(t, `{}`, string(cloneBulkEditTemplateStateRaw(nil)))
+}
diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go
index 64871b9a6..8593e77ed 100644
--- a/backend/internal/service/setting_service.go
+++ b/backend/internal/service/setting_service.go
@@ -10,7 +10,6 @@ import (
"log/slog"
"strconv"
"strings"
- "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -19,18 +18,10 @@ import (
)
var (
- ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
- ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
- ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
- ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
- ErrDefaultSubGroupInvalid = infraerrors.BadRequest(
- "DEFAULT_SUBSCRIPTION_GROUP_INVALID",
- "default subscription group must exist and be subscription type",
- )
- ErrDefaultSubGroupDuplicate = infraerrors.BadRequest(
- "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE",
- "default subscription group cannot be duplicated",
- )
+ ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled")
+ ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found")
+ ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found")
+ ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists")
)
type SettingRepository interface {
@@ -71,12 +62,11 @@ type DefaultSubscriptionGroupReader interface {
// SettingService 系统设置服务
type SettingService struct {
- settingRepo SettingRepository
- defaultSubGroupReader DefaultSubscriptionGroupReader
- cfg *config.Config
- onUpdate func() // Callback when settings are updated (for cache invalidation)
- onS3Update func() // Callback when Sora S3 settings are updated
- version string // Application version
+ settingRepo SettingRepository
+ cfg *config.Config
+ onUpdate func() // Callback when settings are updated (for cache invalidation)
+ onS3Update func() // Callback when Sora S3 settings are updated
+ version string // Application version
}
// NewSettingService 创建系统设置服务实例
diff --git a/backend/internal/service/sora_generation_service_test.go b/backend/internal/service/sora_generation_service_test.go
index 46f322c82..820945f02 100644
--- a/backend/internal/service/sora_generation_service_test.go
+++ b/backend/internal/service/sora_generation_service_test.go
@@ -165,9 +165,6 @@ func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
-func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
- return nil
-}
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
diff --git a/backend/internal/service/token_refresh_parallel_test.go b/backend/internal/service/token_refresh_parallel_test.go
new file mode 100644
index 000000000..c844ef934
--- /dev/null
+++ b/backend/internal/service/token_refresh_parallel_test.go
@@ -0,0 +1,439 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+// --- 并行刷新专用 stub ---
+
+// concurrentTokenRefresherStub 记录并发度和调用次数
+type concurrentTokenRefresherStub struct {
+ canRefreshFn func(*Account) bool
+ needsRefreshFn func(*Account, time.Duration) bool
+ refreshDelay time.Duration
+ refreshErr error
+ credentials map[string]any
+ refreshCalls atomic.Int64
+ maxConcurrent atomic.Int64
+ currentActive atomic.Int64
+}
+
+func (r *concurrentTokenRefresherStub) CanRefresh(account *Account) bool {
+ if r.canRefreshFn != nil {
+ return r.canRefreshFn(account)
+ }
+ return true
+}
+
+func (r *concurrentTokenRefresherStub) NeedsRefresh(account *Account, window time.Duration) bool {
+ if r.needsRefreshFn != nil {
+ return r.needsRefreshFn(account, window)
+ }
+ return true
+}
+
+func (r *concurrentTokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
+ r.refreshCalls.Add(1)
+ active := r.currentActive.Add(1)
+ // 记录峰值并发
+ for {
+ old := r.maxConcurrent.Load()
+ if active <= old || r.maxConcurrent.CompareAndSwap(old, active) {
+ break
+ }
+ }
+ if r.refreshDelay > 0 {
+ time.Sleep(r.refreshDelay)
+ }
+ r.currentActive.Add(-1)
+ if r.refreshErr != nil {
+ return nil, r.refreshErr
+ }
+ // 每次返回新 map,避免多 goroutine 共享同一 map 实例引发竞态
+ creds := make(map[string]any, len(r.credentials))
+ for k, v := range r.credentials {
+ creds[k] = v
+ }
+ return creds, nil
+}
+
+// concurrentTokenRefreshAccountRepo 线程安全的 account repo stub
+type concurrentTokenRefreshAccountRepo struct {
+ mockAccountRepoForGemini
+ mu sync.Mutex
+ updateCalls int
+ setErrorCalls int
+ activeAccounts []Account
+ updateErr error
+}
+
+func (r *concurrentTokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.updateCalls++
+ return r.updateErr
+}
+
+func (r *concurrentTokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorMsg string) error {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ r.setErrorCalls++
+ return nil
+}
+
+func (r *concurrentTokenRefreshAccountRepo) ListActive(ctx context.Context) ([]Account, error) {
+ out := make([]Account, len(r.activeAccounts))
+ copy(out, r.activeAccounts)
+ return out, nil
+}
+
+// --- 测试用例 ---
+
+func TestProcessRefresh_ParallelExecution(t *testing.T) {
+ accounts := make([]Account, 20)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(100 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ }
+ }
+
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ refreshDelay: 20 * time.Millisecond,
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ start := time.Now()
+ svc.processRefresh()
+ elapsed := time.Since(start)
+
+ // 20 个账号每个 20ms,串行至少 400ms;并行(maxConcurrency=10)约 40-60ms
+ require.Equal(t, int64(20), refresher.refreshCalls.Load())
+ require.Less(t, elapsed, 300*time.Millisecond, "并行刷新应显著快于串行")
+ require.Greater(t, refresher.maxConcurrent.Load(), int64(1), "应有多个账号并发刷新")
+ require.LessOrEqual(t, refresher.maxConcurrent.Load(), int64(10), "并发不应超过信号量限制")
+
+ repo.mu.Lock()
+ require.Equal(t, 20, repo.updateCalls)
+ repo.mu.Unlock()
+}
+
+func TestProcessRefresh_SemaphoreLimitsConcurrency(t *testing.T) {
+ accounts := make([]Account, 15)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(200 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ }
+ }
+
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ refreshDelay: 50 * time.Millisecond,
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ svc.processRefresh()
+
+ require.Equal(t, int64(15), refresher.refreshCalls.Load())
+ require.LessOrEqual(t, refresher.maxConcurrent.Load(), int64(10), "并发不应超过 maxConcurrency=10")
+}
+
+func TestProcessRefresh_StopInterruptsPhase2(t *testing.T) {
+ accounts := make([]Account, 30)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(300 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ }
+ }
+
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ refreshDelay: 100 * time.Millisecond,
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ done := make(chan struct{})
+ go func() {
+ svc.processRefresh()
+ close(done)
+ }()
+
+ // 短暂等待让部分 goroutine 启动
+ time.Sleep(30 * time.Millisecond)
+ svc.Stop()
+
+ select {
+ case <-done:
+ // ok
+ case <-time.After(3 * time.Second):
+ t.Fatal("processRefresh 应在收到 stop 信号后及时退出")
+ }
+
+ // 因中断,不应刷新全部 30 个账号
+ require.Less(t, refresher.refreshCalls.Load(), int64(30), "stop 应中断后续任务提交")
+}
+
+func TestProcessRefresh_EmptyAccounts(t *testing.T) {
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: nil}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg)
+ refresher := &concurrentTokenRefresherStub{}
+ svc.refreshers = []TokenRefresher{refresher}
+
+ // 不应 panic
+ require.NotPanics(t, func() {
+ svc.processRefresh()
+ })
+ require.Equal(t, int64(0), refresher.refreshCalls.Load())
+}
+
+func TestProcessRefresh_NoAccountsNeedRefresh(t *testing.T) {
+ accounts := []Account{
+ {ID: 401, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ {ID: 402, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ }
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ needsRefreshFn: func(a *Account, d time.Duration) bool { return false },
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ svc.processRefresh()
+
+ require.Equal(t, int64(0), refresher.refreshCalls.Load())
+}
+
+func TestProcessRefresh_MixedSuccessAndFailure(t *testing.T) {
+ accounts := []Account{
+ {ID: 501, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ {ID: 502, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ {ID: 503, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ {ID: 504, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ }
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+
+ // 偶数 ID 成功,奇数 ID 失败
+ refresher := &concurrentTokenRefresherStub{
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ failRefresher := &concurrentTokenRefresherStub{
+ refreshErr: errors.New("refresh failed"),
+ }
+
+ // 使用 selectiveRefresher 按 ID 分流
+ selectiveRefresher := &selectiveTokenRefresherStub{
+ successRefresher: refresher,
+ failRefresher: failRefresher,
+ failIDs: map[int64]bool{501: true, 503: true},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{selectiveRefresher}
+
+ svc.processRefresh()
+
+ totalCalls := refresher.refreshCalls.Load() + failRefresher.refreshCalls.Load()
+ require.Equal(t, int64(4), totalCalls)
+ require.Equal(t, int64(2), refresher.refreshCalls.Load())
+ require.Equal(t, int64(2), failRefresher.refreshCalls.Load())
+}
+
+// selectiveTokenRefresherStub 按账号 ID 分流到不同的 refresher
+type selectiveTokenRefresherStub struct {
+ successRefresher *concurrentTokenRefresherStub
+ failRefresher *concurrentTokenRefresherStub
+ failIDs map[int64]bool
+}
+
+func (r *selectiveTokenRefresherStub) CanRefresh(account *Account) bool {
+ return true
+}
+
+func (r *selectiveTokenRefresherStub) NeedsRefresh(account *Account, window time.Duration) bool {
+ return true
+}
+
+func (r *selectiveTokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
+ if r.failIDs[account.ID] {
+ return r.failRefresher.Refresh(ctx, account)
+ }
+ return r.successRefresher.Refresh(ctx, account)
+}
+
+func TestProcessRefresh_SingleAccount(t *testing.T) {
+ accounts := []Account{
+ {ID: 601, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ }
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ svc.processRefresh()
+
+ require.Equal(t, int64(1), refresher.refreshCalls.Load())
+ require.Equal(t, int64(1), refresher.maxConcurrent.Load())
+}
+
+func TestProcessRefresh_AllFailed(t *testing.T) {
+ accounts := make([]Account, 5)
+ for i := range accounts {
+ accounts[i] = Account{
+ ID: int64(700 + i),
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Status: StatusActive,
+ }
+ }
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ refreshErr: errors.New("all fail"),
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ // 不应 panic
+ require.NotPanics(t, func() {
+ svc.processRefresh()
+ })
+ require.Equal(t, int64(5), refresher.refreshCalls.Load())
+
+ repo.mu.Lock()
+ require.Equal(t, 5, repo.setErrorCalls)
+ require.Equal(t, 0, repo.updateCalls)
+ repo.mu.Unlock()
+}
+
+func TestProcessRefresh_CanRefreshFilters(t *testing.T) {
+ accounts := []Account{
+ {ID: 801, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ {ID: 802, Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Status: StatusActive},
+ {ID: 803, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive},
+ }
+ repo := &concurrentTokenRefreshAccountRepo{activeAccounts: accounts}
+ lockStub := &tokenRefreshSchedulerLockStub{}
+ refresher := &concurrentTokenRefresherStub{
+ canRefreshFn: func(a *Account) bool { return a.Type == AccountTypeOAuth },
+ credentials: map[string]any{"access_token": "tok"},
+ }
+
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 5,
+ RefreshBeforeExpiryHours: 1,
+ },
+ }
+ svc := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ svc.refreshers = []TokenRefresher{refresher}
+
+ svc.processRefresh()
+
+ // 只有 OAuth 账号(ID 801, 803)应被刷新
+ require.Equal(t, int64(2), refresher.refreshCalls.Load())
+}
diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go
index a37e0d0ac..ae373a831 100644
--- a/backend/internal/service/token_refresh_service.go
+++ b/backend/internal/service/token_refresh_service.go
@@ -6,11 +6,19 @@ import (
"log/slog"
"strings"
"sync"
+ "sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
)
+const (
+ tokenRefreshDistributedLockPlatform = "token_refresh"
+ tokenRefreshDistributedLockMinTTL = 30 * time.Second
+ tokenRefreshDistributedLockMaxTTL = 10 * time.Minute
+ tokenRefreshDistributedLockTimeout = 2 * time.Second
+)
+
// TokenRefreshService OAuth token自动刷新服务
// 定期检查并刷新即将过期的token
type TokenRefreshService struct {
@@ -20,8 +28,9 @@ type TokenRefreshService struct {
cacheInvalidator TokenCacheInvalidator
schedulerCache SchedulerCache // 用于同步更新调度器缓存,解决 token 刷新后缓存不一致问题
- stopCh chan struct{}
- wg sync.WaitGroup
+ stopCh chan struct{}
+ stopOnce sync.Once
+ wg sync.WaitGroup
}
// NewTokenRefreshService 创建token刷新服务
@@ -87,7 +96,12 @@ func (s *TokenRefreshService) Start() {
// Stop 停止刷新服务
func (s *TokenRefreshService) Stop() {
- close(s.stopCh)
+ if s == nil {
+ return
+ }
+ s.stopOnce.Do(func() {
+ close(s.stopCh)
+ })
s.wg.Wait()
slog.Info("token_refresh.service_stopped")
}
@@ -118,9 +132,24 @@ func (s *TokenRefreshService) refreshLoop() {
}
}
+// refreshTask 封装一个待刷新的账号及其对应的刷新器
+type refreshTask struct {
+ account *Account
+ refresher TokenRefresher
+}
+
// processRefresh 执行一次刷新检查
+// 分两阶段:先串行收集需刷新的账号,再并行执行刷新(信号量限制并发数)
func (s *TokenRefreshService) processRefresh() {
- ctx := context.Background()
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+ go func() {
+ select {
+ case <-s.stopCh:
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
// 计算刷新窗口
refreshWindow := time.Duration(s.cfg.RefreshBeforeExpiryHours * float64(time.Hour))
@@ -128,68 +157,188 @@ func (s *TokenRefreshService) processRefresh() {
// 获取所有active状态的账号
accounts, err := s.listActiveAccounts(ctx)
if err != nil {
+ if ctx.Err() != nil {
+ slog.Info("token_refresh.cycle_interrupted_by_stop")
+ return
+ }
slog.Error("token_refresh.list_accounts_failed", "error", err)
return
}
totalAccounts := len(accounts)
- oauthAccounts := 0 // 可刷新的OAuth账号数
- needsRefresh := 0 // 需要刷新的账号数
- refreshed, failed := 0, 0
+
+ // Phase 1: 收集需要刷新的账号(轻量操作,仍串行)
+ var tasks []refreshTask
+ oauthAccounts := 0
for i := range accounts {
account := &accounts[i]
-
- // 遍历所有刷新器,找到能处理此账号的
for _, refresher := range s.refreshers {
if !refresher.CanRefresh(account) {
continue
}
-
oauthAccounts++
-
- // 检查是否需要刷新
- if !refresher.NeedsRefresh(account, refreshWindow) {
- break // 不需要刷新,跳过
+ if refresher.NeedsRefresh(account, refreshWindow) {
+ if !s.tryAcquireDistributedRefreshLock(ctx, account) {
+ break
+ }
+ tasks = append(tasks, refreshTask{account: account, refresher: refresher})
}
+ break // 每个账号只由一个refresher处理
+ }
+ }
- needsRefresh++
+ needsRefresh := len(tasks)
+ if needsRefresh == 0 {
+ slog.Debug("token_refresh.cycle_completed",
+ "total", totalAccounts, "oauth", oauthAccounts,
+ "needs_refresh", 0, "refreshed", 0, "failed", 0)
+ return
+ }
+
+ // Phase 2: 并行刷新(带信号量限制)
+ const maxConcurrency = 10
+ sem := make(chan struct{}, maxConcurrency)
+ var refreshed, failed atomic.Int64
+ var wg sync.WaitGroup
+ interrupted := false
+
+submitLoop:
+ for _, task := range tasks {
+ // 检查停止信号
+ select {
+ case <-s.stopCh:
+ slog.Info("token_refresh.cycle_interrupted_by_stop")
+ interrupted = true
+ break submitLoop
+ default:
+ }
+
+ select {
+ case sem <- struct{}{}: // 获取信号量
+ case <-s.stopCh:
+ slog.Info("token_refresh.cycle_interrupted_by_stop")
+ interrupted = true
+ break submitLoop
+ case <-ctx.Done():
+ interrupted = true
+ break submitLoop
+ }
- // 执行刷新
- if err := s.refreshWithRetry(ctx, account, refresher); err != nil {
+ wg.Add(1)
+ go func(t refreshTask) {
+ defer func() {
+ <-sem // 释放信号量
+ wg.Done()
+ }()
+
+ if err := s.refreshWithRetry(ctx, t.account, t.refresher); err != nil {
slog.Warn("token_refresh.account_refresh_failed",
- "account_id", account.ID,
- "account_name", account.Name,
+ "account_id", t.account.ID,
+ "account_name", t.account.Name,
"error", err,
)
- failed++
+ failed.Add(1)
} else {
slog.Info("token_refresh.account_refreshed",
- "account_id", account.ID,
- "account_name", account.Name,
+ "account_id", t.account.ID,
+ "account_name", t.account.Name,
)
- refreshed++
+ refreshed.Add(1)
}
-
- // 每个账号只由一个refresher处理
- break
- }
+ }(task)
}
- // 无刷新活动时降级为 Debug,有实际刷新活动时保持 Info
- if needsRefresh == 0 && failed == 0 {
+ wg.Wait()
+
+ r, f := int(refreshed.Load()), int(failed.Load())
+ if interrupted {
+ slog.Info("token_refresh.cycle_wait_completed_after_stop",
+ "needs_refresh", needsRefresh, "refreshed", r, "failed", f)
+ }
+ if needsRefresh == 0 && f == 0 {
slog.Debug("token_refresh.cycle_completed",
"total", totalAccounts, "oauth", oauthAccounts,
- "needs_refresh", needsRefresh, "refreshed", refreshed, "failed", failed)
+ "needs_refresh", needsRefresh, "refreshed", r, "failed", f)
} else {
slog.Info("token_refresh.cycle_completed",
- "total", totalAccounts,
- "oauth", oauthAccounts,
- "needs_refresh", needsRefresh,
- "refreshed", refreshed,
- "failed", failed,
+ "total", totalAccounts, "oauth", oauthAccounts,
+ "needs_refresh", needsRefresh, "refreshed", r, "failed", f)
+ }
+}
+
+func (s *TokenRefreshService) tryAcquireDistributedRefreshLock(ctx context.Context, account *Account) bool {
+ if s == nil || account == nil || account.ID <= 0 || s.schedulerCache == nil {
+ return true
+ }
+ lockTTL := s.tokenRefreshDistributedLockTTL()
+ if lockTTL <= 0 {
+ return true
+ }
+ lockCtx := ctx
+ if lockCtx == nil {
+ lockCtx = context.Background()
+ }
+ lockCtx, cancel := context.WithTimeout(lockCtx, tokenRefreshDistributedLockTimeout)
+ defer cancel()
+
+ lockBucket := SchedulerBucket{
+ GroupID: account.ID,
+ Platform: tokenRefreshDistributedLockPlatform,
+ Mode: normalizeTokenRefreshLockMode(account.Platform),
+ }
+ locked, err := s.schedulerCache.TryLockBucket(lockCtx, lockBucket, lockTTL)
+ if err != nil {
+ if ctx != nil && ctx.Err() != nil {
+ slog.Info("token_refresh.distributed_lock_canceled",
+ "account_id", account.ID,
+ "platform", account.Platform,
+ "error", err,
+ )
+ return false
+ }
+ slog.Warn("token_refresh.distributed_lock_failed",
+ "account_id", account.ID,
+ "platform", account.Platform,
+ "error", err,
+ "fail_open", true,
+ )
+ return true
+ }
+ if !locked {
+ slog.Debug("token_refresh.distributed_lock_held",
+ "account_id", account.ID,
+ "platform", account.Platform,
)
+ return false
+ }
+ return true
+}
+
+func normalizeTokenRefreshLockMode(platform string) string {
+ mode := strings.TrimSpace(platform)
+ if mode == "" {
+ return "unknown"
+ }
+ return mode
+}
+
+func (s *TokenRefreshService) tokenRefreshDistributedLockTTL() time.Duration {
+ if s == nil || s.cfg == nil {
+ return tokenRefreshDistributedLockMinTTL
}
+ checkInterval := time.Duration(s.cfg.CheckIntervalMinutes) * time.Minute
+ if checkInterval <= 0 {
+ return tokenRefreshDistributedLockMinTTL
+ }
+ ttl := checkInterval / 2
+ if ttl < tokenRefreshDistributedLockMinTTL {
+ ttl = tokenRefreshDistributedLockMinTTL
+ }
+ if ttl > tokenRefreshDistributedLockMaxTTL {
+ ttl = tokenRefreshDistributedLockMaxTTL
+ }
+ return ttl
}
// listActiveAccounts 获取所有active状态的账号
@@ -281,7 +430,9 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
if attempt < s.cfg.MaxRetries {
// 指数退避:2^(attempt-1) * baseSeconds
backoff := time.Duration(s.cfg.RetryBackoffSeconds) * time.Second * time.Duration(1<<(attempt-1))
- time.Sleep(backoff)
+ if err := s.waitRetryBackoff(ctx, backoff); err != nil {
+ return err
+ }
}
}
@@ -306,6 +457,23 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc
return lastErr
}
+func (s *TokenRefreshService) waitRetryBackoff(ctx context.Context, backoff time.Duration) error {
+ if backoff <= 0 {
+ return nil
+ }
+ timer := time.NewTimer(backoff)
+ defer timer.Stop()
+
+ select {
+ case <-timer.C:
+ return nil
+ case <-s.stopCh:
+ return context.Canceled
+ case <-ctx.Done():
+ return ctx.Err()
+ }
+}
+
// isNonRetryableRefreshError 判断是否为不可重试的刷新错误
// 这些错误通常表示凭证已失效或配置确实缺失,需要用户重新授权
// 注意:missing_project_id 错误只在真正缺失(从未获取过)时返回,临时获取失败不会返回此错误
diff --git a/backend/internal/service/token_refresh_service_test.go b/backend/internal/service/token_refresh_service_test.go
index 8e16c6f5d..142a5c169 100644
--- a/backend/internal/service/token_refresh_service_test.go
+++ b/backend/internal/service/token_refresh_service_test.go
@@ -14,10 +14,11 @@ import (
type tokenRefreshAccountRepo struct {
mockAccountRepoForGemini
- updateCalls int
- setErrorCalls int
- lastAccount *Account
- updateErr error
+ updateCalls int
+ setErrorCalls int
+ lastAccount *Account
+ updateErr error
+ activeAccounts []Account
}
func (r *tokenRefreshAccountRepo) Update(ctx context.Context, account *Account) error {
@@ -31,6 +32,15 @@ func (r *tokenRefreshAccountRepo) SetError(ctx context.Context, id int64, errorM
return nil
}
+func (r *tokenRefreshAccountRepo) ListActive(ctx context.Context) ([]Account, error) {
+ if len(r.activeAccounts) == 0 {
+ return nil, nil
+ }
+ out := make([]Account, 0, len(r.activeAccounts))
+ out = append(out, r.activeAccounts...)
+ return out, nil
+}
+
type tokenCacheInvalidatorStub struct {
calls int
err error
@@ -42,8 +52,9 @@ func (s *tokenCacheInvalidatorStub) InvalidateToken(ctx context.Context, account
}
type tokenRefresherStub struct {
- credentials map[string]any
- err error
+ credentials map[string]any
+ err error
+ refreshCalls int
}
func (r *tokenRefresherStub) CanRefresh(account *Account) bool {
@@ -55,12 +66,41 @@ func (r *tokenRefresherStub) NeedsRefresh(account *Account, refreshWindowDuratio
}
func (r *tokenRefresherStub) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
+ r.refreshCalls++
if r.err != nil {
return nil, r.err
}
return r.credentials, nil
}
+type tokenRefreshSchedulerLockStub struct {
+ SchedulerCache
+ lockByAccount map[int64]bool
+ err error
+ calls []SchedulerBucket
+}
+
+func (s *tokenRefreshSchedulerLockStub) TryLockBucket(ctx context.Context, bucket SchedulerBucket, ttl time.Duration) (bool, error) {
+ _ = ctx
+ _ = ttl
+ s.calls = append(s.calls, bucket)
+ if s.err != nil {
+ return false, s.err
+ }
+ if s.lockByAccount != nil {
+ if ok, exists := s.lockByAccount[bucket.GroupID]; exists {
+ return ok, nil
+ }
+ }
+ return true, nil
+}
+
+func (s *tokenRefreshSchedulerLockStub) SetAccount(ctx context.Context, account *Account) error {
+ _ = ctx
+ _ = account
+ return nil
+}
+
func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{}
@@ -89,6 +129,96 @@ func TestTokenRefreshService_RefreshWithRetry_InvalidatesCache(t *testing.T) {
require.Equal(t, "new-token", account.GetCredential("access_token"))
}
+func TestTokenRefreshService_ProcessRefresh_SkipsWhenDistributedLockHeld(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{
+ activeAccounts: []Account{
+ {ID: 31, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true},
+ },
+ }
+ lockStub := &tokenRefreshSchedulerLockStub{
+ lockByAccount: map[int64]bool{31: false},
+ }
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 2,
+ RefreshBeforeExpiryHours: 1,
+ SyncLinkedSoraAccounts: false,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}}
+ service.refreshers = []TokenRefresher{refresher}
+
+ service.processRefresh()
+
+ require.Len(t, lockStub.calls, 1)
+ require.Equal(t, int64(31), lockStub.calls[0].GroupID)
+ require.Equal(t, tokenRefreshDistributedLockPlatform, lockStub.calls[0].Platform)
+ require.Equal(t, PlatformOpenAI, lockStub.calls[0].Mode)
+ require.Equal(t, 0, refresher.refreshCalls, "lock held by another instance should skip refresh")
+ require.Equal(t, 0, repo.updateCalls)
+}
+
+func TestTokenRefreshService_ProcessRefresh_RefreshesWhenDistributedLockAcquired(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{
+ activeAccounts: []Account{
+ {ID: 32, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true},
+ },
+ }
+ lockStub := &tokenRefreshSchedulerLockStub{
+ lockByAccount: map[int64]bool{32: true},
+ }
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 2,
+ RefreshBeforeExpiryHours: 1,
+ SyncLinkedSoraAccounts: false,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}}
+ service.refreshers = []TokenRefresher{refresher}
+
+ service.processRefresh()
+
+ require.Len(t, lockStub.calls, 1)
+ require.Equal(t, 1, refresher.refreshCalls)
+ require.Equal(t, 1, repo.updateCalls, "lock acquired should allow refresh")
+}
+
+func TestTokenRefreshService_ProcessRefresh_FailOpenWhenDistributedLockError(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{
+ activeAccounts: []Account{
+ {ID: 33, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Status: StatusActive, Schedulable: true},
+ },
+ }
+ lockStub := &tokenRefreshSchedulerLockStub{
+ err: errors.New("redis unavailable"),
+ }
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ CheckIntervalMinutes: 2,
+ RefreshBeforeExpiryHours: 1,
+ SyncLinkedSoraAccounts: false,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, lockStub, cfg)
+ refresher := &tokenRefresherStub{credentials: map[string]any{"access_token": "new-token"}}
+ service.refreshers = []TokenRefresher{refresher}
+
+ service.processRefresh()
+
+ require.Len(t, lockStub.calls, 1)
+ require.Equal(t, 1, refresher.refreshCalls, "lock backend error should fail-open and continue refresh")
+ require.Equal(t, 1, repo.updateCalls)
+}
+
func TestTokenRefreshService_RefreshWithRetry_InvalidatorErrorIgnored(t *testing.T) {
repo := &tokenRefreshAccountRepo{}
invalidator := &tokenCacheInvalidatorStub{err: errors.New("invalidate failed")}
@@ -359,3 +489,56 @@ func TestIsNonRetryableRefreshError(t *testing.T) {
})
}
}
+
+func TestTokenRefreshService_Stop_Idempotent(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 1,
+ RetryBackoffSeconds: 0,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg)
+
+ require.NotPanics(t, func() {
+ service.Stop()
+ service.Stop()
+ })
+}
+
+func TestTokenRefreshService_RefreshWithRetry_StopInterruptsBackoff(t *testing.T) {
+ repo := &tokenRefreshAccountRepo{}
+ cfg := &config.Config{
+ TokenRefresh: config.TokenRefreshConfig{
+ MaxRetries: 3,
+ RetryBackoffSeconds: 5,
+ },
+ }
+ service := NewTokenRefreshService(repo, nil, nil, nil, nil, nil, nil, cfg)
+ account := &Account{
+ ID: 21,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ }
+ refresher := &tokenRefresherStub{
+ err: errors.New("refresh failed"),
+ }
+
+ start := time.Now()
+ done := make(chan error, 1)
+ go func() {
+ done <- service.refreshWithRetry(context.Background(), account, refresher)
+ }()
+
+ time.Sleep(80 * time.Millisecond)
+ service.Stop()
+
+ select {
+ case err := <-done:
+ require.ErrorIs(t, err, context.Canceled)
+ case <-time.After(600 * time.Millisecond):
+ t.Fatal("refreshWithRetry should exit quickly after service stop")
+ }
+ require.Less(t, time.Since(start), time.Second)
+ require.Equal(t, 0, repo.setErrorCalls, "stop 中断时不应落错误状态")
+}
diff --git a/backend/internal/service/token_refresher.go b/backend/internal/service/token_refresher.go
index 0dd3cf45d..a4fc1c70e 100644
--- a/backend/internal/service/token_refresher.go
+++ b/backend/internal/service/token_refresher.go
@@ -7,6 +7,18 @@ import (
"time"
)
+const openAISoraSyncConcurrencyLimit = 8
+
+// needsRefreshWithoutExpiry 在 expires_at 缺失时判断是否需要刷新。
+// 通过 Account.UpdatedAt 避免每轮刷新周期都发起无效刷新:
+// 如果账号在 refreshWindow 内曾被更新,说明最近可能已刷新过,跳过本轮。
+func needsRefreshWithoutExpiry(account *Account, refreshWindow time.Duration) bool {
+ if refreshWindow <= 0 {
+ return true
+ }
+ return time.Since(account.UpdatedAt) >= refreshWindow
+}
+
// TokenRefresher 定义平台特定的token刷新策略接口
// 通过此接口可以扩展支持不同平台(Anthropic/OpenAI/Gemini)
type TokenRefresher interface {
@@ -46,7 +58,8 @@ func (r *ClaudeTokenRefresher) CanRefresh(account *Account) bool {
func (r *ClaudeTokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
- return false
+ // 无过期时间:如果账号近期已更新(可能刚刷新过),跳过本轮
+ return needsRefreshWithoutExpiry(account, refreshWindow)
}
return time.Until(*expiresAt) < refreshWindow
}
@@ -87,6 +100,10 @@ type OpenAITokenRefresher struct {
accountRepo AccountRepository
soraAccountRepo SoraAccountRepository // Sora 扩展表仓储,用于双表同步
syncLinkedSora bool
+ syncLinkedSoraSem chan struct{}
+
+ // test hook: override sync execution target when needed.
+ syncLinkedSoraAccountsFn func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any)
}
// NewOpenAITokenRefresher 创建 OpenAI token刷新器
@@ -94,6 +111,7 @@ func NewOpenAITokenRefresher(openaiOAuthService *OpenAIOAuthService, accountRepo
return &OpenAITokenRefresher{
openaiOAuthService: openaiOAuthService,
accountRepo: accountRepo,
+ syncLinkedSoraSem: make(chan struct{}, openAISoraSyncConcurrencyLimit),
}
}
@@ -120,7 +138,8 @@ func (r *OpenAITokenRefresher) CanRefresh(account *Account) bool {
func (r *OpenAITokenRefresher) NeedsRefresh(account *Account, refreshWindow time.Duration) bool {
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
- return false
+ // 无过期时间:如果账号近期已更新(可能刚刷新过),跳过本轮
+ return needsRefreshWithoutExpiry(account, refreshWindow)
}
return time.Until(*expiresAt) < refreshWindow
@@ -147,12 +166,58 @@ func (r *OpenAITokenRefresher) Refresh(ctx context.Context, account *Account) (m
// 异步同步关联的 Sora 账号(不阻塞主流程)
if r.accountRepo != nil && r.syncLinkedSora {
- go r.syncLinkedSoraAccounts(context.Background(), account.ID, newCredentials)
+ syncCredentials := copyCredentialsMap(newCredentials)
+ syncFn := r.syncLinkedSoraAccounts
+ if r.syncLinkedSoraAccountsFn != nil {
+ syncFn = r.syncLinkedSoraAccountsFn
+ }
+ if r.tryAcquireSyncLinkedSoraSlot() {
+ go func() {
+ defer r.releaseSyncLinkedSoraSlot()
+ syncFn(context.Background(), account.ID, syncCredentials)
+ }()
+ } else {
+ // 达到并发上限时回退为同步执行,避免 goroutine 无界堆积。
+ syncFn(ctx, account.ID, syncCredentials)
+ }
}
return newCredentials, nil
}
+func copyCredentialsMap(src map[string]any) map[string]any {
+ if len(src) == 0 {
+ return map[string]any{}
+ }
+ dst := make(map[string]any, len(src))
+ for k, v := range src {
+ dst[k] = v
+ }
+ return dst
+}
+
+func (r *OpenAITokenRefresher) tryAcquireSyncLinkedSoraSlot() bool {
+ if r == nil || r.syncLinkedSoraSem == nil {
+ return false
+ }
+ select {
+ case r.syncLinkedSoraSem <- struct{}{}:
+ return true
+ default:
+ return false
+ }
+}
+
+func (r *OpenAITokenRefresher) releaseSyncLinkedSoraSlot() {
+ if r == nil || r.syncLinkedSoraSem == nil {
+ return
+ }
+ select {
+ case <-r.syncLinkedSoraSem:
+ default:
+ }
+}
+
// syncLinkedSoraAccounts 同步关联的 Sora 账号的 token(双表同步)
// 该方法异步执行,失败只记录日志,不影响主流程
//
diff --git a/backend/internal/service/token_refresher_test.go b/backend/internal/service/token_refresher_test.go
index 264d79125..27d786634 100644
--- a/backend/internal/service/token_refresher_test.go
+++ b/backend/internal/service/token_refresher_test.go
@@ -3,13 +3,38 @@
package service
import (
+ "context"
"strconv"
"testing"
"time"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/openai"
"github.com/stretchr/testify/require"
)
+type openAIOAuthClientStubForRefresher struct {
+ tokenResp *openai.TokenResponse
+ err error
+}
+
+func (s *openAIOAuthClientStubForRefresher) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
+ return nil, s.err
+}
+
+func (s *openAIOAuthClientStubForRefresher) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ return s.tokenResp, nil
+}
+
+func (s *openAIOAuthClientStubForRefresher) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
+ if s.err != nil {
+ return nil, s.err
+ }
+ return s.tokenResp, nil
+}
+
func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) {
refresher := &ClaudeTokenRefresher{}
refreshWindow := 30 * time.Minute
@@ -64,26 +89,26 @@ func TestClaudeTokenRefresher_NeedsRefresh(t *testing.T) {
{
name: "expires_at missing",
credentials: map[string]any{},
- wantRefresh: false,
+ wantRefresh: true,
},
{
name: "expires_at is nil",
credentials: map[string]any{
"expires_at": nil,
},
- wantRefresh: false,
+ wantRefresh: true,
},
{
name: "expires_at is invalid string",
credentials: map[string]any{
"expires_at": "invalid",
},
- wantRefresh: false,
+ wantRefresh: true,
},
{
name: "credentials is nil",
credentials: nil,
- wantRefresh: false,
+ wantRefresh: true,
},
}
@@ -179,6 +204,36 @@ func TestClaudeTokenRefresher_NeedsRefresh_OutsideWindow(t *testing.T) {
}
}
+func TestNeedsRefreshWithoutExpiry_RecentlyUpdated(t *testing.T) {
+ refreshWindow := 30 * time.Minute
+
+ t.Run("recently_updated_skips_refresh", func(t *testing.T) {
+ // 账号近期更新过(5 分钟前),不需要刷新
+ account := &Account{
+ Platform: PlatformAnthropic,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{},
+ UpdatedAt: time.Now().Add(-5 * time.Minute),
+ }
+ refresher := &ClaudeTokenRefresher{}
+ require.False(t, refresher.NeedsRefresh(account, refreshWindow),
+ "近期更新过的账号无 expires_at 时不应刷新")
+ })
+
+ t.Run("old_updated_needs_refresh", func(t *testing.T) {
+ // 账号很久没更新(2 小时前),需要刷新
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{},
+ UpdatedAt: time.Now().Add(-2 * time.Hour),
+ }
+ refresher := &OpenAITokenRefresher{}
+ require.True(t, refresher.NeedsRefresh(account, refreshWindow),
+ "长期未更新的账号无 expires_at 时应刷新")
+ })
+}
+
func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
refresher := &ClaudeTokenRefresher{}
@@ -266,3 +321,159 @@ func TestOpenAITokenRefresher_CanRefresh(t *testing.T) {
})
}
}
+
+func TestOpenAITokenRefresher_NeedsRefresh(t *testing.T) {
+ refresher := &OpenAITokenRefresher{}
+ refreshWindow := 30 * time.Minute
+
+ tests := []struct {
+ name string
+ credentials map[string]any
+ wantRefresh bool
+ }{
+ {
+ name: "expires_at missing",
+ credentials: map[string]any{
+ "access_token": "token",
+ },
+ wantRefresh: true,
+ },
+ {
+ name: "expires_at invalid",
+ credentials: map[string]any{
+ "expires_at": "invalid",
+ },
+ wantRefresh: true,
+ },
+ {
+ name: "expires_at expired",
+ credentials: map[string]any{
+ "expires_at": strconv.FormatInt(time.Now().Add(-time.Minute).Unix(), 10),
+ },
+ wantRefresh: true,
+ },
+ {
+ name: "expires_at far future",
+ credentials: map[string]any{
+ "expires_at": strconv.FormatInt(time.Now().Add(2*time.Hour).Unix(), 10),
+ },
+ wantRefresh: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ account := &Account{
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: tt.credentials,
+ }
+ require.Equal(t, tt.wantRefresh, refresher.NeedsRefresh(account, refreshWindow))
+ })
+ }
+}
+
+func TestOpenAITokenRefresher_Refresh_AsyncSyncUsesCopiedCredentials(t *testing.T) {
+ oauthSvc := NewOpenAIOAuthService(nil, &openAIOAuthClientStubForRefresher{
+ tokenResp: &openai.TokenResponse{
+ AccessToken: "new_access_token",
+ RefreshToken: "new_refresh_token",
+ ExpiresIn: 3600,
+ },
+ })
+ refresher := NewOpenAITokenRefresher(oauthSvc, &mockAccountRepoForGemini{})
+ refresher.SetSyncLinkedSoraAccounts(true)
+ refresher.syncLinkedSoraSem = make(chan struct{}, 1)
+
+ readNow := make(chan struct{})
+ seenValue := make(chan string, 1)
+ refresher.syncLinkedSoraAccountsFn = func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) {
+ <-readNow
+ if v, ok := newCredentials["custom"].(string); ok {
+ seenValue <- v
+ return
+ }
+ seenValue <- ""
+ }
+
+ account := &Account{
+ ID: 1001,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "refresh_token": "old_refresh_token",
+ "client_id": "test-client",
+ "custom": "original",
+ },
+ }
+
+ newCredentials, err := refresher.Refresh(context.Background(), account)
+ require.NoError(t, err)
+ require.NotNil(t, newCredentials)
+
+ newCredentials["custom"] = "mutated_after_return"
+ close(readNow)
+
+ select {
+ case got := <-seenValue:
+ require.Equal(t, "original", got, "异步同步应使用 credentials 副本,避免并发写污染")
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("timed out waiting for sync hook")
+ }
+}
+
+func TestOpenAITokenRefresher_Refresh_FallsBackToSyncWhenLimiterFull(t *testing.T) {
+ oauthSvc := NewOpenAIOAuthService(nil, &openAIOAuthClientStubForRefresher{
+ tokenResp: &openai.TokenResponse{
+ AccessToken: "new_access_token",
+ RefreshToken: "new_refresh_token",
+ ExpiresIn: 3600,
+ },
+ })
+ refresher := NewOpenAITokenRefresher(oauthSvc, &mockAccountRepoForGemini{})
+ refresher.SetSyncLinkedSoraAccounts(true)
+ refresher.syncLinkedSoraSem = make(chan struct{}, 1)
+ refresher.syncLinkedSoraSem <- struct{}{} // 填满 limiter,强制走同步降级路径
+
+ entered := make(chan struct{})
+ releaseSync := make(chan struct{})
+ refresher.syncLinkedSoraAccountsFn = func(ctx context.Context, openaiAccountID int64, newCredentials map[string]any) {
+ close(entered)
+ <-releaseSync
+ }
+
+ account := &Account{
+ ID: 1002,
+ Platform: PlatformOpenAI,
+ Type: AccountTypeOAuth,
+ Credentials: map[string]any{
+ "refresh_token": "old_refresh_token",
+ "client_id": "test-client",
+ },
+ }
+
+ done := make(chan struct{})
+ go func() {
+ _, _ = refresher.Refresh(context.Background(), account)
+ close(done)
+ }()
+
+ select {
+ case <-entered:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("sync hook was not invoked")
+ }
+
+ select {
+ case <-done:
+ t.Fatal("Refresh should block when falling back to synchronous linked-sora sync")
+ default:
+ }
+
+ close(releaseSync)
+ select {
+ case <-done:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("Refresh did not finish after releasing synchronous sync hook")
+ }
+}
diff --git a/backend/internal/service/usage_billing_compensation_service.go b/backend/internal/service/usage_billing_compensation_service.go
new file mode 100644
index 000000000..47ea71618
--- /dev/null
+++ b/backend/internal/service/usage_billing_compensation_service.go
@@ -0,0 +1,256 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "log/slog"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+)
+
+const (
+ defaultUsageBillingCompensationInterval = 20 * time.Second
+ defaultUsageBillingCompensationBatchSize = 64
+ defaultUsageBillingCompensationTaskTimout = 8 * time.Second
+ defaultUsageBillingCompensationStaleAfter = 3 * time.Minute
+)
+
+// UsageBillingCompensationService retries pending usage charges in billing_usage_entries.
+// It only runs when usageLogRepo supports UsageBillingEntryStore.
+type UsageBillingCompensationService struct {
+ usageLogRepo UsageLogRepository
+ userRepo UserRepository
+ userSubRepo UserSubscriptionRepository
+ billingCache *BillingCacheService
+ cfg *config.Config
+
+ startOnce sync.Once
+ stopOnce sync.Once
+ stopCh chan struct{}
+}
+
+func NewUsageBillingCompensationService(
+ usageLogRepo UsageLogRepository,
+ userRepo UserRepository,
+ userSubRepo UserSubscriptionRepository,
+ billingCache *BillingCacheService,
+ cfg *config.Config,
+) *UsageBillingCompensationService {
+ return &UsageBillingCompensationService{
+ usageLogRepo: usageLogRepo,
+ userRepo: userRepo,
+ userSubRepo: userSubRepo,
+ billingCache: billingCache,
+ cfg: cfg,
+ stopCh: make(chan struct{}),
+ }
+}
+
+func (s *UsageBillingCompensationService) Start() {
+ if s == nil || s.store() == nil {
+ return
+ }
+ if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
+ return
+ }
+ s.startOnce.Do(func() {
+ slog.Info("usage_billing_compensation.started",
+ "interval", defaultUsageBillingCompensationInterval.String(),
+ "batch_size", defaultUsageBillingCompensationBatchSize,
+ )
+ go s.runLoop()
+ })
+}
+
+func (s *UsageBillingCompensationService) Stop() {
+ if s == nil {
+ return
+ }
+ s.stopOnce.Do(func() {
+ close(s.stopCh)
+ slog.Info("usage_billing_compensation.stopped")
+ })
+}
+
+func (s *UsageBillingCompensationService) runLoop() {
+ ticker := time.NewTicker(defaultUsageBillingCompensationInterval)
+ defer ticker.Stop()
+
+ s.processOnce()
+
+ for {
+ select {
+ case <-ticker.C:
+ s.processOnce()
+ case <-s.stopCh:
+ return
+ }
+ }
+}
+
+func (s *UsageBillingCompensationService) processOnce() {
+ store := s.store()
+ if store == nil {
+ return
+ }
+
+ ctx, cancel := context.WithTimeout(context.Background(), defaultUsageBillingCompensationTaskTimout)
+ defer cancel()
+ go func() {
+ select {
+ case <-s.stopCh:
+ cancel()
+ case <-ctx.Done():
+ }
+ }()
+
+ entries, err := store.ClaimUsageBillingEntries(ctx, defaultUsageBillingCompensationBatchSize, defaultUsageBillingCompensationStaleAfter)
+ if err != nil {
+ slog.Warn("usage_billing_compensation.claim_failed", "error", err)
+ return
+ }
+ for i := range entries {
+ if ctx.Err() != nil {
+ return
+ }
+ s.processEntry(ctx, entries[i])
+ }
+}
+
+func (s *UsageBillingCompensationService) processEntry(ctx context.Context, entry UsageBillingEntry) {
+ if entry.Applied || entry.DeltaUSD <= 0 {
+ s.markApplied(ctx, entry)
+ return
+ }
+
+ if err := s.applyEntry(ctx, entry); err != nil {
+ s.markRetry(ctx, entry, err)
+ return
+ }
+}
+
+func (s *UsageBillingCompensationService) applyEntry(ctx context.Context, entry UsageBillingEntry) error {
+ switch entry.BillingType {
+ case BillingTypeSubscription:
+ return s.applySubscriptionEntry(ctx, entry)
+ default:
+ return s.applyBalanceEntry(ctx, entry)
+ }
+}
+
+func (s *UsageBillingCompensationService) applyBalanceEntry(ctx context.Context, entry UsageBillingEntry) error {
+ if s.userRepo == nil {
+ return errors.New("user repository unavailable")
+ }
+
+ cacheDeducted := false
+ if s.billingCache != nil {
+ if err := s.billingCache.DeductBalanceCache(ctx, entry.UserID, entry.DeltaUSD); err != nil {
+ slog.Warn("usage_billing_compensation.balance_cache_deduct_failed",
+ "entry_id", entry.ID,
+ "user_id", entry.UserID,
+ "amount", entry.DeltaUSD,
+ "error", err,
+ )
+ _ = s.billingCache.InvalidateUserBalance(ctx, entry.UserID)
+ } else {
+ cacheDeducted = true
+ }
+ }
+
+ if err := s.runWithTx(ctx, func(txCtx context.Context) error {
+ if err := s.userRepo.DeductBalance(txCtx, entry.UserID, entry.DeltaUSD); err != nil {
+ return err
+ }
+ return s.store().MarkUsageBillingEntryApplied(txCtx, entry.ID)
+ }); err != nil {
+ if s.billingCache != nil && cacheDeducted {
+ _ = s.billingCache.InvalidateUserBalance(ctx, entry.UserID)
+ }
+ return err
+ }
+
+ return nil
+}
+
+func (s *UsageBillingCompensationService) applySubscriptionEntry(ctx context.Context, entry UsageBillingEntry) error {
+ if s.userSubRepo == nil {
+ return errors.New("subscription repository unavailable")
+ }
+ if entry.SubscriptionID == nil {
+ return errors.New("subscription_id is nil for subscription billing")
+ }
+
+ if err := s.runWithTx(ctx, func(txCtx context.Context) error {
+ if err := s.userSubRepo.IncrementUsage(txCtx, *entry.SubscriptionID, entry.DeltaUSD); err != nil {
+ return err
+ }
+ return s.store().MarkUsageBillingEntryApplied(txCtx, entry.ID)
+ }); err != nil {
+ return err
+ }
+
+ if s.billingCache != nil {
+ sub, err := s.userSubRepo.GetByID(ctx, *entry.SubscriptionID)
+ if err == nil && sub != nil {
+ _ = s.billingCache.InvalidateSubscription(ctx, entry.UserID, sub.GroupID)
+ }
+ }
+
+ return nil
+}
+
+func (s *UsageBillingCompensationService) markApplied(ctx context.Context, entry UsageBillingEntry) {
+ store := s.store()
+ if store == nil {
+ return
+ }
+ if err := store.MarkUsageBillingEntryApplied(ctx, entry.ID); err != nil {
+ slog.Warn("usage_billing_compensation.mark_applied_failed", "entry_id", entry.ID, "error", err)
+ }
+}
+
+func (s *UsageBillingCompensationService) markRetry(ctx context.Context, entry UsageBillingEntry, cause error) {
+ store := s.store()
+ if store == nil {
+ return
+ }
+ errMsg := strings.TrimSpace(cause.Error())
+ if len(errMsg) > 500 {
+ errMsg = errMsg[:500]
+ }
+ backoff := usageBillingRetryBackoff(entry.AttemptCount)
+ nextRetryAt := time.Now().Add(backoff)
+ if err := store.MarkUsageBillingEntryRetry(ctx, entry.ID, nextRetryAt, errMsg); err != nil {
+ slog.Warn("usage_billing_compensation.mark_retry_failed",
+ "entry_id", entry.ID,
+ "next_retry_at", nextRetryAt,
+ "error", err,
+ )
+ return
+ }
+ slog.Warn("usage_billing_compensation.requeued",
+ "entry_id", entry.ID,
+ "attempt", entry.AttemptCount,
+ "next_retry_at", nextRetryAt,
+ "error", errMsg,
+ )
+}
+
+func (s *UsageBillingCompensationService) runWithTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ if runner, ok := s.usageLogRepo.(UsageBillingTxRunner); ok && runner != nil {
+ return runner.WithUsageBillingTx(ctx, fn)
+ }
+ return fn(ctx)
+}
+
+func (s *UsageBillingCompensationService) store() UsageBillingEntryStore {
+ store, ok := s.usageLogRepo.(UsageBillingEntryStore)
+ if !ok {
+ return nil
+ }
+ return store
+}
diff --git a/backend/internal/service/usage_billing_compensation_service_test.go b/backend/internal/service/usage_billing_compensation_service_test.go
new file mode 100644
index 000000000..c90ae5903
--- /dev/null
+++ b/backend/internal/service/usage_billing_compensation_service_test.go
@@ -0,0 +1,231 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ "github.com/stretchr/testify/require"
+)
+
+type usageBillingCompRepoStub struct {
+ UsageLogRepository
+
+ claimErr error
+ claims []UsageBillingEntry
+
+ markAppliedCalls int
+ markRetryCalls int
+ lastRetryID int64
+ lastRetryAt time.Time
+ lastRetryErr string
+ lastMarkAppliedCtx context.Context
+ lastMarkRetryCtx context.Context
+ lastTxCtx context.Context
+}
+
+func (s *usageBillingCompRepoStub) GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error) {
+ return nil, ErrUsageBillingEntryNotFound
+}
+
+func (s *usageBillingCompRepoStub) UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error) {
+ return entry, true, nil
+}
+
+func (s *usageBillingCompRepoStub) MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error {
+ s.markAppliedCalls++
+ s.lastMarkAppliedCtx = ctx
+ return nil
+}
+
+func (s *usageBillingCompRepoStub) MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error {
+ s.markRetryCalls++
+ s.lastRetryID = entryID
+ s.lastRetryAt = nextRetryAt
+ s.lastRetryErr = lastError
+ s.lastMarkRetryCtx = ctx
+ return nil
+}
+
+func (s *usageBillingCompRepoStub) ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error) {
+ if s.claimErr != nil {
+ return nil, s.claimErr
+ }
+ out := make([]UsageBillingEntry, len(s.claims))
+ copy(out, s.claims)
+ s.claims = nil
+ return out, nil
+}
+
+func (s *usageBillingCompRepoStub) WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error {
+ s.lastTxCtx = ctx
+ if fn == nil {
+ return nil
+ }
+ return fn(ctx)
+}
+
+type usageBillingCompUserRepoStub struct {
+ UserRepository
+
+ deductCalls int
+ deductErr error
+ lastDeductCtx context.Context
+}
+
+func (s *usageBillingCompUserRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
+ s.deductCalls++
+ s.lastDeductCtx = ctx
+ return s.deductErr
+}
+
+type usageBillingCompSubRepoStub struct {
+ UserSubscriptionRepository
+
+ incrementCalls int
+ incrementErr error
+ lastIncrementCtx context.Context
+}
+
+func (s *usageBillingCompSubRepoStub) IncrementUsage(ctx context.Context, id int64, costUSD float64) error {
+ s.incrementCalls++
+ s.lastIncrementCtx = ctx
+ return s.incrementErr
+}
+
+type usageBillingCompCtxKey string
+
+func TestUsageBillingCompensationService_ProcessOnceBalanceSuccess(t *testing.T) {
+ repo := &usageBillingCompRepoStub{
+ claims: []UsageBillingEntry{
+ {
+ ID: 1,
+ UsageLogID: 1001,
+ UserID: 2001,
+ BillingType: BillingTypeBalance,
+ DeltaUSD: 1.23,
+ AttemptCount: 1,
+ },
+ },
+ }
+ userRepo := &usageBillingCompUserRepoStub{}
+ subRepo := &usageBillingCompSubRepoStub{}
+ svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{})
+
+ svc.processOnce()
+
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.Equal(t, 1, repo.markAppliedCalls)
+ require.Equal(t, 0, repo.markRetryCalls)
+}
+
+func TestUsageBillingCompensationService_ProcessOnceBalanceFailureRequeues(t *testing.T) {
+ repo := &usageBillingCompRepoStub{
+ claims: []UsageBillingEntry{
+ {
+ ID: 2,
+ UsageLogID: 1002,
+ UserID: 2002,
+ BillingType: BillingTypeBalance,
+ DeltaUSD: 2.34,
+ AttemptCount: 2,
+ },
+ },
+ }
+ userRepo := &usageBillingCompUserRepoStub{deductErr: errors.New("db down")}
+ subRepo := &usageBillingCompSubRepoStub{}
+ svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{})
+
+ svc.processOnce()
+
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.Equal(t, 0, repo.markAppliedCalls)
+ require.Equal(t, 1, repo.markRetryCalls)
+ require.Equal(t, int64(2), repo.lastRetryID)
+ require.NotZero(t, repo.lastRetryAt)
+ require.Contains(t, repo.lastRetryErr, "db down")
+}
+
+func TestUsageBillingCompensationService_ProcessOnceSubscriptionSuccess(t *testing.T) {
+ subID := int64(4003)
+ repo := &usageBillingCompRepoStub{
+ claims: []UsageBillingEntry{
+ {
+ ID: 3,
+ UsageLogID: 1003,
+ UserID: 2003,
+ SubscriptionID: &subID,
+ BillingType: BillingTypeSubscription,
+ DeltaUSD: 3.45,
+ AttemptCount: 1,
+ },
+ },
+ }
+ userRepo := &usageBillingCompUserRepoStub{}
+ subRepo := &usageBillingCompSubRepoStub{}
+ svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{})
+
+ svc.processOnce()
+
+ require.Equal(t, 1, subRepo.incrementCalls)
+ require.Equal(t, 1, repo.markAppliedCalls)
+ require.Equal(t, 0, repo.markRetryCalls)
+}
+
+func TestUsageBillingCompensationService_ApplyBalanceEntryPropagatesContext(t *testing.T) {
+ repo := &usageBillingCompRepoStub{}
+ userRepo := &usageBillingCompUserRepoStub{}
+ subRepo := &usageBillingCompSubRepoStub{}
+ svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{})
+
+ entry := UsageBillingEntry{
+ ID: 10,
+ UsageLogID: 1010,
+ UserID: 2010,
+ BillingType: BillingTypeBalance,
+ DeltaUSD: 1.11,
+ }
+ key := usageBillingCompCtxKey("trace")
+ ctx := context.WithValue(context.Background(), key, "balance")
+
+ err := svc.applyBalanceEntry(ctx, entry)
+ require.NoError(t, err)
+ require.Equal(t, 1, userRepo.deductCalls)
+ require.NotNil(t, repo.lastTxCtx)
+ require.NotNil(t, userRepo.lastDeductCtx)
+ require.NotNil(t, repo.lastMarkAppliedCtx)
+ require.Equal(t, "balance", repo.lastTxCtx.Value(key))
+ require.Equal(t, "balance", userRepo.lastDeductCtx.Value(key))
+ require.Equal(t, "balance", repo.lastMarkAppliedCtx.Value(key))
+}
+
+func TestUsageBillingCompensationService_ApplySubscriptionEntryPropagatesContext(t *testing.T) {
+ subID := int64(4010)
+ repo := &usageBillingCompRepoStub{}
+ userRepo := &usageBillingCompUserRepoStub{}
+ subRepo := &usageBillingCompSubRepoStub{}
+ svc := NewUsageBillingCompensationService(repo, userRepo, subRepo, nil, &config.Config{})
+
+ entry := UsageBillingEntry{
+ ID: 11,
+ UsageLogID: 1011,
+ UserID: 2011,
+ SubscriptionID: &subID,
+ BillingType: BillingTypeSubscription,
+ DeltaUSD: 2.22,
+ }
+ key := usageBillingCompCtxKey("trace")
+ ctx := context.WithValue(context.Background(), key, "subscription")
+
+ err := svc.applySubscriptionEntry(ctx, entry)
+ require.NoError(t, err)
+ require.Equal(t, 1, subRepo.incrementCalls)
+ require.NotNil(t, repo.lastTxCtx)
+ require.NotNil(t, subRepo.lastIncrementCtx)
+ require.NotNil(t, repo.lastMarkAppliedCtx)
+ require.Equal(t, "subscription", repo.lastTxCtx.Value(key))
+ require.Equal(t, "subscription", subRepo.lastIncrementCtx.Value(key))
+ require.Equal(t, "subscription", repo.lastMarkAppliedCtx.Value(key))
+}
diff --git a/backend/internal/service/usage_billing_entry.go b/backend/internal/service/usage_billing_entry.go
new file mode 100644
index 000000000..a24714082
--- /dev/null
+++ b/backend/internal/service/usage_billing_entry.go
@@ -0,0 +1,60 @@
+package service
+
+import (
+ "context"
+ "errors"
+ "time"
+)
+
+var ErrUsageBillingEntryNotFound = errors.New("usage billing entry not found")
+
+type UsageBillingEntryStatus int16
+
+const (
+ UsageBillingEntryStatusPending UsageBillingEntryStatus = 0
+ UsageBillingEntryStatusProcessing UsageBillingEntryStatus = 1
+ UsageBillingEntryStatusApplied UsageBillingEntryStatus = 2
+)
+
+type UsageBillingEntry struct {
+ ID int64
+ UsageLogID int64
+ UserID int64
+ APIKeyID int64
+ SubscriptionID *int64
+ BillingType int8
+ Applied bool
+ DeltaUSD float64
+ Status UsageBillingEntryStatus
+ AttemptCount int
+ NextRetryAt time.Time
+ UpdatedAt time.Time
+ CreatedAt time.Time
+ LastError *string
+}
+
+type UsageBillingEntryStore interface {
+ GetUsageBillingEntryByUsageLogID(ctx context.Context, usageLogID int64) (*UsageBillingEntry, error)
+ UpsertUsageBillingEntry(ctx context.Context, entry *UsageBillingEntry) (*UsageBillingEntry, bool, error)
+ MarkUsageBillingEntryApplied(ctx context.Context, entryID int64) error
+ MarkUsageBillingEntryRetry(ctx context.Context, entryID int64, nextRetryAt time.Time, lastError string) error
+ ClaimUsageBillingEntries(ctx context.Context, limit int, processingStaleAfter time.Duration) ([]UsageBillingEntry, error)
+}
+
+type UsageBillingTxRunner interface {
+ WithUsageBillingTx(ctx context.Context, fn func(txCtx context.Context) error) error
+}
+
+func usageBillingRetryBackoff(attempt int) time.Duration {
+ if attempt <= 1 {
+ return 30 * time.Second
+ }
+ backoff := 30 * time.Second
+ for i := 1; i < attempt && backoff < 30*time.Minute; i++ {
+ backoff *= 2
+ }
+ if backoff > 30*time.Minute {
+ return 30 * time.Minute
+ }
+ return backoff
+}
diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go
index b0eccb71b..d09c504ee 100644
--- a/backend/internal/service/wire.go
+++ b/backend/internal/service/wire.go
@@ -333,7 +333,7 @@ var ProviderSet = wire.NewSet(
ProvideRateLimitService,
NewAccountUsageService,
NewAccountTestService,
- ProvideSettingService,
+ NewSettingService,
NewDataManagementService,
ProvideOpsSystemLogSink,
NewOpsService,
diff --git a/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql
index d0ed5d6dd..cad2971b1 100644
--- a/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql
+++ b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql
@@ -1,45 +1,36 @@
--- Add gemini-3.1-flash-image and gemini-3.1-flash-image-preview to model_mapping
+-- Add gemini-3.1-flash-image mapping keys without wiping existing custom mappings.
--
-- Background:
--- Antigravity now supports gemini-3.1-flash-image as the latest image generation model,
--- replacing the previous gemini-3-pro-image.
+-- Antigravity now supports gemini-3.1-flash-image as the latest image generation model.
+-- Existing accounts may still contain gemini-3-pro-image aliases.
--
-- Strategy:
--- Directly overwrite the entire model_mapping with updated mappings
--- This ensures consistency with DefaultAntigravityModelMapping in constants.go
+-- Incrementally upsert only image-related keys in credentials.model_mapping:
+-- 1) add canonical 3.1 image keys
+-- 2) keep legacy 3-pro-image keys but remap them to 3.1 image for compatibility
+-- This preserves user custom mappings and avoids full mapping overwrite.
UPDATE accounts
SET credentials = jsonb_set(
- credentials,
- '{model_mapping}',
- '{
- "claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
- "claude-opus-4-6": "claude-opus-4-6-thinking",
- "claude-opus-4-5-thinking": "claude-opus-4-6-thinking",
- "claude-opus-4-5-20251101": "claude-opus-4-6-thinking",
- "claude-sonnet-4-6": "claude-sonnet-4-6",
- "claude-sonnet-4-5": "claude-sonnet-4-5",
- "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
- "claude-sonnet-4-5-20250929": "claude-sonnet-4-5",
- "claude-haiku-4-5": "claude-sonnet-4-5",
- "claude-haiku-4-5-20251001": "claude-sonnet-4-5",
- "gemini-2.5-flash": "gemini-2.5-flash",
- "gemini-2.5-flash-lite": "gemini-2.5-flash-lite",
- "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking",
- "gemini-2.5-pro": "gemini-2.5-pro",
- "gemini-3-flash": "gemini-3-flash",
- "gemini-3-pro-high": "gemini-3-pro-high",
- "gemini-3-pro-low": "gemini-3-pro-low",
- "gemini-3-flash-preview": "gemini-3-flash",
- "gemini-3-pro-preview": "gemini-3-pro-high",
- "gemini-3.1-pro-high": "gemini-3.1-pro-high",
- "gemini-3.1-pro-low": "gemini-3.1-pro-low",
- "gemini-3.1-pro-preview": "gemini-3.1-pro-high",
- "gemini-3.1-flash-image": "gemini-3.1-flash-image",
- "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
- "gpt-oss-120b-medium": "gpt-oss-120b-medium",
- "tab_flash_lite_preview": "tab_flash_lite_preview"
- }'::jsonb
+ jsonb_set(
+ jsonb_set(
+ jsonb_set(
+ credentials,
+ '{model_mapping,gemini-3.1-flash-image}',
+ '"gemini-3.1-flash-image"'::jsonb,
+ true
+ ),
+ '{model_mapping,gemini-3.1-flash-image-preview}',
+ '"gemini-3.1-flash-image"'::jsonb,
+ true
+ ),
+ '{model_mapping,gemini-3-pro-image}',
+ '"gemini-3.1-flash-image"'::jsonb,
+ true
+ ),
+ '{model_mapping,gemini-3-pro-image-preview}',
+ '"gemini-3.1-flash-image"'::jsonb,
+ true
)
WHERE platform = 'antigravity'
AND deleted_at IS NULL
diff --git a/backend/migrations/064_add_billing_usage_entry_retry_fields.sql b/backend/migrations/064_add_billing_usage_entry_retry_fields.sql
new file mode 100644
index 000000000..aebb39295
--- /dev/null
+++ b/backend/migrations/064_add_billing_usage_entry_retry_fields.sql
@@ -0,0 +1,27 @@
+-- 064_add_billing_usage_entry_retry_fields.sql
+-- Add retry-state columns for billing_usage_entries compensation worker.
+
+ALTER TABLE billing_usage_entries
+ ADD COLUMN IF NOT EXISTS status SMALLINT NOT NULL DEFAULT 0,
+ ADD COLUMN IF NOT EXISTS attempt_count INTEGER NOT NULL DEFAULT 0,
+ ADD COLUMN IF NOT EXISTS next_retry_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ ADD COLUMN IF NOT EXISTS updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
+ ADD COLUMN IF NOT EXISTS last_error TEXT;
+
+-- Keep legacy rows aligned with applied flag.
+UPDATE billing_usage_entries
+SET status = CASE WHEN applied THEN 2 ELSE 0 END
+WHERE status NOT IN (0, 1, 2)
+ OR (applied = TRUE AND status <> 2)
+ OR (applied = FALSE AND status = 2);
+
+ALTER TABLE billing_usage_entries
+ DROP CONSTRAINT IF EXISTS chk_billing_usage_entries_status;
+
+ALTER TABLE billing_usage_entries
+ ADD CONSTRAINT chk_billing_usage_entries_status
+ CHECK (status IN (0, 1, 2));
+
+CREATE INDEX IF NOT EXISTS idx_billing_usage_entries_retry
+ ON billing_usage_entries (status, next_retry_at, updated_at)
+ WHERE applied = FALSE;
diff --git a/deploy/Caddyfile b/deploy/Caddyfile
index b643fe9b8..cbd762b1c 100644
--- a/deploy/Caddyfile
+++ b/deploy/Caddyfile
@@ -30,6 +30,36 @@ api.sub2api.com {
# =========================================================================
# 反向代理配置
# =========================================================================
+ # OpenAI Responses(含 WebSocket/SSE)单独代理策略:
+ # 1) flush_interval -1:尽快转发流式分片,降低中间层缓冲导致的断流概率
+ # 2) versions 1.1:确保上游走标准 HTTP/1.1 Upgrade,避免协议协商差异
+ # 3) stream_timeout/stream_close_delay:为长连接提供更宽松生命周期
+ @openai_responses {
+ path /openai/v1/responses*
+ }
+ reverse_proxy @openai_responses localhost:8080 {
+ # 长连接/流式场景建议关闭代理缓冲
+ flush_interval -1
+ # 长连接超时窗口(避免长会话被代理层过早回收)
+ stream_timeout 24h
+ # 配置热重载时,给现有流预留关闭缓冲期
+ stream_close_delay 5m
+
+ # 传递真实客户端信息
+ header_up X-Real-IP {remote_host}
+ header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP}
+
+ transport http {
+ # WebSocket Upgrade 对上游统一使用 HTTP/1.1,更稳妥
+ versions 1.1
+ keepalive 120s
+ keepalive_idle_conns 256
+ read_buffer 32KB
+ write_buffer 32KB
+ compression off
+ }
+ }
+
reverse_proxy localhost:8080 {
# 健康检查
health_uri /health
@@ -45,9 +75,6 @@ api.sub2api.com {
# 传递真实客户端信息
# 兼容 Cloudflare 和直连:后端应优先读取 CF-Connecting-IP,其次 X-Real-IP
header_up X-Real-IP {remote_host}
- header_up X-Forwarded-For {remote_host}
- header_up X-Forwarded-Proto {scheme}
- header_up X-Forwarded-Host {host}
# 保留 Cloudflare 原始头(如果存在)
# 后端获取 IP 的优先级建议: CF-Connecting-IP → X-Real-IP → X-Forwarded-For
header_up CF-Connecting-IP {http.request.header.CF-Connecting-IP}
diff --git a/deploy/Dockerfile b/deploy/Dockerfile
index b33203009..c9fcf3017 100644
--- a/deploy/Dockerfile
+++ b/deploy/Dockerfile
@@ -7,7 +7,7 @@
# =============================================================================
ARG NODE_IMAGE=node:24-alpine
-ARG GOLANG_IMAGE=golang:1.25.5-alpine
+ARG GOLANG_IMAGE=golang:1.25.7-alpine
ARG ALPINE_IMAGE=alpine:3.20
ARG GOPROXY=https://goproxy.cn,direct
ARG GOSUMDB=sum.golang.google.cn
diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml
index e2eb3130c..a30a0ed79 100644
--- a/deploy/config.example.yaml
+++ b/deploy/config.example.yaml
@@ -207,10 +207,10 @@ gateway:
openai_passthrough_allow_timeout_headers: false
# OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP)
openai_ws:
- # 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。
- mode_router_v2_enabled: false
- # ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效)
- ingress_mode_default: shared
+ # 新版 WS mode 路由(默认开启)。关闭时保持当前 legacy 实现行为。
+ mode_router_v2_enabled: true
+ # ingress 默认模式:off|ctx_pool(仅 mode_router_v2_enabled=true 生效)
+ ingress_mode_default: ctx_pool
# 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由
enabled: true
# 按账号类型细分开关
@@ -248,6 +248,10 @@ gateway:
write_timeout_seconds: 120
pool_target_utilization: 0.7
queue_limit_per_conn: 64
+ # 上游 WebSocket 连接最大存活时间(秒)。
+ # OpenAI 在 60 分钟后强制断开连接,此参数控制主动轮换阈值。
+ # 默认 3300(55 分钟);设为 0 则禁用超龄轮换。
+ upstream_conn_max_age_seconds: 3300
# 流式写出批量 flush 参数
event_flush_batch_size: 1
event_flush_interval_ms: 10
@@ -282,6 +286,24 @@ gateway:
queue: 0.7
error_rate: 0.8
ttft: 0.5
+ # OpenAI HTTP upstream protocol strategy
+ # OpenAI HTTP 上游协议策略
+ openai_http2:
+ # Enable OpenAI HTTP/2 preference (default on)
+ # 启用 OpenAI HTTP/2 优先策略(默认开启)
+ enabled: true
+ # Allow fallback to HTTP/1.1 for incompatible HTTP proxies
+ # 当 HTTP 代理不兼容时允许回退到 HTTP/1.1
+ allow_proxy_fallback_to_http1: true
+ # Fallback triggers after N HTTP/2 compatibility errors within window
+ # 在窗口期内累计 N 次 HTTP/2 兼容错误后触发回退
+ fallback_error_threshold: 2
+ # Error counting window (seconds)
+ # 错误计数窗口(秒)
+ fallback_window_seconds: 60
+ # How long to stay in HTTP/1.1 fallback mode (seconds)
+ # 进入 HTTP/1.1 回退态后的持续时间(秒)
+ fallback_ttl_seconds: 600
# HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults)
# HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值)
# Max idle connections across all hosts
@@ -316,6 +338,15 @@ gateway:
# SSE max line size in bytes (default: 40MB)
# SSE 单行最大字节数(默认 40MB)
max_line_size: 41943040
+ # Usage record worker pool (bounded queue)
+ # 使用量记录异步池(有界队列)
+ usage_record:
+ # queue overflow policy: drop/sample/sync
+ # 队列溢出策略:drop/sample/sync
+ overflow_policy: sync
+ # only used when overflow_policy=sample
+ # 仅在 overflow_policy=sample 时生效
+ overflow_sample_percent: 10
# Log upstream error response body summary (safe/truncated; does not log request content)
# 记录上游错误响应体摘要(安全/截断;不记录请求内容)
log_upstream_error_body: true
diff --git a/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts b/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts
new file mode 100644
index 000000000..f6af1d4cf
--- /dev/null
+++ b/frontend/src/api/__tests__/settings.bulkEditTemplates.spec.ts
@@ -0,0 +1,184 @@
+import { beforeEach, describe, expect, it, vi } from 'vitest'
+import {
+ deleteBulkEditTemplate,
+ getBulkEditTemplates,
+ getBulkEditTemplateVersions,
+ rollbackBulkEditTemplate,
+ upsertBulkEditTemplate
+} from '../admin/bulkEditTemplates'
+import { apiClient } from '../client'
+
+vi.mock('../client', () => ({
+ apiClient: {
+ get: vi.fn(),
+ post: vi.fn(),
+ delete: vi.fn()
+ }
+}))
+
+describe('admin settings bulk-edit templates api', () => {
+ beforeEach(() => {
+ vi.clearAllMocks()
+ })
+
+ it('requests template list with expected query params', async () => {
+ ;(apiClient.get as any).mockResolvedValue({
+ data: {
+ items: [
+ {
+ id: 'tpl-1',
+ name: 'Template',
+ scope_platform: 'openai',
+ scope_type: 'oauth',
+ share_scope: 'team',
+ group_ids: [],
+ state: {},
+ created_by: 1,
+ updated_by: 1,
+ created_at: 1,
+ updated_at: 2
+ }
+ ]
+ }
+ })
+
+ const items = await getBulkEditTemplates({
+ scope_platform: 'openai',
+ scope_type: 'oauth',
+ scope_group_ids: [3, 9]
+ })
+
+ expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates', {
+ params: {
+ scope_platform: 'openai',
+ scope_type: 'oauth',
+ scope_group_ids: '3,9'
+ }
+ })
+ expect(items).toHaveLength(1)
+ expect(items[0].id).toBe('tpl-1')
+ })
+
+ it('returns empty list when response items is invalid', async () => {
+ ;(apiClient.get as any).mockResolvedValue({ data: { items: null } })
+ const items = await getBulkEditTemplates({})
+ expect(items).toEqual([])
+ })
+
+ it('posts upsert payload and returns saved template', async () => {
+ ;(apiClient.post as any).mockResolvedValue({
+ data: {
+ id: 'tpl-2',
+ name: 'Shared',
+ scope_platform: 'openai',
+ scope_type: 'apikey',
+ share_scope: 'groups',
+ group_ids: [1],
+ state: { enableProxy: true },
+ created_by: 2,
+ updated_by: 2,
+ created_at: 10,
+ updated_at: 11
+ }
+ })
+
+ const saved = await upsertBulkEditTemplate({
+ id: 'tpl-2',
+ name: 'Shared',
+ scope_platform: 'openai',
+ scope_type: 'apikey',
+ share_scope: 'groups',
+ group_ids: [1],
+ state: { enableProxy: true }
+ })
+
+ expect(apiClient.post).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates', {
+ id: 'tpl-2',
+ name: 'Shared',
+ scope_platform: 'openai',
+ scope_type: 'apikey',
+ share_scope: 'groups',
+ group_ids: [1],
+ state: { enableProxy: true }
+ })
+ expect(saved.id).toBe('tpl-2')
+ })
+
+ it('requests template versions with scope group params', async () => {
+ ;(apiClient.get as any).mockResolvedValue({
+ data: {
+ items: [
+ {
+ version_id: 'ver-1',
+ share_scope: 'team',
+ group_ids: [],
+ state: { enableOpenAIWSMode: true },
+ updated_by: 11,
+ updated_at: 100
+ }
+ ]
+ }
+ })
+
+ const items = await getBulkEditTemplateVersions('tpl-2', { scope_group_ids: [5, 8] })
+
+ expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-2/versions', {
+ params: {
+ scope_group_ids: '5,8'
+ }
+ })
+ expect(items).toHaveLength(1)
+ expect(items[0].version_id).toBe('ver-1')
+ })
+
+ it('returns empty versions list when payload is invalid', async () => {
+ ;(apiClient.get as any).mockResolvedValue({ data: { items: undefined } })
+
+ const items = await getBulkEditTemplateVersions('tpl-any')
+
+ expect(apiClient.get).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-any/versions', {
+ params: {}
+ })
+ expect(items).toEqual([])
+ })
+
+ it('posts rollback request with optional query params', async () => {
+ ;(apiClient.post as any).mockResolvedValue({
+ data: {
+ id: 'tpl-3',
+ name: 'Rollbacked',
+ scope_platform: 'openai',
+ scope_type: 'oauth',
+ share_scope: 'private',
+ group_ids: [],
+ state: { enableOpenAIPassthrough: false },
+ created_by: 1,
+ updated_by: 2,
+ created_at: 10,
+ updated_at: 12
+ }
+ })
+
+ const saved = await rollbackBulkEditTemplate(
+ 'tpl-3',
+ { version_id: 'ver-2' },
+ { scope_group_ids: [2] }
+ )
+
+ expect(apiClient.post).toHaveBeenCalledWith(
+ '/admin/settings/bulk-edit-templates/tpl-3/rollback',
+ { version_id: 'ver-2' },
+ { params: { scope_group_ids: '2' } }
+ )
+ expect(saved.id).toBe('tpl-3')
+ })
+
+ it('calls delete endpoint for template removal', async () => {
+ ;(apiClient.delete as any).mockResolvedValue({ data: { deleted: true } })
+
+ const result = await deleteBulkEditTemplate('tpl-9')
+
+ expect(apiClient.delete).toHaveBeenCalledWith('/admin/settings/bulk-edit-templates/tpl-9')
+ expect(result).toEqual({ deleted: true })
+ })
+})
diff --git a/frontend/src/api/admin/bulkEditTemplates.ts b/frontend/src/api/admin/bulkEditTemplates.ts
new file mode 100644
index 000000000..45c5e8ac1
--- /dev/null
+++ b/frontend/src/api/admin/bulkEditTemplates.ts
@@ -0,0 +1,129 @@
+import { apiClient } from '../client'
+import type { AccountPlatform, AccountType } from '@/types'
+
+export type BulkEditTemplateShareScope = 'private' | 'team' | 'groups'
+
+export interface BulkEditTemplateRecord> {
+ id: string
+ name: string
+ scope_platform: AccountPlatform | ''
+ scope_type: AccountType | ''
+ share_scope: BulkEditTemplateShareScope
+ group_ids: number[]
+ state: TState
+ created_by: number
+ updated_by: number
+ created_at: number
+ updated_at: number
+}
+
+export interface BulkEditTemplateVersionRecord> {
+ version_id: string
+ share_scope: BulkEditTemplateShareScope
+ group_ids: number[]
+ state: TState
+ updated_by: number
+ updated_at: number
+}
+
+export interface GetBulkEditTemplatesParams {
+ scope_platform?: AccountPlatform | ''
+ scope_type?: AccountType | ''
+ scope_group_ids?: number[]
+}
+
+export interface GetBulkEditTemplateVersionsParams {
+ scope_group_ids?: number[]
+}
+
+export interface UpsertBulkEditTemplateRequest> {
+ id?: string
+ name: string
+ scope_platform: AccountPlatform | ''
+ scope_type: AccountType | ''
+ share_scope: BulkEditTemplateShareScope
+ group_ids: number[]
+ state: TState
+}
+
+export interface RollbackBulkEditTemplateRequest {
+ version_id: string
+}
+
+export async function getBulkEditTemplates>(
+ params: GetBulkEditTemplatesParams
+): Promise[]> {
+ const query: Record = {}
+ if (params.scope_platform) query.scope_platform = params.scope_platform
+ if (params.scope_type) query.scope_type = params.scope_type
+ if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) {
+ query.scope_group_ids = params.scope_group_ids.join(',')
+ }
+
+ const { data } = await apiClient.get<{ items: BulkEditTemplateRecord[] }>(
+ '/admin/settings/bulk-edit-templates',
+ { params: query }
+ )
+ return Array.isArray(data.items) ? data.items : []
+}
+
+export async function getBulkEditTemplateVersions>(
+ templateID: string,
+ params: GetBulkEditTemplateVersionsParams = {}
+): Promise[]> {
+ const query: Record = {}
+ if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) {
+ query.scope_group_ids = params.scope_group_ids.join(',')
+ }
+
+ const { data } = await apiClient.get<{ items: BulkEditTemplateVersionRecord[] }>(
+ `/admin/settings/bulk-edit-templates/${templateID}/versions`,
+ { params: query }
+ )
+ return Array.isArray(data.items) ? data.items : []
+}
+
+export async function upsertBulkEditTemplate>(
+ request: UpsertBulkEditTemplateRequest
+): Promise> {
+ const { data } = await apiClient.post>(
+ '/admin/settings/bulk-edit-templates',
+ request
+ )
+ return data
+}
+
+export async function rollbackBulkEditTemplate>(
+ templateID: string,
+ request: RollbackBulkEditTemplateRequest,
+ params: GetBulkEditTemplateVersionsParams = {}
+): Promise> {
+ const query: Record = {}
+ if (Array.isArray(params.scope_group_ids) && params.scope_group_ids.length > 0) {
+ query.scope_group_ids = params.scope_group_ids.join(',')
+ }
+
+ const { data } = await apiClient.post>(
+ `/admin/settings/bulk-edit-templates/${templateID}/rollback`,
+ request,
+ { params: query }
+ )
+ return data
+}
+
+export async function deleteBulkEditTemplate(templateID: string): Promise<{ deleted: boolean }> {
+ const { data } = await apiClient.delete<{ deleted: boolean }>(
+ `/admin/settings/bulk-edit-templates/${templateID}`
+ )
+ return data
+}
+
+const bulkEditTemplatesAPI = {
+ getBulkEditTemplates,
+ getBulkEditTemplateVersions,
+ upsertBulkEditTemplate,
+ rollbackBulkEditTemplate,
+ deleteBulkEditTemplate
+}
+
+export default bulkEditTemplatesAPI
diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts
index 5db998e57..a2c82ecbf 100644
--- a/frontend/src/api/admin/index.ts
+++ b/frontend/src/api/admin/index.ts
@@ -12,6 +12,7 @@ import redeemAPI from './redeem'
import promoAPI from './promo'
import announcementsAPI from './announcements'
import settingsAPI from './settings'
+import bulkEditTemplatesAPI from './bulkEditTemplates'
import systemAPI from './system'
import subscriptionsAPI from './subscriptions'
import usageAPI from './usage'
@@ -21,7 +22,6 @@ import userAttributesAPI from './userAttributes'
import opsAPI from './ops'
import errorPassthroughAPI from './errorPassthrough'
import dataManagementAPI from './dataManagement'
-import apiKeysAPI from './apiKeys'
/**
* Unified admin API object for convenient access
@@ -36,6 +36,7 @@ export const adminAPI = {
promo: promoAPI,
announcements: announcementsAPI,
settings: settingsAPI,
+ bulkEditTemplates: bulkEditTemplatesAPI,
system: systemAPI,
subscriptions: subscriptionsAPI,
usage: usageAPI,
@@ -44,8 +45,7 @@ export const adminAPI = {
userAttributes: userAttributesAPI,
ops: opsAPI,
errorPassthrough: errorPassthroughAPI,
- dataManagement: dataManagementAPI,
- apiKeys: apiKeysAPI
+ dataManagement: dataManagementAPI
}
export {
@@ -58,6 +58,7 @@ export {
promoAPI,
announcementsAPI,
settingsAPI,
+ bulkEditTemplatesAPI,
systemAPI,
subscriptionsAPI,
usageAPI,
@@ -66,8 +67,7 @@ export {
userAttributesAPI,
opsAPI,
errorPassthroughAPI,
- dataManagementAPI,
- apiKeysAPI
+ dataManagementAPI
}
export default adminAPI
diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue
index 30c3d7390..c1cff2993 100644
--- a/frontend/src/components/account/BulkEditAccountModal.vue
+++ b/frontend/src/components/account/BulkEditAccountModal.vue
@@ -1,7 +1,7 @@
@@ -21,18 +21,279 @@
-
-
-
-
- {{ t('admin.accounts.bulkEdit.mixedPlatformWarning', { platforms: selectedPlatforms.join(', ') }) }}
-
+
+ {{ t('admin.accounts.filters.platform') }}: {{ scopePlatformLabel }} /
+ {{ t('admin.accounts.filters.type') }}: {{ scopeTypeLabel }}
+
+
+
+
+
+
+ {{ t('admin.accounts.bulkEdit.templateTitle') }}
+
+
+ {{
+ t('admin.accounts.bulkEdit.templateScopeHint', {
+ platform: scopePlatformLabel,
+ type: scopeTypeLabel
+ })
+ }}
+
+
+ {{ t('admin.accounts.bulkEdit.templateLoading') }}
+
+
+
+ {{ scopedTemplateRecords.length }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.bulkEdit.templateShareGroupsHint') }}
+
+
+
+
+
+
+
+
+ {{ t('admin.accounts.bulkEdit.templateVersionTitle') }}
+
+
+ {{ t('admin.accounts.bulkEdit.templateVersionHint') }}
+
+
+
+ {{ templateVersionRecords.length }}
+
+
+
+
+ {{ t('admin.accounts.bulkEdit.templateVersionLoading') }}
+
+
+
+ {{ t('admin.accounts.bulkEdit.templateVersionEmpty') }}
+
+
+
+ -
+
+
+ {{ formatTemplateVersionUpdatedAt(version.updatedAt) }}
+
+
+ {{ resolveTemplateShareScopeLabel(version.shareScope) }}
+ · {{ version.groupIDs.join(', ') }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
-
+
@@ -79,14 +78,14 @@ import AppLayout from '@/components/layout/AppLayout.vue'; import Pagination fro
import UsageStatsCards from '@/components/admin/usage/UsageStatsCards.vue'; import UsageFilters from '@/components/admin/usage/UsageFilters.vue'
import UsageTable from '@/components/admin/usage/UsageTable.vue'; import UsageExportProgress from '@/components/admin/usage/UsageExportProgress.vue'
import UsageCleanupDialog from '@/components/admin/usage/UsageCleanupDialog.vue'
-import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'; import GroupDistributionChart from '@/components/charts/GroupDistributionChart.vue'; import TokenUsageTrend from '@/components/charts/TokenUsageTrend.vue'
+import ModelDistributionChart from '@/components/charts/ModelDistributionChart.vue'; import TokenUsageTrend from '@/components/charts/TokenUsageTrend.vue'
import Icon from '@/components/icons/Icon.vue'
-import type { AdminUsageLog, TrendDataPoint, ModelStat, GroupStat } from '@/types'; import type { AdminUsageStatsResponse, AdminUsageQueryParams } from '@/api/admin/usage'
+import type { AdminUsageLog, TrendDataPoint, ModelStat } from '@/types'; import type { AdminUsageStatsResponse, AdminUsageQueryParams } from '@/api/admin/usage'
const { t } = useI18n()
const appStore = useAppStore()
const usageStats = ref(null); const usageLogs = ref([]); const loading = ref(false); const exporting = ref(false)
-const trendData = ref([]); const modelStats = ref([]); const groupStats = ref([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day')
+const trendData = ref([]); const modelStats = ref([]); const chartsLoading = ref(false); const granularity = ref<'day' | 'hour'>('day')
let abortController: AbortController | null = null; let exportAbortController: AbortController | null = null
const exportProgress = reactive({ show: false, progress: 0, current: 0, total: 0, estimatedTime: '' })
const cleanupDialogVisible = ref(false)
diff --git a/frontend/src/views/user/PurchaseSubscriptionView.vue b/frontend/src/views/user/PurchaseSubscriptionView.vue
index fdcd0d34e..55bcf3078 100644
--- a/frontend/src/views/user/PurchaseSubscriptionView.vue
+++ b/frontend/src/views/user/PurchaseSubscriptionView.vue
@@ -1,6 +1,30 @@
+
+
+
+ {{ t('purchase.title') }}
+
+
+ {{ t('purchase.description') }}
+
+
+
+
+
+
+
diff --git a/openspec/changes/2026-02-26-openai-http-path-extreme-performance/proposal.md b/openspec/changes/2026-02-26-openai-http-path-extreme-performance/proposal.md
deleted file mode 100644
index b91a15f97..000000000
--- a/openspec/changes/2026-02-26-openai-http-path-extreme-performance/proposal.md
+++ /dev/null
@@ -1,206 +0,0 @@
-## Why
-
-OpenAI `/v1/responses` HTTP 转发路径(包括 SSE 流式、非流式、透传模式)是网关最核心的热路径,每个请求都会经过完整的 handler → service → upstream → response 链路。经过对全部 ~12000 行 OpenAI 相关代码的逐函数审查,发现当前实现在以下层面存在可量化的非必要开销:
-
-1. **请求体处理**:非透传模式下对所有请求体执行 `json.Unmarshal → map[string]any → json.Marshal` 全量反序列化/序列化,即使只需要修改 1-2 个字段。对于携带大量 `input` 数组的 Codex 请求(常见 50-500KB),这会产生大量中间 `map/slice/interface{}` 临时对象,触发高频 GC。
-2. **SSE 流式转发**:逐行 `fmt.Fprintf` + `Flush` 导致每个 SSE 事件都触发一次系统调用;模型名替换和 usage 解析在绝大多数行上做了不必要的 JSON 解析。
-3. **账号调度**:每次调度执行 4 次候选列表遍历 + 1 次全量排序;运行时统计使用全局写锁;会话哈希使用加密级 SHA-256。
-4. **连接池锁竞争**:WS 连接池 `acquire` 在持有互斥锁期间执行清理逻辑,高并发下成为瓶颈。
-5. **Handler 层冗余**:ops 上下文被设置两次。
-
-这些问题在单请求视角下各自开销不大(μs~ms 级),但在高并发、大请求体、长流式响应的生产场景下会叠加放大,直接影响 TTFT(首 token 时间)、P95/P99 尾延迟和 GC 停顿。
-
-## What Changes
-
-本提案针对 OpenAI HTTP 转发全路径提出 13 项性能优化,按优先级分为三档。
-
-### P0:热路径核心优化(每请求直接命中)
-
-#### 1. 用 sjson 点操作替代全量 json.Unmarshal/Marshal
-
-- **问题定位**: `openai_gateway_service.go:1413-1547` `Forward()` 方法
-- **现状**: 非透传模式下,`getOpenAIRequestBodyMap()` 调用 `json.Unmarshal` 将 `body []byte` 反序列化为 `map[string]any`,修改若干字段后再 `json.Marshal` 回 `[]byte`。
-- **量化影响**: 对 200KB 请求体,`Unmarshal + Marshal` 耗时 ~2-5ms,产生 ~1000+ 次堆分配和 ~500KB 临时内存。
-- **优化方案**: 使用已引入的 `tidwall/sjson` 库做精确字节级修改:
- - `sjson.SetBytes(body, "model", mappedModel)` 替代 `reqBody["model"] = mappedModel` + `json.Marshal`
- - `sjson.DeleteBytes(body, "max_output_tokens")` 替代 `delete(reqBody, "max_output_tokens")` + `json.Marshal`
- - 仅在 `bodyModified == true` 时才执行任何操作,且无需整体反序列化
-- **保留条件**: 涉及复杂嵌套修改(如 `input` 数组内字段校正)的场景仍保留 `map[string]any` 路径作为降级
-- **预估收益**: 大请求体场景 CPU 降低 30-50%,allocs/op 降低 60%+
-
-#### 2. SSE 流式响应批量写入与智能 Flush
-
-- **问题定位**: `openai_gateway_service.go:2762-2787` `handleStreamingResponse()` 方法
-- **现状**: 每读取一行 SSE 事件都执行一次 `fmt.Fprintf(w, "%s\n", line)` + `flusher.Flush()`,在高频 token 输出时每秒可触发数百次系统调用。
-- **量化影响**: 每次 Flush 约 1-5μs 系统调用开销,100 token/s 场景下仅 Flush 就消耗 ~100-500μs/s。
-- **优化方案**:
- - 引入 `bufio.Writer` 包装 `c.Writer`,缓冲区 4KB
- - 仅在 channel 队列为空时(`len(events) == 0`)执行 Flush,实现"尽快但不过度"的语义
- - 保留 keepalive ticker Flush 逻辑不变
-- **风险控制**: 缓冲区不影响 TTFT(第一个事件仍立即 Flush),仅在后续高频事件时生效
-- **预估收益**: 高吞吐流式场景系统调用减少 50-80%
-
-#### 3. 模型名替换增加 bytes.Contains 快速门控
-
-- **问题定位**: `openai_gateway_service.go:2840-2868` `replaceModelInSSELine()` 方法
-- **现状**: 当 `needModelReplace == true` 时,对每行 SSE 事件执行两次 `gjson.Get` + 可能的 `sjson.Set`。但实际上 SSE 流中仅 `response.created`、`response.completed` 等少数事件类型包含 `model` 字段。
-- **优化方案**: 在调用 `replaceModelInSSELine` 前增加 `strings.Contains(data, fromModel)` 快速判断:
- ```go
- if needModelReplace && strings.Contains(data, mappedModel) {
- line = s.replaceModelInSSELine(line, mappedModel, originalModel)
- }
- ```
-- **预估收益**: ~90% 的 SSE 行可跳过 JSON 解析,该函数整体耗时降低 80%+
-
-#### 4. parseSSEUsageBytes 增加长度门控
-
-- **问题定位**: `openai_gateway_service.go:2887-2902`
-- **现状**: 对每行 SSE 数据执行 `bytes.Contains(data, []byte("response.completed"))` 扫描。
-- **优化方案**: `response.completed` 事件通常包含完整 usage 数据,payload 长度远大于普通 token 事件。增加短行快速跳过:
- ```go
- if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) {
- return
- }
- ```
-- **预估收益**: 减少 95%+ 行的 `bytes.Contains` 扫描开销
-
-#### 5. 消除 setOpsRequestContext 双重调用
-
-- **问题定位**: `openai_gateway_handler.go:115, 143`
-- **现状**: 第 115 行以空 model 调用一次,第 143 行解析出 model 后再调用一次。
-- **优化方案**: 删除第 115 行的调用,将所有 ops context 设置延迟到模型解析完成后一次性完成。
-- **预估收益**: 每请求减少一次 context 写入(~200ns)
-
-### P1:调度与连接池路径优化
-
-#### 6. selectByLoadBalance 合并遍历 + 使用堆替代排序
-
-- **问题定位**: `openai_account_scheduler.go:319-488`
-- **现状**: 4 次遍历候选列表 + 1 次 `sort.SliceStable` 全量排序,仅为选出 TopK 最优候选。
-- **优化方案**:
- - 将过滤、负载收集、分数计算合并为一次遍历
- - 使用 `container/heap` 维护 TopK 最小堆,时间复杂度从 O(n log n) 降至 O(n log k)
- - 当候选数 ≤ TopK 时直接跳过排序
-- **预估收益**: 20+ 账号场景调度延迟降低 40-60%
-
-#### 7. openAIAccountRuntimeStats 消除全局写锁
-
-- **问题定位**: `openai_account_scheduler.go:113-142`
-- **现状**: `report()` 使用 `sync.Mutex` 全局写锁更新 EWMA 统计,每请求完成时调用。
-- **优化方案**: 将 `map[int64]*openAIAccountRuntimeStat` 改为 `sync.Map`,每个 `openAIAccountRuntimeStat` 内部使用 `atomic.Uint64`(配合 `math.Float64bits/Float64frombits`)实现无锁 EWMA 更新:
- ```go
- type openAIAccountRuntimeStat struct {
- errorRateEWMABits atomic.Uint64
- ttftEWMABits atomic.Uint64
- hasTTFT atomic.Bool
- }
- ```
-- **风险**: CAS 循环可能在极端竞争下略有重试,但远优于互斥锁
-- **预估收益**: 消除每请求一次的全局锁竞争
-
-#### 8. GenerateSessionHash 使用非加密哈希
-
-- **问题定位**: `openai_gateway_service.go:842-860`
-- **现状**: 对 session_id 执行 `crypto/sha256.Sum256` + `hex.EncodeToString`。会话哈希仅用作缓存 key,不需要抗碰撞。
-- **优化方案**: 使用 `hash/fnv` 或 `xxhash`(第三方)替代 SHA-256。FNV-128 在短字符串上比 SHA-256 快 5-10 倍:
- ```go
- h := fnv.New128a()
- h.Write([]byte(sessionID))
- return hex.EncodeToString(h.Sum(nil))
- ```
-- **预估收益**: 每请求节省 ~0.5-1μs
-
-#### 9. WS 连接池 acquire 清理逻辑外移
-
-- **问题定位**: `openai_ws_pool.go:746-751`
-- **现状**: `acquire()` 在持有 `ap.mu.Lock()` 期间按时间间隔触发 `cleanupAccountLocked()`,内部遍历所有连接、排序空闲连接、执行驱逐。
-- **优化方案**: 已有 `runBackgroundCleanupWorker(30s)` 后台清理机制,将 `acquire()` 中的按需清理移除(或大幅延长触发间隔到 30s+),让清理完全由后台 worker 负责:
- ```go
- // acquire() 中删除以下代码块:
- // if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval {
- // evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns)
- // ap.lastCleanupAt = now
- // }
- ```
-- **风险**: 极端场景下过期连接可能多存活最多 30s,但 health check 和 maxAge 检查会在使用时兜底
-- **预估收益**: acquire 锁持有时间降低 30-50%(尤其在连接数多时)
-
-### P2:低影响但零风险优化
-
-#### 10. io.ReadAll 预分配 buffer
-
-- **问题定位**: `openai_gateway_handler.go:100`
-- **现状**: `io.ReadAll(c.Request.Body)` 未利用 `Content-Length` 做容量预估,较大请求体下会发生多次扩容与拷贝。
-- **优化方案**: 根据 `Content-Length` 预分配:
- ```go
- size := c.Request.ContentLength
- if size <= 0 || size > maxBodySize { size = 512 }
- buf := bytes.NewBuffer(make([]byte, 0, size))
- _, err := io.Copy(buf, io.LimitReader(c.Request.Body, maxBodySize))
- body := buf.Bytes()
- ```
-- **预估收益**: 大请求体减少 3-5 次 slice grow,节省 ~100μs + 减少碎片
-
-#### 11. nextConnID 避免 fmt.Sprintf
-
-- **问题定位**: `openai_ws_pool.go:1278-1281`
-- **现状**: `fmt.Sprintf("oa_ws_%d_%d", accountID, seq)` 使用反射机制。
-- **优化方案**: 使用 `strconv.AppendInt` 手动拼接:
- ```go
- buf := make([]byte, 0, 32)
- buf = append(buf, "oa_ws_"...)
- buf = strconv.AppendInt(buf, accountID, 10)
- buf = append(buf, '_')
- buf = strconv.AppendUint(buf, seq, 10)
- return string(buf)
- ```
-- **预估收益**: ~50-100ns/次
-
-#### 12. handleStreamingResponse 条件性省略 goroutine
-
-- **问题定位**: `openai_gateway_service.go:2637-2662`
-- **现状**: 所有流式请求都创建读取 goroutine + channel 用于超时/keepalive 监控。
-- **优化方案**: 当 `streamInterval == 0 && keepaliveInterval == 0` 时(即未配置超时和 keepalive),退化为主 goroutine 同步读取,省去 goroutine 调度和 channel 同步开销。
-- **预估收益**: 无超时配置场景每请求省 ~2-5μs goroutine 创建/调度开销
-
-#### 13. listSchedulableAccounts 确认本地缓存 TTL
-
-- **问题定位**: `openai_gateway_service.go:1265-1283` 及 `SchedulerSnapshotService` 实现
-- **现状**: 每次调度调用 `listSchedulableAccounts()`,依赖 `schedulerSnapshot` 提供缓存。需确认 snapshot 内部有短 TTL 内存缓存(建议 1-5s)。
-- **优化方案**: 审查并确认 `SchedulerSnapshotService.ListSchedulableAccounts()` 内部确实有 local cache;若无,增加 1-3s TTL 的 `sync.Map` 或 `atomic.Value` 缓存。
-- **预估收益**: 降低每请求 Redis 读取压力,并进一步减少缓存抖动时的回源概率
-
-## Deferred(本次不改,后续独立推进)
-
-- **WS Forwarder 双向消息代理热路径优化**:`openai_ws_forwarder.go` 的事件循环(~2800 行)需要独立分析,已在 `openai-ws-v2-performance` spec 中跟踪。
-- **HTTP Transport 连接池参数调优**:`httpUpstream.Do()` 底层的 `http.Transport` 参数(MaxIdleConnsPerHost、IdleConnTimeout 等)需要结合压测数据决定,不在本提案范围。
-- **gjson → sonic/jsoniter 替代**:全局切换 JSON 库影响面太大,需独立评估兼容性。
-- **账号列表排序结果缓存**:调度器排序结果短 TTL 缓存涉及一致性语义,需独立设计。
-
-## Capabilities
-
-### Modified Capabilities
-
-- `openai-oauth-performance`:本提案扩展 HTTP 转发路径的性能约束,与已有 OAuth 性能 spec 互补。
-
-### New Capabilities(建议新增至 spec)
-
-- **HTTP 转发路径请求体处理**:系统 MUST 避免在仅需修改少量字段时对完整请求体执行全量反序列化/序列化。
-- **SSE 流式转发写入效率**:系统 MUST 在 SSE 流式转发中使用批量写入策略,避免逐事件触发系统调用。
-- **调度器时间复杂度约束**:系统 SHALL 以 O(n) 或 O(n log k) 复杂度完成账号选择,不得在热路径执行全量排序。
-
-## Impact
-
-- **影响模块**:
- - `backend/internal/handler/openai_gateway_handler.go` — P0#5
- - `backend/internal/service/openai_gateway_service.go` — P0#1, P0#2, P0#3, P0#4, P1#8, P2#10, P2#13
- - `backend/internal/service/openai_account_scheduler.go` — P1#6, P1#7
- - `backend/internal/service/openai_ws_pool.go` — P1#9, P2#11
- - `backend/internal/service/scheduler_snapshot_service.go` — P2#13
-- **影响类型**: 热路径 CPU / GC / 系统调用开销、调度延迟、连接池锁竞争。
-- **API 兼容性**: 对外 API 路由与协议完全不变,零 Breaking Change。所有优化均为内部实现级别。
-- **风险等级**: 低。P0 优化均有明确的快速路径门控和降级条件;P1/P2 改动独立可控。
-- **验收标准**:
- - P0 实施后:热路径 allocs/op 降低 ≥40%,200KB 请求体 Forward 延迟降低 ≥30%
- - SSE 流式场景 Flush 系统调用减少 ≥50%
- - P1 实施后:20 账号调度延迟降低 ≥30%,运行时统计无全局锁竞争
diff --git a/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/design.md b/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/design.md
deleted file mode 100644
index e510e9dd2..000000000
--- a/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/design.md
+++ /dev/null
@@ -1,183 +0,0 @@
-## Design Overview
-
-本设计采用统一原则:
-
-- **先降复杂度,再抠常数项**:先清除 N+1 / 循环查库,再优化 JSON 热路径。
-- **语义不变,路径重构**:不改业务规则,只改执行方式。
-- **每项可观测、可回滚**:所有优化都配套指标与独立开关。
-
----
-
-## Phase P0:查询复杂度收敛(最高优先级)
-
-### P0-1 批量账号快照同步
-
-现状:`BulkUpdate` 在需要同步时逐账号调用 `syncSchedulerAccountSnapshot`,触发深查询放大。
-设计:新增批量路径。
-
-建议接口:
-
-- `accountRepo.SyncSchedulerAccountSnapshots(ctx, ids []int64) error`
-- `schedulerCache.SetAccounts(ctx, accounts []*service.Account) error`
-
-执行策略:
-
-1. `BulkUpdate` 一次性收集需同步账号。
-2. 一次 `GetByIDs` 获取快照必要信息。
-3. 一次批量写缓存。
-
-复杂度:`O(N*deep-query)` -> `O(1~2 DB + 1 cache batch)`。
-
-### P0-2 Outbox 批量账号事件
-
-现状:`handleBulkAccountEvent` 逐账号调用 `handleAccountEvent`。
-设计:改为批量消费。
-
-建议接口:
-
-- `handleBulkAccountEvent(ctx, payload)` 内部直接 `GetByIDs`
-- `rebuildByGroupAndPlatformBatch(ctx, jobs []RebuildJob)`
-
-执行策略:
-
-1. 批量加载账号。
-2. 构建 `platform+groupID` 去重集合。
-3. 批量更新 cache + 批量 rebuild。
-
-### P0-3 混合渠道检查批量化
-
-现状:批量账号更新中,每个账号都对每个 group 调 `ListByGroup`。
-设计:单次预加载索引。
-
-建议接口:
-
-- `accountRepo.ListByGroups(ctx, groupIDs []int64) (map[int64][]Account, error)`
-- 或 `groupRepo.GetGroupAccountPlatforms(ctx, groupIDs []int64) (map[int64]PlatformSet, error)`
-
-执行策略:
-
-1. 请求级预加载 group -> platform 集合。
-2. 按账号平台在内存中判冲突。
-3. 仅错误回传时补查 group 名称。
-
-### P0-4 Gemini 预检批量化
-
-现状:候选循环内逐账号分钟查询。
-设计:请求级批量 usage 预取。
-
-建议接口:
-
-- `rateLimitService.PreCheckUsageBatch(ctx, accounts []Account, requestedModel string) (map[int64]bool, error)`
-
-SQL 草案(分钟窗口):
-
-```sql
-SELECT account_id,
- SUM(CASE WHEN ... THEN 1 ELSE 0 END) AS req_count
-FROM usage_logs
-WHERE created_at >= $1 AND created_at < $2
- AND account_id = ANY($3)
-GROUP BY account_id;
-```
-
-执行策略:
-
-1. 先批量聚合 usage。
-2. 生成 `account_id -> pass/fail`。
-3. 选择器只消费内存结果。
-
-### P0-5 管理链路批量化
-
-- `ListUsers`:`GetByUserIDs` 一次查询回填。
-- `SyncUserGroupRates`:单事务批量 upsert/delete。
-- 分组校验:`ExistsByIDs` 替代循环 `GetByID`。
-
----
-
-## Phase P1:热路径 CPU/分配优化
-
-### P1-1 SSE 单事件单次解析
-
-现状:同事件多次 `Unmarshal/Marshal`。
-设计:事件门控 + 解析复用。
-
-执行策略:
-
-1. 非目标事件直接透传(不解析)。
-2. 目标事件仅解析一次。
-3. usage 提取复用同对象。
-4. 仅在字段被改写时再序列化。
-
-目标:单事件结构化解析次数 `<= 1`。
-
-### P1-2 模型映射惰性缓存
-
-现状:`GetModelMapping` 高频重建 map。
-设计:`Account` 内缓存非持久化字段。
-
-建议字段:
-
-- `cachedModelMapping map[string]string`
-- `cachedMappingReady bool`
-
-失效策略:
-
-- 当 `Credentials` 被整体替换时重置缓存。
-
----
-
-## Phase P2:Ops raw 查询收敛与降级
-
-现状:同请求重复扫描 `usage_logs/ops_error_logs`。
-设计:共享 CTE + 超时降级。
-
-执行策略:
-
-1. 统一过滤条件与时间窗 CTE。
-2. 指标尽量一次扫描产出。
-3. 百分位超时则返回基础指标与 `degraded=true`。
-
-原则:preagg 优先,raw 仅作兜底。
-
----
-
-## Observability & SLO
-
-必须新增以下埋点:
-
-- 每请求 DB 查询总数(按接口)
-- Gemini 预检 SQL 次数
-- SSE 每事件解析次数
-- Ops raw 查询耗时与降级次数
-
-发布闸门(同 proposal):
-
-- 批量更新 DB 查询下降 >= 70%
-- Gemini 预检分钟 SQL 次数 <= 1
-- SSE `allocs/op` 下降 >= 25%
-- Ops raw P95 下降 >= 30%
-
----
-
-## Risks & Mitigations
-
-1. 批量预加载导致内存峰值上升
-- 缓解:按 group 分批加载,设置单请求最大 group/account 上限。
-
-2. 模型映射缓存失效不及时
-- 缓解:封装 `SetCredentials` 统一失效;关键路径单测覆盖。
-
-3. Ops raw 合并后 SQL 复杂度提高
-- 缓解:分层构建 CTE,保留 fallback 查询路径与超时保护。
-
-4. SSE 优化导致协议兼容回归
-- 缓解:建立事件回放回归集(message_start/message_delta/error/DONE)。
-
----
-
-## Rollout Plan
-
-1. 先启用 P0 并灰度 10%。
-2. 指标稳定后启用 P1。
-3. 最后启用 P2,并验证看板场景降级能力。
-4. 任一阶段越界,按子域开关即时回滚。
diff --git a/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/proposal.md b/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/proposal.md
deleted file mode 100644
index 25f6c25b8..000000000
--- a/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/proposal.md
+++ /dev/null
@@ -1,96 +0,0 @@
-## Why
-
-本提案最初覆盖 10 个性能问题。当前已完成全部实施与编译/定向测试验证,提案中的“待修复问题”已清空。
-
-## 已修复并移出提案的问题
-
-1. 批量账号更新后逐账号同步快照(N+1)。
-2. outbox 批量账号事件逐账号处理。
-3. 批量更新中的混合渠道检查重复扫描分组成员。
-4. 用户列表加载用户倍率 N+1。
-5. 用户分组倍率同步逐条 delete/upsert。
-6. 分组存在性逐条 `GetByID` 校验(create/update/bulk update 路径)。
-7. Gemini 候选筛选内逐账号分钟预检查库。
-8. Ops Dashboard raw 路径串行多条重聚合 SQL。
-9. SSE 每事件重复 JSON 解码/编码与二次 usage 解析。
-10. 模型映射每次调用都重建 map。
-
-## 当前待修复问题(提案范围)
-
-无(已全部移出)。
-
-## Problem-to-Solution Matrix(已完成)
-
-1. Gemini 预检查询放大
-- 方案:请求级批量 usage 预取(分钟+日窗口),候选循环只消费内存结果。
-
-2. Ops raw 串行重聚合
-- 方案:preagg 优先;raw 共享 CTE + 超时降级。
-
-3. SSE 热路径重复解析
-- 方案:事件门控 + 单事件单次解析 + 无改写事件直透。
-
-4. 模型映射重复构建
-- 方案:`Account` 内惰性缓存 + 凭证更新失效。
-
-## Optimal Strategy(已执行路径)
-
-### Phase R1(已完成)
-- Gemini 预检批量化(问题 1)
-- SSE 单次解析(问题 3)
-
-原因:这两项直接影响实时请求链路,收益最高。
-
-### Phase R2(已完成)
-- 模型映射惰性缓存(问题 4)
-- Ops raw 查询收敛与降级(问题 2)
-
-原因:前者降低调度热点分配,后者优化看板与峰值稳定性。
-
-## What Changes(已实施)
-
-### A. Gemini 配额预检
-- 新增批量预检接口:单请求一次聚合 usage,替换候选循环逐账号查库。
-
-### B. 网关 SSE 热路径
-- 单事件单次解析。
-- usage 提取复用解析对象。
-- 无改写事件直接透传。
-
-### C. 模型映射缓存
-- `Account` 增加 `model_mapping` 惰性缓存字段与失效逻辑。
-
-### D. Ops raw 查询
-- 共享 CTE 收敛扫描。
-- 增加超时降级,保持 preagg 优先。
-
-## Non-goals
-
-- 不改变业务语义(调度、配额、协议)。
-- 不引入强依赖数据库扩展。
-
-## Capabilities
-
-### Added Capabilities
-
-- `backend-performance-hotspots`
-
-## Impact(剩余)
-
-- 影响模块:
- - `backend/internal/service/gemini_messages_compat_service.go`
- - `backend/internal/service/ratelimit_service.go`
- - `backend/internal/service/gateway_service.go`
- - `backend/internal/service/account.go`
- - `backend/internal/repository/ops_repo_dashboard.go`
- - `backend/internal/repository/usage_log_repo.go`
-
-- 验收闸门(发布前仍需压测确认):
- - Gemini 预检分钟窗口 SQL 次数 <= 1/请求。
- - SSE 热路径 `allocs/op` 下降 >= 25%,`B/op` 下降 >= 20%。
- - Ops raw 查询 P95 下降 >= 30%,且超时可降级返回。
- - 模型映射热点调用 CPU 下降(以 pprof hot path 对比确认)。
-
-- 发布与回滚:
- - 分域开关:`gemini-precheck-batch`、`sse-single-parse`、`ops-raw-cte`、`model-mapping-cache`。
- - 任一子域越界可单独回滚。
diff --git a/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/specs/backend-performance-hotspots/spec.md b/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/specs/backend-performance-hotspots/spec.md
deleted file mode 100644
index 8c72258e4..000000000
--- a/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/specs/backend-performance-hotspots/spec.md
+++ /dev/null
@@ -1,113 +0,0 @@
-## ADDED Requirements
-
-### Requirement: 批量账号更新必须使用批量快照同步
-系统 MUST 在批量账号更新场景使用批量快照同步接口,避免逐账号深查询同步。
-
-#### Scenario: 批量状态变更触发快照同步
-- **WHEN** 一次请求更新多个账号且触发调度快照同步
-- **THEN** 系统 MUST 使用批量读取与批量写缓存路径
-- **AND** 系统 MUST NOT 对每个账号循环执行 `GetByID -> accountsToService`
-
-### Requirement: Outbox 批量账号事件必须批量消费
-系统 MUST 将批量账号 outbox 事件按批处理,并对 rebuild 目标进行去重。
-
-#### Scenario: 收到 account_ids 批量事件
-- **WHEN** outbox 事件包含多个 `account_ids`
-- **THEN** 系统 MUST 批量加载账号并批量更新缓存
-- **AND** 系统 SHALL 对 `platform + group` 维度去重 rebuild
-
-### Requirement: Gemini 预检必须避免候选循环内逐账号查库
-系统 MUST 在 Gemini 账号筛选链路中使用请求级批量 usage 预取,避免在候选循环中逐账号查询分钟配额。
-
-#### Scenario: 账号筛选执行 RPM 预检
-- **WHEN** 请求需要从候选账号中选择可用账号
-- **THEN** 系统 MUST 先执行批量 usage 预取并复用结果
-- **AND** 候选循环 MUST NOT 触发逐账号分钟窗口 SQL 查询
-
-### Requirement: Gemini 批量预检结果必须保持语义一致
-系统 MUST 确保批量预检结果与原有逐账号预检在同一时间窗口口径下语义一致。
-
-#### Scenario: 相同输入下预检结果一致
-- **WHEN** 对同一账号集合、同一时间窗口执行逐账号预检与批量预检
-- **THEN** 通过/拒绝结论 MUST 保持一致
-- **AND** 除缓存时效差异外 MUST NOT 引入额外误判
-
-### Requirement: 批量混合渠道检查必须单次预加载分组成员
-系统 MUST 在批量账号更新中的混合渠道检查里复用单次预加载的分组成员索引。
-
-#### Scenario: 批量更新含 groupIDs 且启用混合渠道检查
-- **WHEN** 一个请求内需要对多个账号执行混合渠道检查
-- **THEN** 系统 SHALL 一次性加载相关分组成员数据
-- **AND** 系统 MUST NOT 对每个账号重复调用 `ListByGroup`
-
-### Requirement: Ops Dashboard raw 查询必须收敛重复扫描
-系统 SHALL 在 raw 查询路径复用公共时间窗与过滤 CTE,降低同请求重复扫描成本。
-
-#### Scenario: raw 模式查询概览指标
-- **WHEN** 看板请求走 raw 查询模式
-- **THEN** 系统 SHALL 通过共享 CTE 收敛 `usage_logs/ops_error_logs` 重复扫描
-- **AND** 在高负载下 SHALL 提供可降级返回而非长时间阻塞
-
-### Requirement: SSE 事件处理必须避免重复 JSON 解析
-系统 MUST 在流式事件处理中做到单事件单次解析,并复用解析结果提取 usage。
-
-#### Scenario: 处理 message_start/message_delta 事件
-- **WHEN** 网关收到可解析的 SSE 事件
-- **THEN** 系统 MUST 至多执行一次结构化解析
-- **AND** usage 提取 MUST 复用该解析结果
-
-### Requirement: 用户列表倍率加载必须批量化
-系统 MUST 在用户列表场景使用批量接口加载用户分组倍率,避免 N+1 查询。
-
-#### Scenario: 列表页加载用户与倍率
-- **WHEN** 管理端请求用户列表并需要展示 `GroupRates`
-- **THEN** 系统 MUST 使用批量倍率查询接口
-- **AND** 系统 MUST NOT 对每个用户逐条查询倍率
-
-### Requirement: 用户分组倍率同步必须批量写入
-系统 MUST 在同步用户分组倍率时使用单事务批量 upsert/delete。
-
-#### Scenario: 一次请求同步多个 group 的倍率
-- **WHEN** 后端收到 `SyncUserGroupRates` 请求
-- **THEN** 系统 MUST 采用批量 SQL 完成 upsert 与 delete
-- **AND** 系统 MUST NOT 对每个 group 执行独立 SQL 往返
-
-### Requirement: 模型映射解析必须具备缓存机制
-系统 SHALL 为账号模型映射提供惰性缓存机制,避免热点路径重复构建 map。
-
-#### Scenario: 高频模型匹配调用
-- **WHEN** 同一账号在一次请求/短窗口内多次调用 `IsModelSupported` 或 `GetMappedModel`
-- **THEN** 系统 SHALL 复用已解析映射
-- **AND** 在凭证更新后 MUST 正确失效缓存
-
-### Requirement: 分组存在性校验必须支持批量查询
-系统 MUST 在账号创建/更新/批量更新场景提供分组批量存在性校验能力。
-
-#### Scenario: 请求携带多个 groupIDs
-- **WHEN** 请求需要校验多个分组是否存在
-- **THEN** 系统 MUST 使用批量存在性查询
-- **AND** 系统 MUST NOT 对每个 groupID 逐条 `GetByID`
-
-### Requirement: 性能优化必须具备可观测指标
-系统 MUST 暴露并记录与本提案对应的关键性能指标,支持发布决策与回归定位。
-
-#### Scenario: 发布前后性能对比
-- **WHEN** 团队执行性能优化发布评审
-- **THEN** 系统 MUST 提供优化前后同口径指标对比
-- **AND** 指标至少包含 DB 查询数、SSE 解析次数、raw 查询耗时、allocs/op
-
-### Requirement: 性能优化必须支持分域灰度与回滚
-系统 MUST 为各优化子域提供独立开关,并支持按子域回滚。
-
-#### Scenario: 指标越界触发回滚
-- **WHEN** 任一子域灰度期间出现关键指标越界
-- **THEN** 系统 MUST 能独立回滚该子域优化
-- **AND** 其他子域优化 SHALL 可继续保持开启
-
-### Requirement: 性能优化不得改变业务语义
-系统 MUST 在优化后保持原有调度、配额、网关协议与管理语义一致。
-
-#### Scenario: 语义回归验证
-- **WHEN** 执行回归测试(调度、配额、流式协议、管理接口)
-- **THEN** 业务行为 MUST 与优化前一致
-- **AND** 仅允许性能指标变化,不允许功能语义变化
diff --git a/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/tasks.md b/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/tasks.md
deleted file mode 100644
index 56e0a5fd3..000000000
--- a/openspec/changes/2026-02-26-optimize-backend-hotpath-performance/tasks.md
+++ /dev/null
@@ -1,28 +0,0 @@
-## M0. 提案质量门禁(已完成)
-
-- [x] M0.1 第 1 轮复审:检查“问题是否全部覆盖 + 是否有最优优先级路径”。
-- [x] M0.2 第 2 轮复审:对齐 proposal/design/tasks/spec 的验收口径。
-- [x] M0.3 第 3 轮复审:补齐发布闸门、回滚边界、可观测指标。
-
-## M1. 已完成实施(从提案问题列表移除)
-
-- [x] M1.1 批量快照同步:`BulkUpdate` 不再逐账号深查询同步。
-- [x] M1.2 outbox 批量账号事件改为批量加载与一次性分组重建。
-- [x] M1.3 批量更新混合渠道检查改为单次预加载分组成员。
-- [x] M1.4 用户列表倍率加载增加批量读取路径(`GetByUserIDs`)。
-- [x] M1.5 用户分组倍率同步改为批量 delete + 批量 upsert。
-- [x] M1.6 分组存在性校验增加批量检查路径(create/update/bulk update)。
-
-## M2. 剩余实施(已完成)
-
-- [x] M2.1 新增 `PreCheckUsageBatch`,移除 Gemini 候选循环逐账号分钟查库。
-- [x] M2.2 SSE 事件门控 + usage 单次解析复用。
-- [x] M2.3 `Account` 增加 `model_mapping` 惰性缓存与失效机制。
-- [x] M2.4 Ops raw 查询共享 CTE + 超时降级返回。
-
-## M3. 验证、灰度与回滚
-
-- [ ] M3.1 新增埋点:Gemini 预检 SQL 次数、SSE 事件解析次数、Ops raw 降级次数。
-- [ ] M3.2 在 staging 输出“优化前后”压测报告(P50/P95/P99、allocs/op、B/op、慢 SQL)。
-- [ ] M3.3 按开关分阶段灰度发布(R1 -> R2)。
-- [ ] M3.4 任一闸门未达标时执行子域级回滚并记录复盘。
diff --git a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/design.md b/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/design.md
deleted file mode 100644
index 223be7c91..000000000
--- a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/design.md
+++ /dev/null
@@ -1,152 +0,0 @@
-## Context
-
-本项目前端采用 Vue 3 + Vite。当前性能瓶颈不是单点,而是“构建分包策略 + 后台重页面组织 + 运行时重复开销”叠加导致。目标是在不改变现有业务功能与外部接口的前提下,系统性降低加载与运行时成本。
-
-约束:
-- 必须保持现有路由、接口、权限、配置语义不变
-- 必须支持灰度、可回滚
-- 必须保证新旧前端版本在滚动发布期间可共存
-
-## Goals / Non-Goals
-
-**Goals**
-- 降低首屏及核心后台页面 JS 传输与执行成本
-- 降低重页面(尤其 `AccountsView`)的首进页面开销
-- 移除运行时冗余调用与高频重计算
-- 建立前端性能发布门禁与可回滚机制
-
-**Non-Goals**
-- 不改动后端 API 契约
-- 不重写 UI 框架/设计系统
-- 不引入破坏性路由结构变更
-
-## Decisions
-
-### D1: 分包策略改为“功能边界优先”
-
-**选择**:按“关键路径/后台重功能/低频工具”拆包,避免 `vendor-misc` 与 `vendor-ui` 过度混合。
-
-**理由(最优性)**:
-- 直接减少非关键代码进入关键链路的概率
-- 比仅做 gzip/压缩参数微调收益更高且可控
-- 与 Vite 的 `manualChunks` 机制天然兼容,落地成本低
-
-### D2: `xlsx` 从共享 UI 包独立
-
-**选择**:`xlsx` 单独 chunk,仅在导出场景按需加载。
-
-**理由(最优性)**:
-- `xlsx` 体积大、使用频率低,天然适合延迟加载
-- 与 `@vueuse/core` 解耦后可避免“轻页面背重工具”
-
-### D3: 路由预取从固定邻接表升级为自适应策略
-
-**选择**:保留现有邻接表作为 `legacy`,新增 `adaptive` 模式:
-- 低带宽/低内存设备禁用重页面预取
-- 重页面仅在空闲预算充足时预取
-
-**理由(最优性)**:
-- 兼顾弱网设备体验与高端设备速度
-- 比“一刀切关闭预取”更平衡
-
-### D4: onboarding 与公告能力延迟初始化
-
-**选择**:`driver.js` 与公告详情渲染链路(`marked + DOMPurify`)改为惰性加载。
-
-**理由(最优性)**:
-- 二者均非每次访问必需,适合按需加载
-- 可显著减轻 `AppLayout/AppHeader` 公共链路
-
-### D5: `AccountsView` 全量静态依赖改按需加载
-
-**选择**:重弹窗/重图表/低频面板改 `defineAsyncComponent` + 触发时加载。
-
-**理由(最优性)**:
-- 对最大 chunk 直接见效
-- 不改变页面功能,仅改变加载时机,兼容风险低
-
-### D6: 模型白名单从“单文件硬编码”改“静态分片 + 远程增量 + 回退”
-
-**选择**:
-- 平台级模型清单先做静态分片(首选主路径)
-- 远程配置仅做增量覆盖(可选)
-- 远程失败时无条件回退静态分片快照
-
-**理由(最优性)**:
-- 静态分片可立即降包且不引入强依赖
-- 远程增量可减少后续发版频率
-- 保底回退确保升级与网络异常下可用性
-
-### D7: DataTable 大数据场景走轻量路径
-
-**选择**:
-- 将列宽测量节流并限定触发条件
-- 超阈值(如 >200 行)默认服务端排序
-
-**理由(最优性)**:
-- 避免持续 DOM 测量和全量排序
-- 保留小表场景的交互体验
-
-### D8: 模板变异排序改不可变计算
-
-**选择**:模板内 `selectedErrorCodes.sort(...)` 改为 `computed(() => [...selectedErrorCodes].sort(...))`。
-
-**理由(最优性)**:
-- 避免副作用与潜在渲染抖动
-- 改动极小,收益稳定
-
-### D9: 发布策略采用“开关 + 指标门禁 + 分阶段收敛”
-
-**选择**:所有高影响优化通过开关灰度,达门禁后再全量。
-
-**门禁建议**:
-- `/admin/accounts` 首次加载 JS 传输量下降 >= 30%
-- 首页关键链路 JS 体积不回升
-- 前端 JS 错误率相对基线上升 < 0.05%
-- 性能灰度期间无 P1 功能回归
-
-**回滚策略**:
-- 开关级回滚(分钟级)优先于代码回滚
-- 任一门禁不达标立即回退对应开关
-
-### D10: 灰度开关采用“运行时配置优先,编译时默认兜底”
-
-**选择**:
-- 生产环境灰度与回滚使用 `public_settings.perf_flags`(运行时)
-- `VITE_*` 只作为开发和构建默认值兜底
-
-**理由(最优性)**:
-- 运行时开关可在不重新构建前端的前提下灰度/回滚
-- 与当前项目 `window.__APP_CONFIG__` 注入机制兼容
-- 旧前端可忽略未知配置字段,具备天然前向兼容
-
-## Compatibility Design
-
-为保证向前兼容与平滑升级,新增以下机制:
-
-1. 默认兼容:新能力默认关闭(沿用旧行为)
-2. 双路径并存:旧逻辑与新逻辑可共存一段观察窗口
-3. 失败回退:远程模型配置失败自动回退本地快照
-4. 逐步收敛:`legacy -> adaptive -> off` 或反向可切换
-5. 混部兼容:滚动发布期间“旧前端 + 新配置 / 新前端 + 旧配置”均可工作
-
-## Risks / Trade-offs
-
-- 路由预取降载可能让“次跳转瞬时速度”略降
- - 通过自适应策略与灰度阈值平衡
-- 过度拆包会增加请求数
- - 控制 chunk 数量,避免碎片化
-- 远程模型配置引入可用性依赖
- - 本地快照兜底 + 缓存 TTL
-
-## Migration Plan
-
-1. Phase 0: 基线采集与埋点(不改行为)
-2. Phase 1: 分包/异步加载/模板副作用修复
-3. Phase 2: 运行时策略优化(预取/表格/重复调用)
-4. Phase 3: 运行时灰度开关接入(`public_settings.perf_flags`)与混部验证
-5. Phase 4: 灰度门禁、全量、观察与收敛
-
-## Open Questions
-
-- 无。当前方案已收敛,按上述迁移步骤执行。
diff --git a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/proposal.md b/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/proposal.md
deleted file mode 100644
index d1c6ed7bc..000000000
--- a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/proposal.md
+++ /dev/null
@@ -1,111 +0,0 @@
-# Change: 前端性能全面优化与兼容升级治理
-
-## Why
-
-对 `frontend/` 与当前构建产物进行二次复核后,确认存在一组“高概率影响真实用户体验”的性能问题,主要集中在以下四类:
-
-1. 首屏与公共链路负担偏高(分包与公共组件加载策略)
-2. 管理后台重页面体积过大(账户页、模型映射、重弹窗)
-3. 运行时存在可避免的重复/重计算(重复鉴权、冗余触发、重排序/测量)
-4. 缺少可灰度、可回滚的前端性能发布门禁
-
-构建产物关键体积(2026-02-26 本地复核):
-- `vendor-ui-CAt8eLho.js`:430,775 B(gzip 142,558 B)
-- `AccountsView-BUXw2FOq.js`:417,365 B(gzip 98,447 B)
-- `OpsDashboard-BtQ5fWkO.js`:218,636 B(gzip 47,615 B)
-- `vendor-misc-DPWBSk0M.js`:198,727 B(gzip 70,402 B)
-
-## What Changes
-
-### A. 二次确认结果(含误报排除)
-
-| 编号 | 复核结论 | 证据 | 处理优先级 |
-|---|---|---|---|
-| FEP-01 分包粒度粗,`vendor-misc` 进入首屏预加载 | 真问题 | `frontend/vite.config.ts:72-100`,`backend/internal/web/dist/index.html:10-13` | P0 |
-| FEP-02 `@vueuse` 与 `xlsx` 混包(`vendor-ui`) | 真问题(但非“首屏阻塞”) | `frontend/vite.config.ts:85-87`,`backend/internal/web/dist/assets/vendor-ui-*.js` | P1 |
-| FEP-03 路由预取过激,包含重页面(`/admin/accounts`) | 真问题 | `frontend/src/composables/useRoutePrefetch.ts:24-35` | P0 |
-| FEP-04 `AppLayout` 全局挂载 onboarding(`driver.js`) | 真问题 | `frontend/src/components/layout/AppLayout.vue:30-43`,`frontend/src/composables/useOnboardingTour.ts:2-3` | P1 |
-| FEP-05 `AnnouncementBell` 常驻 + 挂载即拉公告 | 真问题 | `frontend/src/components/layout/AppHeader.vue:27`,`frontend/src/components/common/AnnouncementBell.vue:317-318,439-442` | P0 |
-| FEP-06 `AccountsView` 静态导入过多重组件/弹窗 | 真问题 | `frontend/src/views/admin/AccountsView.vue:254-306` | P0 |
-| FEP-07 模型白名单与预设映射超大硬编码 | 真问题 | `frontend/src/composables/useModelWhitelist.ts:5-313` | P0 |
-| FEP-08 `HomeView` 与 router 重复 `checkAuth` | 真问题(收益中等) | `frontend/src/views/HomeView.vue:474`,`frontend/src/router/index.ts:399` | P2 |
-| FEP-09 `DataTable` 存在较重测量与客户端排序开销 | 真问题 | `frontend/src/components/common/DataTable.vue:201-243,483-499` | P1 |
-| FEP-10 模板内 `sort()` 原地变异数组 | 真问题 | `frontend/src/components/account/CreateAccountModal.vue:1183`,`frontend/src/components/account/EditAccountModal.vue:332`,`frontend/src/components/account/BulkEditAccountModal.vue:372` | P1 |
-| FEP-11 `AppSidebar` 内联大量 SVG render 函数 | 真问题 | `frontend/src/components/layout/AppSidebar.vue:172-461` | P1 |
-| FEP-12 `adminSettingsStore.fetch` 双触发 | 真问题(低风险) | `frontend/src/components/layout/AppSidebar.vue:589-603`,`frontend/src/stores/adminSettings.ts:51-54` | P2 |
-
-误报排除(本提案已修正):
-- “`vendor-ui` 是首屏 `modulepreload` 关键阻塞”不成立;当前首屏预加载是 `vendor-vue/vendor-misc/vendor-i18n`。因此本项降级为“共享依赖膨胀”问题处理。
-- “`adminSettingsStore.fetch` 一定导致双网络请求”不严格成立;store 有 `loading` 保护。但重复触发仍增加无效调用与维护复杂度,保留优化。
-
-### B. 最优解决方案(按阶段)
-
-- **P0(先做)**
- - 重构 `manualChunks`:拆分 `vendor-misc`,移除非关键公共依赖的首屏绑定
- - `AccountsView` 及重弹窗改为按需异步加载
- - 预取策略从“固定邻接”改为“自适应(网络/设备/路由体积)”
- - 公告铃铛改“打开时拉取 + 轻量未读数预热”,避免挂载即重请求
- - 模型白名单改为“平台分片 + 按需加载 + 缓存兜底”
-
-- **P1(随后)**
- - onboarding(`driver.js`)改懒加载,去公共链路静态依赖
- - `DataTable` 改轻量测量策略;大数据量默认服务端排序
- - 修复模板 `sort()` 变异(改为不可变排序)
- - `AppSidebar` 图标改为统一图标组件/资源映射,减少大段内联 render 函数
-
-- **P2(收尾)**
- - 合并鉴权初始化入口,移除 `HomeView` 重复 `checkAuth`
- - 去除 `adminSettingsStore.fetch` 双入口触发,保留单入口
-
-### C. 向前兼容与平滑升级(必须项)
-
-新增灰度开关(运行时开关,默认保持旧行为,保证升级无中断):
-- `public_settings.perf_flags.prefetch_mode=legacy|adaptive|off`(默认 `legacy`)
-- `public_settings.perf_flags.announcement_lazy_enabled=false`
-- `public_settings.perf_flags.onboarding_lazy_enabled=false`
-- `public_settings.perf_flags.accounts_async_modals_enabled=false`
-- `public_settings.perf_flags.model_whitelist_remote_enabled=false`
-- `public_settings.perf_flags.datatable_lightweight_enabled=false`
-- `VITE_*` 仅作为本地开发与构建期默认值,不作为生产灰度主开关
-
-发布顺序:
-1. 先上可观测(仅埋点,不改行为)
-2. 再开小流量灰度(10%)
-3. 稳定后扩大到 50%
-4. 达门禁后全量
-
-门禁阈值(连续 24 小时满足才可进入下一阶段):
-- `/admin/accounts` 首次 JS 下载量下降 >= 30%(相对基线)
-- 前端错误率相对基线上升 < 0.05%
-- 关键流程(登录、账户增改删、公告、引导)回归用例通过率 100%
-
-回滚原则:任一门禁不达标,单开关回退,不影响其余能力。
-
-兼容矩阵:
-- 新前端 + 旧配置:按默认值走旧行为(兼容)
-- 新前端 + 新配置:按灰度开关启用新能力
-- 旧前端 + 新配置:忽略未知 `perf_flags` 字段(兼容)
-
-## Capabilities
-
-### New Capabilities
-- `frontend-bundle-optimization`
-- `frontend-runtime-performance`
-- `frontend-compatibility-rollout`
-
-## Impact
-
-- Affected specs:
- - `frontend-bundle-optimization`
- - `frontend-runtime-performance`
- - `frontend-compatibility-rollout`
-- Affected code:
- - `frontend/vite.config.ts`
- - `frontend/src/router/*`
- - `frontend/src/components/layout/*`
- - `frontend/src/components/common/*`
- - `frontend/src/views/admin/*`
- - `frontend/src/composables/*`
- - `frontend/src/stores/*`
-- 外部 API:无协议变更
-- 数据兼容:保持向前兼容;新路径全部可开关回退
diff --git a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/review-rounds.md b/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/review-rounds.md
deleted file mode 100644
index b109d49c1..000000000
--- a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/review-rounds.md
+++ /dev/null
@@ -1,95 +0,0 @@
-## frontend-performance-comprehensive 多轮复核记录
-
-### 第 1 轮:问题真实性复核(源码)
-
-结论:核心问题均为真实问题,且可定位到具体文件。
-
-已确认:
-- 分包策略过粗:`frontend/vite.config.ts:72-100`
-- 首屏预加载含 `vendor-misc`:`backend/internal/web/dist/index.html:10-13`
-- 重页面预取:`frontend/src/composables/useRoutePrefetch.ts:24-35`
-- onboarding 公共链路静态依赖:`frontend/src/components/layout/AppLayout.vue:30-43`,`frontend/src/composables/useOnboardingTour.ts:2-3`
-- 公告挂载即拉数据:`frontend/src/components/common/AnnouncementBell.vue:439-442`
-- `AccountsView` 静态引入过多重组件:`frontend/src/views/admin/AccountsView.vue:254-306`
-- 大型模型映射硬编码:`frontend/src/composables/useModelWhitelist.ts:5-313`
-- 重复鉴权:`frontend/src/views/HomeView.vue:474` 与 `frontend/src/router/index.ts:399`
-- `DataTable` 重测量/排序路径:`frontend/src/components/common/DataTable.vue:201-243,483-499`
-- 模板内 `sort()` 变异:
- - `frontend/src/components/account/CreateAccountModal.vue:1183`
- - `frontend/src/components/account/EditAccountModal.vue:332`
- - `frontend/src/components/account/BulkEditAccountModal.vue:372`
-- `AppSidebar` 大量内联 SVG render:`frontend/src/components/layout/AppSidebar.vue:172-461`
-- `adminSettingsStore.fetch` 双触发:`frontend/src/components/layout/AppSidebar.vue:589-603`
-
-### 第 2 轮:构建产物与体积复核(二次确认)
-
-结论:体积热点与源码问题一致,优化优先级成立。
-
-关键体积(本地构建产物)
-- `vendor-ui`: 430,775 B(gzip 142,558 B)
-- `AccountsView`: 417,365 B(gzip 98,447 B)
-- `OpsDashboard`: 218,636 B(gzip 47,615 B)
-- `vendor-misc`: 198,727 B(gzip 70,402 B)
-
-附加确认:
-- `vendor-ui` 中确实包含 `xlsx`(构建产物可检索到 `xlsx.js` 标识)。
-
-### 第 3 轮:误报排除与优先级校准
-
-结论:存在两处“需要降级表述”的点,已修正进提案。
-
-1. 关于 `vendor-ui`:
-- 原判断“首屏阻塞”不准确;`index.html` 首屏 `modulepreload` 未包含 `vendor-ui`。
-- 修正为“共享依赖膨胀风险”,优先级从 P0 下调为 P1。
-
-2. 关于 `adminSettingsStore.fetch`:
-- 原判断“双请求”不严格;store 的 `loading`/`loaded` 保护可避免重复请求。
-- 仍保留为低优先级问题(重复触发/可维护性)。
-
-### 第 4 轮:最优方案与兼容性复审
-
-结论:方案已收敛为“收益最大且兼容风险最低”的路径。
-
-- 采用 feature flag 渐进发布,而非一次性切换
-- 默认保持旧行为,确保向前兼容
-- 对高风险点(预取策略、模型清单来源、重页面异步化)提供单项回滚开关
-- 设定量化门禁(体积、错误率、回归)后再推进全量
-
-### 最新结论
-
-- 二次确认结果:问题真实性成立,误报已排除。
-- 最优性结论:当前提案满足“可灰度、可回滚、向前兼容、升级可控”的要求。
-
-### 第 5 轮:灰度机制可执行性复审(本次)
-
-发现问题:
-- 提案将灰度开关写为 `VITE_*`,这属于构建期变量,不适合作为生产运行时灰度与快速回滚主路径。
-
-修复动作:
-- 将开关主路径统一调整为 `public_settings.perf_flags.*`(运行时)。
-- 明确 `VITE_*` 仅作为开发/构建默认值兜底。
-- 在 `proposal.md`、`design.md`、`tasks.md`、`frontend-compatibility-rollout/spec.md` 同步修复。
-
-### 第 6 轮:最优方案收敛复审(本次)
-
-发现问题:
-- `design.md` 中“模型白名单分片或远程”表述存在策略歧义,且留有开放问题,不满足“给出最优方案并可直接执行”的要求。
-
-修复动作:
-- 收敛为“静态分片主路径 + 远程增量覆盖 + 失败回退静态快照”的确定性方案。
-- 移除不必要歧义,开放问题改为“无”,并补齐迁移阶段。
-
-### 第 7 轮:向前兼容与门禁严格性复审(本次)
-
-发现问题:
-- 原文对混部兼容(旧前端+新配置)与推进阻断条件(门禁连续性)描述不够可验证。
-
-修复动作:
-- 新增兼容矩阵(新前端+旧配置、旧前端+新配置、新前端+新配置)。
-- 新增“连续 24h 门禁阈值”与“不达标禁止推进”要求。
-- 在 `tasks.md` 增加混部兼容验证任务与量化门禁任务。
-
-### 复审后结论
-
-- 已完成 3 轮增量复审并修复全部发现问题。
-- 当前提案具备:问题真实性证据、确定性最优方案、运行时灰度回滚能力、向前兼容与混部可验证门禁。
diff --git a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-bundle-optimization/spec.md b/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-bundle-optimization/spec.md
deleted file mode 100644
index e7bd31eb8..000000000
--- a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-bundle-optimization/spec.md
+++ /dev/null
@@ -1,29 +0,0 @@
-## ADDED Requirements
-
-### Requirement: Frontend Bundle Boundary Optimization
-
-Frontend 打包策略 SHALL 按“关键路径、后台重功能、低频工具”进行边界拆分,避免低频重依赖进入高频公共链路。
-
-#### Scenario: 首屏不加载低频重工具
-- **WHEN** 用户首次进入首页或常规仪表盘
-- **THEN** 构建产物不得在首屏关键链路中加载低频重工具(如 `xlsx`)
-- **AND** 与当前首屏路由无关的非关键 vendor 包不得通过首屏 `modulepreload` 强制进入
-
-#### Scenario: 后台重页面按需加载重能力
-- **WHEN** 用户未打开账户管理相关重弹窗/重图表
-- **THEN** 对应组件代码不得提前随页面主 chunk 一并加载
-- **AND** 仅在用户触发后再按需加载
-
-### Requirement: Model Whitelist Payload Decoupling
-
-模型白名单与预设映射 SHALL 支持分片或远程加载,并保留本地快照回退,避免单文件超大硬编码持续膨胀。
-
-#### Scenario: 远程加载失败自动回退
-- **WHEN** 远程模型配置请求失败或超时
-- **THEN** 系统 SHALL 自动回退到本地快照
-- **AND** 用户可继续完成账户配置流程
-
-#### Scenario: 平台级按需加载
-- **WHEN** 用户只操作单一平台模型配置
-- **THEN** 系统 SHALL 仅加载对应平台的模型数据
-- **AND** 不应加载全量平台模型清单
diff --git a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-compatibility-rollout/spec.md b/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-compatibility-rollout/spec.md
deleted file mode 100644
index 0dd684a58..000000000
--- a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-compatibility-rollout/spec.md
+++ /dev/null
@@ -1,48 +0,0 @@
-## ADDED Requirements
-
-### Requirement: Forward-Compatible Performance Rollout
-
-所有高影响前端性能优化 SHALL 支持开关化发布,并默认保持旧行为,以保证向前兼容和无中断升级。
-
-#### Scenario: 默认行为兼容旧版本
-- **WHEN** 新版本部署但性能开关未开启
-- **THEN** 应用行为 SHALL 与旧版本保持一致
-- **AND** 不得因优化代码引入功能语义变化
-
-#### Scenario: 灰度期间可单项回退
-- **WHEN** 某项优化在灰度中触发错误率或回归告警
-- **THEN** 系统 SHALL 支持仅回退该项开关
-- **AND** 其他已稳定优化项保持生效
-
-### Requirement: Runtime Flag Source and Backward Compatibility
-
-性能开关 SHALL 优先来自运行时公共配置,并保证旧客户端对新增字段的向前兼容。
-
-#### Scenario: 运行时配置优先于构建默认值
-- **WHEN** `public_settings.perf_flags` 与本地构建默认值同时存在
-- **THEN** 前端 SHALL 优先采用 `public_settings.perf_flags`
-- **AND** 本地构建默认值仅作为缺省兜底
-
-#### Scenario: 旧前端忽略新增配置字段
-- **WHEN** 服务端返回旧前端未知的 `perf_flags` 字段
-- **THEN** 旧前端 SHALL 忽略未知字段并保持现有行为
-- **AND** 不得出现初始化失败或页面不可用
-
-### Requirement: Metrics-Gated Progressive Enablement
-
-优化能力启用 SHALL 受指标门禁控制,不满足阈值不得推进下一阶段。
-
-#### Scenario: 指标不达标阻断全量
-- **WHEN** 灰度期间关键指标(错误率、关键页面加载)不达标
-- **THEN** 发布流程 SHALL 阻断全量
-- **AND** 自动进入回退或继续观察流程
-
-#### Scenario: 指标达标后分阶段推进
-- **WHEN** 灰度指标连续达标
-- **THEN** 允许按 10% -> 50% -> 100% 分阶段推进
-- **AND** 每阶段均需保留回退能力
-
-#### Scenario: 不满足门禁时禁止推进
-- **WHEN** 任一阶段未满足连续 24 小时门禁阈值
-- **THEN** 发布流程 SHALL 禁止进入下一阶段
-- **AND** SHALL 触发回退或继续观察
diff --git a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-runtime-performance/spec.md b/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-runtime-performance/spec.md
deleted file mode 100644
index 851e9d040..000000000
--- a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/specs/frontend-runtime-performance/spec.md
+++ /dev/null
@@ -1,41 +0,0 @@
-## ADDED Requirements
-
-### Requirement: Runtime Work Avoidance
-
-前端运行时 SHALL 避免可识别的重复调用与不必要计算,优先降低高频路径上的 CPU 与内存开销。
-
-#### Scenario: 鉴权初始化仅执行一次
-- **WHEN** 应用首次完成路由初始化
-- **THEN** 鉴权状态恢复流程 SHALL 只执行一次
-- **AND** 不应由多个页面重复触发相同初始化逻辑
-
-#### Scenario: 管理设置拉取单入口触发
-- **WHEN** 侧边栏初始化管理设置
-- **THEN** 仅允许一个入口触发 `adminSettingsStore.fetch`
-- **AND** 不得出现 watch 与 mounted 双入口重复触发
-
-#### Scenario: 大数据表避免客户端重排序
-- **WHEN** 表格数据量超过运行时阈值(例如 200 行)
-- **THEN** 前端 SHALL 使用服务端排序或等效轻量策略
-- **AND** 不得在每次渲染周期执行全量客户端排序
-
-### Requirement: Lazy Initialization for Non-Critical Features
-
-非关键能力(引导、公告详情渲染)SHALL 采用惰性初始化,避免进入公共链路初始执行。
-
-#### Scenario: 引导能力按需加载
-- **WHEN** 用户未触发引导功能
-- **THEN** `driver.js` 相关代码不得在公共布局初始化阶段执行
-
-#### Scenario: 公告详情按需渲染
-- **WHEN** 用户未打开公告面板
-- **THEN** 公告详情渲染链路不应触发全量加载与渲染
-
-### Requirement: Deterministic and Side-Effect-Free Rendering
-
-模板层渲染 SHALL 避免副作用表达式,确保渲染行为稳定可预测。
-
-#### Scenario: 渲染列表不变异源数组
-- **WHEN** 组件展示已选错误码列表
-- **THEN** 组件 SHALL 使用不可变排序结果进行渲染
-- **AND** 不得在模板表达式中直接调用会变异原数组的 `sort()`
diff --git a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/tasks.md b/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/tasks.md
deleted file mode 100644
index df60eabf7..000000000
--- a/openspec/changes/2026-02-26-optimize-frontend-performance-comprehensive/tasks.md
+++ /dev/null
@@ -1,47 +0,0 @@
-## 0. 基线与门禁准备
-
-- [ ] 0.1 在 `frontend/` 执行构建并产出基线报告(chunk 原始体积 + gzip 体积 + 路由首进加载链路)
-- [ ] 0.2 增加前端性能观测指标(页面级加载耗时、路由跳转耗时、JS 错误率)
-- [ ] 0.3 定义并固化灰度门禁阈值(`/admin/accounts` 体积、错误率、回归用例)
-- [ ] 0.4 固化推进门禁(连续 24h 满足:`/admin/accounts` 首次 JS 下载量下降 >= 30%、错误率增幅 < 0.05%、关键回归用例 100% 通过)
-
-## 1. 分包与重页面优化(P0)
-
-- [ ] 1.1 `frontend/vite.config.ts`:重构 `manualChunks`,拆分 `vendor-misc` 与 `vendor-ui`,将 `xlsx` 独立为低频 chunk
-- [ ] 1.2 `frontend/src/views/admin/AccountsView.vue`:将低频重组件(弹窗/图表)改为异步加载
-- [ ] 1.3 `frontend/src/composables/useModelWhitelist.ts`:改造为平台分片加载;保留本地快照兜底
-- [ ] 1.4 `frontend/src/components/common/AnnouncementBell.vue`:移除挂载即全量拉取,改为打开时加载详情、可选轻量未读数预热
-- [ ] 1.5 重新构建并对比:确认 `AccountsView` 与公共 chunk 体积下降达到目标
-
-## 2. 运行时性能优化(P1)
-
-- [ ] 2.1 `frontend/src/composables/useRoutePrefetch.ts`:新增 `adaptive` 模式,低资源设备跳过重页面预取
-- [ ] 2.2 `frontend/src/components/layout/AppLayout.vue` + `frontend/src/composables/useOnboardingTour.ts`:引导能力改惰性初始化
-- [ ] 2.3 `frontend/src/components/common/DataTable.vue`:限制重测量触发条件并节流;大数据量默认服务端排序
-- [ ] 2.4 `frontend/src/components/layout/AppSidebar.vue`:图标定义从内联 render 函数迁移到统一图标组件/映射
-- [ ] 2.5 `frontend/src/components/account/CreateAccountModal.vue`:修复模板 `sort()` 原地变异
-- [ ] 2.6 `frontend/src/components/account/EditAccountModal.vue`:修复模板 `sort()` 原地变异
-- [ ] 2.7 `frontend/src/components/account/BulkEditAccountModal.vue`:修复模板 `sort()` 原地变异
-
-## 3. 冗余调用与兼容治理(P2)
-
-- [ ] 3.1 `frontend/src/views/HomeView.vue`:移除与 router 重复的 `authStore.checkAuth()` 初始化路径
-- [ ] 3.2 `frontend/src/components/layout/AppSidebar.vue`:保留单入口触发 `adminSettingsStore.fetch()`(去除重复触发)
-- [ ] 3.3 为关键优化增加运行时开关读取(`public_settings.perf_flags.*`),并保留 `VITE_*` 作为默认值兜底
-- [ ] 3.4 完善“新路径失败回退旧路径”的兜底逻辑(特别是模型清单加载)
-- [ ] 3.5 `backend` 公共设置接口新增可选字段 `perf_flags`(非必填、非破坏性),验证旧前端忽略未知字段
-
-## 4. 灰度发布与验收
-
-- [ ] 4.1 10% 灰度开启:`public_settings.perf_flags.accounts_async_modals_enabled=true` + `public_settings.perf_flags.announcement_lazy_enabled=true`
-- [ ] 4.2 达标后扩大到 50%,再全量;期间持续监控错误率与关键页面指标
-- [ ] 4.3 灰度稳定后启用 `public_settings.perf_flags.prefetch_mode=adaptive`
-- [ ] 4.4 完成回归测试:`pnpm --dir frontend run lint:check`、`pnpm --dir frontend run typecheck`、关键路由手工冒烟
-- [ ] 4.5 记录回滚手册(按开关逐项回退)并验证回退有效
-- [ ] 4.6 混部兼容验证:新前端+旧配置、旧前端+新配置、新前端+新配置三种组合均通过冒烟
-
-## 5. 收尾与文档
-
-- [ ] 5.1 更新性能优化文档(含“已修复问题清单、收益、风险、回滚命令”)
-- [ ] 5.2 执行 `openspec validate 2026-02-26-optimize-frontend-performance-comprehensive --strict`
-- [ ] 5.3 提交最终验收结论(是否满足“向前兼容、升级无中断”)
diff --git a/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/design.md b/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/design.md
deleted file mode 100644
index 244c7198a..000000000
--- a/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/design.md
+++ /dev/null
@@ -1,212 +0,0 @@
-## Context
-
-本设计采用“并行双路径”原则:
-
-- legacy 路径:原实现,保持不变
-- v2 路径:新增实现,由开关启用
-
-v2 路径承载以下新能力:
-
-1. WS mode 三态(`off/shared/dedicated`)
-2. 协议对称(`ws->ws`、`http->http`)
-3. dedicated 会话稳定性增强
-4. 账号并发数即连接池上限
-
-## Goals
-
-1. 新增实现不破坏原行为(默认不开启)。
-2. 明确三态 WS mode 的配置、优先级和兼容迁移。
-3. 将协议对称变为可测试的硬约束。
-4. 将账号并发数绑定为账号连接池上限。
-5. 提供 dedicated 在高并发下的容量与拒绝语义。
-
-## Non-Goals
-
-1. 不替换 legacy 实现。
-2. 不修改客户端协议格式。
-3. 不引入跨实例共享会话状态。
-
-## Configuration Design
-
-### A. New Master Switch (Legacy Isolation)
-
-新增:`gateway.openai_ws.mode_router_v2_enabled: bool`(默认 `false`)
-
-- `false`:保持原实现完整运行。
-- `true`:启用 v2 mode 路由能力。
-
-### B. Tri-Mode Configuration
-
-新增:`gateway.openai_ws.ingress_mode_default: string`(`off|shared|dedicated`,默认 `shared`)
-
-账号新增(按类型):
-
-- `accounts.extra.openai_oauth_responses_websockets_v2_mode`
-- `accounts.extra.openai_apikey_responses_websockets_v2_mode`
-
-取值均为:`off|shared|dedicated`
-
-### C. Backward Compatibility
-
-在 v2 路径中,账号模式按如下顺序解析:
-
-1. 新模式字段(`*_mode`)
-2. 旧布尔字段:
- - `openai_oauth_responses_websockets_v2_enabled`
- - `openai_apikey_responses_websockets_v2_enabled`
- - `responses_websockets_v2_enabled`
- - `openai_ws_enabled`
-3. 全局默认 `ingress_mode_default`
-
-映射规则:`true => shared`,`false => off`。
-
-## Protocol Symmetry (V2 Path)
-
-当 `mode_router_v2_enabled=true` 时,执行硬约束:
-
-1. 入站 WS:仅允许上游 WS(`ws->ws`)
-2. 入站 HTTP:仅允许上游 HTTP(`http->http`)
-3. 禁止跨协议:`ws->http`、`http->ws`
-
-失败语义:
-
-- WS 路径不 fallback 到 HTTP;直接返回可诊断 close 错误。
-- HTTP 路径不 upgrade 到 WS;保持 HTTP 内部重试逻辑。
-
-## Mode Resolution and Lifecycle
-
-### Step 1: Router Branching
-
-1. `mode_router_v2_enabled=false` => legacy
-2. `mode_router_v2_enabled=true` => v2
-
-### Step 2: Effective Mode (V2)
-
-在 v2 路径中,`effectiveMode` 受以下门禁约束:
-
-1. 全局门禁:`enabled/force_http/responses_websockets_v2`
-2. 账号类型门禁:`oauth_enabled/apikey_enabled`
-3. 账号模式解析(新字段优先,旧字段回退)
-
-### Step 3: Per Mode Behavior
-
-- `off`:拒绝 WS mode;HTTP 正常走 HTTP。
-- `shared`:复用现有共享池策略。
-- `dedicated`:每个客户端 WS 会话独占上游连接并强亲和。
-
-## Account Concurrency = Pool Max (V2)
-
-### Rule
-
-v2 路径中,账号池上限由账号并发数决定:
-
-- `max_conns_for_account = account.concurrency`
-
-约束:
-
-1. `account.concurrency <= 0` 视为不可调度(直接拒绝 WS)。
-2. `dedicated` 下活跃会话数不得超过 `account.concurrency`。
-3. `shared` 下连接总数也不得超过该上限。
-
-说明:legacy 路径继续使用当前池参数计算逻辑,不受本规则影响。
-
-## `store=false` Three-Layer Strategy (Dedicated)
-
-### Layer 1
-
-发送前治理 `previous_response_id`,减少 `previous_response_not_found` 无效往返。
-
-### Layer 2
-
-会话内强亲和,仅允许 `sessionConnID`。
-
-### Layer 3
-
-连接不可用时,剥离失效续链锚点并执行 input replay 单次恢复。
-
-## State Model (V2 Dedicated)
-
-会话态:
-
-1. `effectiveMode`
-2. `sessionConnID`
-3. `lastTurnResponseID`
-4. `lastTurnReplayInput`
-
-会话结束后全部销毁;连接标记不可复用。
-
-## High-Concurrency Model
-
-定义:
-
-- `U`: 活跃 WS 会话
-- `A`: 账号数
-- `Ci`: 第 i 个账号并发数
-- `S = ΣCi`
-
-dedicated 满足:`U <= S`。
-
-示例:50 账号 * 20 并发 => `S=1000`。
-
-若目标是 1000 同时在线会话,建议按 20% 冗余评估到 1200 级容量(否则高峰会出现拒绝)。
-
-## Observability
-
-新增指标(v2 打标):
-
-1. `openai_ws_mode_router_v2_requests_total{protocol_path,mode}`
-2. `openai_ws_protocol_symmetry_reject_total{from,to}`
-3. `openai_ws_ingress_sessions_active{mode}`
-4. `openai_ws_ingress_acquire_fail_total{mode,reason}`
-5. `openai_ws_ingress_replay_total{mode,result}`
-6. `openai_ws_account_pool_limit_hits_total{account_id}`
-
-关键日志字段:
-
-1. `router_version=legacy|v2`
-2. `ws_mode=off|shared|dedicated`
-3. `protocol_path=ws->ws|http->http`
-4. `account_concurrency`
-5. `account_pool_max`
-
-## Three-Round Proposal Review and Fixes
-
-### Review Round 1
-
-问题:早期方案会直接改现有路径,违反“新增不能改原实现”。
-
-修复:引入 `mode_router_v2_enabled`,明确 legacy/v2 双路径并存,默认 legacy。
-
-### Review Round 2
-
-问题:协议对称规则缺少作用域,可能误伤 legacy 行为。
-
-修复:将协议对称约束限定在 v2 路径;legacy 路径保持现状。
-
-### Review Round 3
-
-问题:连接池上限来源不明确(全局参数 vs 账号并发数)。
-
-修复:在 v2 路径将 `account.concurrency` 定义为账号池硬上限,并补充拒绝语义与观测指标。
-
-## Rollout Plan
-
-1. 上线但不启用 v2(开关默认 `false`)。
-2. 小流量启用 v2 + shared 观察。
-3. 再灰度 dedicated 账号组。
-4. 指标越界即回退为 legacy(仅关开关)。
-
-## Test Strategy
-
-### Unit
-
-1. `mode_router_v2_enabled=false` 与当前基线一致。
-2. 三态模式解析正确(新字段/旧字段/默认值)。
-3. 协议对称拒绝 `ws->http`、`http->ws`。
-4. v2 路径账号并发数池上限生效。
-
-### Integration
-
-1. v2 shared 与 dedicated 行为验证。
-2. dedicated 连接中断 replay 恢复验证。
-3. 1000 会话压测下拒绝率和时延验证。
diff --git a/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/proposal.md b/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/proposal.md
deleted file mode 100644
index 11558d253..000000000
--- a/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/proposal.md
+++ /dev/null
@@ -1,81 +0,0 @@
-## Why
-
-你当前的核心诉求已经明确为 4 点:
-
-1. WS mode 需要 3 种模式:`off/shared/dedicated`。
-2. 协议必须对称:只允许 `ws->ws`、`http->http`,禁止 `ws->http`、`http->ws`。
-3. 新增实现不能破坏原有实现(默认保持旧行为,按开关启用新逻辑)。
-4. 账号“并发数”字段就是该账号连接池最大连接数。
-
-现有提案仍偏向“布尔 dedicated 开关”,且对“旧实现隔离”与“并发数即池上限”的约束不够硬,需要补齐。
-
-## What Changes
-
-本提案升级为“**增量并行实现**”:新增 v2 路径,不替换旧路径。
-
-### 1. 新增总开关,确保不改旧实现
-
-新增:`gateway.openai_ws.mode_router_v2_enabled`(默认 `false`)。
-
-- `false`:100% 走原实现(legacy 路径),行为不变。
-- `true`:仅指定请求进入新模式路由(tri-mode + symmetry + dedicated 强化)。
-
-### 2. WS mode 三态
-
-新增统一模式:`off|shared|dedicated`。
-
-- 全局默认:`gateway.openai_ws.ingress_mode_default`(默认 `shared`)
-- 账号级(按类型):
- - `accounts.extra.openai_oauth_responses_websockets_v2_mode`
- - `accounts.extra.openai_apikey_responses_websockets_v2_mode`
-
-### 3. 协议对称硬约束(在 v2 路径生效)
-
-- WS 入站仅允许 `ws->ws`
-- HTTP 入站仅允许 `http->http`
-- 明确拒绝 `ws->http`、`http->ws`
-
-### 4. dedicated + store=false 三层稳定策略
-
-- Layer 1:发送前治理 `previous_response_id`
-- Layer 2:连接硬亲和(同会话固定同连接)
-- Layer 3:连接中断时 input replay 单次恢复
-
-### 5. 账号并发数即连接池上限(v2 路径)
-
-在 v2 路径中,账号连接池上限由账号并发数直接决定:
-
-- `max_conns_for_account = account.concurrency`
-- `dedicated` 模式下可同时承载的会话上限也受该值约束
-
-说明:legacy 路径维持现有全局池参数语义,不被本设计改变。
-
-## Non-Goals
-
-- 不改 HTTP/WSv1 业务协议。
-- 不移除或重写 legacy 逻辑。
-- 不在本提案实现跨实例会话绑定共享。
-
-## Capabilities
-
-### Modified Capabilities
-
-- `openai-ws-v2-performance`
-
-## Impact
-
-- 影响模块(预期):
- - `backend/internal/config/config.go`
- - `backend/internal/config/config_test.go`
- - `backend/internal/service/account.go`
- - `backend/internal/service/openai_ws_protocol_resolver.go`
- - `backend/internal/service/openai_ws_forwarder.go`
- - `backend/internal/service/openai_ws_pool.go`
- - `backend/internal/service/openai_ws_forwarder_ingress_session_test.go`
- - `frontend/src/components/account/CreateAccountModal.vue`
- - `frontend/src/components/account/EditAccountModal.vue`
- - `frontend/src/i18n/locales/zh.ts`
- - `frontend/src/i18n/locales/en.ts`
-
-- 兼容性:默认 `mode_router_v2_enabled=false`,原行为不变。
-- 风险级别:中(开启 `dedicated` 后资源消耗上升)。
diff --git a/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/specs/openai-ws-v2-performance/spec.md b/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/specs/openai-ws-v2-performance/spec.md
deleted file mode 100644
index bed15b75e..000000000
--- a/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/specs/openai-ws-v2-performance/spec.md
+++ /dev/null
@@ -1,85 +0,0 @@
-## ADDED Requirements
-
-### Requirement: New implementation MUST NOT change legacy behavior by default
-系统 MUST 以新增路径方式实现,不得默认改变既有行为。
-
-#### Scenario: v2 router disabled
-- **WHEN** `gateway.openai_ws.mode_router_v2_enabled=false`
-- **THEN** 系统 MUST 继续执行 legacy 路径
-- **AND** 行为 MUST 与当前实现保持一致
-
-### Requirement: WS mode MUST support off/shared/dedicated in v2 router
-系统 MUST 在 v2 路径支持三态 WS mode。
-
-#### Scenario: mode off
-- **WHEN** v2 路径解析模式为 `off`
-- **THEN** 系统 MUST 禁止该账号类型使用 WS mode
-
-#### Scenario: mode shared
-- **WHEN** v2 路径解析模式为 `shared`
-- **THEN** 系统 MUST 使用共享连接池语义
-
-#### Scenario: mode dedicated
-- **WHEN** v2 路径解析模式为 `dedicated`
-- **THEN** 系统 MUST 为每个客户端 WS 会话分配独占上游连接
-- **AND** 同会话内所有 turn MUST 复用该连接
-
-### Requirement: V2 router MUST be backward compatible with legacy WS flags
-系统 MUST 支持旧布尔字段映射到新三态模式。
-
-#### Scenario: legacy enabled
-- **WHEN** 新 `*_mode` 缺失且旧 `*_enabled=true`
-- **THEN** 系统 MUST 解析为 `shared`
-
-#### Scenario: legacy disabled
-- **WHEN** 新 `*_mode` 缺失且旧 `*_enabled=false`
-- **THEN** 系统 MUST 解析为 `off`
-
-#### Scenario: no account flags
-- **WHEN** 账号新旧字段均缺失
-- **THEN** 系统 MUST 使用 `gateway.openai_ws.ingress_mode_default`
-
-### Requirement: Protocol symmetry MUST be enforced in v2 router
-系统 MUST 在 v2 路径强制协议对称。
-
-#### Scenario: websocket ingress
-- **WHEN** 客户端以 WS 入站并走 v2 路径
-- **THEN** 系统 MUST 仅允许 `ws->ws`
-- **AND** MUST NOT fallback to HTTP
-
-#### Scenario: http ingress
-- **WHEN** 客户端以 HTTP 入站并走 v2 路径
-- **THEN** 系统 MUST 仅允许 `http->http`
-- **AND** MUST NOT upgrade to WS
-
-### Requirement: Account concurrency MUST define per-account pool max in v2 router
-系统 MUST 在 v2 路径将账号并发数作为该账号连接池上限。
-
-#### Scenario: positive account concurrency
-- **WHEN** 账号 `concurrency > 0`
-- **THEN** 系统 MUST 使用 `account.concurrency` 作为该账号 `max_conns`
-
-#### Scenario: non-positive account concurrency
-- **WHEN** 账号 `concurrency <= 0`
-- **THEN** 系统 MUST 拒绝该账号的 WS 调度
-- **AND** MUST 记录可观测日志与指标
-
-### Requirement: Dedicated store=false path MUST support chain governance and replay recovery
-系统 MUST 在 dedicated + store=false 路径支持前置治理与重放恢复。
-
-#### Scenario: proactive previous_response governance
-- **WHEN** `store=false` 请求包含不可信 `previous_response_id`
-- **THEN** 系统 SHALL 在发送前执行治理(如剥离)
-
-#### Scenario: dedicated connection loss
-- **WHEN** dedicated 会话中连接中断
-- **THEN** 系统 SHALL 执行一次 input replay 重建
-- **AND** 失败时 MUST 返回明确错误并提示重启会话
-
-### Requirement: V2 router MUST expose mode and symmetry observability
-系统 MUST 输出 v2 路径关键观测指标。
-
-#### Scenario: mixed traffic with v2 enabled
-- **WHEN** v2 开启且流量包含多个 mode
-- **THEN** 系统 MUST 输出按 `ws_mode` 分桶的会话与失败指标
-- **AND** MUST 输出协议对称拒绝计数与连接池上限命中计数
diff --git a/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/tasks.md b/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/tasks.md
deleted file mode 100644
index 2bab900de..000000000
--- a/openspec/changes/2026-02-27-openai-wsv2-ingress-dedicated-mode/tasks.md
+++ /dev/null
@@ -1,51 +0,0 @@
-## M0. Triple Review Gate
-
-- [ ] M0.1 第 1 轮审核:确认“新增不改旧实现”边界。
-- [ ] M0.2 第 2 轮审核:确认协议对称约束作用域与失败语义。
-- [ ] M0.3 第 3 轮审核:确认“账号并发数=池上限”规则与容量口径。
-
-## M1. Router Isolation
-
-- [ ] M1.1 新增 `gateway.openai_ws.mode_router_v2_enabled`(默认 `false`)。
-- [ ] M1.2 在入口增加 legacy/v2 分支,不删除 legacy 逻辑。
-- [ ] M1.3 增加回归测试:v2 关闭时行为与当前基线一致。
-
-## M2. Tri-Mode Config
-
-- [ ] M2.1 新增 `gateway.openai_ws.ingress_mode_default`,校验 `off|shared|dedicated`。
-- [ ] M2.2 账号 extra 新增 `*_mode` 字段读取(oauth/apikey)。
-- [ ] M2.3 旧字段映射兼容(`*_enabled` 与兼容旧键)。
-
-## M3. Protocol Symmetry (V2 Only)
-
-- [ ] M3.1 v2 + WS 入站仅允许 `ws->ws`。
-- [ ] M3.2 v2 + HTTP 入站仅允许 `http->http`。
-- [ ] M3.3 禁止 `ws->http` 与 `http->ws` 并返回明确错误。
-
-## M4. Account Concurrency as Pool Max
-
-- [ ] M4.1 在 v2 路径将 `account.concurrency` 作为账号池上限。
-- [ ] M4.2 `account.concurrency<=0` 的拒绝语义与日志。
-- [ ] M4.3 dedicated/shared 两种模式均应用该上限。
-- [ ] M4.4 新增单测覆盖并发上限命中与拒绝。
-
-## M5. Dedicated Stability
-
-- [ ] M5.1 dedicated 首轮独占建连。
-- [ ] M5.2 dedicated 会话内连接硬亲和。
-- [ ] M5.3 `store=false` 前置治理 `previous_response_id`。
-- [ ] M5.4 连接中断 input replay 单次恢复。
-- [ ] M5.5 会话结束连接不可复用。
-
-## M6. Frontend/Admin
-
-- [ ] M6.1 账号 WS mode 从布尔改为三态选择器。
-- [ ] M6.2 增加“账号并发数=池上限”说明文案。
-- [ ] M6.3 保持旧字段读取兼容,新字段优先写入。
-
-## M7. Observability and Rollout
-
-- [ ] M7.1 增加 `router_version/ws_mode/protocol_path` 关键日志。
-- [ ] M7.2 增加 symmetry reject 与 pool limit hit 指标。
-- [ ] M7.3 先开 shared 灰度,再开 dedicated 灰度。
-- [ ] M7.4 指标异常时仅关 `mode_router_v2_enabled` 回滚。
diff --git a/openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/proposal.md b/openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/proposal.md
deleted file mode 100644
index 4df5058d7..000000000
--- a/openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/proposal.md
+++ /dev/null
@@ -1,13 +0,0 @@
-# Change: Avoid extra DB lookup on sticky session hit
-
-## Why
-Sticky-session hits in `SelectAccountWithLoadAwareness` currently call `accountRepo.GetByID` even though the candidate accounts are already loaded in `listSchedulableAccounts`. This adds a redundant DB query on the hot path, increasing latency and DB load.
-
-## What Changes
-- Build a map of `accountID -> *Account` from the schedulable accounts list.
-- On sticky-session hit, use the in-memory map to validate group/platform/model support and attempt slot acquisition without an extra DB lookup.
-- Keep behavior unchanged when the sticky account is not in the candidate set (fall back to load-aware selection).
-
-## Impact
-- Affected specs: `schedule-account`
-- Affected code: `backend/internal/service/gateway_service.go`
diff --git a/openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/specs/schedule-account/spec.md b/openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/specs/schedule-account/spec.md
deleted file mode 100644
index 611ea3b2a..000000000
--- a/openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/specs/schedule-account/spec.md
+++ /dev/null
@@ -1,7 +0,0 @@
-## ADDED Requirements
-### Requirement: Sticky-session hit reuses schedulable accounts list
-The scheduler SHALL resolve sticky-session account selection from the schedulable accounts list already loaded for the request and SHALL NOT issue an additional account-by-ID database query when the sticky account is present in that list.
-
-#### Scenario: Sticky session hit without extra DB query
-- **WHEN** a scheduling request has a sticky session that points to an account in the schedulable accounts list
-- **THEN** the scheduler reuses that in-memory account data and does not query the database by account ID
diff --git a/openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/tasks.md b/openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/tasks.md
deleted file mode 100644
index 3251d4c40..000000000
--- a/openspec/changes/archive/2026-01-10-refactor-sticky-session-hit-lookup/tasks.md
+++ /dev/null
@@ -1,4 +0,0 @@
-## 1. Implementation
-- [x] 1.1 Build an account lookup map in `SelectAccountWithLoadAwareness` for sticky-session checks.
-- [x] 1.2 Use the map on sticky-session hit and remove the `GetByID` query.
-- [x] 1.3 Add a `SelectAccountWithLoadAwareness` unit test that asserts sticky-session hit does not call `accountRepo.GetByID` (use a mock/stub repo with call counting).
diff --git a/openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/proposal.md b/openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/proposal.md
deleted file mode 100644
index 63c95d1fe..000000000
--- a/openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/proposal.md
+++ /dev/null
@@ -1,24 +0,0 @@
-# Change: 添加前端 Chunk 加载错误自动恢复机制
-
-## Why
-
-当前端应用重新部署后,用户浏览器可能缓存了旧版本的 `index.html`,其中引用的 JS chunk 文件(如 `DashboardView-CG6GXl8p.js`)在服务器上已被新版本替换。当用户导航到使用懒加载的路由时,会触发以下错误:
-
-```
-Failed to fetch dynamically imported module: https://api.aicodex.top:8443/assets/DashboardView-CG6GXl8p.js
-```
-
-这导致用户无法正常使用应用,必须手动刷新页面或清除缓存。
-
-## What Changes
-
-- 在 Vue Router 的 `onError` 钩子中添加 chunk 加载错误检测
-- 检测到 chunk 加载失败时自动刷新页面以获取最新资源
-- 使用 `sessionStorage` 防止无限刷新循环(10 秒内只允许一次自动刷新)
-- 刷新仍失败时输出清晰的控制台错误提示
-
-## Impact
-
-- Affected specs: `frontend-routing`(新增 capability)
-- Affected code:
- - `frontend/src/router/index.ts` - 增强 `router.onError` 处理逻辑
diff --git a/openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/specs/frontend-routing/spec.md b/openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/specs/frontend-routing/spec.md
deleted file mode 100644
index 5379e9615..000000000
--- a/openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/specs/frontend-routing/spec.md
+++ /dev/null
@@ -1,24 +0,0 @@
-## ADDED Requirements
-
-### Requirement: Chunk Load Error Recovery
-
-The frontend application SHALL automatically recover from chunk loading failures caused by deployment updates. When a dynamically imported module fails to load, the router SHALL detect the error and attempt to reload the page to fetch the latest resources.
-
-#### Scenario: Dynamic import fails due to stale cache
-- **WHEN** a user navigates to a lazily-loaded route
-- **AND** the browser has cached an outdated `index.html` referencing old chunk files
-- **AND** the server returns 404 for the requested chunk
-- **THEN** the router detects the chunk load error
-- **AND** automatically reloads the page to fetch the latest version
-
-#### Scenario: Reload cooldown prevents infinite loop
-- **WHEN** a chunk load error triggers an automatic page reload
-- **AND** the reload occurs within 10 seconds of a previous reload attempt
-- **THEN** the router SHALL NOT trigger another reload
-- **AND** SHALL log an error message suggesting the user clear their browser cache
-
-#### Scenario: Successful recovery after reload
-- **WHEN** the page reloads due to a chunk load error
-- **AND** the browser fetches the latest `index.html` and chunk files
-- **THEN** the user can successfully navigate to the intended route
-- **AND** the application functions normally
diff --git a/openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/tasks.md b/openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/tasks.md
deleted file mode 100644
index ba9534754..000000000
--- a/openspec/changes/archive/2026-01-16-add-chunk-load-error-recovery/tasks.md
+++ /dev/null
@@ -1,12 +0,0 @@
-## 1. 代码实现
-
-- [x] 1.1 在 `router.onError` 中添加 chunk 加载错误检测逻辑
-- [x] 1.2 检测多种错误消息模式(`Failed to fetch dynamically imported module`、`Loading chunk`、`Loading CSS chunk`、`ChunkLoadError`)
-- [x] 1.3 使用 `sessionStorage` 记录刷新时间戳,防止无限刷新循环
-- [x] 1.4 设置 10 秒冷却时间,避免网络问题导致的快速重复刷新
-
-## 2. 测试验证
-
-- [x] 2.1 TypeScript 类型检查通过
-- [x] 2.2 手动测试:模拟 chunk 加载失败,验证自动刷新行为
-- [x] 2.3 手动测试:验证 10 秒内不会重复刷新
diff --git a/openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/proposal.md b/openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/proposal.md
deleted file mode 100644
index 15e2962c2..000000000
--- a/openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/proposal.md
+++ /dev/null
@@ -1,33 +0,0 @@
-# Change: 重构 TimingWheelService 错误处理
-
-## Why
-
-`NewTimingWheelService()` 初始化失败时直接 `panic(err)`,与项目中其他模块"返回 error + 上层处理"的错误处理风格不一致。这种做法在极端情况下会导致进程崩溃,且不给上层调用者处理错误的机会。
-
-**问题代码位置**:`backend/internal/service/timing_wheel_service.go:27`
-
-```go
-tw, err := collection.NewTimingWheel(...)
-if err != nil {
- panic(err) // 问题所在
-}
-```
-
-## What Changes
-
-- 修改 `NewTimingWheelService()` 函数签名为 `(*TimingWheelService, error)`
-- 移除 panic 调用,改为返回 error
-- 更新所有调用方以处理返回的 error
-- 确保应用启动时正确处理 TimingWheel 初始化失败的情况
-
-## Impact
-
-- Affected specs: `timing-wheel`
-- Affected code:
- - `backend/internal/service/timing_wheel_service.go` - 核心修改
- - `backend/internal/service/wire.go` - Provider 签名/返回值需要调整以透传 error
- - `backend/cmd/server/wire_gen.go` - Wire 生成文件会随 Provider 变化而更新(需要重新生成)
- - `backend/cmd/server/main.go` - `initializeApplication(...)` 返回 error 时会 `log.Fatalf(...)` 并退出(非 0)
- - 任何其他直接调用 `NewTimingWheelService()` 的代码(需统一处理返回的 error)
-
-**生成文件注意事项**:修改 `backend/internal/service/wire.go` 后,需要运行 `cd backend && go generate ./cmd/server` 重新生成 `backend/cmd/server/wire_gen.go`。
diff --git a/openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/specs/timing-wheel/spec.md b/openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/specs/timing-wheel/spec.md
deleted file mode 100644
index 5bef41ac3..000000000
--- a/openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/specs/timing-wheel/spec.md
+++ /dev/null
@@ -1,19 +0,0 @@
-## ADDED Requirements
-
-### Requirement: TimingWheel Initialization Error Handling
-
-当 TimingWheel 初始化失败时,`NewTimingWheelService()` SHALL 返回 error 而不是触发 panic。函数签名 MUST 为 `(*TimingWheelService, error)`,以便调用方能够感知初始化失败并按“启动失败”路径处理。
-
-#### Scenario: TimingWheel 初始化失败时不触发 panic
-- **WHEN** 底层 `collection.NewTimingWheel()` 返回 error
-- **THEN** `NewTimingWheelService()` 返回 `nil` 和包装后的 error(例如使用 `%w` 包装)
-- **AND** 不发生 panic(进程不应因该错误直接崩溃)
-
-#### Scenario: TimingWheel 初始化成功
-- **WHEN** 底层 `collection.NewTimingWheel()` 初始化成功
-- **THEN** `NewTimingWheelService()` 返回有效的 `*TimingWheelService` 和 `nil` error
-
-#### Scenario: 初始化失败导致应用启动失败并退出(非 0)
-- **WHEN** `initializeApplication(...)` 调用 TimingWheel 的 provider/constructor 并收到 error
-- **THEN** `initializeApplication(...)` 将该 error 返回给调用方
-- **AND** `backend/cmd/server/main.go` 记录 fatal 日志并以非 0 状态码退出进程
diff --git a/openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/tasks.md b/openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/tasks.md
deleted file mode 100644
index d00d2d3bb..000000000
--- a/openspec/changes/archive/2026-01-16-refactor-timing-wheel-error-handling/tasks.md
+++ /dev/null
@@ -1,21 +0,0 @@
-## 1. 代码修改
-
-- [x] 1.1 修改 `NewTimingWheelService()` 返回类型为 `(*TimingWheelService, error)`
-- [x] 1.2 将 `panic(err)` 替换为 `return nil, fmt.Errorf("failed to create timing wheel: %w", err)`
-- [x] 1.3 添加 `fmt` 包的 import(如果尚未导入)
-- [x] 1.4 (可选增强)引入可注入的 TimingWheel factory(例如包级变量/私有构造函数),便于单测覆盖失败分支
-
-## 2. 调用方更新
-
-- [x] 2.1 查找所有 `NewTimingWheelService()` 的调用位置
-- [x] 2.2 更新调用方以处理返回的 error
-- [x] 2.3 修改 `ProvideTimingWheelService()` 返回类型为 `(*TimingWheelService, error)`,并在成功后 `Start()`
-- [x] 2.4 重新生成 Wire:`cd backend && go generate ./cmd/server`(更新 `backend/cmd/server/wire_gen.go`)
-- [x] 2.5 确保应用启动失败时有清晰的错误日志(当前 `backend/cmd/server/main.go` 会 `log.Fatalf("Failed to initialize application: %v", err)` 并退出)
-
-## 3. 测试验证
-
-- [x] 3.1 编译验证,确保没有编译错误
-- [x] 3.2 运行现有测试,确保不破坏现有功能
-- [x] 3.3 手动测试应用启动正常
-- [x] 3.4 (可选增强)新增单测:模拟 `collection.NewTimingWheel()` 返回 error,验证 `NewTimingWheelService()` 不 panic 且返回 error
diff --git a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/.openspec.yaml b/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/.openspec.yaml
deleted file mode 100644
index 44652448b..000000000
--- a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/.openspec.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-schema: spec-driven
-created: 2026-02-11
diff --git a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/design.md b/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/design.md
deleted file mode 100644
index 8a8cb7c68..000000000
--- a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/design.md
+++ /dev/null
@@ -1,119 +0,0 @@
-## Context
-
-当前 OpenAI OAuth 主链路为:`Codex CLI/VSCode` 发起 `POST /v1/responses`,经 API Key 认证、并发控制、账号调度、OAuth token 获取后转发到 `chatgpt.com/backend-api/codex/responses`(或 OpenAI 平台接口),并将 SSE/JSON 回传客户端。
-
-在当前实现中,性能开销主要集中在四个层面:
-
-- 请求热路径:同一请求体在 handler/service 中存在多次解析与拷贝,且有非必要上下文写入。
-- 并发与调度路径:常态请求存在额外 Redis 往返,且部分释放逻辑引入请求级 goroutine 开销。
-- 流式转发路径:SSE 逐行处理中存在正则匹配、重复字符串处理与频繁 JSON 解析,导致 CPU 与 GC 压力增大。
-- token 路径:锁竞争时的固定等待策略放大尾延迟。
-
-约束条件:
-
-- 对外 API 协议和行为保持兼容(包括 OpenAI Responses 流式语义)。
-- 优先优化 P95/P99 与稳定性,不以牺牲正确性/可运维性换取短期吞吐。
-- 优化必须可观测、可灰度、可回滚。
-
-## Goals / Non-Goals
-
-**Goals:**
-
-- 建立 OpenAI OAuth 路径统一性能目标:网关附加延迟、TTFT、P95/P99、错误率、CPU 与内存分配。
-- 将常态请求中的非必要 Redis 往返与 goroutine 开销降到最低。
-- 降低 SSE 热路径 CPU/GC 成本,提升流式场景尾延迟稳定性。
-- 优化 token 获取竞争路径,减少锁等待带来的请求抖动。
-- 提供可量化的压测与回归基线,确保后续变更可持续守住性能红线。
-
-**Non-Goals:**
-
-- 不改变对外路由、请求/响应 JSON 结构与鉴权协议。
-- 不在本次变更中引入新的外部基础设施(如新增 MQ、新增缓存集群)。
-- 不重构全部网关模块,仅聚焦 OpenAI OAuth 高流量主路径。
-
-## Decisions
-
-### 决策 1:按“路径分层”进行优化,而非一次性大重构
-
-- 选择:将优化拆为请求热路径、并发调度路径、流式路径、token 路径四类子改造,逐步落地。
-- 原因:该线路跨越 handler/service/repository/config,直接大重构风险高、回滚困难。
-- 备选方案:一次性重写 OpenAI gateway handler/service。
-- 不选原因:改动面过大,难以定位回归,难以灰度验证单项收益。
-
-### 决策 2:并发控制改为“先抢槽再排队”的快速路径
-
-- 选择:优先尝试直接获取用户并发槽,失败后再进入等待计数与等待逻辑。
-- 原因:当前“先加等待计数再抢槽”会让常态成功请求承担额外 Redis 往返。
-- 备选方案:保留现状,仅调大 Redis 与连接池。
-- 不选原因:只能缓解,不解决协议级额外操作造成的固有延迟。
-
-### 决策 3:释放回调改为 `context.AfterFunc`/等效轻量机制
-
-- 选择:取消每请求一个守护 goroutine 的释放模式,采用 context 生命周期回调。
-- 原因:高并发流式场景中,短生命周期 goroutine 数量会放大调度与内存压力。
-- 备选方案:保留 goroutine,依赖 runtime 自适应。
-- 不选原因:尾延迟仍受影响,且难以精准控量。
-
-### 决策 4:SSE 热路径去正则化与选择性解析
-
-- 选择:SSE 行识别由正则改为前缀/状态机方式;仅对关键事件做 JSON 解析(如 `response.completed` usage),减少每行开销。
-- 原因:SSE 高频循环中正则 + 字符串替换 + JSON 反序列化是热点。
-- 备选方案:保留现有实现,仅通过 CPU 扩容应对。
-- 不选原因:成本不可持续,且扩容无法有效压缩 P99 抖动。
-
-### 决策 5:token 锁竞争改为短轮询+jitter,替代固定 sleep
-
-- 选择:将锁竞争等待从固定 `200ms` 改为短间隔重试并加抖动,支持更快命中缓存刷新结果。
-- 原因:固定等待会在高并发下直接拉高请求尾延迟。
-- 备选方案:继续固定 sleep 并缩短数值。
-- 不选原因:缺乏自适应能力,竞争抖动下仍易形成延迟台阶。
-
-### 决策 6:观测先行,建立性能守门指标
-
-- 选择:为该链路补齐阶段性耗时指标(认证/调度/token/上游首包/SSE 处理)与压测报告模板。
-- 原因:没有统一指标就无法客观验证“极致优化”是否达标。
-- 备选方案:仅通过主观体感或单一 QPS 指标评估。
-- 不选原因:无法定位瓶颈迁移,也无法防止回归。
-
-## Risks / Trade-offs
-
-- [风险] SSE 解析策略变更可能引入协议兼容问题(尤其是边缘事件格式)
- → Mitigation:增加真实样本回放测试;灰度期间双写校验(新旧解析结果比对)。
-
-- [风险] 并发控制流程调整可能导致等待计数与槽位计数不一致
- → Mitigation:增加一致性指标(槽位数、等待数、释放成功率)并设置告警。
-
-- [风险] token 锁策略调整可能在极端场景增加刷新请求数
- → Mitigation:保留分布式锁上限与刷新熔断策略,设置刷新 QPS 保护阈值。
-
-- [权衡] 引入更多性能指标会增加少量埋点开销
- → Mitigation:仅对关键路径埋点,避免高基数标签,必要时采样。
-
-## Migration Plan
-
-1. 基线阶段:在现网采集当前性能基线(QPS、P50/P95/P99、TTFT、错误率、CPU/内存、Redis RT)。
-2. 改造阶段(分批):
- - 批次 A:并发控制快速路径 + 释放机制优化。
- - 批次 B:SSE 热路径优化。
- - 批次 C:token 竞争路径优化。
- - 批次 D:请求体与中间件热路径微优化。
-3. 灰度阶段:按实例或流量比例灰度,逐批验证收益与回归。
-4. 全量阶段:达到验收指标后全量发布。
-5. 回滚策略:每批次独立开关或可逆提交,异常时按批次快速回滚。
-
-## 审核门禁(Coding Gate)
-
-在进入编码前,以下门禁项 MUST 完成签核:
-
-- 基线签核:确认基线采集时间窗口、压测场景、流量模型与样本数据来源。
-- 目标签核:确认并冻结性能目标阈值(至少包含 P95/P99、TTFT、错误率、CPU/内存、Redis RT)。
-- 发布签核:确认灰度批次、观测指标、回滚触发阈值与回滚路径。
-
-> 说明:若上述任一门禁未签核,本变更 SHOULD NOT 进入编码阶段。
-
-## Open Questions
-
-- 是否需要对不同客户端类型(Codex CLI / VSCode / 其他)采用不同优化策略与阈值?
-- 连接池隔离策略是否从 `account_proxy` 调整为 `proxy` 需要先在哪个环境完成风险验证?
-- 目标 SLO 的最终阈值(例如网关附加 P99、TTFT)由谁签字确认,验收窗口多久?
-- 是否需要将该线路纳入 CI 压测基线门禁,避免后续功能迭代带来性能回归?
diff --git a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/final-acceptance-report.md b/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/final-acceptance-report.md
deleted file mode 100644
index 2eb9c6f4f..000000000
--- a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/final-acceptance-report.md
+++ /dev/null
@@ -1,67 +0,0 @@
-# OpenAI OAuth 性能优化最终验收报告
-
-> 变更:`optimize-openai-oauth-performance`
-> 报告日期:`2026-02-12`
-> 负责人:`AI 协作演练(待真实环境签字)`
-
-## 1. 验收范围
-
-- 链路:Codex CLI/VSCode -> `/v1/responses` -> 网关 -> OpenAI/ChatGPT 上游
-- 版本:`当前工作分支(含 2.x/3.x/4.x/5.x 优化)`
-- 环境:`本地 mock 演练 + 后端全量测试`
-
-## 2. 基线 vs 优化后
-
-> 说明:本报告为“演练验收版”,真实生产基线请结合 `docs/perf/openai-oauth-baseline-template.md` 补全。
-
-| 指标 | 基线值 | 优化后 | 变化 | 目标阈值 | 是否达标 |
-|---|---:|---:|---:|---:|---|
-| 请求耗时 P50 (ms) | - | - | - | - | 待真实压测 |
-| 请求耗时 P95 (ms) | - | - | - | - | 待真实压测 |
-| 请求耗时 P99 (ms) | - | - | - | - | 待真实压测 |
-| TTFT P99 (ms) | 900(阈值) | 640(灰度 D 演练) | -260 | <=900 | 达标 |
-| 请求错误率 (%) | 2(阈值) | 0.72(灰度 D 演练) | -1.28 | <=2 | 达标 |
-| 上游错误率 (%) | 2(阈值) | 0.67(灰度 D 演练) | -1.33 | <=2 | 达标 |
-| CPU 峰值 (%) | - | - | - | - | 待真实压测 |
-| 内存峰值 (MiB) | - | - | - | - | 待真实压测 |
-| Redis RT P99 (ms) | - | - | - | - | 待真实压测 |
-
-## 3. 灰度发布记录(6.1)
-
-| 批次 | 流量比例 | 观察窗口 | 结果 | 异常说明 |
-|---|---:|---|---|---|
-| A | 5% | mock 演练 | 通过 | 无 |
-| B | 20% | mock 演练 | 通过 | 无 |
-| C | 50% | mock 演练 | 通过 | 无 |
-| D | 100% | mock 演练 | 通过 | 无 |
-
-详见:`docs/perf/openai-oauth-gray-drill-report.md`
-
-## 4. 阈值守护与回滚演练(6.2)
-
-- 阈值守护执行:`通过`
-- 自动守护脚本结果:
- - 正常批次:退出码 `0`
- - 注入异常场景:退出码 `2`(触发回滚条件)
-- 回滚演练:
- - 触发条件:`TTFT P99=1550ms,错误率=6.3%,上游错误率=5.6%`
- - 回滚判定:`脚本明确输出“建议:停止扩量并执行回滚”`
- - 结果:`成功`
-
-## 5. 风险与后续行动
-
-- 剩余风险:
- - 缺少真实生产流量下 CPU/内存/Redis RT 的量化对比
- - 需要按业务窗口执行真实灰度并补齐签字
-- 后续优化计划:
- - 在预发/生产执行同口径 k6 与 dashboard 对比
- - 将灰度守护脚本接入发布流水线(失败即阻断扩量)
-
-## 6. 验收结论(6.3)
-
-- 结论:`通过(演练环境)`
-- 是否关闭本变更:`是(演练闭环已完成);生产签字后可归档`
-- 签字:
- - 研发负责人:`待补充`
- - SRE/运维负责人:`待补充`
- - 产品负责人:`待补充`
diff --git a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/proposal.md b/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/proposal.md
deleted file mode 100644
index 356387283..000000000
--- a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/proposal.md
+++ /dev/null
@@ -1,29 +0,0 @@
-## Why
-
-当前 OpenAI OAuth 线路(Codex CLI → `/v1/responses` → 网关转发 ChatGPT/OpenAI)在高并发下存在多处非必要开销(重复解析、额外 Redis 往返、热路径日志与 goroutine 开销、SSE 逐行处理成本),导致网关附加延迟与尾延迟(p95/p99)偏高。随着 Codex/VSCode 客户端流量增长,这条链路已成为核心体验路径,需要以“极致性能”为目标进行系统性优化。
-
-## What Changes
-
-- 为 OpenAI OAuth 线路建立明确的性能目标与验收口径(重点关注网关附加延迟、TTFT、p95/p99、CPU/内存分配、错误率)。
-- 优化请求热路径:减少请求体重复解析与不必要拷贝,收敛中间件与上下文写入的额外成本。
-- 优化并发与调度路径:减少常态请求的 Redis 往返次数,降低等待队列与槽位管理的额外开销。
-- 优化流式转发路径:降低 SSE 逐行处理中的正则与 JSON 解析成本,减少流式场景的 CPU 和 GC 压力。
-- 优化 OAuth token 获取路径:降低锁竞争时的等待成本,避免固定 sleep 放大尾延迟。
-- 增强可观测性:补齐性能指标与压测基线,确保优化收益可量化、可回归验证。
-
-## Capabilities
-
-### New Capabilities
-
-- `openai-oauth-performance`: 定义并约束 OpenAI OAuth 端到端高性能网关能力,包括请求热路径、调度并发路径、流式转发路径和 token 获取路径的性能要求与验收标准。
-
-### Modified Capabilities
-
-- (无)
-
-## Impact
-
-- 影响模块:`backend/internal/handler/openai_gateway_handler.go`、`backend/internal/handler/gateway_helper.go`、`backend/internal/service/openai_gateway_service.go`、`backend/internal/service/openai_token_provider.go`、`backend/internal/repository/http_upstream.go`、`backend/internal/server/middleware/*`、相关 `config` 默认值与校验。
-- 影响系统:Redis(并发/等待队列/调度缓存访问模式)、上游连接池行为、日志与监控指标。
-- API 兼容性:对外 API 路由与协议保持兼容,不引入 Breaking Change;主要是性能与内部行为优化。
-- 风险与收益:需要通过压测与灰度验证来控制回归风险;预期显著降低网关附加延迟与尾延迟,提升高并发稳定性。
diff --git a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/review.md b/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/review.md
deleted file mode 100644
index 5c866821c..000000000
--- a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/review.md
+++ /dev/null
@@ -1,78 +0,0 @@
-## 提案三轮复审记录(进入编码前)
-
-### 复审范围
-
-- `proposal.md`
-- `design.md`
-- `specs/openai-oauth-performance/spec.md`
-- `tasks.md`
-
----
-
-## 第 1 轮复审:结构完整性审查
-
-**检查项**
-
-- OpenSpec `spec-driven` 四件套是否齐全
-- proposal 章节是否完整(Why/What Changes/Capabilities/Impact)
-- capability 与 spec 文件路径是否一一对应
-- tasks 是否符合 `- [ ] X.Y` 可追踪格式
-
-**结果**
-
-- 通过(结构完整)
-- `openspec validate optimize-openai-oauth-performance --strict` 校验通过
-
-**发现与处理**
-
-- 无阻塞问题
-
----
-
-## 第 2 轮复审:一致性与可测性审查
-
-**检查项**
-
-- proposal → design → specs → tasks 的语义链路是否一致
-- specs 是否全部为可测试要求(Requirement + Scenario)
-- tasks 是否覆盖 specs 要求
-
-**结果**
-
-- 基本通过(存在可执行门禁不够显式的问题)
-
-**发现与处理**
-
-1. 问题:进入编码前的“签核门禁”未在任务中显式固化,容易直接跳过。
- 处理:在 `tasks.md` 新增 **0. 审核签核门禁** 分组(0.1~0.3)。
-
-2. 问题:design 中虽有 open questions,但缺少“未签核不得编码”的明确约束。
- 处理:在 `design.md` 新增 **审核门禁(Coding Gate)** 章节。
-
----
-
-## 第 3 轮复审:落地与风险门禁审查
-
-**检查项**
-
-- 灰度与回滚路径是否明确
-- 风险项是否有对应缓解
-- 是否具备“审核通过后再编码”的可执行条件
-
-**结果**
-
-- 条件通过(需业务/研发负责人完成 3 项签核)
-
-**待签核项(阻塞编码)**
-
-- [x] A. 基线签核:确认基线窗口、压测场景、流量模型、样本数据集(见 signoff-and-rollout.md)
-- [x] B. 目标签核:确认性能阈值(P95/P99、TTFT、错误率、CPU/内存、Redis RT)(见 signoff-and-rollout.md)
-- [x] C. 发布签核:确认灰度比例、监控指标、回滚触发阈值(见 signoff-and-rollout.md)
-
----
-
-## 复审结论
-
-- 当前提案质量:**可执行,且满足 OpenSpec 规范**
-- 结论:**条件通过(待 A/B/C 三项签核完成)**
-- 建议:签核完成后,按 `tasks.md` 从 `1.1` 开始进入编码
diff --git a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/signoff-and-rollout.md b/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/signoff-and-rollout.md
deleted file mode 100644
index b49c8f3c5..000000000
--- a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/signoff-and-rollout.md
+++ /dev/null
@@ -1,104 +0,0 @@
-# OpenAI OAuth 性能优化签核与灰度发布手册
-
-> 变更:`optimize-openai-oauth-performance`
-> 更新时间:`2026-02-12`
-
-## 0. 审核签核门禁(0.1 / 0.2 / 0.3)
-
-### 0.1 基线窗口与压测场景冻结(已确认)
-
-- 基线窗口:`2026-02-12 09:00` ~ `2026-02-12 12:00`(本地/预发同口径)
-- 压测脚本:`tools/perf/openai_oauth_responses_k6.js`
-- 压测模型:
- - 非流式:`NON_STREAM_RPS=8`
- - 流式:`STREAM_RPS=4`
- - 时长:`DURATION=3m`
-- 样本请求:`/v1/responses`(Codex CLI UA,短文本 + stream/非 stream)
-- 报告模板:`docs/perf/openai-oauth-baseline-template.md`
-
-### 0.2 性能目标阈值冻结(已确认)
-
-- SLA 下限:`99.5%`
-- TTFT P99 上限:`900ms`
-- 请求错误率上限:`2%`
-- 上游错误率上限:`2%`
-
-接口固化路径:
-
-- `GET /api/v1/admin/ops/settings/metric-thresholds`
-- `PUT /api/v1/admin/ops/settings/metric-thresholds`
-
-示例阈值文件:`docs/perf/openai-oauth-metric-thresholds.example.json`
-
-### 0.3 灰度策略与回滚阈值冻结(已确认)
-
-- 灰度批次:`5% -> 20% -> 50% -> 100%`
-- 每批观察窗口:`15~30 分钟`
-- 回滚触发(任一满足即回滚):
- - `TTFT P99 > 1200ms` 持续 `3 分钟`
- - `请求错误率 > 5%` 持续 `3 分钟`
- - `上游错误率 > 5%` 持续 `3 分钟`
-
----
-
-## 6. 灰度发布与验收(6.1 / 6.2 / 6.3)
-
-### 6.1 批次灰度执行记录模板
-
-| 批次 | 流量比例 | 开始时间 | 结束时间 | 结果 | 备注 |
-|---|---:|---|---|---|---|
-| A | 5% | | | | |
-| B | 20% | | | | |
-| C | 50% | | | | |
-| D | 100% | | | | |
-
-每批必填观察项:
-
-- `duration.p99_ms`
-- `ttft.p99_ms`
-- `error_rate`
-- `upstream_error_rate`
-- CPU / 内存 / Redis RT
-
-### 6.2 阈值守护与快速回滚操作
-
-#### 自动守护脚本
-
-```bash
-python tools/perf/openai_oauth_gray_guard.py \
- --base-url http://127.0.0.1:5231 \
- --admin-token
\
- --platform openai \
- --time-range 30m
-```
-
-- 返回 `0`:指标通过,可继续观察/扩量
-- 返回 `2`:超阈值,立即停止扩量并回滚
-
-#### 建议回滚步骤
-
-1. 停止当前批次扩量;冻结发布。
-2. 回退到上一批稳定比例(或直接回到 0%)。
-3. 按 `request_id/account_id` 抽样最近 30 分钟失败请求。
-4. 导出 `ops dashboard overview + error trend + upstream errors`。
-5. 在复盘会中确认根因后再重启灰度。
-
-### 6.3 最终验收报告输出要求
-
-最终报告使用:`openspec/changes/optimize-openai-oauth-performance/final-acceptance-report.md`
-
-必含内容:
-
-- 优化前后对比(P50/P95/P99、TTFT、错误率、CPU/内存、Redis RT)
-- 各批次灰度收益与风险记录
-- 回滚演练结果(成功/失败、耗时)
-- 最终结论(是否关闭变更)
-
-
----
-
-## 附:演练产物
-
-- 灰度守护演练报告:`docs/perf/openai-oauth-gray-drill-report.md`
-- 演练脚本:`tools/perf/openai_oauth_gray_drill.py`
-- 守护脚本:`tools/perf/openai_oauth_gray_guard.py`
diff --git a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/specs/openai-oauth-performance/spec.md b/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/specs/openai-oauth-performance/spec.md
deleted file mode 100644
index e47fd9c5b..000000000
--- a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/specs/openai-oauth-performance/spec.md
+++ /dev/null
@@ -1,43 +0,0 @@
-## ADDED Requirements
-
-### Requirement: OpenAI OAuth 链路性能目标可量化
-系统 MUST 为 OpenAI OAuth `/v1/responses` 链路定义并维护可量化的性能目标,至少覆盖网关附加延迟、TTFT、P95/P99、错误率与资源开销基线,并以统一口径输出对比结果。
-
-#### Scenario: 发布前具备性能基线与目标对比
-- **WHEN** 团队发起 OpenAI OAuth 性能优化发布评审
-- **THEN** 评审材料 MUST 包含优化前后同口径压测结果与目标达成情况
-
-### Requirement: 请求热路径避免重复解析与不必要拷贝
-系统 SHALL 在 OpenAI OAuth 请求处理热路径中避免对同一请求体进行重复解析与不必要数据拷贝,保证常态请求不引入额外的可避免 CPU/内存开销。
-
-#### Scenario: 常态请求路径不发生多次完整解析
-- **WHEN** 网关处理一个合法的 OpenAI OAuth 非异常请求
-- **THEN** 热路径实现 SHALL 不重复执行可避免的全量 JSON 解析与大对象拷贝
-
-### Requirement: 并发控制快速路径最小化额外存储往返
-系统 SHALL 对并发控制采用快速路径策略:在可直接获得并发槽位时,不执行不必要的等待队列写入,并最小化常态请求的额外 Redis 往返。
-
-#### Scenario: 可立即获得槽位时跳过等待队列写入
-- **WHEN** 请求到达且用户与账号并发槽位均可立即获取
-- **THEN** 系统 SHALL 直接进入上游转发路径而不执行等待队列计数写入
-
-### Requirement: 流式转发热路径降低逐行处理成本
-系统 MUST 优化 OpenAI OAuth SSE 流式转发热路径,降低逐行处理中的高频字符串与 JSON 操作成本,并保持与 OpenAI Responses 流式协议兼容。
-
-#### Scenario: 流式协议兼容且处理开销降低
-- **WHEN** 客户端发起 OpenAI OAuth 流式请求并持续接收事件
-- **THEN** 系统 SHALL 保持事件语义兼容,同时逐行处理不应依赖可替代的高开销通用解析手段
-
-### Requirement: Token 竞争路径控制尾延迟放大
-系统 SHALL 在 OpenAI OAuth token 获取的锁竞争场景中采用低抖动等待策略,避免固定大步长等待导致的尾延迟放大。
-
-#### Scenario: 锁竞争下请求不出现固定等待台阶
-- **WHEN** 多个并发请求同时命中同一 OAuth 账号的 token 刷新竞争
-- **THEN** 请求等待策略 SHALL 使用短周期可回退机制,并避免固定长等待造成显著延迟台阶
-
-### Requirement: 优化发布必须具备可灰度与可回滚保障
-系统 MUST 为 OpenAI OAuth 性能优化提供灰度发布、关键指标监控与回滚策略,确保在异常时可快速恢复到稳定状态。
-
-#### Scenario: 灰度阶段触发阈值时可快速回滚
-- **WHEN** 灰度期间关键指标(错误率或 P99)超出预设阈值
-- **THEN** 运行策略 MUST 支持按批次或按开关回滚,并恢复至优化前稳定行为
diff --git a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/tasks.md b/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/tasks.md
deleted file mode 100644
index 3a16df051..000000000
--- a/openspec/changes/archive/2026-02-12-optimize-openai-oauth-performance/tasks.md
+++ /dev/null
@@ -1,41 +0,0 @@
-## 0. 审核签核门禁
-
-- [x] 0.1 确认并冻结基线窗口、压测场景、流量模型与样本数据集
-- [x] 0.2 确认并签核性能目标阈值(P95/P99、TTFT、错误率、CPU/内存、Redis RT)
-- [x] 0.3 确认灰度策略(分批比例)与回滚触发阈值(错误率/延迟)
-
-## 1. 基线与可观测性
-
-- [x] 1.1 为 OpenAI OAuth `/v1/responses` 链路补齐阶段耗时指标(认证、调度、token、上游首包、SSE 处理)
-- [x] 1.2 建立统一压测脚本与报告模板,输出优化前基线(P50/P95/P99、TTFT、错误率、CPU、内存、Redis RT)
-- [x] 1.3 在监控中新增性能验收看板与告警阈值(重点关注 P99 与错误率)
-
-## 2. 请求热路径优化
-
-- [x] 2.1 收敛 handler/service 中对同一请求体的重复解析,移除可避免的全量 JSON 解析
-- [x] 2.2 减少热路径中不必要的大对象拷贝与字符串转换(含 ops 上下文字段)
-- [x] 2.3 为热路径优化补充回归测试,确保请求校验与兼容行为不变
-
-## 3. 并发与调度路径优化
-
-- [x] 3.1 调整并发流程为“先抢槽再排队”,使可立即获取槽位请求跳过等待队列写入
-- [x] 3.2 将槽位释放守护逻辑改为轻量机制(如 context 生命周期回调),移除请求级守护 goroutine
-- [x] 3.3 增加并发一致性验证(槽位计数、等待计数、释放成功率)与对应测试
-
-## 4. 流式转发热路径优化
-
-- [x] 4.1 将 SSE 行识别从正则改为低开销前缀/状态机方式
-- [x] 4.2 将 usage 解析改为选择性解析(仅关键事件),降低每行 JSON 处理成本
-- [x] 4.3 增加流式协议兼容测试(事件顺序、DONE 语义、异常事件)并完成样本回放验证
-
-## 5. Token 竞争路径优化
-
-- [x] 5.1 将 token 锁竞争等待从固定长等待改为短轮询+jitter 策略
-- [x] 5.2 增加高并发 token 刷新竞争测试,验证无显著固定延迟台阶
-- [x] 5.3 增加 token 刷新保护指标(刷新 QPS、锁等待分布、失败率)并配置告警
-
-## 6. 灰度发布与验收
-
-- [x] 6.1 按批次灰度发布(并发路径、流式路径、token 路径、热路径微优化)并记录每批收益
-- [x] 6.2 在灰度期间执行阈值守护,触发条件时验证可按批次快速回滚
-- [x] 6.3 完成最终验收报告,确认性能目标达成并关闭本变更
diff --git "a/openspec/changes/archive/2026-02-25-\303\247/.openspec.yaml" "b/openspec/changes/archive/2026-02-25-\303\247/.openspec.yaml"
deleted file mode 100644
index e331c975d..000000000
--- "a/openspec/changes/archive/2026-02-25-\303\247/.openspec.yaml"
+++ /dev/null
@@ -1,2 +0,0 @@
-schema: spec-driven
-created: 2026-02-25
diff --git "a/openspec/changes/archive/2026-02-25-\303\247/design.md" "b/openspec/changes/archive/2026-02-25-\303\247/design.md"
deleted file mode 100644
index 52b6bbc7d..000000000
--- "a/openspec/changes/archive/2026-02-25-\303\247/design.md"
+++ /dev/null
@@ -1,165 +0,0 @@
-## Context
-
-OpenAI WS v2 当前实现已经具备功能完整性,但在“长会话 + 大 payload + 高频事件 + 失败重试”组合场景下出现以下结构性性能瓶颈:
-
-- 热路径重复序列化与日志重计算(forwarder)。
-- 连接池 `Acquire` 的排序与计时器分配成本(pool/client)。
-- 重试分类与节奏控制不足导致的放大效应(gateway)。
-
-这些问题叠加后,会在失败场景产生额外建连、额外日志、额外 payload 处理,最终拖慢成功请求并放大 P99 抖动。
-
-## Goals / Non-Goals
-
-**Goals**
-
-- 让 WS v2 请求在热路径上做到“低分配、低复制、低系统调用”。
-- 让连接池在高并发下保持稳定复杂度和高复用。
-- 让失败处理从“重试放大”转为“收敛控制”,避免重连风暴。
-- 提供可量化的性能验收与回归门禁。
-
-**Non-Goals**
-
-- 不改变对外 API 协议与客户端调用方式。
-- 不引入新基础设施依赖。
-- 不改变“WS 默认开启”既有产品策略。
-
-## Decisions
-
-### 决策 1:WS payload 构建改为单次序列化快照
-
-- 现状问题:`payloadAsJSON` 在同一请求内多次执行,且日志提取依赖反复解析字符串。
-- 方案:在 `forwardOpenAIWSV2` 中引入一次性 `payloadSnapshot`:
- - 同时持有 `[]byte`(发送)和 `map`(按需修改)。
- - `previous_response_id`、`prompt_cache_key`、`event.type` 等字段直接从 map 读取,避免 `gjson.Get(string)`。
- - `setOpenAIWSTurnMetadata` 后仅在 payload 实际变化时重编码。
-- 收益:减少 JSON 编码/解码、降低字符串分配。
-
-### 决策 2:日志预算化,重统计改为采样+上限
-
-- 现状问题:`summarizeOpenAIWSPayloadKeySizes` 对每个字段 `json.Marshal`,在大 `tools` 下昂贵。
-- 方案:
- - 新增采样开关和采样率(默认低频)。
- - `payload_key_sizes` 改为“估算+截断”策略,不对超大字段做完整 marshal。
- - 事件日志保留关键里程碑(connect/write/read_fail/terminal),常规 token 事件按采样记录。
-- 收益:显著降低日志计算 CPU 与日志 I/O。
-
-### 决策 3:事件流处理改为“字节优先”,减少字符串往返
-
-- 现状问题:事件循环中 `string(message)` 多次重复,且每事件立即 flush。
-- 方案:
- - `toolCorrector` 增加 `[]byte` 版本入口,避免频繁字节转字符串再转回字节。
- - usage 解析改为按事件类型 gating(仅 `response.completed` 进入字段解析)。
- - 流式写出支持轻量批量 flush(例如 token 事件 N 条或 T 毫秒一刷),终态事件强制 flush。
-- 收益:减少分配和 syscall 次数,改善吞吐与 P99。
-
-### 决策 4:连接池选择从全量排序改为低复杂度结构
-
-- 现状问题:`Acquire` 每次排序连接数组,账号连接数上升后代价明显。
-- 方案:
- - 引入“最小等待者优先”的增量结构(小根堆或有界桶),避免每次全量 `sort`。
- - 保留 `preferred_conn_id` 快速路径,命中时 O(1)。
- - 定时惰性重平衡,避免每次 `Acquire` 触发重排。
-- 收益:连接选择成本从 O(n log n) 降到近 O(log n)/O(1)。
-
-### 决策 5:读写超时上下文复用,削减计时器分配
-
-- 现状问题:每次 read/write/ping 都 `context.WithTimeout`,热点场景分配多。
-- 方案:
- - 请求级创建一次父 deadline,上下游读写复用。
- - 连接级 read/write 在必要时复用 timer(或使用统一 deadline context)。
- - 仅 ping 健康检查保留独立短超时。
-- 收益:减少 timer 与 context 对象分配,降低 GC 压力。
-
-### 决策 6:代理建连复用 HTTP Transport
-
-- 现状问题:代理路径按 dial 动态 new `http.Client/Transport`,连接池无法复用。
-- 方案:
- - 建立 `proxyURL -> *http.Client/*http.Transport` 复用缓存(带 LRU/TTL)。
- - 设置合理的 `MaxIdleConns`, `MaxIdleConnsPerHost`, `IdleConnTimeout`。
-- 收益:减少 TLS 握手和短连接抖动,提升建连效率。
-
-### 决策 7:重试策略改为“可重试白名单 + 指数退避 + jitter”
-
-- 现状问题:失败重试无退避,策略类失败(如 1008)仍重复尝试。
-- 方案:
- - 非重试错误(策略违规、鉴权类、参数类)直接 HTTP fallback,不再重复 WS。
- - 可重试错误才进入指数退避:`base * 2^n + jitter`,设置最大上限。
- - 引入账号级短路熔断:连续失败达到阈值后在冷却窗口内直走 HTTP。
-- 收益:抑制重连风暴,降低失败路径资源放大。
-
-### 决策 8:预热触发去抖,防止后台建连风暴
-
-- 现状问题:`ensureTargetIdleAsync` 触发频繁,峰值时可能造成额外 prewarm 压力。
-- 方案:
- - 增加账号级 prewarm cooldown(毫秒级)。
- - 当最近失败率高于阈值时暂停预热,优先维持现有连接健康。
- - `targetConnCount` 采用 EWMA 负载而非瞬时 waiters 峰值。
-- 收益:降低无效预热和建连抖动。
-
-## Architecture Changes
-
-### 模块改造
-
-1. `openai_ws_forwarder.go`
-- 引入 payload 快照结构与字节优先处理链。
-- 日志采样、payload 大字段预算和流式 flush 策略。
-
-2. `openai_ws_pool.go`
-- `Acquire` 连接选择器改为增量结构。
-- `ensureTargetIdleAsync` 增加触发去抖与失败保护。
-
-3. `openai_ws_client.go`
-- 代理 HTTP client/transport 复用池。
-- 建连参数支持 keep-alive 与空闲连接上限。
-
-4. `openai_gateway_service.go`
-- 重试分类细化,加入 backoff+jitter 与熔断冷却。
-- `1008` 等策略类失败快速回退。
-
-5. `config.go`
-- 新增性能治理配置项:
- - `gateway.openai_ws.retry_backoff_initial_ms`
- - `gateway.openai_ws.retry_backoff_max_ms`
- - `gateway.openai_ws.retry_jitter_ratio`
- - `gateway.openai_ws.non_retryable_close_statuses`
- - `gateway.openai_ws.payload_log_sample_rate`
- - `gateway.openai_ws.prewarm_cooldown_ms`
- - `gateway.openai_ws.event_flush_batch_size`
- - `gateway.openai_ws.event_flush_interval_ms`
-
-## Observability & Benchmarks
-
-- 指标新增:
- - `openai_ws_payload_analyze_ms`
- - `openai_ws_retry_attempts`
- - `openai_ws_backoff_ms`
- - `openai_ws_conn_pick_ms`
- - `openai_ws_transport_reuse_ratio`
- - `openai_ws_non_retryable_fast_fallback_total`
-- 压测维度:
- - 短请求(小 payload)/长请求(大 tools)/高失败率注入(1008/5xx/timeout)
- - 流式与非流式分开对比
- - 单账号热点与多账号均衡场景
-
-## Migration Plan
-
-1. 阶段 A(安全收益优先)
-- 上线重试分类 + backoff + 1008 快速 fallback + 基础日志采样。
-
-2. 阶段 B(热路径减负)
-- 上线 payload 单次序列化、字节优先处理、usage 选择性解析。
-
-3. 阶段 C(连接与建连优化)
-- 上线连接选择器优化、prewarm 去抖、代理 transport 复用。
-
-4. 阶段 D(门禁与固化)
-- 固化压测基线到回归门禁,未达标不得发布。
-
-## Rollback Strategy
-
-- 任一阶段都可通过独立开关回退到旧逻辑:
- - `openai_ws_retry_policy_v2_enabled`
- - `openai_ws_fast_payload_path_enabled`
- - `openai_ws_pool_picker_v2_enabled`
- - `openai_ws_transport_cache_enabled`
-- 指标越界(错误率、P99、fallback rate)时自动触发回退。
diff --git "a/openspec/changes/archive/2026-02-25-\303\247/proposal.md" "b/openspec/changes/archive/2026-02-25-\303\247/proposal.md"
deleted file mode 100644
index 4e71071d4..000000000
--- "a/openspec/changes/archive/2026-02-25-\303\247/proposal.md"
+++ /dev/null
@@ -1,85 +0,0 @@
-## Why
-
-当前 OpenAI Responses WebSocket v2 已进入主路径,但从线上日志和代码热路径看,仍存在可观的性能浪费与放大效应:在高并发、长会话、失败重试场景下,网关附加延迟、CPU/GC 压力和无效重试次数偏高。
-
-本提案目标是“修复所有已识别的 WS v2 性能问题”,并形成可量化、可灰度、可回滚的闭环。
-
-### 三轮分析结论(汇总)
-
-1. 热路径 CPU/分配开销偏高(forwarder)
-- `backend/internal/service/openai_ws_forwarder.go:668`、`backend/internal/service/openai_ws_forwarder.go:703`:同一次请求对 payload 重复序列化,且基于字符串再次提取字段。
-- `backend/internal/service/openai_ws_forwarder.go:177`:日志统计 `payload_key_sizes` 对每个字段执行 `json.Marshal`,在大 `tools/input` 场景放大 CPU。
-- `backend/internal/service/openai_ws_forwarder.go:895`、`backend/internal/service/openai_ws_forwarder.go:998`、`backend/internal/service/openai_ws_forwarder.go:1001`:流事件循环里存在 `[]byte <-> string` 频繁转换、逐事件 flush 与 usage 解析,导致分配和系统调用开销上升。
-
-2. 连接池与客户端路径存在调度/建连额外成本
-- `backend/internal/service/openai_ws_pool.go:527`、`backend/internal/service/openai_ws_pool.go:719`、`backend/internal/service/openai_ws_pool.go:727`:`Acquire` 频繁全量排序连接(O(n log n)),账号连接数增大后成本上升。
-- `backend/internal/service/openai_ws_pool.go:267`、`backend/internal/service/openai_ws_pool.go:295`:每次读写都创建 `context.WithTimeout`,计时器对象和取消函数在热点下产生分配压力。
-- `backend/internal/service/openai_ws_client.go:56`、`backend/internal/service/openai_ws_client.go:57`:按请求新建 `http.Client/Transport`(代理场景),连接复用能力弱,握手和 TLS 成本高。
-
-3. 重试与降级策略产生失败放大效应
-- `backend/internal/service/openai_gateway_service.go:1350`、`backend/internal/service/openai_gateway_service.go:1375`:重试循环未引入指数退避+jitter,失败时容易形成重连风暴。
-- `backend/internal/service/openai_gateway_service.go:337`、`backend/internal/service/openai_gateway_service.go:350`:重试分类偏粗,`1008` 等策略类失败仍会被重复尝试,导致“无效重试 + payload 裁剪 + 重日志”叠加放大。
-
-## What Changes
-
-- 建立 WS v2 性能修复三层方案:
- - 热路径优化:单次序列化、低开销字段提取、日志预算化、事件写出批量策略。
- - 连接与调度优化:连接选择从全量排序改为低复杂度策略,代理建连复用 transport,减少热点计时器分配。
- - 失败控制优化:非重试错误快速降级,重试路径引入指数退避+jitter+熔断冷却,抑制失败放大。
-- 补齐专项性能观测与压测验收:
- - `网关附加延迟 / TTFT / P95/P99 / CPU / allocs / WS 复用率 / 重试分布 / fallback rate` 全量纳入。
-- 明确发布策略:
- - 保持对外 API 不变,按账号灰度,阈值触发一键回退 HTTP。
-
-## Performance Targets
-
-- WSv2 流式请求网关附加延迟:P95 降低 >= 25%,P99 降低 >= 20%(相对基线)。
-- WSv2 热路径 CPU 时间:每千请求 CPU 降低 >= 20%。
-- WSv2 热路径内存分配:`allocs/op` 降低 >= 30%,`B/op` 降低 >= 25%。
-- 单请求平均 WS 尝试次数:<= 1.2;`retry_exhausted` 比例 <= 0.5%。
-- 连接池复用率:>= 75%;同账号建连速率峰值较基线下降 >= 30%。
-- 失败放大抑制:`close_status=1008` 场景不超过 1 次 WS 尝试后必须进入 HTTP 回退。
-
-## Scope / Constraints
-
-- 保持外部接口与协议兼容:客户端仍走 `POST /v1/responses`。
-- 本提案不引入新的外部基础设施(如新增 MQ)。
-- 保持“OpenAI Responses WebSocket 默认开启”策略,不在本提案中调整默认开关语义。
-
-## Capabilities
-
-### New Capabilities
-
-- `openai-ws-v2-performance`: 定义并约束 OpenAI Responses WebSocket v2 的性能目标、失败控制策略、连接调度策略与验收标准。
-
-### Modified Capabilities
-
-- (无)
-
-## Impact
-
-- Backend
- - `backend/internal/service/openai_ws_forwarder.go`
- - `backend/internal/service/openai_ws_pool.go`
- - `backend/internal/service/openai_ws_client.go`
- - `backend/internal/service/openai_gateway_service.go`
- - `backend/internal/service/openai_ws_state_store.go`
- - `backend/internal/config/config.go`
-- Tests / Perf
- - `backend/internal/service/*_test.go`(forwarder/pool/retry/fallback)
- - `tools/perf/*`(新增 WSv2 基线与回归脚本)
-- Ops
- - 监控与告警面板新增 WSv2 性能指标与重试分布。
-
-## Risks
-
-- 低开销日志与选择性解析改造可能引入可观测性盲区。
-- 重试收敛后,短时成功率可能下降但整体延迟与资源效率提升。
-- 连接池选择策略改造若实现不当可能导致局部热点。
-
-## Rollout
-
-1. 基线采样:先冻结当前指标与流量模型。
-2. 小流量灰度:按账号 allowlist 分批启用优化开关。
-3. 阈值守护:任何一项关键指标越界立即回退。
-4. 全量发布:达成验收指标后全量启用。
diff --git "a/openspec/changes/archive/2026-02-25-\303\247/specs/openai-ws-v2-performance/spec.md" "b/openspec/changes/archive/2026-02-25-\303\247/specs/openai-ws-v2-performance/spec.md"
deleted file mode 100644
index f4ebf6e98..000000000
--- "a/openspec/changes/archive/2026-02-25-\303\247/specs/openai-ws-v2-performance/spec.md"
+++ /dev/null
@@ -1,98 +0,0 @@
-## ADDED Requirements
-
-### Requirement: WSv2 转发热路径必须避免重复序列化与重复字段解析
-系统 MUST 在单次 WSv2 请求处理过程中避免可消除的 payload 重复序列化、重复字符串解析与重复大对象拷贝。
-
-#### Scenario: 单请求仅进行必要序列化
-- **WHEN** 网关处理一次合法的 OpenAI WSv2 请求
-- **THEN** payload 编码与字段提取 SHALL 采用单次快照策略
-- **AND** 系统 MUST NOT 在同一请求中重复执行可避免的全量 JSON 编码
-
-### Requirement: WSv2 日志必须受预算与采样控制
-系统 MUST 对 WSv2 热路径日志与 payload 统计执行预算控制,避免日志计算放大主流程开销。
-
-#### Scenario: 大 payload 场景日志成本受控
-- **WHEN** 请求包含大型 `tools` 或 `input` 字段
-- **THEN** 系统 SHALL 使用采样和截断策略记录诊断信息
-- **AND** 系统 MUST NOT 对所有字段每次都执行高开销序列化统计
-
-### Requirement: WS 事件循环必须最小化字节与字符串往返转换
-系统 SHALL 在 WS 事件处理循环中优先使用字节路径,降低 `[]byte <-> string` 的频繁转换成本。
-
-#### Scenario: 高频 token 事件下保持低分配
-- **WHEN** 流式请求持续输出高频 token 事件
-- **THEN** 事件处理路径 MUST 使用字节优先处理与选择性解析
-- **AND** 在不影响协议语义前提下 MUST 减少每事件的临时对象分配
-
-### Requirement: 连接池获取路径必须使用低复杂度连接选择策略
-系统 MUST 为账号连接池提供低复杂度连接选择机制,避免在每次 `Acquire` 上执行全量排序。
-
-#### Scenario: 账号连接数增加时获取开销受控
-- **WHEN** 同一账号连接池中连接数量上升
-- **THEN** `Acquire` 延迟 SHALL 维持稳定并接近 O(1)/O(log n) 复杂度
-- **AND** `preferred_conn_id` 命中时 MUST 走快速路径
-
-### Requirement: 代理建连必须复用 HTTP 传输资源
-系统 MUST 复用代理建连使用的 HTTP client/transport,避免按请求重复创建传输对象。
-
-#### Scenario: 同代理地址连续建连
-- **WHEN** 同一 `proxyURL` 在短时间内多次用于 WS 建连
-- **THEN** 系统 SHALL 复用同一传输资源池
-- **AND** 握手延迟与建连 CPU 开销 MUST 低于未复用基线
-
-### Requirement: WS 重试策略必须具备分类、退避与熔断能力
-系统 MUST 将 WS 失败分为可重试与不可重试两类,并对可重试路径应用退避与抖动策略。
-
-#### Scenario: 策略类失败快速回退
-- **WHEN** 上游返回策略违规类关闭状态(例如 `1008`)
-- **THEN** 系统 MUST 在一次失败后快速回退到 HTTP
-- **AND** 系统 MUST NOT 连续进行多次无效 WS 重试
-
-#### Scenario: 可重试失败执行退避
-- **WHEN** 发生可重试的瞬时错误(如网络抖动、上游 5xx)
-- **THEN** 系统 SHALL 使用指数退避并附加 jitter 控制重试节奏
-- **AND** 重试次数与等待时长 MUST 受配置上限约束
-
-### Requirement: 预热与扩容策略必须防抖并避免建连风暴
-系统 SHALL 对连接预热和扩容触发执行防抖控制,避免瞬时负载波动触发过量后台建连。
-
-#### Scenario: 高频 Acquire 下预热触发受控
-- **WHEN** 同账号在短窗口内出现大量 Acquire 调用
-- **THEN** 系统 MUST 保证同一账号预热线程/任务数量有界
-- **AND** 预热触发 MUST 受 cooldown 与失败率门限控制
-
-### Requirement: WSv2 性能优化不得改变“默认开启”产品策略
-系统 MUST 在性能优化实施后保持 OpenAI Responses WebSocket 的默认开启策略不变,不得通过性能提案将默认行为回退为关闭。
-
-#### Scenario: 配置默认值保持开启
-- **WHEN** 系统加载默认网关配置
-- **THEN** `gateway.openai_ws.enabled` MUST 保持为 `true`
-- **AND** 性能优化开关 MUST 只影响实现细节,不改变 WS 默认启用语义
-
-### Requirement: WSv2 性能优化发布必须满足量化验收与回滚保障
-系统 MUST 在 WSv2 性能优化上线前后提供统一口径基线对比,并具备阈值触发回滚能力。
-
-#### Scenario: 发布验收材料完整
-- **WHEN** 团队评审 WSv2 性能优化发布
-- **THEN** 材料 MUST 包含 `TTFT`、`P95/P99`、`CPU`、`allocs/op`、`retry_attempts`、`fallback_rate` 的前后对比
-
-### Requirement: WSv2 性能优化必须达到明确阈值
-系统 MUST 基于统一压测口径达到本提案定义的性能阈值,未达标不得全量发布。
-
-#### Scenario: 延迟与资源阈值达标
-- **WHEN** 在统一基线环境完成 WSv2 优化回归压测
-- **THEN** 网关附加延迟 `P95` MUST 至少降低 25%
-- **AND** 网关附加延迟 `P99` MUST 至少降低 20%
-- **AND** 热路径 `allocs/op` MUST 至少降低 30%
-- **AND** 热路径 `B/op` MUST 至少降低 25%
-
-#### Scenario: 重试与连接复用阈值达标
-- **WHEN** 在统一基线环境完成失败注入与稳态压测
-- **THEN** 单请求平均 `retry_attempts` MUST 小于等于 1.2
-- **AND** `retry_exhausted` 比例 MUST 小于等于 0.5%
-- **AND** 连接池复用率 MUST 大于等于 75%
-
-#### Scenario: 指标越界可快速回滚
-- **WHEN** 灰度阶段关键指标超出预设阈值
-- **THEN** 系统 MUST 支持按开关快速回滚到稳定路径
-- **AND** 回滚后行为 MUST 与回滚前基线兼容
diff --git "a/openspec/changes/archive/2026-02-25-\303\247/tasks.md" "b/openspec/changes/archive/2026-02-25-\303\247/tasks.md"
deleted file mode 100644
index 3d6e943b0..000000000
--- "a/openspec/changes/archive/2026-02-25-\303\247/tasks.md"
+++ /dev/null
@@ -1,50 +0,0 @@
-## 0. 签核门禁
-
-- [ ] 0.1 冻结 WSv2 当前基线(P50/P95/P99、TTFT、CPU、allocs、重试分布、fallback rate)
-- [ ] 0.2 确认并签字性能目标阈值与回滚阈值
-- [ ] 0.3 确认灰度账号清单与分批比例
-- [x] 0.4 执行 `openspec validate ç --strict` 并留档
-
-## 1. 热路径优化(forwarder)
-
-- [x] 1.1 重构 `forwardOpenAIWSV2` 为单次 payload 序列化快照,移除重复 `payloadAsJSON`
-- [x] 1.2 `previous_response_id/prompt_cache_key/type` 改为 map 直接读取,减少 `gjson.Get(string)`
-- [x] 1.3 `summarizeOpenAIWSPayloadKeySizes` 改为采样+预算模式,避免全字段 marshal
-- [x] 1.4 引入 `toolCorrector` 字节接口,减少 `[]byte <-> string` 转换
-- [x] 1.5 流式写出增加轻量批量 flush 策略(终态事件强制 flush)
-- [x] 1.6 增加基准测试:`BenchmarkOpenAIWSForwarderHotPath`
-
-## 2. 连接池与客户端优化(pool/client)
-
-- [x] 2.1 将连接选择从全量排序改为低复杂度增量结构(堆或等效策略)
-- [x] 2.2 `Acquire` 路径保留 `preferred_conn_id` O(1) 快速命中
-- [x] 2.3 引入 read/write timeout 上下文复用,减少热点 `WithTimeout` 分配
-- [x] 2.4 代理建连改造为 transport/client 复用缓存(含 TTL/LRU)
-- [x] 2.5 `ensureTargetIdleAsync` 增加账号级 cooldown 与失败率抑制
-- [x] 2.6 增加基准测试:`BenchmarkOpenAIWSPoolAcquire`
-
-## 3. 重试与降级优化(gateway)
-
-- [x] 3.1 完善重试分类:策略类/鉴权类/参数类失败标记为不可重试
-- [x] 3.2 对可重试错误引入指数退避+jitter(带最大上限)
-- [x] 3.3 对 `close_status=1008` 路径改为单次尝试后快速 HTTP fallback
-- [x] 3.4 增加账号级熔断冷却窗口,避免失败风暴期间反复打 WS
-- [x] 3.5 增加重试策略单测与故障注入测试
-
-## 4. 可观测性与压测
-
-- [x] 4.1 增加 WSv2 专项指标:`conn_pick_ms`、`retry_attempts`、`backoff_ms`、`transport_reuse_ratio`
-- [x] 4.2 增加日志采样配置与运行时校验,防止日志放大
-- [x] 4.3 补充压测脚本(短请求/长请求/错误注入/热点账号)
-- [ ] 4.4 输出优化前后对比报告并纳入发布评审材料
-- [ ] 4.5 按统一口径校验阈值:`P95`/`P99`/`allocs-op`/`B-op`/`retry_attempts`/`retry_exhausted`/`reuse_ratio`
-
-## 5. 回归与发布
-
-- [x] 5.1 HTTP/SSE 路径回归,确保无行为退化
-- [x] 5.2 WSv2 流式协议兼容回归(事件顺序、DONE、usage)
-- [ ] 5.3 按账号灰度发布并持续观测 24h
-- [ ] 5.4 达成阈值后全量,否则按开关回滚并复盘
-- [ ] 5.5 进行一次“阈值越界自动回滚”演练并记录结果
-- [x] 5.6 同步更新 `deploy/config.example.yaml` 与运行手册中的新配置项说明
-- [ ] 5.7 输出最终验收报告(含指标、风险、回滚演练结果)并归档到 `openspec/changes/.../final-acceptance-report.md`
diff --git "a/openspec/changes/archive/2026-02-25-\303\247/validation.md" "b/openspec/changes/archive/2026-02-25-\303\247/validation.md"
deleted file mode 100644
index 44ba46fa1..000000000
--- "a/openspec/changes/archive/2026-02-25-\303\247/validation.md"
+++ /dev/null
@@ -1,5 +0,0 @@
-# OpenSpec 校验留档
-
-- 时间:2026-02-25 20:23:32 +0800
-- 命令:`openspec validate ç --strict`
-- 结果:`Change 'ç' is valid`
diff --git a/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/.openspec.yaml b/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/.openspec.yaml
deleted file mode 100644
index 85ae75c1f..000000000
--- a/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/.openspec.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-schema: spec-driven
-created: 2026-02-26
diff --git a/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/design.md b/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/design.md
deleted file mode 100644
index 1d3dde02b..000000000
--- a/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/design.md
+++ /dev/null
@@ -1,174 +0,0 @@
-## Context
-
-当前代码中“类型”由 `stream` 和 `openai_ws_mode` 组合推导,且筛选只支持 `stream`。
-这导致类型语义分散在数据库、后端 DTO、前端展示、导出逻辑中,扩展和维护成本高。
-
-本设计目标是在不破坏现有接口的前提下,引入 `request_type` 作为主枚举字段,实现:
-
-- 单一事实源
-- 可扩展类型体系
-- 向前兼容升级
-- 可回滚
-
-## Goals / Non-Goals
-
-### Goals
-
-- 为 usage 记录建立统一枚举字段 `request_type`。
-- 兼容历史数据与旧客户端(旧字段、旧参数仍可用)。
-- 支持新筛选维度(列表/统计/趋势/模型/清理)。
-- 全链路灰度发布,避免中断与大回归。
-
-### Non-Goals
-
-- 本期不删除 `stream`、`openai_ws_mode` 字段。
-- 本期不强制所有调用方立刻改用新参数。
-
-## Data Model Design
-
-### 数据库
-
-`usage_logs` 新增字段:
-
-- `request_type SMALLINT NOT NULL DEFAULT 0`
-
-并新增:
-
-- `CHECK (request_type IN (0,1,2,3))`
-- 索引:`idx_usage_logs_request_type_created_at(request_type, created_at)`
-
-当前枚举编码:
-
-- `0=unknown`
-- `1=sync`
-- `2=stream`
-- `3=ws_v2`
-
-说明:本期 CHECK 仅覆盖已落地值。未来新增类型(如 `ws_v1`、`grpc`、`batch`)时,通过新迁移扩展 CHECK 与映射,不影响当前兼容策略。
-
-### 回填策略
-
-按批回填,避免长事务:
-
-- `openai_ws_mode=true` -> `3(ws_v2)`
-- `openai_ws_mode=false and stream=true` -> `2(stream)`
-- else -> `1(sync)`
-
-### 读写策略
-
-- 写入:后端双写(新字段 + 旧字段)。
-- 读取:优先读 `request_type`;若为 `unknown`,按旧字段推导。
-
-## API Compatibility Design
-
-### 响应字段
-
-新增:
-
-- `request_type`(字符串枚举:`sync`/`stream`/`ws_v2`/`unknown`)
-
-保留:
-
-- `stream`
-- `openai_ws_mode`
-
-### 请求参数
-
-新增:
-
-- `request_type`(字符串枚举)
-
-保留:
-
-- `stream`
-
-### 参数语义与校验
-
-- `request_type` 可选值为:`unknown`、`sync`、`stream`、`ws_v2`。
-- `request_type` 采用小写规范值;非法值 MUST 返回 `400 Bad Request`,并返回可接受值列表。
-- 当同时传入 `request_type` 与 `stream` 时,按 `request_type` 过滤,`stream` 仅用于兼容旧客户端。
-- 为避免接口行为漂移,`request_type` 参数解析在用户侧与管理侧复用同一校验逻辑。
-
-### 参数优先级
-
-- 当 `request_type` 存在时,优先按 `request_type` 过滤。
-- `stream` 继续支持,用于旧客户端。
-
-该策略确保新客户端可直接枚举筛选,旧客户端行为不变。
-
-### 过滤入口覆盖范围
-
-`request_type` 过滤能力覆盖以下入口:
-
-- 用户 usage 列表:`GET /api/v1/usage`
-- 管理员 usage 列表/统计:`GET /api/v1/admin/usage`、`GET /api/v1/admin/usage/stats`
-- 用户 dashboard 趋势/模型:`GET /api/v1/usage/dashboard/trend`、`GET /api/v1/usage/dashboard/models`
-- 管理员 dashboard 趋势/模型:`GET /api/v1/admin/dashboard/trend`、`GET /api/v1/admin/dashboard/models`
-- usage cleanup 任务:`POST /api/v1/admin/usage/cleanup-tasks`(过滤条件)
-
-### 兼容字段一致性
-
-为确保旧客户端口径稳定,响应中 `stream/openai_ws_mode` 与 `request_type` 的关系必须一致:
-
-- `request_type=ws_v2` -> `openai_ws_mode=true`
-- `request_type=stream` -> `openai_ws_mode=false && stream=true`
-- `request_type=sync` -> `openai_ws_mode=false && stream=false`
-- `request_type=unknown` -> 按历史旧字段回退推导,不强制改写存量值
-
-## Frontend Design
-
-### 展示
-
-- 类型徽标与文案优先使用 `request_type`。
-- 若响应没有 `request_type`(老后端),回退旧逻辑:
- - `openai_ws_mode ? ws : stream ? stream : sync`
-
-### 筛选
-
-管理端类型筛选改为枚举选项:
-
-- 全部
-- 同步(sync)
-- 流式(stream)
-- WS(ws_v2)
-
-### 导出
-
-CSV/Excel 导出与表格使用同一 `resolveRequestTypeLabel`,避免口径不一致。
-
-## Upgrade / Rollback Plan
-
-### Upgrade
-
-1. DB migration:加列 + 索引 + 回填。
-2. Backend:双写双读 + 新参数 + 新响应字段。
-3. Frontend:枚举渲染 + 枚举筛选 + 旧字段回退。
-
-### Rollback
-
-- 回滚前端:后端仍返回旧字段,不影响。
-- 回滚后端:数据库保留旧字段,系统可继续运行。
-- 由于不删旧字段,回滚无需数据修复。
-
-## Risks and Mitigations
-
-- 风险:新旧字段短期不一致。
- - 缓解:读取优先新字段,`unknown` 自动回退旧字段推导;增加一致性监控。
-- 风险:大表回填锁与性能波动。
- - 缓解:按 ID 批量回填,低峰执行,监控慢 SQL 与复制延迟。
-- 风险:多入口筛选遗漏。
- - 缓解:统一扩展过滤结构体,覆盖列表/统计/趋势/模型/清理所有入口测试。
-
-## Testing Strategy
-
-- 单元测试
- - `request_type <-> 旧字段` 映射
- - DTO 回退逻辑
- - handler 参数解析、非法值校验与参数优先级
-- 仓储集成测试
- - 插入/读取 `request_type`
- - 列表/统计/趋势/模型/清理按 `request_type` 过滤正确
-- 回归测试
- - 仅用旧参数 `stream` 的行为不变
- - 前端在无 `request_type` 响应时显示不变
- - 响应 `request_type` 与 `stream/openai_ws_mode` 一致性不回归
diff --git a/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/proposal.md b/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/proposal.md
deleted file mode 100644
index 16f2bc65b..000000000
--- a/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/proposal.md
+++ /dev/null
@@ -1,66 +0,0 @@
-## Why
-
-当前“使用记录 -> 类型”由两个布尔字段组合推导:`openai_ws_mode` 与 `stream`。该方案存在以下问题:
-
-- 语义分散:类型不是单一事实源,展示层与数据层容易出现分支漂移。
-- 扩展困难:新增协议类型(如 `ws_v1`、`grpc`、`batch`)时,需要在后端、前端、导出、筛选多处追加 if/else。
-- 回归风险高:任何链路漏赋值都会导致展示错误。
-- 筛选能力弱:当前筛选参数仅支持 `stream=true/false`,无法直接筛选 `WS`。
-
-## What Changes
-
-- 新增 usage 主枚举字段:`request_type`,作为“类型”唯一主事实源。
-- 保留兼容字段:`stream`、`openai_ws_mode`(至少保留 2 个版本周期,不做破坏性删除)。
-- 新增查询参数:`request_type`(列表、统计、趋势、模型统计、清理任务均支持)。
-- `request_type` 参数仅接受 `unknown/sync/stream/ws_v2`(小写),非法值返回 `400 Bad Request`。
-- 前端表格/导出/筛选升级为枚举驱动,同时保留旧字段回退逻辑。
-- 响应保持兼容字段一致性:`request_type` 与 `stream/openai_ws_mode` 映射保持稳定。
-- 提供历史数据回填与灰度升级方案,保证向前兼容、可回滚。
-
-## 枚举定义(建议)
-
-- `unknown` = 0
-- `sync` = 1
-- `stream` = 2
-- `ws_v2` = 3
-- 预留未来值:`ws_v1`、`grpc`、`batch`
-
-## 兼容映射规则
-
-- `openai_ws_mode=true` -> `ws_v2`
-- `openai_ws_mode=false && stream=true` -> `stream`
-- `openai_ws_mode=false && stream=false` -> `sync`
-
-该映射与当前线上展示逻辑保持一致,确保历史口径不变。
-
-## Capabilities
-
-### New Capabilities
-
-- `usage-request-type`: 使用记录类型由枚举统一建模,支持扩展与统一筛选。
-
-## Impact
-
-- Backend
- - 数据库:`usage_logs` 新增 `request_type` 列与索引、回填脚本。
- - 领域模型:`service.UsageLog`、DTO、查询过滤结构新增 `request_type`。
- - Repository:插入、读取、列表/统计/趋势/清理筛选支持 `request_type`。
- - Handler/API:新增 `request_type` 请求参数,保留 `stream` 参数兼容。
-- Frontend
- - 类型定义新增 `request_type`。
- - 管理端筛选“类型”改为枚举选项(可筛 `WS`),旧接口回退兼容。
- - 用户端/管理端表格与导出统一走枚举渲染。
-- Tests
- - 增加回填映射、双读回退、双参数兼容、筛选准确性测试。
-
-## Rollout
-
-1. 先发 DB 迁移(加列、索引、回填,不删旧字段)。
-2. 再发后端双写双读(响应新增 `request_type`,旧字段继续返回)。
-3. 最后发前端枚举化(带旧字段回退)。
-4. 观察稳定后再评估旧字段淘汰计划。
-
-## Rollback
-
-- 任何阶段回滚到旧后端/旧前端均可运行:旧字段仍在且语义不变。
-- 即使出现 `request_type=unknown` 历史写入,新版本读取也会按旧字段回退推导。
diff --git a/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/specs/usage-request-type/spec.md b/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/specs/usage-request-type/spec.md
deleted file mode 100644
index d909d7960..000000000
--- a/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/specs/usage-request-type/spec.md
+++ /dev/null
@@ -1,93 +0,0 @@
-## ADDED Requirements
-
-### Requirement: 系统必须以 request_type 作为使用记录类型的主事实源
-系统 MUST 在 `usage_logs` 中持久化 `request_type` 枚举字段,并将其作为类型展示与筛选的主事实源。
-
-#### Scenario: 新增记录写入 request_type
-- **WHEN** 网关记录一条新的 usage 日志
-- **THEN** 系统 MUST 写入有效的 `request_type` 枚举值
-- **AND** 枚举值 MUST 在约束集合内(`unknown/sync/stream/ws_v2`)
-
-#### Scenario: 读取优先 request_type
-- **WHEN** 系统读取 usage 日志用于 API 返回
-- **THEN** 系统 MUST 优先使用 `request_type` 作为类型来源
-
-### Requirement: 系统必须保持与旧字段兼容
-系统 MUST 在迁移期保持 `stream` 与 `openai_ws_mode` 的向后兼容能力。
-
-#### Scenario: 旧字段仍保留
-- **WHEN** 新版本后端返回 usage 记录
-- **THEN** 响应 MUST 继续包含 `stream` 与 `openai_ws_mode`
-
-#### Scenario: request_type 缺失时回退
-- **WHEN** 历史记录 `request_type` 为 `unknown` 或不可用
-- **THEN** 系统 MUST 按旧字段推导类型
-- **AND** 推导规则 MUST 与既有展示口径一致
-
-#### Scenario: 响应字段保持兼容一致
-- **WHEN** 系统返回一条包含 `request_type` 的 usage 记录
-- **THEN** 响应中的 `stream` 与 `openai_ws_mode` MUST 与 `request_type` 保持一致映射
-- **AND** `request_type=ws_v2` MUST 对应 `openai_ws_mode=true`
-- **AND** `request_type=stream` MUST 对应 `openai_ws_mode=false && stream=true`
-- **AND** `request_type=sync` MUST 对应 `openai_ws_mode=false && stream=false`
-
-### Requirement: 系统必须支持 request_type 查询过滤并兼容 stream 参数
-系统 MUST 提供 `request_type` 过滤能力,并继续兼容历史 `stream` 参数。
-
-#### Scenario: 使用 request_type 过滤列表
-- **WHEN** 客户端请求携带 `request_type`
-- **THEN** 系统 MUST 按 `request_type` 执行过滤
-
-#### Scenario: request_type 参数非法值
-- **WHEN** 客户端请求携带非法 `request_type`(不在 `unknown/sync/stream/ws_v2` 中)
-- **THEN** 系统 MUST 返回 `400 Bad Request`
-- **AND** 错误信息 MUST 提示可接受枚举值
-
-#### Scenario: 旧客户端使用 stream 过滤
-- **WHEN** 客户端仅携带 `stream`
-- **THEN** 系统 MUST 保持历史过滤行为不变
-
-#### Scenario: 同时携带 request_type 与 stream
-- **WHEN** 请求同时携带 `request_type` 与 `stream`
-- **THEN** 系统 MUST 优先按 `request_type` 过滤
-
-#### Scenario: request_type 过滤覆盖所有 usage 入口
-- **WHEN** 客户端访问 usage 列表/统计/趋势/模型/清理任务入口并携带 `request_type`
-- **THEN** 系统 MUST 在对应入口应用一致的 `request_type` 过滤语义
-
-### Requirement: 历史数据迁移必须可在线执行且不破坏旧逻辑
-系统 MUST 提供可在线迁移的回填方案,使历史数据具备 `request_type`,且迁移前后展示口径一致。
-
-#### Scenario: 历史回填映射
-- **WHEN** 执行历史数据回填
-- **THEN** `openai_ws_mode=true` MUST 映射为 `ws_v2`
-- **AND** `openai_ws_mode=false && stream=true` MUST 映射为 `stream`
-- **AND** 其他情况 MUST 映射为 `sync`
-
-#### Scenario: 分批回填
-- **WHEN** 数据量较大
-- **THEN** 回填 MUST 支持分批执行以降低锁与性能风险
-
-### Requirement: 前端必须在新旧后端间保持显示一致
-前端 MUST 支持 `request_type` 优先展示,并在老后端响应中自动回退旧字段推导。
-
-#### Scenario: 新后端响应
-- **WHEN** 响应包含 `request_type`
-- **THEN** 前端 MUST 使用 `request_type` 渲染类型标签与样式
-
-#### Scenario: 老后端响应
-- **WHEN** 响应不包含 `request_type`
-- **THEN** 前端 MUST 使用旧字段推导类型
-- **AND** 渲染结果 MUST 与升级前一致
-
-### Requirement: 升级与回滚必须可独立进行
-系统 MUST 支持数据库、后端、前端分阶段升级与独立回滚,不要求一次性切换。
-
-#### Scenario: 后端先回滚
-- **WHEN** 新数据库已上线但后端回滚到旧版本
-- **THEN** 系统 MUST 继续可用
-- **AND** 旧字段语义 MUST 保持不变
-
-#### Scenario: 前端先升级
-- **WHEN** 前端升级但后端尚未返回 `request_type`
-- **THEN** 前端 MUST 通过回退逻辑保持功能正常
diff --git a/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/tasks.md b/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/tasks.md
deleted file mode 100644
index 6e5249d0a..000000000
--- a/openspec/changes/archive/2026-02-26-add-usage-request-type-enum/tasks.md
+++ /dev/null
@@ -1,47 +0,0 @@
-## 1. 数据库迁移
-
-- [ ] 1.1 新增迁移:`usage_logs.request_type SMALLINT NOT NULL DEFAULT 0`
-- [ ] 1.2 增加 `request_type` 枚举值约束(CHECK)
-- [ ] 1.3 增加索引 `idx_usage_logs_request_type_created_at`
-- [ ] 1.4 编写批量回填脚本/SQL(按旧字段映射)
-- [ ] 1.5 补充迁移集成测试(列存在、默认值、约束)
-- [ ] 1.6 回填支持 dry-run 与分批参数(batch size/游标),并提供回填前后行数对账
-
-## 2. 后端模型与仓储
-
-- [ ] 2.1 在 `service.UsageLog` 增加 `RequestType` 字段与枚举类型
-- [ ] 2.2 `usage_log_repo` 的 insert/select/scan 增加 `request_type`
-- [ ] 2.3 写入链路实现双写:`request_type` + 旧字段
-- [ ] 2.4 读取链路实现双读回退:`request_type=unknown` 时由旧字段推导
-- [ ] 2.5 增加仓储集成测试覆盖 `request_type`
-
-## 3. 后端 API 与筛选
-
-- [ ] 3.1 DTO 新增 `request_type` 响应字段(保留 `stream`/`openai_ws_mode`)
-- [ ] 3.2 用户 usage 列表接口新增 `request_type` 查询参数
-- [ ] 3.3 管理员 usage 列表/统计接口新增 `request_type` 查询参数
-- [ ] 3.4 dashboard trend/model stats 新增 `request_type` 查询参数
-- [ ] 3.5 usage cleanup 过滤条件新增 `request_type`
-- [ ] 3.6 明确并实现参数优先级:`request_type` 优先于 `stream`
-- [ ] 3.7 统一 `request_type` 参数解析与校验(仅接受 `unknown/sync/stream/ws_v2`,非法值返回 400)
-- [ ] 3.8 响应层实现兼容字段一致性映射(`request_type` 与 `stream/openai_ws_mode`)
-- [ ] 3.9 补充 handler/service/repository 全链路测试(含非法参数、覆盖入口、优先级)
-
-## 4. 前端改造
-
-- [ ] 4.1 `frontend/src/types` 增加 `request_type` 类型定义
-- [ ] 4.2 管理端筛选组件 `UsageFilters` 将“类型”升级为枚举筛选(含 WS)
-- [ ] 4.3 用户端与管理端表格统一使用 `request_type` 渲染(保留旧字段回退)
-- [ ] 4.4 导出逻辑统一使用枚举映射函数(CSV/Excel 同口径)
-- [ ] 4.5 dashboard API 参数透传 `request_type`(用户端与管理员端 trend/models)
-- [ ] 4.6 清理任务弹窗 `UsageCleanupDialog` 支持按 `request_type` 创建任务
-- [ ] 4.7 实现老后端兼容回退(无 `request_type` 字段或不支持 `request_type` 查询参数时回退旧逻辑)
-- [ ] 4.8 更新中英文 i18n 文案
-
-## 5. 发布与观测
-
-- [ ] 5.1 制定灰度计划:先 DB,再后端,再前端
-- [ ] 5.2 增加一致性监控:`request_type` 与旧字段映射差异
-- [ ] 5.3 准备回滚手册(前后端独立回滚)
-- [ ] 5.4 上线前执行回填对账(按类型抽样比对 `request_type` 与旧字段映射)
-- [ ] 5.5 上线后验证:类型展示、筛选、导出、统计、清理任务口径一致
diff --git a/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/design.md b/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/design.md
deleted file mode 100644
index 4dd28b6ad..000000000
--- a/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/design.md
+++ /dev/null
@@ -1,61 +0,0 @@
-## Context
-
-本次改动聚焦“确认且可安全上线”的 WSv2 运行时优化,目标是降低热路径开销并减少连接脏状态复用风险。
-
-## Goals
-
-- 在不改外部协议的前提下,降低 WSv2 高并发场景 CPU/分配和首包抖动。
-- 提升连接生命周期一致性,减少死连接/脏连接进入复用池。
-- 提高异常可观测性,避免 silent failure。
-
-## Non-Goals
-
-- 不重构调度器主流程。
-- 不在本次引入新的外部依赖或持久化结构。
-- 不改变 WS 默认开启策略。
-
-## Decisions
-
-### 1) Forwarder 热路径与错误处理
-
-- 采用包级 `strings.Replacer`,消除每条日志重复构建开销。
-- `error` 事件统一 `MarkBroken`,避免不可回退分支把异常连接放回池。
-- 客户端断连后进入“最小处理”路径:跳过 model/tool 修正,仅保留必要状态推进。
-- usage 解析改为事件门控,减少高频 token 事件的无效 JSON 查找。
-- ingress WS 客户端断连后继续读上游直到 terminal,与 HTTP-SSE drain 语义对齐。
-
-### 2) Pool 并发模型与后台维护
-
-- 写操作超时统一继承父 `context`,避免断链请求占住写超时窗口。
-- 读写锁拆分,恢复一读一写并发能力。
-- 增加后台 worker:
- - ping worker:每 30s 探测空闲连接,失败即回收。
- - cleanup worker:每 30s 扫描全部账号池,执行过期/空闲清理。
-- 兜底队列上限下调,避免配置缺失时出现极端长排队。
-
-### 3) 入口与依赖保护
-
-- 降低 WS 读上限至 16MB,收敛异常消息内存风险。
-- 代理 transport 增加 TLS 握手超时,防止代理链路卡死。
-- 协议决策器对未知认证类型显式回退 HTTP。
-- Redis 状态读写统一包裹 3s 超时,避免长连接上下文下的阻塞外溢。
-
-## Validation
-
-- 单测新增/更新:
- - 协议决策未知认证回退。
- - StateStore Redis 独立超时。
- - Pool 写超时继承父 `context`。
- - Pool 读写并发与后台 sweep 行为。
- - Client TLSHandshakeTimeout。
-- 定向回归:
- - `go test ./internal/service -run "OpenAIWS|ProxyResponsesWebSocketFromClient|Forward_WSv2|ProtocolResolver|StateStore"`
- - `go test ./internal/handler -run "OpenAI|websocket|Responses"`
-
-## Risks & Mitigations
-
-- 风险:后台 worker 增加周期性开销。
- - 缓解:只处理空闲连接;失败处理走现有 `evict`。
-- 风险:断连后的最小处理路径影响日志可读性。
- - 缓解:保留关键断连日志与终态统计。
-
diff --git a/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/proposal.md b/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/proposal.md
deleted file mode 100644
index f6b3466fe..000000000
--- a/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/proposal.md
+++ /dev/null
@@ -1,69 +0,0 @@
-## Why
-
-针对 OpenAI WSv2 热路径与连接池,我们复核了 20 项审查点。当前代码中存在一组可直接证明的问题:
-
-- 高热路径存在可避免分配:`normalizeOpenAIWSLogValue` 每次创建 `strings.NewReplacer`。
-- 连接生命周期不一致:`error` 事件后在部分分支未标记连接损坏。
-- 写上游超时不继承父 `context`,客户端断连后仍可能阻塞到默认超时。
-- `BindResponseAccount` 错误被静默忽略,粘连异常缺少可观测性。
-- ingress WS 非首轮 turn 并发参数与首轮调度参数不一致。
-- 协议决策器对未知认证类型缺少显式回退。
-- 连接池缺少后台健康探测与后台清理,仅在 `Acquire` 被动触发。
-- 连接 I/O 读写共用同一互斥锁,限制全双工并发能力。
-- 代理 `Transport` 缺少 `TLSHandshakeTimeout`。
-- Redis 状态读写缺少独立短超时,异常时可能拖长请求。
-- 事件循环在客户端断连后仍执行非必要处理。
-- 入站 WS 在客户端断连时未继续 drain 上游,行为与 HTTP-SSE 不一致。
-- 消息读取上限过高(128MB),存在内存风险。
-
-## What Changes
-
-本变更统一落地以下优化:
-
-1. `forwarder` 热路径与连接安全
-- 将日志值归一化替换器提升为包级变量。
-- 收到 `error` 事件后一律 `MarkBroken()`。
-- `BindResponseAccount` 失败输出 `warn` 日志。
-- 客户端已断连后,跳过 model/tool 修正等非必要处理,仅保留必要解析。
-- usage 解析增加事件类型快速门控(仅 `response.completed`)。
-- ingress WS 客户端断连时继续 drain 上游至 terminal,不再立刻打断并污染连接复用。
-
-2. `pool` 并发与后台维护
-- `writeJSONWithTimeout` 支持继承父 `context`(新增 `WriteJSONWithContextTimeout`)。
-- 连接 I/O 锁拆分为 `readMu/writeMu`,支持并发一读一写。
-- 新增后台 ping worker(30s)探测所有空闲连接。
-- 新增后台 cleanup worker(30s)定期扫描所有账号池。
-- `queueLimitPerConn` 兜底默认值从 `256` 下调为 `16`。
-
-3. `client/handler/protocol/state_store` 可靠性与资源保护
-- 代理 `http.Transport` 增加 `TLSHandshakeTimeout: 10s`。
-- WS 读上限从 `128MB` 下调为 `16MB`(客户端与 ingress 入口一致)。
-- ingress 非首轮 turn 统一使用调度器确定的并发参数。
-- 协议决策器对未知认证类型显式回退 HTTP。
-- `OpenAIWSStateStore` 对 Redis `set/get/delete` 增加独立 3s 超时包装。
-
-## Deferred (已确认但本次不直接改)
-
-- terminal 后“尾包探测”方案:直接 probe read 会对 `coder/websocket` 连接状态产生副作用,需改为更安全机制后再落地。
-- prewarm `creating` 计数语义重构:涉及扩容/预热协同策略,需要独立压测验证。
-- `replaceOpenAIWSMessageModel` 的双 `sjson.SetBytes` 深度优化:需在正确性与性能之间进一步权衡。
-- `GetResponseAccount` Redis 命中后本地回填:需先定义“无陈旧读”一致性边界。
-
-## Capabilities
-
-### Modified Capabilities
-
-- `openai-ws-v2-performance`
-
-## Impact
-
-- 影响模块:
- - `backend/internal/service/openai_ws_forwarder.go`
- - `backend/internal/service/openai_ws_pool.go`
- - `backend/internal/service/openai_ws_client.go`
- - `backend/internal/service/openai_ws_state_store.go`
- - `backend/internal/service/openai_ws_protocol_resolver.go`
- - `backend/internal/handler/openai_gateway_handler.go`
-- 影响类型:热路径性能、连接池稳定性、异常可观测性。
-- 兼容性:外部 API 与协议保持不变。
-
diff --git a/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/specs/openai-ws-v2-performance/spec.md b/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/specs/openai-ws-v2-performance/spec.md
deleted file mode 100644
index 79adae036..000000000
--- a/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/specs/openai-ws-v2-performance/spec.md
+++ /dev/null
@@ -1,79 +0,0 @@
-## ADDED Requirements
-
-### Requirement: WSv2 error 事件后的连接必须不可复用
-系统 MUST 在收到上游 `type=error` 事件后将当前连接标记为损坏,避免回池复用。
-
-#### Scenario: error 事件触发统一损坏标记
-- **WHEN** 上游返回 `error` 事件
-- **THEN** 系统 MUST 执行连接损坏标记
-- **AND** 不得因“是否可回退”分支差异而漏标记
-
-### Requirement: WSv2 写上游超时必须继承父 context
-系统 MUST 在写上游 WS 时继承调用方父 `context`,避免客户端已断开时仍长时间阻塞。
-
-#### Scenario: 父 context 已取消
-- **WHEN** 父 `context` 已取消
-- **THEN** 写上游操作 MUST 立即感知取消并返回
-- **AND** MUST NOT 阻塞到默认写超时
-
-### Requirement: 连接池必须具备后台 ping 与后台清理
-系统 MUST 在 `Acquire` 之外提供后台连接维护能力。
-
-#### Scenario: 空闲连接后台心跳
-- **WHEN** 连接处于空闲状态
-- **THEN** 系统 SHALL 按周期对空闲连接执行 ping
-- **AND** ping 失败连接 MUST 被回收
-
-#### Scenario: 长时间无请求账号
-- **WHEN** 某账号长时间无新请求
-- **THEN** 系统 SHALL 仍执行后台清理
-- **AND** 过期/无效连接 MUST 被回收
-
-### Requirement: 连接 I/O 必须支持并发一读一写
-系统 MUST 避免将 WS 读写串行化到同一把锁上。
-
-#### Scenario: 读阻塞期间执行写/Ping
-- **WHEN** 读路径处于阻塞等待
-- **THEN** 写路径 SHOULD 仍可独立推进
-- **AND** 不得因单锁竞争导致心跳/写入长时间饥饿
-
-### Requirement: ingress WS 客户端断连后应继续 drain 上游
-系统 MUST 在 ingress WS 模式下对客户端断连采用“继续 drain 到 terminal”的策略。
-
-#### Scenario: 客户端中途断开
-- **WHEN** 向客户端写事件返回断连错误
-- **THEN** 系统 SHALL 继续读取上游直到 terminal
-- **AND** 连接不得因该断连被立即标记损坏
-
-### Requirement: 状态存储 Redis 操作必须有独立短超时
-系统 MUST 为 WS 状态存储的 Redis 操作设置独立短超时,避免长上下文阻塞。
-
-#### Scenario: Redis 网络异常
-- **WHEN** Redis 操作发生网络抖动/分区
-- **THEN** `set/get/delete` MUST 在短超时内返回
-- **AND** 不得无限依赖上层长连接 context
-
-### Requirement: 协议决策必须对未知认证类型显式回退 HTTP
-系统 MUST 在未知 OpenAI 认证类型下显式回退 HTTP。
-
-#### Scenario: 非 OAuth 且非 API Key 账号
-- **WHEN** 账号认证类型不在已知集合内
-- **THEN** 协议决策 MUST 返回 HTTP
-- **AND** MUST NOT 进入 WS 子开关判定路径
-
-### Requirement: WS 消息读取上限必须受控
-系统 MUST 对 ingress 与上游 WS 客户端统一设置合理读取上限,降低异常大包内存风险。
-
-#### Scenario: 默认读取上限
-- **WHEN** 系统创建 ingress/上游 WS 连接
-- **THEN** 读取上限 MUST 为受控值(16MB)
-- **AND** ingress 与上游配置 MUST 保持一致
-
-### Requirement: 粘连绑定失败必须可观测
-系统 MUST 对 `BindResponseAccount` 失败记录警告日志。
-
-#### Scenario: 粘连绑定异常
-- **WHEN** 状态存储返回绑定错误
-- **THEN** 系统 MUST 记录 `warn` 级别日志
-- **AND** 日志 MUST 包含 group/account/response 标识
-
diff --git a/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/tasks.md b/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/tasks.md
deleted file mode 100644
index 81b8ee639..000000000
--- a/openspec/changes/archive/2026-02-26-optimize-openai-ws-v2-runtime-performance/tasks.md
+++ /dev/null
@@ -1,40 +0,0 @@
-## 1. Forwarder 热路径与错误语义
-
-- [x] 1.1 将 `normalizeOpenAIWSLogValue` 的 `strings.NewReplacer` 提升为包级变量
-- [x] 1.2 `error` 事件后统一 `lease.MarkBroken()`
-- [x] 1.3 `BindResponseAccount` 返回错误增加 `warn` 级日志(4 处调用点)
-- [x] 1.4 客户端断连后跳过非必要 model/tool 修正
-- [x] 1.5 usage 解析增加事件门控(`response.completed`)
-- [x] 1.6 ingress WS 客户端断连改为继续 drain 上游至 terminal
-
-## 2. Pool 并发与后台维护
-
-- [x] 2.1 写超时增加父 context 继承能力(`WriteJSONWithContextTimeout`)
-- [x] 2.2 连接锁拆分为 `readMu/writeMu`
-- [x] 2.3 增加后台 ping worker(30s)
-- [x] 2.4 增加后台 cleanup worker(30s)
-- [x] 2.5 `queueLimitPerConn` 兜底默认值下调为 16
-
-## 3. 入口与依赖保护
-
-- [x] 3.1 代理 Transport 增加 `TLSHandshakeTimeout: 10s`
-- [x] 3.2 WS 消息读取上限从 128MB 下调到 16MB(client + ingress)
-- [x] 3.3 ingress 非首轮 turn 统一使用调度器并发参数
-- [x] 3.4 协议决策器补齐未知认证类型回退 HTTP
-- [x] 3.5 Redis `set/get/delete` 增加独立 3s 超时
-
-## 4. 测试与回归
-
-- [x] 4.1 协议决策新增未知认证类型用例
-- [x] 4.2 StateStore 新增 Redis 超时上下文用例
-- [x] 4.3 Pool 新增父 context 写超时/读写并发/后台 sweep 用例
-- [x] 4.4 Client 新增 TLSHandshakeTimeout 配置用例
-- [x] 4.5 通过 service/handler 定向回归测试
-
-## 5. 延后项(已确认,待后续提案)
-
-- [ ] 5.1 terminal 后脏数据探测改为“无副作用”的安全方案
-- [ ] 5.2 prewarm `creating` 计数语义重构与压测验收
-- [ ] 5.3 `replaceOpenAIWSMessageModel` 双 `sjson.SetBytes` 深度优化
-- [ ] 5.4 `GetResponseAccount` Redis 命中后的本地回填一致性方案
-
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/.openspec.yaml b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/.openspec.yaml
deleted file mode 100644
index 85ae75c1f..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/.openspec.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-schema: spec-driven
-created: 2026-02-26
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/design.md b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/design.md
deleted file mode 100644
index b1d7c0bf5..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/design.md
+++ /dev/null
@@ -1,487 +0,0 @@
-# Sora 客户端功能完善 — 技术设计
-
-## Context
-
-### 现状
-
-系统当前的 Sora 功能是一个以透传为主的网关(`/sora/v1/chat/completions`),面向 API 开发者。现状仍会尝试将生成媒体落到本地磁盘(`/app/data/sora`),并由清理任务在 7 天后清理。没有面向终端用户的 Web 界面,没有 S3 对象存储支持,没有用户配额管理,Sora 平台账号类型仅支持 OAuth。
-
-### 目标架构
-
-新增一个面向终端用户的 Sora 客户端层,位于现有网关之上:
-
-```text
-┌─────────────────────────────────────────────────────────┐
-│ 前端 Sora 客户端 UI │
-│ /sora (生成页 + 作品库) │
-└────────────────────────┬────────────────────────────────┘
- │ /api/v1/sora/*
-┌────────────────────────▼────────────────────────────────┐
-│ Sora Client Handler (新增) │
-│ ┌─────────┐ ┌──────────┐ ┌───────────┐ ┌────────────┐ │
-│ │ 配额检查 │ │ 生成记录 │ │ S3 上传 │ │ 即时下载 │ │
-│ └─────────┘ └──────────┘ └───────────┘ └────────────┘ │
-└────────────────────────┬────────────────────────────────┘
- │ 内部调用
-┌────────────────────────▼────────────────────────────────┐
-│ 现有 Sora Gateway (不变) │
-│ /sora/v1/chat/completions │
-│ ┌──────────┐ ┌──────────┐ ┌───────────────────────┐ │
-│ │ 账号选择 │ │ SDK 直连 │ │ HTTP 透传(apikey 账号) │ │
-│ └──────────┘ └──────────┘ └───────────────────────┘ │
-└─────────────────────────────────────────────────────────┘
-```
-
-### 约束
-
-- 现有 `/sora/v1/chat/completions` API 必须完全向后兼容
-- S3 存储在系统设置中独立配置(使用 `aws-sdk-go-v2` 直连)
-- 前端技术栈:Vue 3 + TypeScript + TailwindCSS(与现有前端一致)
-- 暗色主题风格参考 sora.com 官方客户端
-
-## Goals / Non-Goals
-
-### Goals
-
-1. 管理员在"系统设置"中配置 S3 存储,勾选开放给 Sora 用户使用
-2. 每个用户有存储配额限制,防止无限使用
-3. 提供参考 Sora 官方客户端的 Web 界面(生成 + 作品库)
-4. Sora 平台支持 API Key 账号类型,实现 sub2api 级联部署
-5. API Key 直接调用不存储,直接返回 URL
-6. 客户端 UI 调用走存储 + 记录 + 配额管理
-
-### Non-Goals
-
-- 不实现 Sora 官方的社交功能(Explore 社区、关注、评论)
-- 不实现视频编辑功能(Re-cut、Remix、Blend、Loop)
-- 不实现角色(Characters)功能
-- 不实现 Storyboard 时间线编辑器(仅支持分镜提示词格式)
-- 不实现移动端 App
-- 不支持用户自行配置存储(由管理员统一配置)
-
-## Decisions
-
-### D1: 两条调用路径分离(核心决策)
-
-**决策**:API Key 调用(`/sora/v1/chat/completions`)和客户端 UI 调用(`/api/v1/sora/generate`)完全独立。
-
-**理由**:
-
-| 方案 | 优点 | 缺点 |
-|------|------|------|
-| **A: 统一路径,全部走存储** | 实现简单 | API 用户被强制消耗存储配额;破坏现有 API 兼容性 |
-| **B: 统一路径,参数控制** | 灵活 | API 复杂度增加;需要所有 API 用户适配新参数 |
-| **✅ C: 两条独立路径** | 零破坏性变更;职责清晰 | 两套代码路径需维护 |
-
-选择 C。`/sora/v1/chat/completions` 保持纯透传,`/api/v1/sora/generate` 在上层包装存储/记录/配额逻辑。
-
-**实现约束**:`SoraGatewayService.Forward()` 仅负责生成与上游交互,不在 API Key 直调路径执行媒体落盘;客户端路径在 `SoraClientHandler` 内调用独立存储服务完成 S3/本地/上游降级。
-
-### D2: S3 存储使用 aws-sdk-go-v2 直连
-
-**决策**:直接使用 `aws-sdk-go-v2` 连接 S3 兼容存储,不依赖现有数据管理的 gRPC 代理。
-
-**理由**:
-- 现有 gRPC 数据管理代理仅用于备份功能,没有通用文件上传 RPC
-- 新增 gRPC RPC 会增加对外部代理进程的依赖和复杂度
-- `aws-sdk-go-v2` 直连更灵活,支持流式上传、预签名 URL 等高级功能
-- S3 配置独立存储在系统设置表中,管理员在系统设置页面配置
-
-**S3 配置存储**:在 Settings 表中新增以下键值:
-- `sora_s3_enabled` — 是否启用 S3 存储
-- `sora_s3_endpoint` — S3 端点(如 `https://s3.amazonaws.com`)
-- `sora_s3_region` — 区域
-- `sora_s3_bucket` — 存储桶
-- `sora_s3_access_key_id` — 访问密钥 ID
-- `sora_s3_secret_access_key` — 访问密钥(加密存储)
-- `sora_s3_prefix` — 对象键前缀(如 `sora/`)
-- `sora_s3_force_path_style` — 是否强制路径模式
-- `sora_s3_cdn_url` — CDN 域名(可选,用于生成公开 URL)
-
-### D3: Sora API Key 账号类型存储为 `apikey`
-
-**决策**:Sora 平台的 API Key / 上游透传账号存储为 `type = "apikey"`(而非 `type = "upstream"`)。
-
-**理由**:
-- 复用 Antigravity 的 upstream 模式(UI 上显示为"上游透传",存储为 `apikey`)
-- `Account.GetBaseURL()` 已支持 `apikey` 类型返回 `base_url`
-- 后端验证已允许 `apikey` 类型
-- 无需新增数据库字段或常量
-
-### D4: 生成记录表设计
-
-**决策**:新建独立的 `sora_generations` 表,不扩展现有 `usage_logs` 表。
-
-```sql
-CREATE TABLE sora_generations (
- id BIGSERIAL PRIMARY KEY,
- user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE,
- api_key_id BIGINT,
-
- -- 生成参数
- model VARCHAR(64) NOT NULL,
- prompt TEXT NOT NULL DEFAULT '',
- media_type VARCHAR(16) NOT NULL DEFAULT 'video', -- video/image
-
- -- 结果
- status VARCHAR(16) NOT NULL DEFAULT 'pending', -- pending/generating/completed/failed/cancelled
- media_url TEXT NOT NULL DEFAULT '',
- media_urls JSONB, -- 多图时的 URL 数组
- file_size_bytes BIGINT NOT NULL DEFAULT 0,
- storage_type VARCHAR(16) NOT NULL DEFAULT 'none', -- s3/local/upstream/none
- s3_object_keys JSONB, -- S3 object key 数组(视频单元素,多图多个),删除时逐一清理
-
- -- 上游信息
- upstream_task_id VARCHAR(128) NOT NULL DEFAULT '',
- error_message TEXT NOT NULL DEFAULT '',
-
- -- 时间
- created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
- completed_at TIMESTAMPTZ,
-
- -- 索引列
- -- created_at 不做唯一约束,避免同一用户高并发写入时产生碰撞
-);
-
-CREATE INDEX idx_sora_gen_user_created ON sora_generations(user_id, created_at DESC);
-CREATE INDEX idx_sora_gen_user_status ON sora_generations(user_id, status);
-```
-
-**理由**:
-- `usage_logs` 是计费日志,混入媒体管理会增加复杂度
-- 独立表方便查询(按用户、按状态、按时间)和管理(删除联动 S3 清理)
-- `s3_object_keys` 为 JSONB 数组,支持单视频(`["key.mp4"]`)和多图(`["key1.jpg","key2.jpg"]`)场景
-
-### D5: 配额追踪方案
-
-**决策**:在 `users` 表新增字段(而非独立表)。
-
-```sql
-ALTER TABLE users ADD COLUMN sora_storage_quota_bytes BIGINT NOT NULL DEFAULT 0; -- 0 = 使用系统默认
-ALTER TABLE users ADD COLUMN sora_storage_used_bytes BIGINT NOT NULL DEFAULT 0;
-```
-
-系统默认配额通过 Settings 表存储(`sora_default_storage_quota_bytes`)。
-
-**理由**:
-- 配额是用户的核心属性,放在 users 表查询最高效
-- 不需要 JOIN,减少请求延迟
-- 分组级别的配额覆盖通过 `groups.sora_storage_quota_bytes` 字段实现
-
-**配额判断逻辑**:
-```text
-有效配额 = user.sora_storage_quota_bytes > 0
- ? user.sora_storage_quota_bytes
- : group.sora_storage_quota_bytes > 0
- ? group.sora_storage_quota_bytes
- : settings.sora_default_storage_quota_bytes
-```
-
-### D6: Sora apikey 账号的 HTTP 透传实现
-
-**决策**:在 `SoraGatewayService.Forward()` 开头加分支判断,apikey 类型走独立的 HTTP 透传方法。
-
-```text
-Forward(ctx, c, account, body, clientStream):
- if account.Type == "apikey" && account.GetBaseURL() != "":
- return s.forwardToUpstream(ctx, c, account, body, clientStream)
- // ... 现有 SDK 直连逻辑
-```
-
-`forwardToUpstream()` 的行为:
-- 创建/更新账号时校验 `base_url`:必填,且必须包含 `http://` 或 `https://` scheme
-- 构造 HTTP 请求到规范化路径 `{base_url}/sora/v1/chat/completions`(去除重复斜杠)
-- Header: `Authorization: Bearer `, `Content-Type: application/json`
-- 流式响应:逐字节透传 SSE 流
-- 非流式响应:读取完整响应后返回
-- 错误处理:复用现有的失败转移逻辑(`UpstreamFailoverError`)
-
-### D7: 前端 UI 架构(参考 Sora 官方客户端)
-
-**决策**:采用单页面 + Tab 切换的设计,两个主视图:
-
-**参考 sora.com 的设计特点**:
-- 暗色主题(#0D0D0D 背景,白色文字)
-- 底部创作栏(提示词输入在页面底部)
-- 网格作品库(卡片缩略图 + hover 预览)
-- 右上角生成队列状态指示
-
-**我们的适配**:
-
-```text
-┌────────────────────────────────────────────────────────┐
-│ [生成] [作品库] 2.1GB/5GB │ ← 页面内导航(无Logo/头像,由侧边栏提供)
-├────────────────────────────────────────────────────────┤
-│ │
-│ ┌──────────────────────────────────────────────┐ │
-│ │ ⏳ 生成中 sora-2 "一只猫在月球上..." │ │ ← 任务卡片 1(最新)
-│ │ 已等待 3:42 · 预计剩余 8 分钟 │ │
-│ │ [取消生成] │ │
-│ └──────────────────────────────────────────────┘ │
-│ │
-│ ┌──────────────────────────────────────────────┐ │
-│ │ ✅ 已完成 gpt-image-1 │ │ ← 任务卡片 2(自动保存成功)
-│ │ ✓ 已保存到云端 [📥 本地下载] │ │
-│ └──────────────────────────────────────────────┘ │
-│ │
-│ ┌──────────────────────────────────────────────┐ │
-│ │ ✅ 已完成 sora-2(无存储降级) │ │ ← 任务卡片 3(S3不可用)
-│ │ [📥 本地下载] ⏱ 剩余 11:28 可下载 │ │ ← 过期倒计时
-│ └──────────────────────────────────────────────┘ │
-│ │
-│ 无任务时: 欢迎使用 Sora · 输入提示词开始创作 │
-│ │
-├────────────────────────────────────────────────────────┤
-│ 模型: [sora2-landscape-10s ▼] 正在生成 1/3 │ ← 底部创作栏
-│ ┌──────────────────────────────────────────────────┐ │
-│ │ 描述你想要的视频或图片... │ │ ← 提示词输入框
-│ └──────────────────────────────────────────────────┘ │
-│ [横屏] [竖屏] [方形] [10s] [15s] [25s] [📎] [生成▶] │ ← 设置栏
-│ ⚠ 存储未配置,生成后请立即下载 │ ← 无存储提示(条件显示)
-└────────────────────────────────────────────────────────┘
-```
-
-**完成后按钮状态说明**:
-- S3 自动保存成功:`✓ 已保存到云端` + `📥 本地下载`(作品自动进入作品库)
-- 降级本地存储:`✓ 已保存到本地` + `📥 本地下载`
-- S3 不可用(upstream):`📥 本地下载` + 15 分钟过期倒计时
-- S3 后续启用时,历史 upstream 记录可手动 `☁️ 保存到存储`
-
-**作品库页**:
-
-```text
-┌────────────────────────────────────────────────────────┐
-│ [生成] [作品库] 2.1GB/5GB │
-├────────────────────────────────────────────────────────┤
-│ [全部] [视频] [图片] 搜索... │ ← 筛选栏
-├────────────────────────────────────────────────────────┤
-│ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ │
-│ │ 🎬 │ │ 🖼️ │ │ 🎬 │ │ 🎬 │ │
-│ │ │ │ │ │ │ │ │ │ ← 4列网格
-│ │ sora2.. │ │ gpt-img │ │ sora2.. │ │ sora2.. │ │
-│ │ 3分钟前 │ │ 1小时前 │ │ 昨天 │ │ 2天前 │ │
-│ └─────────┘ └─────────┘ └─────────┘ └─────────┘ │
-│ ┌─────────┐ ┌─────────┐ │
-│ │ 🖼️ │ │ 🎬 │ │
-│ │ │ │ │ │
-│ └─────────┘ └─────────┘ │
-└────────────────────────────────────────────────────────┘
-```
-
-**组件拆分**:
-
-```text
-views/user/SoraView.vue 主页面容器
-components/sora/
-├── SoraNavBar.vue 页面内导航(tab切换 + 配额,无Logo/头像)
-├── SoraGeneratePage.vue 生成页
-│ ├── SoraPromptBar.vue 底部创作栏
-│ ├── SoraModelSelector.vue 模型选择下拉
-│ ├── SoraProgressCard.vue 生成进度卡片
-│ └── SoraNoStorageWarning.vue 无存储提示
-├── SoraLibraryPage.vue 作品库页
-│ ├── SoraLibraryGrid.vue 网格布局
-│ ├── SoraMediaCard.vue 单个作品卡片
-│ └── SoraEmptyState.vue 空状态
-├── SoraMediaPreview.vue 作品详情/预览弹窗
-├── SoraQuotaBar.vue 配额展示条
-└── SoraDownloadDialog.vue 即时下载弹窗
-```
-
-### D8: Sora S3 配置独立于数据管理
-
-**决策**:在系统设置(Settings 表)中独立存储 Sora S3 配置,不修改现有数据管理的 gRPC Proto 或 S3 Profile。
-
-**理由**:
-- 现有数据管理使用 gRPC 代理架构,修改 Proto 需要同步更新外部 Agent 进程
-- Sora S3 存储与备份 S3 存储的用途不同,独立配置更清晰
-- 系统设置表已有成熟的读写机制,新增配置项成本最低
-
-**前端**:在系统设置页面(SettingsView.vue)新增"Sora S3 存储配置"区域,包含:
-- 启用开关
-- S3 连接信息表单(endpoint、region、bucket、access_key_id、secret_access_key)
-- 高级选项(prefix、force_path_style、CDN URL)
-- "测试连接"按钮
-
-**运行时获取**:Sora S3 Storage Service 启动时从 Settings 表读取配置,缓存 S3 客户端实例。配置变更时刷新缓存。
-
-### D9: 客户端生成异步流程与自动存储
-
-**决策**:客户端生成采用"异步生成 + 自动存储 + 前端轮询"模式,解决手动保存与后端自动上传的矛盾。
-
-**核心流程**:
-
-```text
-POST /api/v1/sora/generate
- │
- ├─ 1. 配额预检查(有效配额 > 0 时检查)
- ├─ 2. 创建 sora_generations 记录(status=pending)
- ├─ 3. 立即返回 { generation_id } 给前端(异步继续)
- │
- ├─ 4. 后台:调用 Forward() 获取上游媒体 URL
- ├─ 5. 后台:自动上传到 S3(若可用且配额足够)
- │ └─ S3 失败 → 降级本地 → 再失败 → 保留上游 URL
- ├─ 6. 后台:更新 sora_generations 记录(status/media_url/storage_type 等)
- └─ 7. 后台:累加存储配额(仅 S3/本地存储时)
-```
-
-**前端状态同步**:
-- 前端通过轮询 `GET /api/v1/sora/generations/:id` 获取状态更新
-- 轮询策略(递减频率避免服务端压力):
- - 0-2 分钟:每 3 秒
- - 2-10 分钟:每 10 秒
- - 10-30 分钟:每 30 秒
- - 超过 30 分钟:提示"生成时间异常,可能已超时"
-- 页面加载时自动调用 `GET /api/v1/sora/generations?status=pending,generating` 恢复所有进行中任务
-
-**完成后的 UI 状态**:
-- `storage_type = 's3'`(自动保存成功):显示 "✓ 已保存到云端" + "📥 本地下载"
-- `storage_type = 'local'`(降级本地):显示 "✓ 已保存到本地" + "📥 本地下载"
-- `storage_type = 'upstream'`(S3 不可用):显示 "📥 本地下载" + 过期倒计时(15 分钟)
- - 若 S3 后续启用,用户可手动点击 "☁️ 保存到存储"(调用 `POST /api/v1/sora/generations/:id/save`)
-
-**多任务并发**:
-- 允许用户同时发起最多 3 个生成任务(超出排队等待)
-- 生成页中间区域以时间线方式纵向排列所有活跃任务卡片,最新在最上方
-- 底部创作栏显示当前活跃任务数(如 "正在生成 2/3")
-
-**理由**:
-- 异步返回让用户立即得到反馈(不用同步等 5-20 分钟)
-- 自动存储避免上游 URL 过期导致作品丢失(用户无需手动操作)
-- 不想要的作品可在作品库中删除以释放配额
-- 轮询而非 WebSocket,因为生成间隔以秒计,轮询足够且实现简单
-
-### D10: 取消生成与浏览器通知
-
-**决策**:支持取消进行中的生成任务,并在任务完成时通过浏览器通知告知用户。
-
-**取消生成**:
-- 进度卡片显示"取消"按钮(带二次确认)
-- 后端 `POST /api/v1/sora/generations/:id/cancel`:标记状态为 `cancelled`
-- 若上游任务已提交但无法真正取消,后端仍标记为 cancelled 并忽略后续结果
-- 取消的任务不累加存储配额
-
-**浏览器通知**:
-- 首次使用时请求 Notification API 权限
-- 任务完成时发送桌面通知:"您的视频/图片已生成完成!"
-- 任务失败时通知:"生成失败:{简短原因}"
-- 标签页 title 闪烁提示(如 "(1) ✓ 生成完成 - Sora")
-
-### D11: Sora 页面嵌入全局侧边栏布局
-
-**决策**:Sora 客户端页面作为普通用户页面嵌入全局侧边栏布局内,不独立接管全屏。
-
-**问题**:Sora 官方客户端有独立的顶部导航栏(Logo + Tab + 配额 + 头像),如果在系统内也实现独立导航栏,会与全局侧边栏重复(Logo、头像),且用户无法通过侧边栏快速切换到其他页面。
-
-**方案对比**:
-
-| 方案 | 优点 | 缺点 |
-|------|------|------|
-| A: 全屏接管(隐藏侧边栏) | 沉浸式体验 | 导航不一致;需实现返回按钮;与现有页面体验割裂 |
-| **✅ B: 嵌入侧边栏布局** | 与所有页面导航一致 | 可用宽度稍少(侧边栏 64/256px) |
-
-选择 B。Sora 页面内仅保留 Tab 切换("生成"/"作品库")+ 配额进度条,去掉独立 Logo 和头像。
-
-**侧边栏菜单项**:
-- 条件显示:通过公共设置 `sora_client_enabled`(后端根据是否有活跃 Sora 账号推断)控制显示
-- 图标:使用 Heroicons 线性 Sparkles 图标(与现有菜单图标同为 stroke 风格)
-- 位置:普通用户在 Dashboard 之后;管理员"我的账户"在 API 密钥之后
-- 简单模式:`hideInSimpleMode: true`
-- 双菜单同步:同时添加到 `userNavItems` 和 `personalNavItems`
-
-**公共设置**:后端新增 `sora_client_enabled` 到公共设置 API 响应中。判断逻辑:`sora_client_enabled = (活跃 Sora 账号数 > 0)`。前端通过 `appStore.cachedPublicSettings?.sora_client_enabled` 条件渲染菜单项。
-
-## Risks / Trade-offs
-
-### R1: aws-sdk-go-v2 新依赖引入
-
-**风险**:引入 `aws-sdk-go-v2` 会增加项目的依赖体积。
-
-**缓解**:仅引入必要的子模块(`s3`、`config`、`credentials`),不引入完整 SDK。Go 模块系统可以精确控制依赖粒度。
-
-### R2: 大文件上传到 S3 的内存压力
-
-**风险**:视频文件可达 200MB,如果先下载到内存再上传 S3,会导致内存峰值。
-
-**缓解**:使用流式管道 — 从 Sora 上游下载的同时流式上传到 S3(`io.Pipe`),避免全量缓存到内存。
-
-### R3: 配额计算的并发一致性
-
-**风险**:多个并发请求同时检查配额可能导致超额。
-
-**缓解**:
-- 先计算 `effective_quota`(用户 > 分组 > 系统默认),再执行原子更新:
- `UPDATE users SET sora_storage_used_bytes = sora_storage_used_bytes + :delta WHERE id = :id AND (:effective_quota = 0 OR sora_storage_used_bytes + :delta <= :effective_quota)`
-- 若原子更新失败则回滚新上传文件,并返回配额错误,不接受超额落盘
-
-### R4: apikey 账号透传的响应格式差异
-
-**风险**:上游 sub2api 返回的响应格式可能与本地 SDK 直连的格式不完全一致。
-
-**缓解**:透传模式下不做任何响应体改写,原样返回。客户端需要兼容两种来源的响应。
-
-### R5: 无存储模式下的用户体验
-
-**风险**:上游临时 URL 过期时间不确定(可能几分钟到几小时),用户可能来不及下载。
-
-**缓解**:
-- 正常模式(S3 可用):D9 的自动存储彻底消除此风险
-- 降级模式(S3 不可用):
- - 进度卡片显示醒目的 **15 分钟倒计时进度条**("剩余 12:34 可下载")
- - 剩余 5 分钟时通过浏览器通知提醒用户
- - 剩余 2 分钟时卡片边框变为红色警告态
- - 过期后显示"链接已过期,作品无法恢复"(灰色不可操作)
- - 用户离开页面时触发 `beforeunload` 警告:"您有未下载的生成结果,离开后可能丢失"
-
-### R6: 多任务并发的资源控制
-
-**风险**:允许多任务并发(D9)可能导致同一用户占用过多上游并发槽位。
-
-**缓解**:
-- 客户端并发上限 3 个(超出返回 429 + "请等待当前任务完成")
-- 复用现有 `SoraGatewayHandler` 的用户级并发控制(`userSlots`)
-- 后端限制同一用户 pending+generating 状态的记录不超过 3 条
-
-### R7: 预签名 URL 24 小时过期导致作品库碎图
-
-**风险**:S3 预签名 URL 24 小时后过期,如果前端缓存了旧 URL,作品库中的缩略图和视频会变成 broken image。
-
-**缓解**:
-- 作品库列表 API 每次返回时由后端**动态生成新的预签名 URL**(有效期 24 小时)
-- 前端不缓存媒体 URL,每次打开作品库/详情都请求新数据
-- 配置 CDN URL 时无此问题(CDN URL 永久有效)
-
-## Migration Plan
-
-### 阶段一:基础设施(可独立部署)
-
-1. 数据库迁移:`sora_generations` 表 + `users` 表新增配额字段
-2. 系统设置新增默认配额键值 + Sora S3 配置键值
-3. 引入 `aws-sdk-go-v2` 依赖
-
-### 阶段二:后端能力
-
-4. Sora apikey 账号类型 — 前端账号创建表单 + 后端 HTTP 透传
-5. S3 上传服务 — 读取系统设置中的 S3 配置,使用 aws-sdk-go-v2 实现文件上传
-6. 生成记录 CRUD + 配额管理 Service
-7. Sora Client Handler + 路由注册
-
-### 阶段三:前端客户端
-
-8. Sora 客户端 UI — 生成页 + 作品库 + 进度 + 配额
-9. 无存储即时下载模式
-10. 侧边栏菜单 + 路由 + 国际化
-
-### 回滚策略
-
-- 所有数据库迁移都是 additive(新增表/字段),不修改现有结构
-- 前端新页面独立于现有页面,删除路由即可回滚
-- `/sora/v1/chat/completions` 不做任何修改,API 用户无感知
-- Sora S3 配置独立于现有备份 S3 Profile,不影响现有备份功能
-
-## Open Questions
-
-1. **视频缩略图如何生成?** — ✅ 已决策:前端用 `` 标签截第一帧(无需服务端依赖)
-2. **配额单位是否需要更细粒度?** — ✅ 已决策:第一版只做字节数,不做次数限制
-3. **是否需要管理员查看所有用户的 Sora 生成记录?** — 🔜 第二版再做,后续按需添加
-4. **生成状态推送方式?** — ✅ 已决策(D9):使用前端轮询(递减频率),不引入 WebSocket
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/proposal.md b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/proposal.md
deleted file mode 100644
index 4818ff674..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/proposal.md
+++ /dev/null
@@ -1,356 +0,0 @@
-# Sora 客户端功能完善
-
-## Why
-
-当前 Sora 功能存在五个核心问题:
-
-1. **容量瓶颈** — 视频文件单个可达 200MB+,本地磁盘存储会快速耗尽,且默认 7 天清理导致用户资产丢失。
-2. **缺乏客户端体验** — 没有面向用户的 Sora 界面,用户只能通过 API 调用,无法像 Sora 官方客户端那样浏览作品、管理任务。
-3. **无存储时的体验断裂** — 如果管理员未配置 S3 存储,用户生成完毕后只能拿到一个临时 URL,离开对话后就丢失了。
-4. **不支持级联部署** — 当前 Sora 平台账号只支持 OAuth 直连 OpenAI,无法实现 `sub2api(API Key) → sub2api(OAuth) → OpenAI` 的两层桥接架构,限制了多站点分发和权限隔离场景。
-5. **账户管理缺少 Sora API Key 类型** — 当前"账户管理 > 创建账号"中 Sora 平台被硬编码为仅支持 OAuth(`CreateAccountModal.vue` 第 2597-2601 行强制设置 `form.type = 'oauth'`)。其他平台如 Anthropic、OpenAI、Gemini 都支持 API Key 类型(可配置 `base_url` + `api_key` 指向自定义上游),Antigravity 还支持 upstream 类型。Sora 缺少这个选项,导致无法实现级联部署。
-
-## What Changes
-
-### 一、存储层改造(管理员配置 S3 供用户使用)
-
-- 在系统设置中新增独立的 Sora S3 存储配置(endpoint、bucket、region、access_key、secret_key 等)
-- 管理员配置并启用后,系统将生成的媒体直接通过 `aws-sdk-go-v2` 上传到 S3 兼容存储,返回 CDN/签名 URL 给用户
-- 不依赖现有数据管理的 gRPC 代理,S3 配置独立管理
-- 保留本地存储作为回退方案(S3 不可用时降级)
-
-### 二、两种调用路径的存储策略差异
-
-系统有两种调用入口,存储行为完全不同:
-
-**路径 A:API Key 直接调用**(`/sora/v1/chat/completions`,开发者/程序调用)
-
-- **不存储媒体** — 生成完成后直接返回上游 URL,由调用方自行下载
-- **不记录生成历史** — 不写入 `sora_generations` 表(保持现有行为,纯透传网关)
-- **不检查存储配额** — 仅检查余额/计费(现有逻辑)
-- **理由**:API 用户是开发者,有能力自行处理媒体下载和存储,不应强制消耗系统 S3 空间
-
-**路径 B:Sora 客户端 UI 调用**(`/api/v1/sora/generate`,Web 界面用户)
-
-- **异步生成** — 立即返回 generation_id,后台异步完成生成(用户无需同步等待 5-20 分钟)
-- **自动存储到 S3** — 生成完成后自动上传到管理员配置的 S3,返回永久 URL
-- **记录生成历史** — 写入 `sora_generations` 表,用户可在作品库浏览
-- **检查存储配额** — 生成前检查,超限拒绝并引导用户释放空间
-- **支持取消** — 用户可取消进行中的生成任务
-- **理由**:Web 用户需要持久化作品、浏览历史、管理生成记录
-
-**无存储场景的降级**(仅影响路径 B):
-
-当管理员**未配置 S3 存储**时,路径 B 的行为:
-
-- **生成完成后**:直接将上游临时 URL 透传给客户端
-- **客户端提示**:显示醒目提示 —— "存储未配置,请立即下载,链接将在短时间内过期"
-- **自动下载触发**:生成完成后自动触发浏览器下载(或提供一键下载按钮)
-- **生成记录保留**:仍记录生成元数据(提示词、模型、状态),`storage_type='upstream'`,`media_url` 为上游临时 URL
-- **降级链**:`S3 存储(优先)→ 本地磁盘存储(回退)→ 上游临时 URL 透传(最终回退,即生即下载)`
-
-**对照总结**:
-
-```
- 路径 A: API Key 直接调用 路径 B: Sora 客户端 UI
- /sora/v1/chat/completions /api/v1/sora/generate
- ─────────────────────────────────────────────────────────────────────────
- 使用者 开发者 / 程序 Web 界面用户
- 存储到 S3 ❌ 不存储 ✅ 上传到 S3
- 生成记录 ❌ 不记录 ✅ 写入 sora_generations
- 配额检查 ❌ 不检查存储配额 ✅ 检查
- 返回内容 上游临时 URL(自行下载) S3 永久 URL(可浏览历史)
- 无存储时 正常返回临时 URL(无影响) 降级为即生即下载 + 提示
-```
-
-### 三、用户存储配额
-
-- 新增用户级别的 Sora 存储配额字段(默认值由管理员在系统设置中配置)
-- 管理员可为单个用户或分组设置不同的存储配额上限
-- 每次生成完成并上传后累计用户已用空间,超出配额则拒绝新的生成请求
-- 提供配额查询 API,用户可在客户端中查看已用/剩余空间
-
-### 四、账户管理新增 Sora API Key / 上游透传 账号类型
-
-**现状分析**:
-
-当前 Sora 平台在账号创建界面中被硬编码为仅支持 OAuth 类型:
-
-```typescript
-// CreateAccountModal.vue 第 2597-2601 行
-if (newPlatform === 'sora') {
- accountCategory.value = 'oauth-based'
- addMethod.value = 'oauth'
- form.type = 'oauth'
-}
-```
-
-而其他平台已支持多种账号类型:
-
-| 平台 | OAuth | Setup Token | API Key | Upstream |
-|------|-------|-------------|---------|----------|
-| Anthropic | ✅ | ✅ | ✅ (`base_url` + `api_key`) | — |
-| OpenAI | ✅ | — | ✅ (`base_url` + `api_key`) | — |
-| Gemini | ✅ (3种OAuth方式) | — | ✅ (`base_url` + `api_key`) | — |
-| Antigravity | ✅ | — | — | ✅ (实际存为`apikey`类型, `base_url` + `api_key`) |
-| **Sora** | **✅ (唯一)** | **—** | **❌ 缺失** | **❌ 缺失** |
-
-**设计方案**:
-
-为 Sora 平台新增 "API Key / 上游透传" 账号类型选项,复用 Antigravity 的 upstream 模式(实际存储为 `apikey` 类型):
-
-**前端变更**(`CreateAccountModal.vue`):
-
-- 取消 Sora 平台的硬编码 OAuth 限制
-- Sora 平台展示两个账号类别选项卡:
- - **OAuth 认证**(现有功能,连接 OpenAI Sora 官方)
- - **API Key / 上游透传**(新增,连接另一个 sub2api 或兼容 API)
-- 选择"API Key / 上游透传"时,显示以下表单字段:
- - `Base URL`(必填)— 上游 Sora 服务地址,默认占位符 `https://your-upstream-sub2api.com`
- - `API Key`(必填)— 上游服务的 API Key,占位符 `sk-...`
-- 表单提交时,`form.type = 'apikey'`,credentials 包含 `{ base_url, api_key }`
-
-**前端 UI 交互设计**:
-
-```
-┌─ 创建 Sora 账号 ─────────────────────────────────┐
-│ │
-│ 平台: [Sora ▼] │
-│ │
-│ 账号类型: │
-│ ┌──────────────┐ ┌──────────────────────┐ │
-│ │ ● OAuth 认证 │ │ ○ API Key / 上游透传 │ │
-│ └──────────────┘ └──────────────────────┘ │
-│ │
-│ ── 选择 OAuth 认证时(现有流程)── │
-│ [发起 OpenAI OAuth 授权] │
-│ │
-│ ── 选择 API Key / 上游透传时(新增)── │
-│ Base URL: [https://upstream.example.com ] │
-│ API Key: [sk-•••••••••••••••• ] │
-│ │
-│ 提示: 适用于连接另一个 sub2api 实例或兼容的 │
-│ Sora API 服务。请求将以 API Key 认证 │
-│ 透传到上游的 /sora/v1/chat/completions │
-│ │
-│ [测试连接] [创建账号] │
-└────────────────────────────────────────────────────┘
-```
-
-**后端变更**:
-
-- 后端验证逻辑无需修改(已允许 `oneof=oauth setup-token apikey upstream`)
-- `SoraGatewayService.Forward()` 新增分支:当 `account.Type == "apikey"` 且 `account.GetBaseURL() != ""` 时,不走 `SoraSDKClient`,而是将请求 HTTP 透传到规范化后的上游地址(`{base_url}/sora/v1/chat/completions`),Header 中携带 `Authorization: Bearer `
-- Sora apikey 账号创建/编辑时强制校验 `base_url`(必填,需包含 `http://` 或 `https://` scheme)
-- 复用现有 `Account.GetBaseURL()` 方法(已支持 `apikey` 类型返回 `base_url`)
-- 响应直接透传回客户端(流式/非流式均兼容)
-
-**编辑账号**(`EditAccountModal.vue`):
-
-- 现有 Sora OAuth 账号的编辑功能不变
-- 新增 API Key 类型 Sora 账号的编辑支持(可修改 `base_url` 和 `api_key`)
-- 账号测试(`AccountTestModal.vue`):API Key 类型发送一个轻量级请求到上游验证连通性
-
-### 五、sub2api 二级桥接(基于第四节的 API Key 账号类型实现)
-
-**场景**:分站 A(面向终端用户)需要通过总站 B(拥有 Sora OAuth 账号)来访问 OpenAI Sora。
-
-```
-终端用户 → sub2api-A(分站) → sub2api-B(总站) → OpenAI Sora
- │ │
- │ Sora apikey 账号 │ Sora oauth 账号
- │ base_url=总站地址 │ access_token=OpenAI
- │ api_key=总站Key │
- ▼ ▼
- HTTP 透传请求 SDK 直连 OpenAI
-```
-
-**实现方式**:
-
-级联部署不需要额外的代码能力,完全基于第四节新增的 "API Key / 上游透传" 账号类型:
-
-1. **总站 B**:创建 Sora OAuth 账号(现有功能),正常连接 OpenAI
-2. **总站 B**:创建 API Key(`/sora` 端点的 Key),供分站使用
-3. **分站 A**:创建 Sora API Key 账号(第四节新增),填写:
- - `base_url` = 总站 B 的地址(如 `https://main-site.example.com`)
- - `api_key` = 总站 B 发放的 API Key
-4. **分站 A** 的 Sora 网关检测到 `apikey` 类型账号,将请求透传到总站 B 的 `/sora/v1/chat/completions`
-5. **总站 B** 收到请求后用 OAuth 账号走 SDK 连接 OpenAI,返回结果
-6. **分站 A** 收到响应后,由自己决定是否存储到 S3(存储层完全独立)
-
-**为什么需要级联**:
-
-- **权限隔离**:OAuth 账号是敏感资产,集中管理在总站,分站只需 API Key
-- **多站点分发**:一个总站可服务多个分站,每个分站独立管理用户和计费
-- **运维简化**:OAuth Token 刷新、Cloudflare 防护等复杂逻辑只需总站处理
-
-**注意**:级联不是独立的能力,而是 `sora-upstream-bridge` 能力的一个使用场景。
-
-### 六、Sora 客户端界面(参考官方客户端)
-
-- 在前端新增 Sora 客户端页面(`/sora`),面向普通用户
-- **嵌入全局布局**:Sora 页面保留在全局侧边栏布局内渲染(不独立接管全屏),页面内仅保留 Tab 切换 + 配额进度条,去掉独立 Logo 和头像(由侧边栏提供)
-- **条件显示**:侧边栏 Sora 菜单项仅在管理员配置了活跃 Sora 账号时显示(`sora_client_enabled`)
-- **生成页面**:输入提示词、选择模型(视频/图片/分辨率/时长)、上传参考图、发起生成
-- **作品库页面**:网格展示历史生成作品(缩略图/视频预览),支持下载、删除
-- **生成进度**:实时显示当前生成任务的进度状态(排队中/生成中/完成/失败)
-- **配额展示**:展示当前用户的存储用量和剩余配额
-- **无存储提醒**:未配置存储时,显示即时下载提示和自动下载功能
-
-### 七、后端 Sora 生成记录
-
-- 新增 `sora_generations` 表,记录每次生成的元数据(用户、模型、提示词、媒体 URL、文件大小、状态、存储方式等)
-- 提供生成记录的 CRUD API(列表、详情、删除)
-- 删除作品时同步清理 S3 中的文件并释放配额
-- 无存储模式下仍记录生成历史(但媒体 URL 会标记过期)
-
-### 八、管理员配置
-
-- 在系统设置中新增独立的 Sora S3 存储配置(endpoint、bucket、region、access_key_id、secret_access_key、prefix、force_path_style 等),使用 `aws-sdk-go-v2` 直连
-- 在系统设置中新增 Sora 默认存储配额配置
-- 在用户管理 / 分组管理中可覆盖单个用户 / 分组的配额
-
-## Capabilities
-
-### New Capabilities
-
-- `sora-s3-media-storage`: 将 Sora 生成的媒体文件上传到管理员在系统设置中配置的 S3 兼容存储(S3/R2/OSS/MinIO),使用 `aws-sdk-go-v2` 直连,替代本地磁盘存储,支持 CDN/签名 URL 分发。无存储时降级为即生即下载模式。
-
-- `sora-user-storage-quota`: 用户级别的 Sora 存储配额管理,包括配额设置(系统默认值 + 用户/分组覆盖)、用量追踪(每次上传累计)、超限拒绝、配额查询 API。
-
-- `sora-generation-history`: Sora 生成记录的持久化存储与管理,包括 `sora_generations` 表、CRUD API、作品删除与 S3 文件清理联动、无存储模式下的历史保留与过期标记。
-
-- `sora-client-ui`: 面向用户的 Sora 客户端前端界面,参考 Sora 官方客户端设计,包括生成页面(提示词输入、模型选择、参考图上传、多任务并发展示)、作品库(网格展示、下载、删除)、异步生成+前端轮询、自动保存到 S3、取消生成、生成完成浏览器通知、无存储时 15 分钟倒计时提醒、页面刷新后任务恢复。
-
-- `sora-account-apikey`: 为 Sora 平台新增 "API Key / 上游透传" 账号类型。前端:取消 Sora 平台的 OAuth 硬编码限制,新增 API Key 选项卡(`base_url` + `api_key` 表单);后端:`Forward()` 中检测 `apikey` 类型账号时走 HTTP 透传而非 SDK 直连,请求发到 `base_url/sora/v1/chat/completions`。该能力同时实现了 sub2api 二级桥接(分站 API Key → 总站 OAuth → OpenAI)。
-
-### Modified Capabilities
-
-- `sora-generation-gateway`: 现有 Sora 网关转发逻辑保持不变 — `/sora/v1/chat/completions` 继续作为纯透传网关,不存储媒体、不记录历史、不检查存储配额(仅保留现有的计费和并发控制)。存储/历史/配额逻辑全部由新的 `sora-client-ui`(`/api/v1/sora/generate`)在上层处理。注:apikey 账号的 HTTP 透传由 `sora-account-apikey` 能力负责。
-
-- `sora-s3-settings`: 系统设置新增独立的 Sora S3 存储配置区域(不依赖数据管理的 gRPC S3 Profile),包含完整的 S3 连接参数和测试连接功能。
-
-## Impact
-
-### 数据库变更
-
-- 新增 `sora_generations` 表(生成记录:用户ID、模型、提示词、媒体URL、文件大小、存储方式、`s3_object_keys` JSONB 数组、状态、创建时间)
-- 在 `users` 表新增 `sora_storage_quota_bytes`、`sora_storage_used_bytes` 字段(配额上限、已用空间)
-- 系统设置新增 Sora S3 配置键值(`sora_s3_endpoint`、`sora_s3_bucket`、`sora_s3_region` 等)
-- 系统设置新增 `sora_default_storage_quota_bytes` 键值
-- 公共设置 API 新增 `sora_client_enabled` 字段(后端根据活跃 Sora 账号数推断,供前端条件显示 Sora 菜单项)
-- 分组表新增 `sora_storage_quota_bytes` 字段(可选覆盖)
-
-### 后端代码变更
-
-- `service/sora_gateway_service.go` — `Forward()` 方法新增 apikey 账号 HTTP 透传分支,并将 API Key 直调路径保持为“不落盘、不记录”的纯透传语义
-- `service/sora_s3_storage.go` — 新增 S3 上传能力(使用 `aws-sdk-go-v2` 直连,读取系统设置中的 S3 配置)
-- 新增 `service/sora_generation_service.go` — 生成记录 CRUD + S3 文件清理
-- 新增 `service/sora_quota_service.go` — 配额管理
-- 新增 `service/sora_upstream_forwarder.go` — apikey 类型 Sora 账号的 HTTP 透传逻辑(请求转发 + 流式响应代理)
-- 新增 `handler/sora_client_handler.go` — 用户端 Sora API(异步生成、历史、配额、取消、手动保存、存储状态查询)
-- `handler/admin/setting_handler.go` — 系统设置新增 Sora S3 配置接口
-- `server/routes/` — 新增用户端 Sora 路由
-
-### 前端代码变更
-
-- `components/account/CreateAccountModal.vue` — **核心变更**:取消 Sora 平台 OAuth 硬编码限制,新增"API Key / 上游透传"选项卡和表单(`base_url` + `api_key`)
-- `components/account/EditAccountModal.vue` — 支持编辑 Sora apikey 类型账号的 `base_url` 和 `api_key`
-- `components/account/credentialsBuilder.ts` — 新增 Sora apikey 类型的 credentials 构建逻辑
-- 新增 `views/user/SoraView.vue` — Sora 客户端主页面
-- 新增 `components/sora/` — 生成表单、作品库、进度条、配额展示、即时下载等组件
-- `api/` — 新增 Sora 客户端 API 调用
-- `router/index.ts` — 新增 `/sora` 路由
-- `components/layout/AppSidebar.vue` — 新增 Sora 菜单项
-- `i18n/` — 新增国际化文本(含 Sora API Key 账号相关翻译)
-- 管理端"系统设置"页面 — 新增 Sora S3 存储配置区域
-
-### API 变更
-
-- 新增 Sora 客户端 API(路径 B,供 Web UI 使用):
- - `POST /api/v1/sora/generate` — 发起生成(异步:立即返回 generation_id,后台完成生成 + 自动 S3 上传)
- - `GET /api/v1/sora/generations` — 生成历史列表(支持按状态筛选,用于恢复进行中任务)
- - `GET /api/v1/sora/generations/:id` — 生成详情(前端轮询获取状态更新)
- - `POST /api/v1/sora/generations/:id/save` — 手动保存到存储(仅针对 storage_type='upstream' 的记录,S3 后续启用时使用)
- - `POST /api/v1/sora/generations/:id/cancel` — 取消进行中的生成任务
- - `DELETE /api/v1/sora/generations/:id` — 删除作品(联动 S3 清理 + 释放配额)
- - `GET /api/v1/sora/quota` — 查询配额用量
- - `GET /api/v1/sora/models` — 可用模型列表
- - `GET /api/v1/sora/storage-status` — 存储状态查询(S3 是否可用,供前端决定 UI 展示)
-- 现有 API(路径 A,不变):
- - `/sora/v1/chat/completions` — 保持纯透传,不存储、不记录,直接返回 URL 给 API 调用方自行下载
-- 扩展管理端 API:系统设置新增 Sora S3 配置读写接口
-
-### 依赖
-
-- 需要 S3 兼容的 Go SDK(`github.com/aws/aws-sdk-go-v2`),直连 S3 存储
-- 现有 `go-sora2api v1.1.0` SDK 无需变更
-
-## 现有架构参考(代码分析摘要)
-
-### 当前 Sora 网关转发流程
-
-```
-POST /sora/v1/chat/completions
- │
- ├─ 并发控制(用户级 + 账号级)
- ├─ 账号选择 + 失败转移(最多切 3 个账号)
- │
- ▼
-SoraGatewayService.Forward()
- ├─ 解析请求(模型、提示词、图片/视频输入)
- ├─ 预检查(PreflightCheck 验证额度)
- ├─ 创建任务(CreateImage/Video/StoryboardTask)
- ├─ 轮询等待完成(poll 2s 间隔,最多 600 次 = 20 分钟)
- ├─ [可选] 去水印处理
- ├─ 下载到本地存储 / 回退上游 URL
- └─ 返回 Chat Completions 格式响应
-```
-
-### Sora S3 存储方案
-
-- 独立于现有数据管理的 gRPC S3 Profile 体系
-- 在系统设置(Settings 表)中新增 Sora S3 配置项
-- 后端使用 `aws-sdk-go-v2` 直接连接 S3 兼容存储
-- 前端在系统设置页面新增 Sora S3 配置区域(含测试连接按钮)
-
-### 现有账号类型(各平台支持情况)
-
-| Type | 说明 | Anthropic | OpenAI | Gemini | Antigravity | Sora (现状) | Sora (本次新增) |
-|------|------|-----------|--------|--------|-------------|-------------|----------------|
-| `oauth` | OAuth 认证 | ✅ | ✅ | ✅ (3种方式) | ✅ | ✅ 唯一支持 | ✅ 保留 |
-| `setup-token` | Setup Token | ✅ | — | — | — | ❌ | — |
-| `apikey` | API Key + base_url | ✅ | ✅ | ✅ | ✅ (作为upstream) | ❌ | **✅ 新增** |
-| `upstream` | 上游透传 | — | — | — | — | ❌ | — |
-
-**关键限制代码**(本次需修改):
-- `CreateAccountModal.vue:2597-2601` — Sora 平台强制 `form.type = 'oauth'`
-- `Account.GetBaseURL()` — 仅 `apikey` 类型返回 `base_url`(无需修改,已兼容)
-- `handler/admin/account_handler.go:94` — 验证已支持 `apikey`(无需修改)
-
-### 关键文件路径
-
-| 模块 | 文件 |
-|------|------|
-| **Sora 后端** | |
-| Sora 网关服务 | `service/sora_gateway_service.go` (1484 行) |
-| Sora SDK 客户端 | `service/sora_sdk_client.go` |
-| Sora 媒体存储 | `service/sora_media_storage.go` |
-| Sora 模型配置 | `service/sora_models.go` |
-| Sora HTTP Handler | `handler/sora_gateway_handler.go` |
-| Sora 路由注册 | `server/routes/gateway.go` (第 104-124 行) |
-| **账号管理后端** | |
-| 账号类型常量 | `domain/constants.go` |
-| 账号 Service | `service/account.go` (`GetBaseURL()` 第 521 行) |
-| 账号 Handler | `handler/admin/account_handler.go` (验证规则第 94 行) |
-| **账号管理前端** | |
-| 创建账号对话框 | `components/account/CreateAccountModal.vue` (2100+ 行, **Sora 硬编码在第 2597 行**) |
-| 编辑账号对话框 | `components/account/EditAccountModal.vue` |
-| Credentials 构建 | `components/account/credentialsBuilder.ts` |
-| **系统设置** | |
-| 系统设置 Handler | `handler/admin/setting_handler.go` |
-| 前端系统设置页 | `views/admin/SettingsView.vue`(新增 Sora S3 配置区域) |
-| **其他前端** | |
-| 前端路由 | `router/index.ts` |
-| 前端侧边栏 | `components/layout/AppSidebar.vue` |
-| 前端类型定义 | `types/index.ts` (`AccountPlatform`, `AccountType` 第 485 行) |
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/review-rounds.md b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/review-rounds.md
deleted file mode 100644
index b115da4d9..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/review-rounds.md
+++ /dev/null
@@ -1,43 +0,0 @@
-## sora-client-s3-storage 三轮审核记录
-
-### 第 1 轮:一致性审核(proposal/design/tasks/specs)
-
-发现问题:
-- 关键术语不一致:`四个核心问题` 实际列了 5 条。
-- 能力名不一致:`sora-gateway` 与现有规格能力 `sora-generation-gateway` 不一致。
-- 配额键名不一致:`sora_default_quota_bytes` 与 `sora_default_storage_quota_bytes` 混用。
-- 数据模型表述歧义:`user_sora_quotas 表或 users 字段` 两种方案并存。
-- 路径/文件名不一致:`handler/admin/settings_handler.go` 与仓库实际 `handler/admin/setting_handler.go` 不一致。
-
-修复动作:
-- 统一 proposal 术语、能力名、配置键名、字段命名与文件路径。
-- 删除数据库方案歧义,明确采用 `users` 与 `groups` 字段方案。
-
-### 第 2 轮:可实施性审核(结构与迁移)
-
-发现问题:
-- `sora_generations` 使用 `(user_id, created_at)` 联合唯一约束,存在高并发写入冲突风险。
-- 分组配额字段命名未带单位后缀(`sora_storage_quota`),与 `*_bytes` 体系不一致。
-- 任务清单缺少“路径 A 不落盘”的明确改造任务,无法保证 `/sora/v1/chat/completions` 纯透传目标。
-
-修复动作:
-- 将联合唯一约束改为普通索引 `(user_id, created_at DESC)`。
-- 统一分组字段为 `groups.sora_storage_quota_bytes`。
-- 在 `tasks.md` 增加路径 A 不落盘任务与对应验证项。
-
-### 第 3 轮:鲁棒性审核(边界与运维)
-
-发现问题:
-- apikey 透传 URL 拼接写法为字符串拼接,存在双斜杠与非法 base_url 风险。
-- S3 访问 URL 策略未收敛(CDN 与预签名都提到,但缺少决策规则)。
-- 配额并发控制表述中默认容忍超额,与严格配额目标冲突。
-
-修复动作:
-- 在 proposal/design/spec 中补充 `base_url` 校验(必填 + scheme)与规范化拼接要求。
-- 在 `sora-s3-media-storage/spec.md` 明确“CDN 优先,预签名兜底”策略。
-- 将配额并发策略改为“原子更新失败即回滚文件并报错”,取消容忍超额表述。
-
-### 结论
-
-- 已完成 3 轮审核并修复全部已发现问题。
-- 当前提案在一致性、可实施性和鲁棒性三个维度均已收敛,且已同步更新 `proposal.md`、`design.md`、`tasks.md` 与相关 specs。
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/sora-client-mockup.html b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/sora-client-mockup.html
deleted file mode 100644
index 95f654313..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/sora-client-mockup.html
+++ /dev/null
@@ -1,3088 +0,0 @@
-
-
-
-
-
-Sora 客户端 — 嵌入全局侧边栏布局
-
-
-
-
-
-
-
-
-
-
-
-
-
- ⚠️
- 管理员未开通云存储,生成完成后请使用"本地下载"保存文件,否则将会丢失。
-
-
-
-
-
-
-
-
-
-
-
-
将你的想象力变成视频
-
输入一段描述,Sora 将为你创作逼真的视频或图片。尝试以下示例开始创作。
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
一只金色的柴犬在东京涩谷街头散步,镜头跟随,电影感画面,4K 高清
-
-
-
- 已等待 3:42 · 预计剩余 8 分钟
-
-
-
-
-
-
-
-
-
赛博朋克风格的未来城市,霓虹灯倒映在雨后积水中,夜景,电影级色彩
-
-
-
-
-
-
-
水墨画风格,一叶扁舟在山水间漂泊,薄雾缭绕,中国古典意境
-
-
-
- ✓ 已保存到云端
-
-
-
-
- ⏱ 剩余 11:28 可下载
-
-
-
-
-
-
-
-
-
无人机航拍视角,冰岛极光下的冰川湖面反射绿色光芒,慢速推进
-
-
- ✓ 已保存到云端
-
-
-
-
-
-
-
-
赛博朋克风格的机械战士在城市废墟中战斗,火焰和爆炸特效
-
⛔ 内容策略限制
-
提示词包含暴力相关内容,触发了安全策略。请修改提示词中的相关描述后重试。
-
-
-
-
-
-
-
-
-
-
-
-
日式庭园,秋天的红叶铺满石板路,细雨飘落,禅意氛围
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
共 8 个作品
-
-
-
-
-
-
-
-
-
-
VIDEO
-
-
-
-
-
▶
-
0:10
-
-
-
sora2-landscape-10s
-
2 分钟前
-
-
-
-
-
-
-
-
-
-
-
VIDEO
-
-
-
-
-
▶
-
0:15
-
-
-
sora2-portrait-15s
-
1 小时前
-
-
-
-
-
-
-
-
VIDEO
-
-
-
-
-
▶
-
0:25
-
-
-
sora2-landscape-25s
-
2 小时前
-
-
-
-
-
-
-
-
-
-
-
VIDEO
-
-
-
-
-
▶
-
0:10
-
-
-
sora2-landscape-10s
-
昨天
-
-
-
-
-
-
-
-
-
-
-
VIDEO
-
-
-
-
-
▶
-
0:15
-
-
-
sora2-portrait-15s
-
2 天前
-
-
-
-
-
-
-
-
🎬
-
还没有任何作品
-
你的创作成果将会展示在这里。前往生成页,开始你的第一次创作吧。
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- ▼
-
-
⚠ 存储未配置
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
📥
-
立即下载
-
当前未配置远程存储,文件仅临时保存。请立即下载以免丢失。
-
-
-
-
-
-
-
-
-
-
-
☁️
-
保存到存储
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-account-apikey/spec.md b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-account-apikey/spec.md
deleted file mode 100644
index e322009b8..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-account-apikey/spec.md
+++ /dev/null
@@ -1,82 +0,0 @@
-## ADDED Requirements
-
-### Requirement: Sora 平台支持 API Key 账号类型
-系统 SHALL 为 Sora 平台新增 "API Key / 上游透传" 账号类型,取消现有 OAuth 硬编码限制。
-
-#### Scenario: 前端创建 Sora API Key 账号
-- **WHEN** 管理员在账号创建对话框中选择 Sora 平台
-- **THEN** 系统 SHALL 显示两个账号类别选项卡:"OAuth 认证"和"API Key / 上游透传"
-- **AND** 选择"API Key / 上游透传"时 SHALL 显示 `Base URL`(必填)和 `API Key`(必填)表单字段
-- **AND** 提交时 `form.type` SHALL 设置为 `'apikey'`
-
-#### Scenario: Base URL 字段校验
-- **WHEN** 管理员创建或编辑 `platform=sora, type=apikey` 账号
-- **THEN** `base_url` SHALL 为必填
-- **AND** `base_url` SHALL 以 `http://` 或 `https://` 开头
-- **AND** 不满足校验时 SHALL 拒绝保存并提示明确错误
-
-#### Scenario: 取消 Sora OAuth 硬编码
-- **WHEN** 用户选择 Sora 平台
-- **THEN** 系统 SHALL 不再强制设置 `form.type = 'oauth'`
-- **AND** SHALL 允许用户选择 OAuth 或 API Key 类型
-
-### Requirement: Sora API Key 账号编辑
-系统 SHALL 支持编辑 Sora API Key 类型账号的 `base_url` 和 `api_key`。
-
-#### Scenario: 编辑 Sora API Key 账号
-- **WHEN** 管理员编辑一个 `platform=sora, type=apikey` 的账号
-- **THEN** 编辑界面 SHALL 显示 `Base URL` 和 `API Key` 可编辑字段
-- **AND** 保存时 SHALL 更新 `credentials` 中的 `base_url` 和 `api_key`
-
-### Requirement: Sora API Key 账号连通性测试
-系统 SHALL 支持 Sora API Key 账号的连通性测试。
-
-#### Scenario: 测试连通性成功
-- **WHEN** 管理员点击"测试连接"
-- **AND** 上游 `base_url` 可达且 `api_key` 有效
-- **THEN** 系统 SHALL 发送轻量级请求到上游验证连通性
-- **AND** 返回测试成功结果
-
-#### Scenario: 测试连通性失败
-- **WHEN** 上游不可达或认证失败
-- **THEN** 系统 SHALL 返回明确的错误信息(如"连接超时"、"认证失败")
-
-### Requirement: Sora apikey 账号 HTTP 透传
-系统 SHALL 对 `type=apikey` 的 Sora 账号执行 HTTP 透传,而非 SDK 直连。
-
-#### Scenario: apikey 账号走 HTTP 透传
-- **WHEN** `SoraGatewayService.Forward()` 检测到 `account.Type == "apikey"` 且 `account.GetBaseURL() != ""`
-- **THEN** 系统 SHALL 调用 `forwardToUpstream()` 方法
-- **AND** SHALL 不使用 `SoraSDKClient` 直连
-
-#### Scenario: HTTP 透传请求构造
-- **WHEN** 系统执行 `forwardToUpstream()`
-- **THEN** 系统 SHALL 构造 HTTP POST 请求到规范化拼接的 `{base_url}/sora/v1/chat/completions`
-- **AND** Header SHALL 包含 `Authorization: Bearer ` 和 `Content-Type: application/json`
-- **AND** 请求体 SHALL 原样透传客户端请求体
-
-#### Scenario: 流式响应透传
-- **WHEN** 上游返回流式 SSE 响应
-- **THEN** 系统 SHALL 逐字节透传 SSE 流到客户端
-- **AND** SHALL 不缓存完整响应
-
-#### Scenario: 非流式响应透传
-- **WHEN** 上游返回非流式 JSON 响应
-- **THEN** 系统 SHALL 读取完整响应后原样返回客户端
-
-#### Scenario: 上游错误触发失败转移
-- **WHEN** 上游返回 401/403/429/5xx 错误
-- **THEN** 系统 SHALL 复用现有的 `UpstreamFailoverError` 机制触发账号切换
-
-### Requirement: sub2api 二级桥接
-系统 SHALL 通过 API Key 账号类型天然支持 sub2api 级联部署。
-
-#### Scenario: 分站通过 API Key 连接总站
-- **WHEN** 分站创建 Sora API Key 账号,`base_url` 指向总站地址
-- **THEN** 分站的 Sora 请求 SHALL 通过 HTTP 透传到总站的 `/sora/v1/chat/completions`
-- **AND** 总站 SHALL 使用自己的 OAuth 账号连接 OpenAI
-
-#### Scenario: 级联中的存储独立性
-- **WHEN** 分站收到总站返回的生成结果
-- **THEN** 分站 SHALL 根据自己的 S3 配置决定是否存储
-- **AND** 存储行为与总站无关(完全独立)
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-client-ui/spec.md b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-client-ui/spec.md
deleted file mode 100644
index bd3466f55..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-client-ui/spec.md
+++ /dev/null
@@ -1,305 +0,0 @@
-## ADDED Requirements
-
-### Requirement: Sora 客户端路由与菜单
-系统 SHALL 在前端新增 Sora 客户端页面,可通过侧边栏菜单访问。菜单项的显示须与现有侧边栏风格一致,并遵循条件显示、简单模式、双菜单同步等现有模式。
-
-#### Scenario: 路由注册
-- **WHEN** 前端路由初始化
-- **THEN** 系统 SHALL 注册 `/sora` 路由,加载 `SoraView.vue` 页面
-- **AND** 路由 meta SHALL 设置 `requiresAuth: true, requiresAdmin: false`
-
-#### Scenario: 侧边栏菜单项(条件显示)
-- **WHEN** 用户登录后查看侧边栏
-- **AND** 公共设置 `sora_client_enabled` 为 true(后端根据是否存在活跃 Sora 账号自动推断)
-- **THEN** 侧边栏 SHALL 显示"Sora"菜单项
-- **AND** 菜单项 SHALL 使用 Heroicons 线性风格的 Sparkles 图标(与现有侧边栏图标统一为 stroke 风格,`h-5 w-5`)
-- **AND** 点击后 SHALL 跳转到 `/sora` 页面
-
-#### Scenario: 菜单项在管理员未启用 Sora 时隐藏
-- **WHEN** 公共设置 `sora_client_enabled` 为 false(无活跃 Sora 账号)
-- **THEN** 侧边栏 SHALL 不显示"Sora"菜单项
-- **AND** 用户直接访问 `/sora` 时 SHALL 显示功能未启用提示页
-
-#### Scenario: 菜单位置与双菜单同步
-- **WHEN** Sora 菜单项显示
-- **THEN** 对于普通用户(`userNavItems`),Sora SHALL 位于"Dashboard"之后、"API 密钥"之前
-- **AND** 对于管理员"我的账户"区域(`personalNavItems`),Sora SHALL 位于"API 密钥"之后、"使用记录"之前
-- **AND** 两个菜单列表 SHALL 同步添加(确保管理员和普通用户均可访问)
-
-#### Scenario: 简单模式隐藏
-- **WHEN** 系统处于简单模式(`isSimpleMode = true`)
-- **THEN** Sora 菜单项 SHALL 隐藏(`hideInSimpleMode: true`)
-
-### Requirement: Sora 客户端页面内导航
-系统 SHALL 在 Sora 客户端页面顶部显示页面内导航栏,仅包含 Tab 切换和配额信息。Sora 页面嵌入在全局侧边栏布局内,不独立展示 Logo 或用户头像(这些已由全局侧边栏提供)。
-
-#### Scenario: 页面内导航栏显示
-- **WHEN** 用户进入 Sora 客户端页面
-- **THEN** 页面顶部 SHALL 显示页面内导航栏,包含"生成"/"作品库" Tab 切换
-- **AND** 右侧 SHALL 显示配额进度条(如 "2.1GB / 5GB")
-- **AND** 导航栏 SHALL 不包含 Logo 和用户头像(避免与全局侧边栏重复)
-- **AND** Sora 页面 SHALL 保留在全局侧边栏布局内渲染(用户可通过侧边栏随时切换到其他页面)
-
-#### Scenario: Tab 切换
-- **WHEN** 用户点击"生成"或"作品库" Tab
-- **THEN** 页面 SHALL 切换到对应视图,不刷新页面
-
-### Requirement: 生成页面 - 底部创作栏
-系统 SHALL 在生成页底部固定显示创作栏,用于输入提示词和配置生成参数。
-
-#### Scenario: 提示词输入
-- **WHEN** 用户在创作栏输入提示词
-- **THEN** 输入框 SHALL 支持多行文本,自动扩展高度
-- **AND** 支持 Ctrl/Cmd + Enter 快捷键触发生成
-
-#### Scenario: 模型选择
-- **WHEN** 用户点击模型选择器
-- **THEN** 系统 SHALL 从 `GET /api/v1/sora/models` 获取可用模型列表
-- **AND** 下拉菜单 SHALL 按视频模型和图片模型分组显示
-
-#### Scenario: 视频参数配置
-- **WHEN** 用户选择视频模型
-- **THEN** 创作栏 SHALL 显示方向选择(横屏/竖屏/方形)和时长选择(10s/15s/25s)
-
-#### Scenario: 图片模型隐藏视频参数
-- **WHEN** 用户选择图片模型(如 gpt-image-1)
-- **THEN** 创作栏 SHALL 隐藏方向选择和时长选择
-
-#### Scenario: 参考图上传
-- **WHEN** 用户点击图片上传按钮
-- **THEN** 系统 SHALL 允许上传参考图片,作为生成输入的 `image_url`
-
-### Requirement: 生成页面 - 发起生成
-系统 SHALL 通过底部创作栏的"生成"按钮发起 Sora 生成请求。
-
-#### Scenario: 发起视频生成
-- **WHEN** 用户填写提示词并点击"生成"按钮
-- **AND** 当前选择视频模型
-- **THEN** 系统 SHALL 发送 `POST /api/v1/sora/generate`,包含 `prompt`、`model`、`media_type=video`、方向、时长参数
-- **AND** 页面 SHALL 显示新的进度卡片(pending 状态)
-
-#### Scenario: 发起图片生成
-- **WHEN** 用户填写提示词并选择图片模型后点击"生成"
-- **THEN** 系统 SHALL 发送生成请求,`media_type=image`
-- **AND** 页面 SHALL 显示新的进度卡片
-
-#### Scenario: 配额不足预防与提示
-- **WHEN** 用户配额使用率超过 90%
-- **THEN** 配额进度条 SHALL 变为黄色警告色,提示"存储空间即将用完"
-- **AND** 配额使用率达 100% 时,"生成"按钮 SHALL 禁用并显示 tooltip "存储配额已满"
-
-#### Scenario: 配额不足错误引导
-- **WHEN** 生成请求返回 HTTP 429(配额不足)
-- **THEN** 页面 SHALL 弹出配额不足对话框,包含:
- - 当前配额使用详情(已用 / 总配额)
- - 引导文案"您可以在作品库中删除不需要的作品来释放存储空间"
- - "前往作品库"按钮(直接跳转到作品库页面)
-
-### Requirement: 生成页面 - 进度展示
-系统 SHALL 在生成页中间区域实时展示当前生成任务的进度状态。
-
-#### Scenario: 排队中状态
-- **WHEN** 生成记录 `status = 'pending'`
-- **THEN** 进度卡片 SHALL 显示"排队中"状态、灰色状态指示、提示词摘要(前 50 字)
-- **AND** SHALL 显示"取消"按钮
-
-#### Scenario: 生成中状态
-- **WHEN** 生成记录 `status = 'generating'`
-- **THEN** 进度卡片 SHALL 显示"生成中"动画、提示词预览
-- **AND** SHALL 显示已等待时长(如"已等待 3:42")和预估剩余时间(如"预计剩余 8 分钟")
-- **AND** SHALL 显示"取消"按钮
-- **AND** 超过 20 分钟未完成时 SHALL 显示"生成时间异常,建议取消重试"
-
-#### Scenario: 生成完成 - 自动保存成功
-- **WHEN** 生成记录 `status = 'completed'` 且 `storage_type = 's3'`
-- **THEN** 进度卡片 SHALL 显示生成结果预览(视频播放器或图片缩略图)
-- **AND** SHALL 显示 "✓ 已保存到云端" 状态标识
-- **AND** SHALL 提供"📥 本地下载"按钮
-- **AND** 作品自动出现在作品库中
-
-#### Scenario: 生成完成 - 降级本地存储
-- **WHEN** 生成记录 `status = 'completed'` 且 `storage_type = 'local'`
-- **THEN** 进度卡片 SHALL 显示 "✓ 已保存到本地" 状态标识
-- **AND** SHALL 提供"📥 本地下载"按钮
-
-#### Scenario: 生成完成 - 无存储(upstream)
-- **WHEN** 生成记录 `status = 'completed'` 且 `storage_type = 'upstream'`
-- **THEN** 进度卡片 SHALL 显示"📥 本地下载"按钮
-- **AND** SHALL 显示 15 分钟过期倒计时进度条(基于 `completed_at` 计算)
-- **AND** 若 S3 当前可用,SHALL 显示可点击的"☁️ 保存到存储"按钮
-- **AND** 若 S3 不可用,"☁️ 保存到存储"按钮 SHALL 禁用并 tooltip "管理员未开通云存储"
-- **AND** 倒计时结束后 SHALL 禁用所有按钮并显示"链接已过期"
-
-#### Scenario: 生成失败状态
-- **WHEN** 生成记录 `status = 'failed'`
-- **THEN** 进度卡片 SHALL 显示分类错误信息:
- - 上游服务错误 → "服务暂时不可用,建议稍后重试"
- - 内容审核不通过 → "提示词包含不支持的内容,请修改后重试"
- - 超时 → "生成超时,建议降低分辨率或时长后重试"
-- **AND** SHALL 提供"重试"按钮(一键以相同参数重新发起)
-- **AND** SHALL 提供"编辑后重试"按钮(将参数回填到创作栏)
-- **AND** SHALL 提供"删除"按钮
-
-#### Scenario: 任务取消状态
-- **WHEN** 生成记录 `status = 'cancelled'`
-- **THEN** 进度卡片 SHALL 显示"已取消"灰色状态
-- **AND** SHALL 提供"重新生成"和"删除"按钮
-
-### Requirement: 生成页面 - 多任务管理与状态恢复
-系统 SHALL 支持多个并发生成任务的展示和页面刷新后的状态恢复。
-
-#### Scenario: 多任务并发展示
-- **WHEN** 用户有多个进行中或刚完成的生成任务
-- **THEN** 生成页中间区域 SHALL 以时间线方式纵向排列所有任务卡片,最新在最上方
-- **AND** 底部创作栏 SHALL 显示当前活跃任务数(如"正在生成 2/3")
-- **AND** 超过并发上限(3 个)时,"生成"按钮 SHALL 禁用并提示"请等待当前任务完成"
-
-#### Scenario: 页面刷新后恢复任务
-- **WHEN** 用户刷新页面或重新进入 Sora 客户端
-- **THEN** 系统 SHALL 调用 `GET /api/v1/sora/generations?status=pending,generating` 获取进行中任务
-- **AND** SHALL 自动恢复所有进度卡片的显示
-- **AND** SHALL 继续对进行中任务执行轮询
-
-#### Scenario: 前端轮询策略
-- **WHEN** 存在 pending 或 generating 状态的任务
-- **THEN** 前端 SHALL 按递减频率轮询 `GET /api/v1/sora/generations/:id`:
- - 0-2 分钟:每 3 秒
- - 2-10 分钟:每 10 秒
- - 10-30 分钟:每 30 秒
-- **AND** 每次轮询结果 SHALL 更新卡片显示
-- **AND** 卡片上 SHALL 显示"最后更新:N 秒前"以确认数据实时性
-
-#### Scenario: 浏览器通知
-- **WHEN** 生成任务完成或失败
-- **AND** 浏览器标签页不在前台
-- **THEN** 系统 SHALL 通过 Notification API 发送桌面通知
-- **AND** 标签页 title SHALL 闪烁提示(如"(1) ✓ 生成完成 - Sora")
-
-### Requirement: 生成页面 - 无存储提醒
-系统 SHALL 在未配置存储时显示醒目提示。
-
-#### Scenario: 无存储警告
-- **WHEN** 用户进入生成页
-- **AND** S3 和本地存储均未配置
-- **THEN** 创作栏 SHALL 显示警告标签"存储未配置,生成后请立即下载"
-
-#### Scenario: S3 可用时自动保存(正常模式)
-- **WHEN** 管理员已开通 S3 存储
-- **AND** 用户存储配额未超限
-- **THEN** 生成完成后系统 SHALL 自动上传到 S3
-- **AND** 卡片 SHALL 显示"✓ 已保存到云端"
-
-#### Scenario: S3 不可用时的降级提示
-- **WHEN** 管理员未开通 S3 存储(`sora_s3_enabled = false`)
-- **THEN** 生成完成后卡片 SHALL 仅显示"📥 本地下载"按钮
-- **AND** "☁️ 保存到存储"按钮 SHALL 禁用并 tooltip "管理员未开通云存储"
-
-#### Scenario: 手动保存到存储(仅 upstream 记录)
-- **WHEN** 生成记录 `storage_type = 'upstream'` 且 S3 当前可用
-- **THEN** "☁️ 保存到存储"按钮 SHALL 可点击
-- **AND** 点击后 SHALL 调用 `POST /api/v1/sora/generations/:id/save`
-- **AND** 上传过程中按钮 SHALL 显示 loading 状态
-- **AND** 上传成功后按钮 SHALL 变为"✓ 已保存"
-- **AND** 上传失败 SHALL 显示错误信息并允许重试
-
-#### Scenario: 无存储生成完成自动提示下载
-- **WHEN** 生成完成且 `storage_type = 'upstream'`
-- **THEN** 客户端 SHALL 弹出提醒弹窗"文件仅临时保存,请在 15 分钟内下载"
-- **AND** SHALL 显示 15 分钟倒计时
-
-#### Scenario: 离开页面未下载警告
-- **WHEN** 存在 `storage_type = 'upstream'` 且未过期的完成记录
-- **AND** 用户尝试离开或关闭页面
-- **THEN** 系统 SHALL 触发 `beforeunload` 事件警告"您有未下载的生成结果,离开后可能丢失"
-
-### Requirement: 作品库页面 - 网格展示
-系统 SHALL 在作品库页面以网格布局展示用户的历史生成作品。
-
-#### Scenario: 作品网格显示
-- **WHEN** 用户切换到"作品库" Tab
-- **THEN** 系统 SHALL 从 `GET /api/v1/sora/generations?storage_type=s3,local` 获取已保存记录
-- **AND** SHALL 以响应式网格展示作品卡片(桌面 4 列、平板 3 列、移动端 1-2 列)
-- **AND** `storage_type = 'upstream'` 或 `'none'` 的记录 SHALL 不在作品库中显示
-- **AND** S3 作品的 URL SHALL 由后端每次请求时动态生成(避免预签名过期)
-
-#### Scenario: 作品卡片信息
-- **WHEN** 作品卡片渲染
-- **THEN** 每张卡片 SHALL 显示:缩略图/视频预览、类型角标(VIDEO/IMAGE)、模型名称、生成时间
-- **AND** 视频卡片 SHALL 显示播放按钮和时长标签
-
-#### Scenario: 卡片 hover 操作
-- **WHEN** 用户 hover 作品卡片
-- **THEN** SHALL 显示"📥 本地下载"和"🗑 删除"操作按钮
-- **AND** 缩略图 SHALL 轻微放大效果(scale 1.05,transition 0.2s)
-
-### Requirement: 作品库页面 - 筛选
-系统 SHALL 支持按类型筛选作品。
-
-#### Scenario: 全部/视频/图片筛选
-- **WHEN** 用户点击筛选按钮(全部/视频/图片)
-- **THEN** 作品网格 SHALL 只显示对应类型的记录
-- **AND** SHALL 更新显示作品数量
-
-#### Scenario: 空状态
-- **WHEN** 筛选结果为空或用户无任何生成记录
-- **THEN** 页面 SHALL 显示空状态引导(图标 + "暂无作品" + "开始创作"按钮)
-
-### Requirement: 作品详情与操作
-系统 SHALL 支持查看作品详情和执行下载、删除操作。
-
-#### Scenario: 查看作品详情
-- **WHEN** 用户点击作品卡片
-- **THEN** 系统 SHALL 弹出预览弹窗,显示完整的媒体内容、提示词、模型信息、生成时间
-
-#### Scenario: 本地下载作品
-- **WHEN** 用户点击"本地下载"按钮
-- **THEN** 系统 SHALL 触发浏览器下载对应媒体文件
-
-#### Scenario: 保存作品到存储
-- **WHEN** 用户点击"保存到存储"按钮
-- **AND** 管理员已开通 S3 存储
-- **THEN** 系统 SHALL 将媒体文件上传到 S3
-- **AND** 更新生成记录的 `storage_type`、`s3_object_keys`
-- **AND** 累加用户存储配额
-
-#### Scenario: 删除作品
-- **WHEN** 用户点击删除按钮
-- **THEN** 系统 SHALL 弹出确认对话框
-- **AND** 确认后调用 `DELETE /api/v1/sora/generations/:id`
-- **AND** 删除成功后 SHALL 从网格中移除卡片并更新配额显示
-
-### Requirement: 暗色主题设计
-系统 SHALL 采用参考 Sora 官方客户端的暗色主题设计。
-
-#### Scenario: 暗色主题样式
-- **WHEN** 用户访问 Sora 客户端页面
-- **THEN** 页面背景 SHALL 为深黑色(`#0D0D0D`)
-- **AND** 文字 SHALL 为白色/浅灰色
-- **AND** 卡片和输入框 SHALL 使用多层次灰色(`#1A1A1A`、`#242424`、`#2A2A2A`)
-- **AND** 导航栏 SHALL 有毛玻璃效果(`backdrop-filter: blur`)
-
-### Requirement: 响应式布局
-系统 SHALL 支持不同屏幕尺寸下的自适应布局。
-
-#### Scenario: 桌面端布局
-- **WHEN** 屏幕宽度 > 1200px
-- **THEN** 作品网格 SHALL 显示 4 列
-
-#### Scenario: 平板端布局
-- **WHEN** 屏幕宽度 900px - 1200px
-- **THEN** 作品网格 SHALL 调整为 3 列
-
-#### Scenario: 移动端布局
-- **WHEN** 屏幕宽度 < 600px
-- **THEN** 作品网格 SHALL 调整为 1-2 列
-
-### Requirement: 国际化支持
-系统 SHALL 为 Sora 客户端所有文案提供中英文国际化支持。
-
-#### Scenario: 中文环境
-- **WHEN** 系统语言设置为中文
-- **THEN** 所有 Sora 客户端文案 SHALL 显示中文
-
-#### Scenario: 英文环境
-- **WHEN** 系统语言设置为英文
-- **THEN** 所有 Sora 客户端文案 SHALL 显示英文
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-generation-gateway/spec.md b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-generation-gateway/spec.md
deleted file mode 100644
index b6574ab28..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-generation-gateway/spec.md
+++ /dev/null
@@ -1,129 +0,0 @@
-## MODIFIED Requirements
-
-### Requirement: Sora 生成网关入口
-系统 SHALL 提供 `POST /v1/chat/completions` 作为 Sora 生成入口(仅限 `platform=sora` 分组)。
-
-#### Scenario: Sora 分组调用 /v1/chat/completions
-- **WHEN** 请求的 API Key 分组平台为 `sora`
-- **AND** 请求体包含 `model` 与 `messages`
-- **THEN** 网关按 Sora 规则处理并返回流式或非流式结果
-- **AND** 若生成需要流式,网关 SHALL 强制 `stream=true` 或返回明确提示
-
-#### Scenario: Sora 专用路由调用 /sora/v1/chat/completions
-- **WHEN** 客户端请求 `POST /sora/v1/chat/completions`
-- **THEN** 网关 SHALL 强制使用 `platform=sora` 的调度与生成逻辑
-
-#### Scenario: 非流式请求策略
-- **WHEN** 客户端请求 `stream=false`
-- **THEN** 网关 SHALL 选择"强制流式并聚合"或"返回明确错误",并在文档中一致说明
-- **AND** 默认策略 SHALL 为"强制流式并聚合"
-
-#### Scenario: 非 Sora 分组调用 /v1/chat/completions
-- **WHEN** 请求的 API Key 分组平台不为 `sora`
-- **THEN** 网关 SHALL 返回 4xx 并提示不支持该平台
-
-#### Scenario: API Key 直接调用不存储不记录
-- **WHEN** 请求通过 `/sora/v1/chat/completions`(API Key 直接调用路径)
-- **THEN** 网关 SHALL 不将媒体文件上传到 S3
-- **AND** SHALL 不执行本地磁盘媒体落盘
-- **AND** SHALL 不写入 `sora_generations` 表
-- **AND** SHALL 不检查存储配额
-- **AND** SHALL 直接返回上游 URL(保持现有行为)
-
-### Requirement: Sora 调度与失败切换
-系统 SHALL 对 Sora 账号执行调度、并发控制、失败切换,与 OpenAI 调度一致。
-
-#### Scenario: 账号可用时成功调度
-- **WHEN** 至少存在一个可调度的 Sora 账号
-- **THEN** 选择优先级最高且最近未使用的账号,并在完成后刷新 LRU
-
-#### Scenario: 上游失败触发切换
-- **WHEN** 上游返回 401/403/429/5xx
-- **THEN** 网关 SHALL 切换账号并重试,直到达到最大切换次数
-
-#### Scenario: apikey 类型账号调度到 HTTP 透传
-- **WHEN** 调度选中的 Sora 账号 `type = 'apikey'` 且 `base_url` 非空
-- **THEN** 网关 SHALL 调用 `forwardToUpstream()` 执行 HTTP 透传
-- **AND** SHALL 不使用 `SoraSDKClient` 直连
-
-## ADDED Requirements
-
-### Requirement: Sora 客户端生成入口
-系统 SHALL 提供 `POST /api/v1/sora/generate` 作为客户端 UI 专用生成入口。
-
-#### Scenario: 客户端 UI 调用生成
-- **WHEN** 用户通过 Sora 客户端 UI 发起生成请求
-- **THEN** 系统 SHALL 接受请求并内部调用现有 `SoraGatewayService.Forward()` 完成生成
-- **AND** 在上层包装存储/记录/配额逻辑
-
-#### Scenario: 客户端生成流程(异步)
-- **WHEN** `POST /api/v1/sora/generate` 收到请求
-- **THEN** 系统 SHALL 按以下顺序执行:
- 1. 检查存储配额(有效配额 > 0 时)
- 2. 检查用户当前 pending+generating 任务数不超过 3
- 3. 创建 `sora_generations` 记录(status=pending)
- 4. **立即返回** `{ generation_id, status: "pending" }` 给前端
- 5. 后台异步:内部调用 `SoraGatewayService.Forward()` 获取上游媒体 URL(不在该步骤落盘)
- 6. 后台异步:自动上传媒体到 S3(若可用),否则降级到本地/上游 URL
- 7. 后台异步:更新生成记录(status、media_url、storage_type、file_size 等)
- 8. 后台异步:累加存储配额(仅 S3/本地存储时)
-
-#### Scenario: 前端轮询生成状态
-- **WHEN** 前端需要获取生成任务最新状态
-- **THEN** 系统 SHALL 通过 `GET /api/v1/sora/generations/:id` 返回完整记录
-- **AND** 前端 SHALL 按递减频率轮询(3s → 10s → 30s)
-
-#### Scenario: 并发生成上限
-- **WHEN** 用户 pending+generating 状态的任务已达 3 个
-- **THEN** 系统 SHALL 返回 HTTP 429 + "请等待当前任务完成后再发起新任务"
-
-### Requirement: Sora 可用模型列表 API
-系统 SHALL 提供 `GET /api/v1/sora/models` 供客户端 UI 获取可用模型。
-
-#### Scenario: 获取可用 Sora 模型
-- **WHEN** 用户请求 `GET /api/v1/sora/models`
-- **THEN** 系统 SHALL 返回系统内置的 Sora 模型列表
-- **AND** 每个模型 SHALL 包含 `id`、`name`、`media_type`(video/image)、`description`
-
-### Requirement: 手动保存到存储
-系统 SHALL 提供 `POST /api/v1/sora/generations/:id/save` 供用户将未自动保存的作品手动上传到 S3。
-
-#### Scenario: 手动保存 upstream 记录到 S3
-- **WHEN** 用户请求 `POST /api/v1/sora/generations/:id/save`
-- **AND** 该记录 `storage_type = 'upstream'` 且 `media_url` 未过期
-- **AND** S3 存储当前可用
-- **THEN** 系统 SHALL 从 `media_url` 下载媒体并上传到 S3
-- **AND** 更新记录 `storage_type = 's3'`、`s3_object_keys`、`file_size_bytes`
-- **AND** 累加用户存储配额
-
-#### Scenario: 手动保存时 URL 已过期
-- **WHEN** 上游 URL 已过期(下载返回 403/404)
-- **THEN** 系统 SHALL 返回 HTTP 410 + "媒体链接已过期,无法保存"
-
-#### Scenario: 手动保存时 S3 不可用
-- **WHEN** S3 存储未启用或配置不完整
-- **THEN** 系统 SHALL 返回 HTTP 503 + "云存储未配置,请联系管理员"
-
-### Requirement: 取消生成任务
-系统 SHALL 提供 `POST /api/v1/sora/generations/:id/cancel` 供用户取消进行中的生成任务。
-
-#### Scenario: 取消 pending/generating 状态的任务
-- **WHEN** 用户请求 `POST /api/v1/sora/generations/:id/cancel`
-- **AND** 该记录 `status` 为 `pending` 或 `generating`
-- **THEN** 系统 SHALL 将记录状态更新为 `cancelled`
-- **AND** SHALL 不累加任何存储配额
-- **AND** 若上游任务已提交,后续返回的结果 SHALL 被忽略
-
-#### Scenario: 取消非活跃状态的任务
-- **WHEN** 该记录 `status` 为 `completed`、`failed` 或 `cancelled`
-- **THEN** 系统 SHALL 返回 HTTP 409 + "任务已结束,无法取消"
-
-### Requirement: 存储状态查询
-系统 SHALL 提供 `GET /api/v1/sora/storage-status` 供前端查询当前存储可用性。
-
-#### Scenario: 查询存储状态
-- **WHEN** 用户请求 `GET /api/v1/sora/storage-status`
-- **THEN** 系统 SHALL 返回 `{ s3_enabled, s3_healthy, local_enabled }`
-- **AND** `s3_enabled` 表示管理员是否启用 S3
-- **AND** `s3_healthy` 表示 S3 客户端是否初始化成功
-- **AND** `local_enabled` 表示本地存储是否可用
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-generation-history/spec.md b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-generation-history/spec.md
deleted file mode 100644
index 5a36554c1..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-generation-history/spec.md
+++ /dev/null
@@ -1,138 +0,0 @@
-## ADDED Requirements
-
-### Requirement: 生成记录数据模型
-系统 SHALL 新建 `sora_generations` 表存储每次 Sora 客户端 UI 生成的元数据。
-
-#### Scenario: 数据库表创建
-- **WHEN** 数据库迁移执行
-- **THEN** 系统 SHALL 创建 `sora_generations` 表,包含以下字段:
- - `id` (BIGSERIAL PRIMARY KEY)
- - `user_id` (BIGINT NOT NULL, FK → users.id ON DELETE CASCADE)
- - `api_key_id` (BIGINT, 可空)
- - `model` (VARCHAR(64) NOT NULL)
- - `prompt` (TEXT NOT NULL DEFAULT '')
- - `media_type` (VARCHAR(16) NOT NULL DEFAULT 'video')
- - `status` (VARCHAR(16) NOT NULL DEFAULT 'pending')
- - `media_url` (TEXT NOT NULL DEFAULT '')
- - `media_urls` (JSONB, 多图 URL 数组)
- - `file_size_bytes` (BIGINT NOT NULL DEFAULT 0)
- - `storage_type` (VARCHAR(16) NOT NULL DEFAULT 'none')
- - `s3_object_keys` (JSONB, S3 object key 数组)
- - `upstream_task_id` (VARCHAR(128) NOT NULL DEFAULT '')
- - `error_message` (TEXT NOT NULL DEFAULT '')
- - `created_at` (TIMESTAMPTZ NOT NULL DEFAULT NOW())
- - `completed_at` (TIMESTAMPTZ)
-- **AND** SHALL 创建 `(user_id, created_at DESC)` 普通索引(非唯一)
-- **AND** SHALL 创建 `(user_id, status)` 索引
-
-### Requirement: 创建生成记录
-系统 SHALL 在客户端 UI 发起生成时创建记录,并在生成过程中更新状态。
-
-#### Scenario: 发起生成时创建 pending 记录
-- **WHEN** 用户通过 `POST /api/v1/sora/generate` 发起生成
-- **THEN** 系统 SHALL 在 `sora_generations` 中创建一条 `status = 'pending'` 的记录
-- **AND** 记录 SHALL 包含 `user_id`、`model`、`prompt`、`media_type`
-
-#### Scenario: 上游开始处理时更新为 generating
-- **WHEN** 上游开始处理生成任务
-- **THEN** 系统 SHALL 更新记录 `status = 'generating'`
-- **AND** 记录 `upstream_task_id`
-
-#### Scenario: 生成成功时更新为 completed
-- **WHEN** 生成完成且媒体文件存储成功
-- **THEN** 系统 SHALL 更新记录 `status = 'completed'`
-- **AND** 更新 `media_url`、`media_urls`、`file_size_bytes`、`storage_type`、`s3_object_keys`、`completed_at`
-
-#### Scenario: 生成失败时更新为 failed
-- **WHEN** 生成过程中发生错误
-- **THEN** 系统 SHALL 更新记录 `status = 'failed'`
-- **AND** 记录 `error_message`
-
-#### Scenario: 用户取消生成
-- **WHEN** 用户通过 `POST /api/v1/sora/generations/:id/cancel` 取消任务
-- **AND** 记录状态为 `pending` 或 `generating`
-- **THEN** 系统 SHALL 更新记录 `status = 'cancelled'`
-- **AND** SHALL 不累加配额
-
-#### Scenario: 手动保存到存储后更新
-- **WHEN** 用户对 `storage_type = 'upstream'` 的记录手动触发保存
-- **AND** S3 上传成功
-- **THEN** 系统 SHALL 更新 `storage_type = 's3'`、`s3_object_keys`、`file_size_bytes`
-- **AND** 累加存储配额
-
-### Requirement: 查询生成历史列表
-系统 SHALL 提供分页查询用户生成历史的 API。
-
-#### Scenario: 获取用户生成历史
-- **WHEN** 用户请求 `GET /api/v1/sora/generations`
-- **THEN** 系统 SHALL 返回当前用户的生成记录列表,按 `created_at DESC` 排序
-- **AND** 支持分页参数 `page`(默认 1)和 `page_size`(默认 20,最大 100)
-
-#### Scenario: 按媒体类型筛选
-- **WHEN** 请求携带 `media_type=video` 或 `media_type=image`
-- **THEN** 系统 SHALL 只返回对应类型的记录
-
-#### Scenario: 按状态筛选
-- **WHEN** 请求携带 `status=completed`
-- **THEN** 系统 SHALL 只返回对应状态的记录
-
-#### Scenario: 按存储类型筛选(作品库专用)
-- **WHEN** 请求携带 `storage_type=s3,local`
-- **THEN** 系统 SHALL 返回已持久化存储(S3 或本地)的记录
-- **AND** 作品库页面默认 SHALL 使用 `storage_type=s3,local` 筛选,展示所有已保存的作品
-- **AND** `storage_type='upstream'` 和 `'none'` 的记录 SHALL 不在作品库中显示
-
-#### Scenario: 预签名 URL 动态生成
-- **WHEN** 返回 `storage_type = 's3'` 的记录列表
-- **AND** 未配置 CDN URL
-- **THEN** 系统 SHALL 为每条记录动态生成新的 S3 预签名 URL(24 小时有效)
-- **AND** 前端 SHALL 不缓存媒体 URL
-
-#### Scenario: 恢复进行中的任务
-- **WHEN** 请求携带 `status=pending,generating`
-- **THEN** 系统 SHALL 返回用户所有进行中的生成任务
-- **AND** 前端页面加载时 SHALL 调用此接口恢复任务进度显示
-
-### Requirement: 查询生成详情
-系统 SHALL 提供查询单条生成记录详情的 API。
-
-#### Scenario: 获取生成详情
-- **WHEN** 用户请求 `GET /api/v1/sora/generations/:id`
-- **AND** 该记录属于当前用户
-- **THEN** 系统 SHALL 返回完整的生成记录详情
-
-#### Scenario: 访问他人记录返回 404
-- **WHEN** 用户请求的生成记录不属于当前用户
-- **THEN** 系统 SHALL 返回 HTTP 404
-
-### Requirement: 删除生成记录
-系统 SHALL 提供删除生成记录的 API,并联动清理存储文件和配额。
-
-#### Scenario: 删除单条记录
-- **WHEN** 用户请求 `DELETE /api/v1/sora/generations/:id`
-- **AND** 该记录属于当前用户
-- **THEN** 系统 SHALL 删除数据库记录
-- **AND** 若 `storage_type = 's3'`,SHALL 删除 S3 文件
-- **AND** 若 `storage_type = 'local'`,SHALL 删除本地文件
-- **AND** SHALL 释放对应的存储配额
-
-#### Scenario: 删除不存在的记录
-- **WHEN** 记录不存在或不属于当前用户
-- **THEN** 系统 SHALL 返回 HTTP 404
-
-### Requirement: 无存储模式下保留生成历史
-系统 SHALL 在无存储可用时仍记录生成元数据。
-
-#### Scenario: 无存储时记录元数据
-- **WHEN** S3 和本地存储均不可用
-- **AND** 客户端 UI 生成完成
-- **THEN** 系统 SHALL 创建生成记录,`storage_type = 'upstream'`
-- **AND** `media_url` 为上游临时 URL
-- **AND** 系统 SHALL 不累加存储配额
-
-#### Scenario: 过期 URL 标记与倒计时
-- **WHEN** 生成记录的 `storage_type = 'upstream'`
-- **THEN** 客户端 SHALL 显示 15 分钟倒计时进度条(基于 `completed_at` 计算剩余时间)
-- **AND** 剩余 5 分钟时 SHALL 通过浏览器通知提醒用户
-- **AND** 剩余 2 分钟时卡片边框 SHALL 变为红色警告态
-- **AND** 超过 15 分钟后 SHALL 显示"链接已过期,作品无法恢复",禁用下载和保存按钮
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-s3-media-storage/spec.md b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-s3-media-storage/spec.md
deleted file mode 100644
index 6d226c62c..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-s3-media-storage/spec.md
+++ /dev/null
@@ -1,104 +0,0 @@
-## ADDED Requirements
-
-### Requirement: S3 媒体存储服务初始化
-系统 SHALL 在启动时从系统设置(Settings 表)读取 Sora S3 配置,使用 `aws-sdk-go-v2` 初始化 S3 客户端。
-
-#### Scenario: Sora S3 已启用且配置完整
-- **WHEN** 系统启动或 S3 配置变更
-- **AND** Settings 中 `sora_s3_enabled = true` 且必填字段(endpoint、bucket、access_key_id、secret_access_key)均已配置
-- **THEN** 系统 SHALL 使用 `aws-sdk-go-v2` 初始化 S3 客户端
-- **AND** 系统 SHALL 缓存 S3 客户端实例,标记 S3 存储为可用
-
-#### Scenario: Sora S3 未启用或配置不完整
-- **WHEN** 系统启动或 S3 配置变更
-- **AND** `sora_s3_enabled = false` 或缺少必填配置
-- **THEN** 系统 SHALL 标记 S3 存储为不可用
-- **AND** 客户端 UI 调用路径 SHALL 降级为本地存储或即生即下载模式
-
-### Requirement: 媒体文件上传到 S3
-系统 SHALL 将 Sora 客户端 UI 生成的媒体文件流式上传到 S3 兼容存储。
-
-#### Scenario: 视频文件上传成功
-- **WHEN** Sora 客户端 UI 调用路径生成完成,返回上游媒体 URL
-- **AND** S3 存储可用
-- **THEN** 系统 SHALL 使用流式管道(`io.Pipe`)从上游 URL 下载并同时上传到 S3
-- **AND** S3 object key 格式 SHALL 为 `sora/{user_id}/{YYYY/MM/DD}/{uuid}.{ext}`
-- **AND** 上传完成后 SHALL 返回 S3 访问 URL(签名 URL 或 CDN URL)
-- **AND** 系统 SHALL 记录 `s3_object_keys` 数组到生成记录中(视频为单元素数组)
-
-#### Scenario: 图片文件上传成功
-- **WHEN** Sora 客户端 UI 生成图片完成
-- **AND** S3 存储可用
-- **THEN** 系统 SHALL 使用与视频相同的上传流程将图片上传到 S3
-- **AND** 支持多图场景(`media_urls` 数组中每个 URL 都上传)
-
-#### Scenario: S3 上传失败降级
-- **WHEN** S3 上传过程中发生错误(网络超时、权限错误等)
-- **THEN** 系统 SHALL 降级到本地磁盘存储(复用现有 `SoraMediaStorage`)
-- **AND** 若本地存储也失败,SHALL 降级为返回上游临时 URL
-- **AND** 生成记录的 `storage_type` SHALL 反映实际存储位置
-
-#### Scenario: 大文件流式上传避免内存溢出
-- **WHEN** 上游媒体文件大于 50MB
-- **THEN** 系统 SHALL 使用流式管道上传,不将完整文件缓存到内存
-- **AND** 内存峰值 SHALL 不超过 16MB 缓冲区
-
-### Requirement: S3 文件删除
-系统 SHALL 在用户删除生成记录时同步删除 S3 中对应的文件。
-
-#### Scenario: 删除 S3 文件(含多图)
-- **WHEN** 用户通过作品库删除一条生成记录
-- **AND** 该记录的 `storage_type = 's3'` 且 `s3_object_keys` 非空
-- **THEN** 系统 SHALL 遍历 `s3_object_keys` 数组,逐一调用 S3 DeleteObject 删除所有文件
-- **AND** 释放对应的存储配额(`sora_storage_used_bytes` 减去 `file_size_bytes`)
-
-#### Scenario: S3 删除失败不阻塞记录删除
-- **WHEN** S3 DeleteObject 调用失败(部分或全部)
-- **THEN** 系统 SHALL 仍然删除数据库中的生成记录
-- **AND** 系统 SHALL 记录告警日志,包含失败的 `s3_object_keys` 以便后续清理
-
-### Requirement: 三层降级链
-系统 SHALL 支持 S3 → 本地磁盘 → 上游临时 URL 的三层存储降级。
-
-#### Scenario: S3 可用时优先使用 S3
-- **WHEN** 客户端 UI 生成完成
-- **AND** S3 存储可用
-- **THEN** 系统 SHALL 使用 S3 存储,`storage_type = 's3'`
-
-#### Scenario: S3 不可用时降级到本地
-- **WHEN** 客户端 UI 生成完成
-- **AND** S3 存储不可用但本地存储启用
-- **THEN** 系统 SHALL 使用本地存储,`storage_type = 'local'`
-
-#### Scenario: 均不可用时透传上游 URL
-- **WHEN** 客户端 UI 生成完成
-- **AND** S3 和本地存储均不可用
-- **THEN** 系统 SHALL 直接返回上游临时 URL,`storage_type = 'upstream'`
-- **AND** 客户端 SHALL 显示即时下载提示
-
-### Requirement: S3 访问 URL 生成策略
-系统 SHALL 为 S3 中的媒体文件按配置生成可访问 URL(CDN 优先,预签名兜底)。
-
-#### Scenario: 配置 CDN URL 时返回 CDN 地址
-- **WHEN** 系统设置中配置了 `sora_s3_cdn_url`
-- **THEN** 系统 SHALL 返回基于 `sora_s3_cdn_url + object_key` 的访问地址
-- **AND** SHALL 不额外生成预签名 URL
-
-#### Scenario: 未配置 CDN URL 时生成预签名 URL
-- **WHEN** 系统未配置 `sora_s3_cdn_url`
-- **THEN** 系统 SHALL 生成 S3 预签名 URL,有效期 SHALL 为 24 小时
-- **AND** URL SHALL 支持直接在浏览器中播放/查看
-
-### Requirement: 预签名 URL 动态刷新
-系统 SHALL 在返回 S3 媒体记录时动态生成访问 URL,避免预签名过期导致作品库碎图。
-
-#### Scenario: 列表 API 动态生成 URL
-- **WHEN** `GET /api/v1/sora/generations` 返回 `storage_type = 's3'` 的记录
-- **AND** 未配置 CDN URL
-- **THEN** 后端 SHALL 为每条记录的 `s3_object_keys` 动态生成新的预签名 URL 填充到 `media_url` / `media_urls`
-- **AND** 前端 SHALL 不缓存这些 URL
-
-#### Scenario: 详情 API 动态生成 URL
-- **WHEN** `GET /api/v1/sora/generations/:id` 返回 `storage_type = 's3'` 的记录
-- **THEN** 后端 SHALL 动态生成预签名 URL
-- **AND** 批量签名性能 SHALL 不影响列表加载速度(使用并发签名或缓存短期 URL)
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-s3-settings/spec.md b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-s3-settings/spec.md
deleted file mode 100644
index da9aea93f..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-s3-settings/spec.md
+++ /dev/null
@@ -1,39 +0,0 @@
-## ADDED Requirements
-
-### Requirement: Sora S3 存储配置
-系统 SHALL 在系统设置中提供独立的 Sora S3 存储配置,使用 `aws-sdk-go-v2` 直连 S3 兼容存储,不依赖现有数据管理的 gRPC 代理。
-
-#### Scenario: 系统设置新增 Sora S3 配置项
-- **WHEN** 管理员访问系统设置页面
-- **THEN** 页面 SHALL 显示"Sora S3 存储配置"区域
-- **AND** 包含以下配置项:
- - 启用开关(`sora_s3_enabled`)
- - S3 端点(`sora_s3_endpoint`)
- - 区域(`sora_s3_region`)
- - 存储桶(`sora_s3_bucket`)
- - 访问密钥 ID(`sora_s3_access_key_id`)
- - 访问密钥(`sora_s3_secret_access_key`,加密存储,界面显示为密码框)
- - 对象键前缀(`sora_s3_prefix`,可选)
- - 强制路径模式(`sora_s3_force_path_style`,可选)
- - CDN 域名(`sora_s3_cdn_url`,可选)
-
-#### Scenario: 保存 Sora S3 配置
-- **WHEN** 管理员填写 S3 配置并点击保存
-- **THEN** 系统 SHALL 将配置保存到 Settings 表
-- **AND** `sora_s3_secret_access_key` SHALL 加密存储
-- **AND** Sora S3 Storage Service SHALL 刷新缓存的 S3 客户端配置
-
-#### Scenario: 测试 S3 连接
-- **WHEN** 管理员点击"测试连接"按钮
-- **THEN** 系统 SHALL 使用当前表单中的配置创建临时 S3 客户端
-- **AND** 执行 `HeadBucket` 或 `PutObject` + `DeleteObject` 测试连通性
-- **AND** 返回测试结果(成功/失败 + 错误信息)
-
-#### Scenario: 禁用 Sora S3 存储
-- **WHEN** 管理员关闭 `sora_s3_enabled` 开关
-- **THEN** Sora 客户端 UI 的生成结果 SHALL 降级到本地存储或上游 URL 透传
-
-#### Scenario: S3 配置不完整
-- **WHEN** `sora_s3_enabled = true` 但缺少必填字段(endpoint/bucket/access_key_id/secret_access_key)
-- **THEN** 系统 SHALL 视为 S3 存储不可用
-- **AND** SHALL 在日志中记录配置不完整的警告
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-user-storage-quota/spec.md b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-user-storage-quota/spec.md
deleted file mode 100644
index ae899c87a..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/specs/sora-user-storage-quota/spec.md
+++ /dev/null
@@ -1,91 +0,0 @@
-## ADDED Requirements
-
-### Requirement: 用户存储配额字段
-系统 SHALL 在 `users` 表新增 Sora 存储配额字段,用于追踪每个用户的配额和用量。
-
-#### Scenario: 用户表新增配额字段
-- **WHEN** 数据库迁移执行
-- **THEN** `users` 表 SHALL 新增 `sora_storage_quota_bytes BIGINT NOT NULL DEFAULT 0` 字段(0 表示使用系统默认)
-- **AND** `users` 表 SHALL 新增 `sora_storage_used_bytes BIGINT NOT NULL DEFAULT 0` 字段
-
-### Requirement: 系统默认配额设置
-系统 SHALL 提供全局默认 Sora 存储配额设置,管理员可在系统设置中配置。
-
-#### Scenario: 管理员设置全局默认配额
-- **WHEN** 管理员在系统设置中设置 `sora_default_storage_quota_bytes`
-- **THEN** 系统 SHALL 将该值保存到 Settings 表
-- **AND** 所有未单独设置配额的用户 SHALL 使用该默认值
-
-#### Scenario: 未设置全局默认配额
-- **WHEN** `sora_default_storage_quota_bytes` 未设置或为 0
-- **THEN** 系统 SHALL 不限制用户存储空间(即无配额限制)
-
-### Requirement: 配额优先级判断
-系统 SHALL 按用户级 → 分组级 → 系统默认的优先级计算有效配额。
-
-#### Scenario: 用户级配额优先
-- **WHEN** 用户 `sora_storage_quota_bytes > 0`
-- **THEN** 有效配额 SHALL 为用户级配额值
-
-#### Scenario: 分组级配额次优先
-- **WHEN** 用户 `sora_storage_quota_bytes = 0`(未单独设置)
-- **AND** 用户所属分组 `sora_storage_quota_bytes > 0`
-- **THEN** 有效配额 SHALL 为分组级配额值
-
-#### Scenario: 系统默认配额兜底
-- **WHEN** 用户和分组的配额均未设置(均为 0)
-- **THEN** 有效配额 SHALL 为 `settings.sora_default_storage_quota_bytes`
-
-### Requirement: 生成前配额检查
-系统 SHALL 在客户端 UI 调用路径发起生成前检查存储配额。
-
-#### Scenario: 配额充足允许生成
-- **WHEN** 用户发起 Sora 客户端生成请求
-- **AND** `sora_storage_used_bytes < 有效配额`
-- **THEN** 系统 SHALL 允许生成请求继续
-
-#### Scenario: 配额不足拒绝生成
-- **WHEN** 用户发起 Sora 客户端生成请求
-- **AND** `sora_storage_used_bytes >= 有效配额`
-- **AND** 有效配额 > 0
-- **THEN** 系统 SHALL 返回 HTTP 429 错误
-- **AND** 响应 SHALL 包含 `{ quota_bytes, used_bytes, message: "存储配额已满,请删除不需要的作品释放空间" }`
-- **AND** 响应 SHALL 包含 `guide: "delete_works"` 字段,前端据此显示引导对话框
-
-#### Scenario: 无配额限制时不检查
-- **WHEN** 有效配额 = 0(系统默认也未设置)
-- **THEN** 系统 SHALL 跳过配额检查,允许生成
-
-### Requirement: 配额原子更新
-系统 SHALL 使用原子操作更新用户已用存储空间,防止并发超额。
-
-#### Scenario: 生成完成后累加用量
-- **WHEN** 媒体文件上传到 S3/本地存储成功
-- **THEN** 系统 SHALL 在计算出 `effective_quota` 后执行原子 SQL:`UPDATE users SET sora_storage_used_bytes = sora_storage_used_bytes + :file_size WHERE id = :id AND (:effective_quota = 0 OR sora_storage_used_bytes + :file_size <= :effective_quota)`
-- **AND** 若原子更新失败(超额),系统 SHALL 删除已上传的文件并返回配额错误
-
-#### Scenario: 删除作品后释放配额
-- **WHEN** 用户删除一条生成记录
-- **AND** 该记录 `file_size_bytes > 0`
-- **THEN** 系统 SHALL 执行 `UPDATE users SET sora_storage_used_bytes = sora_storage_used_bytes - file_size WHERE id = ?`
-- **AND** `sora_storage_used_bytes` SHALL 不低于 0
-
-### Requirement: 配额查询 API
-系统 SHALL 提供配额查询接口,用户可查看当前用量和剩余空间。
-
-#### Scenario: 查询用户 Sora 配额
-- **WHEN** 用户请求 `GET /api/v1/sora/quota`
-- **THEN** 系统 SHALL 返回 `{ quota_bytes, used_bytes, available_bytes, quota_source }`
-- **AND** `quota_source` SHALL 标明配额来源("user" / "group" / "system" / "unlimited")
-
-### Requirement: 管理员配额管理
-管理员 SHALL 可以在用户管理和分组管理中设置 Sora 存储配额。
-
-#### Scenario: 管理员设置单个用户配额
-- **WHEN** 管理员在用户编辑页面设置 Sora 存储配额
-- **THEN** 系统 SHALL 更新 `users.sora_storage_quota_bytes`
-
-#### Scenario: 管理员设置分组配额
-- **WHEN** 管理员在分组管理中设置 Sora 存储配额
-- **THEN** 系统 SHALL 更新 `groups.sora_storage_quota_bytes` 字段
-- **AND** 该分组下所有未单独设置配额的用户 SHALL 使用分组配额
diff --git a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/tasks.md b/openspec/changes/archive/2026-02-27-sora-client-s3-storage/tasks.md
deleted file mode 100644
index ae6d3597e..000000000
--- a/openspec/changes/archive/2026-02-27-sora-client-s3-storage/tasks.md
+++ /dev/null
@@ -1,150 +0,0 @@
-## 1. 数据库迁移
-
-- [x] 1.1 创建 `sora_generations` 表迁移脚本(含 `s3_object_keys JSONB` 数组字段、所有索引、外键约束)
-- [x] 1.2 `users` 表新增 `sora_storage_quota_bytes` 和 `sora_storage_used_bytes` 字段
-- [x] 1.3 `groups` 表新增 `sora_storage_quota_bytes` 字段
-- [x] 1.4 系统设置新增 `sora_default_storage_quota_bytes` 键值
-
-## 2. Sora S3 存储配置(系统设置)
-
-- [x] 2.1 后端:Settings 表新增 Sora S3 配置键值(sora_s3_enabled、sora_s3_endpoint、sora_s3_region、sora_s3_bucket、sora_s3_access_key_id、sora_s3_secret_access_key、sora_s3_prefix、sora_s3_force_path_style、sora_s3_cdn_url)
-- [x] 2.2 后端:系统设置 API 新增 Sora S3 配置读写接口(含 secret_access_key 加密存储)
-- [x] 2.3 后端:新增 Sora S3 连接测试接口(HeadBucket 验证连通性)
-- [x] 2.4 前端:系统设置页面新增"Sora S3 存储配置"区域(启用开关 + S3 连接表单 + 测试连接按钮)
-
-## 3. Sora API Key 账号类型(sora-account-apikey)
-
-- [x] 3.1 前端 `CreateAccountModal.vue`:取消 Sora 平台 OAuth 硬编码限制(第 2597-2601 行)
-- [x] 3.2 前端 `CreateAccountModal.vue`:新增 Sora 平台的"API Key / 上游透传"选项卡和表单(base_url + api_key)
-- [x] 3.3 前端 `EditAccountModal.vue`:支持编辑 Sora apikey 类型账号
-- [x] 3.4 前端 `credentialsBuilder.ts`:新增 Sora apikey 类型的 credentials 构建逻辑
-- [x] 3.5 后端 `sora_gateway_service.go`:`Forward()` 方法新增 apikey 类型分支判断
-- [x] 3.6 后端新增 `sora_upstream_forwarder.go`:实现 `forwardToUpstream()` HTTP 透传方法(流式+非流式)
-- [x] 3.7 后端 apikey 透传错误处理:复用 `UpstreamFailoverError` 机制实现失败转移
-- [x] 3.8 前端/后端:Sora apikey 账号连通性测试支持
-- [x] 3.9 前端/后端:Sora apikey 账号 `base_url` 校验(必填 + scheme 合法)与上游 URL 规范化拼接
-
-## 4. S3 媒体存储服务(sora-s3-media-storage)
-
-- [x] 4.1 引入 `aws-sdk-go-v2` 依赖;新增 `service/sora_s3_storage.go`:从 Settings 表读取 S3 配置,初始化 aws-sdk-go-v2 S3 客户端并缓存
-- [x] 4.2 实现流式上传方法:从上游 URL 下载并通过 `io.Pipe` 流式上传到 S3
-- [x] 4.3 实现 S3 object key 命名规则:`sora/{user_id}/{YYYY/MM/DD}/{uuid}.{ext}`,多图生成多个 key 存入 `s3_object_keys` JSONB 数组
-- [x] 4.4 实现 S3 访问 URL 策略:CDN URL 优先,否则动态生成 24h 预签名 URL(列表/详情接口每次请求时重新签名)
-- [x] 4.5 实现 S3 文件删除方法(遍历 `s3_object_keys` 数组逐一删除)
-- [x] 4.6 实现三层降级链逻辑:S3 → 本地(复用 SoraMediaStorage)→ 上游临时 URL
-- [x] 4.7 系统设置中 S3 配置变更时自动刷新缓存的 S3 客户端
-
-## 5. 用户存储配额管理(sora-user-storage-quota)
-
-- [x] 5.1 新增 `service/sora_quota_service.go`:配额优先级判断逻辑(用户 → 分组 → 系统默认)
-- [x] 5.2 实现配额检查方法:生成前检查存储是否超限
-- [x] 5.3 实现配额原子更新:上传成功后累加用量,删除后释放用量
-- [x] 5.4 实现配额查询 API:`GET /api/v1/sora/quota` 返回配额信息
-
-## 6. 生成记录管理(sora-generation-history)
-
-- [x] 6.1 新增 `service/sora_generation_service.go`:生成记录 CRUD 方法
-- [x] 6.2 实现创建记录(pending → generating → completed/failed 状态流转)
-- [x] 6.3 实现查询历史列表(分页 + 按类型/状态筛选 + 按创建时间倒序)
-- [x] 6.4 实现查询详情(权限校验:只能查看自己的记录)
-- [x] 6.5 实现删除记录(联动 S3/本地文件清理 + 配额释放)
-- [x] 6.6 无存储模式下记录元数据(storage_type='upstream',不累加配额)
-
-## 7. Sora 客户端 Handler 与路由(sora-generation-gateway)
-
-- [x] 7.1 新增 `handler/sora_client_handler.go`:客户端 API Handler
-- [x] 7.2 实现 `POST /api/v1/sora/generate`(异步):配额检查 → 并发数检查(≤3) → 创建 pending 记录 → **立即返回 generation_id** → 后台异步(Forward → 自动上传S3/降级 → 更新记录 → 累加配额)
-- [x] 7.3 实现 `GET /api/v1/sora/generations`:历史列表接口(支持 status/storage_type/media_type 筛选;S3 记录动态生成预签名 URL)
-- [x] 7.4 实现 `GET /api/v1/sora/generations/:id`:详情接口(动态预签名 URL)
-- [x] 7.5 实现 `DELETE /api/v1/sora/generations/:id`:删除接口
-- [x] 7.6 实现 `GET /api/v1/sora/quota`:配额查询接口
-- [x] 7.7 实现 `GET /api/v1/sora/models`:可用模型列表接口
-- [x] 7.8 注册路由:`server/routes/` 新增 `/api/v1/sora/*` 路由组
-- [x] 7.9 调整 `/sora/v1/chat/completions` 直调路径:保持纯透传,不执行本地/S3 媒体落盘
-- [x] 7.10 实现 `POST /api/v1/sora/generations/:id/save`:手动保存到 S3(仅 upstream 记录,含 URL 过期检测)
-- [x] 7.11 实现 `POST /api/v1/sora/generations/:id/cancel`:取消生成任务(标记 cancelled,忽略后续结果)
-- [x] 7.12 实现 `GET /api/v1/sora/storage-status`:返回 { s3_enabled, s3_healthy, local_enabled }
-
-## 8. 管理员配额管理界面
-
-- [x] 8.1 系统设置页面:新增"Sora 默认存储配额"设置项 — 集成在 Sora S3 存储配置卡片中
-- [x] 8.2 用户管理页面:用户编辑表单新增"Sora 存储配额"字段
-- [x] 8.3 分组管理页面:分组编辑表单新增"Sora 存储配额"字段
-- [x] 8.4 后端 API 适配:用户/分组的创建和更新接口支持新增字段
-
-## 9. Sora 客户端前端 - 基础框架(sora-client-ui)
-
-- [x] 9.1 新增 `views/user/SoraView.vue`:Sora 客户端主页面容器(暗色主题)
-- [x] 9.2 新增 `components/sora/SoraNavBar.vue`:页面内导航栏(仅 Tab 切换 + 配额条,不含 Logo/头像)— Tab 导航集成在 SoraView.vue 中
-- [x] 9.3 前端路由注册:`router/index.ts` 新增 `/sora` 路由(`requiresAuth: true, requiresAdmin: false`)
-- [x] 9.4 侧边栏菜单:`AppSidebar.vue` 新增 Sora 菜单项(Sparkles 线性图标,`hideInSimpleMode: true`),同时添加到 `userNavItems`(Dashboard 之后)和 `personalNavItems`(API 密钥之后),条件显示 `sora_client_enabled`
-- [x] 9.5 API 模块:新增 `api/sora.ts`,封装所有 Sora 客户端 API 调用
-- [x] 9.6 后端公共设置 API:新增 `sora_client_enabled` 字段到公共设置响应(根据活跃 Sora 账号数 > 0 推断)
-- [x] 9.7 功能未启用提示页:用户直接访问 `/sora` 但 `sora_client_enabled = false` 时显示提示
-
-## 10. Sora 客户端前端 - 生成页
-
-- [x] 10.1 新增 `components/sora/SoraGeneratePage.vue`:生成页主容器(多任务时间线布局)
-- [x] 10.2 新增 `components/sora/SoraPromptBar.vue`:底部创作栏(提示词输入 + 参数选择 + 生成按钮 + 活跃任务计数)
-- [x] 10.3 新增 `components/sora/SoraModelSelector.vue`:模型选择下拉(视频/图片分组)— 集成在 SoraPromptBar 中
-- [x] 10.4 新增 `components/sora/SoraProgressCard.vue`:生成进度卡片(6 种状态:pending/generating/completed-s3/completed-upstream/failed/cancelled)
- - pending/generating:显示已等待时长 + 预估剩余 + 取消按钮
- - completed-s3:显示"✓ 已保存到云端" + 本地下载
- - completed-upstream:显示 15 分钟倒计时 + 本地下载 + 保存到存储
- - failed:分类错误信息 + 重试/编辑后重试/删除
- - cancelled:已取消 + 重新生成/删除
-- [x] 10.5 新增 `components/sora/SoraNoStorageWarning.vue`:无存储提示组件
-- [x] 10.6 实现 Ctrl/Cmd + Enter 快捷键触发生成
-- [x] 10.7 实现图片模型切换时隐藏方向/时长选择
-- [x] 10.8 实现参考图上传功能
-- [x] 10.9 实现前端轮询机制:递减频率(3s→10s→30s)轮询 GET /api/v1/sora/generations/:id
-- [x] 10.10 实现页面加载时恢复进行中任务(GET /api/v1/sora/generations?status=pending,generating)
-- [x] 10.11 实现浏览器通知(Notification API):任务完成/失败时通知 + 标签页 title 闪烁
-- [x] 10.12 实现 beforeunload 警告:存在未下载的 upstream 记录时阻止离开
-- [x] 10.13 实现 upstream 记录的 15 分钟倒计时 UI(进度条 + 红色警告态)
-- [x] 10.14 实现取消生成功能:调用 POST /api/v1/sora/generations/:id/cancel + 二次确认
-
-## 11. Sora 客户端前端 - 作品库页
-
-- [x] 11.1 新增 `components/sora/SoraLibraryPage.vue`:作品库页主容器(请求 storage_type=s3,local 筛选已保存作品)
-- [x] 11.2 新增 `components/sora/SoraLibraryGrid.vue`:响应式网格布局(CSS Grid auto-fill, 4→3→2→1 列)— 集成在 SoraLibraryPage 中
-- [x] 11.3 新增 `components/sora/SoraMediaCard.vue`:作品卡片(缩略图 + 类型角标 + hover 显示下载和删除)— 集成在 SoraLibraryPage 中
-- [x] 11.4 新增 `components/sora/SoraEmptyState.vue`:空状态引导(图标 + "暂无作品" + "开始创作"按钮)— 集成在 SoraLibraryPage 中
-- [x] 11.5 实现全部/视频/图片筛选功能
-- [x] 11.6 实现分页加载(滚动加载或按钮加载)
-
-## 12. Sora 客户端前端 - 弹窗与辅助组件
-
-- [x] 12.1 新增 `components/sora/SoraMediaPreview.vue`:作品详情预览弹窗
-- [x] 12.2 新增 `components/sora/SoraQuotaBar.vue`:配额进度条组件
-- [x] 12.3 新增 `components/sora/SoraDownloadDialog.vue`:即时下载弹窗(无存储模式)
-- [x] 12.4 实现视频缩略图:前端用 `` 标签截取第一帧 — 使用 hover 播放方式实现
-
-## 13. 国际化
-
-- [x] 13.1 添加 Sora 客户端中文翻译文本(生成页、作品库、配额、错误提示等)
-- [x] 13.2 添加 Sora 客户端英文翻译文本
-- [x] 13.3 添加 Sora API Key 账号相关的中英文翻译文本 — 之前已有
-- [x] 13.4 添加 Sora S3 存储设置相关的中英文翻译文本
-
-## 14. 集成测试与验证
-
-- [x] 14.1 验证 API Key 直接调用路径 (`/sora/v1/chat/completions`) 保持完全向后兼容
-- [x] 14.2 验证客户端 UI 异步生成完整流程:generate → 立即返回 → 轮询 → 自动 S3 上传 → 记录 → 配额
-- [x] 14.3 验证三层降级链:S3 → 本地 → 上游 URL
-- [x] 14.4 验证配额超限拒绝、引导对话框和释放逻辑
-- [x] 14.5 验证 Sora apikey 账号 HTTP 透传和 sub2api 级联部署
-- [x] 14.6 验证无存储模式下的倒计时 + 即时下载 + beforeunload 警告
-- [x] 14.7 验证数据库迁移脚本的向后兼容性(additive only)
-- [x] 14.8 验证 Sora apikey `base_url` 非法输入拦截和 URL 规范化拼接(避免双斜杠)
-- [x] 14.9 验证 `/sora/v1/chat/completions` 路径不再创建本地媒体文件
-- [x] 14.10 验证取消生成功能:pending/generating 可取消,cancelled 不累加配额
-- [x] 14.11 验证手动保存到存储:upstream 记录 → 点击保存 → S3 上传 → 状态更新 → 配额累加
-- [x] 14.12 验证页面刷新恢复:刷新后自动恢复所有进行中任务卡片
-- [x] 14.13 验证多任务并发:同时 3 个任务正常运行,第 4 个被拒绝
-- [x] 14.14 验证预签名 URL 动态刷新:作品库每次打开获取新 URL,不出现碎图
-- [x] 14.15 验证浏览器通知:任务完成/失败时桌面通知 + 标签页 title 闪烁
-- [x] 14.16 验证 Sora 菜单条件显示:无 Sora 账号时侧边栏不显示 Sora 入口;添加 Sora 账号后自动出现
-- [x] 14.17 验证双菜单同步:普通用户和管理员"我的账户"均能看到 Sora 菜单项
-- [x] 14.18 验证简单模式:开启 simpleMode 后 Sora 菜单项隐藏
-- [x] 14.19 验证 Sora 页面嵌入布局:Sora 页面在全局侧边栏内渲染,侧边栏可正常切换其他页面
diff --git a/openspec/changes/backend-performance-optimization/.openspec.yaml b/openspec/changes/backend-performance-optimization/.openspec.yaml
deleted file mode 100644
index 85ae75c1f..000000000
--- a/openspec/changes/backend-performance-optimization/.openspec.yaml
+++ /dev/null
@@ -1,2 +0,0 @@
-schema: spec-driven
-created: 2026-02-26
diff --git a/openspec/changes/backend-performance-optimization/design.md b/openspec/changes/backend-performance-optimization/design.md
deleted file mode 100644
index 1f9244033..000000000
--- a/openspec/changes/backend-performance-optimization/design.md
+++ /dev/null
@@ -1,148 +0,0 @@
-## Context
-
-本项目是一个多平台 AI API 网关(支持 Claude/OpenAI/Gemini/Sora/Antigravity),使用 Go + Gin + Ent ORM + Redis + PostgreSQL 技术栈。经全量代码审计确认 34 个性能问题,涉及网关热路径(每请求必经)、数据库查询、Redis 缓存策略、中间件开销、日志系统和构建产物等层面。
-
-当前系统处理链路:HTTP 请求 → 全局中间件(Recovery/Logger/CORS/SecurityHeaders)→ 认证中间件(API Key Auth)→ Handler → Service(调度/网关/计费)→ Repository(DB/Redis/HTTP 上游)→ 响应。
-
-约束条件:
-- 所有优化必须为内部实现变更,不改变外部 API 接口
-- 数据库索引变更必须使用 `CONCURRENTLY` 避免阻塞
-- 不修改配置级参数(连接池大小、TTL 等已由团队调优)
-
-## Goals / Non-Goals
-
-**Goals:**
-- 降低网关热路径的每请求内存分配和 CPU 开销
-- 消除数据库层的冗余查询和缺失索引
-- 优化 Redis 访问模式(消除重复查询、串行批量化)
-- 减少中间件在 API 路由上的无效开销
-- 提升日志系统在高并发下的吞吐效率
-- 将 Docker 运行时镜像稳定在 `<30MB` 且保持可运维能力(健康检查/时区/TLS 证书)
-
-**Non-Goals:**
-- 不做架构层面重构(保持现有分层结构不变)
-- 不修改配置级参数(连接池、TTL、Worker 数量等)
-- 不引入新的外部依赖(使用项目已有的 gjson/sjson/xxhash 等)
-- 不修改 API 接口行为或响应格式
-- 不引入新的业务字段语义变更;仅允许补齐与现有数据库已存在列一致的 Ent schema 声明
-
-## Decisions
-
-### D1: WS 消息解析策略 — gjson 按需提取 vs 定义结构体
-
-**选择**: gjson 按需提取
-
-**理由**: WS 消息 payload 字段多且随 API 版本变化,定义完整结构体维护成本高。`gjson.GetBytes` 零分配提取所需字段(`type`、`model`、`prompt_cache_key` 等),仅在需要修改 payload 时才退回到 `json.Unmarshal`。项目已广泛使用 gjson(`openai_gateway_service.go` 中有 50+ 处),模式一致。
-
-**替代方案**: `jsoniter` 或 `sonic` — 需引入新依赖,收益不如 gjson 按需提取大。
-
-### D2: DNS 查询缓存实现 — sync.Map + TTL vs 独立缓存库
-
-**选择**: `sync.Map` + `time.Time` TTL(30 秒)
-
-**理由**: 缓存条目极少(仅上游 API 主机,如 `api.anthropic.com`、`api.openai.com` 等不超过 10 个),用 `sync.Map` 实现最简、无额外依赖。定期过期通过写入时间戳 + 读取时判断实现,无需后台清理 goroutine。
-
-**替代方案**: ristretto / go-cache — 项目已有这两个依赖,但对 <10 个条目的场景引入 LRU 过度设计。
-
-### D3: 全局日志无锁化 — atomic.Pointer vs sync.Once
-
-**选择**: `atomic.Pointer[zap.Logger]` + `atomic.Pointer[zap.SugaredLogger]`
-
-**理由**: 日志对象在启动后基本不变(仅热重载时变更),`atomic.Pointer` 的 Load 操作是单次 CPU 指令,比 `RWMutex.RLock/RUnlock` 高效一个数量级。`Reconfigure` 时通过 `Store` 原子替换,保证线程安全。`sync.Once` 不支持后续的 Reconfigure 场景。
-
-### D4: 索引策略 — Ent schema 定义 vs 手写迁移 SQL
-
-**选择**: 手写迁移 SQL(`backend/migrations/*.sql`)优先 + Ent schema 对齐
-
-**理由**: 当前项目使用内置 migration runner 执行 `backend/migrations` SQL,并依赖 checksum 保证不可变性。对于 `PARTIAL INDEX` + `CONCURRENTLY` 这类在线索引场景,手写 SQL 更可控、风险更低。Ent schema 可在后续补充普通索引定义以保持模型可读性,但不作为线上索引落地主路径。
-
-**注意**: migration runner 当前按文件事务执行,`CREATE INDEX CONCURRENTLY` 不能直接放在默认事务迁移中;需要先提供非事务迁移能力(或等效的独立执行流程)。对生产库新增索引统一采用 `CREATE INDEX CONCURRENTLY`;回滚采用 `DROP INDEX CONCURRENTLY`。
-
-### D5: 请求体增量 patch — sjson vs 手动拼接
-
-**选择**: `sjson.SetBytes` / `sjson.DeleteBytes`
-
-**理由**: 项目已引入 `sjson`,API 稳定。对于仅修改/删除少量字段的场景(如删除 `max_output_tokens`、修改 `model`),sjson 直接操作原始 `[]byte`,避免全量 `json.Marshal(map[string]any)` 的分配开销。
-
-### D6: Dockerfile 多阶段构建目标镜像
-
-**选择**: `alpine:3.21` 作为运行时基础镜像
-
-**理由**: 需要 `ca-certificates`(TLS 连接上游 API)和 `tzdata`(时区支持),`scratch` 镜像不含这些。`distroless` 也可行但 alpine 调试更方便。使用 `CGO_ENABLED=0` 确保静态链接。为避免新增 `curl` 依赖,healthcheck 改用 BusyBox 自带 `wget`。
-
-### D7: 会话哈希切换兼容策略 — 直接替换 vs 双读双写过渡
-
-**选择**: 双读双写过渡
-
-**理由**: 粘性会话 key 由 `openai:` 组成。若从 SHA-256 直接替换为 xxhash,会导致滚动发布期间新旧实例命中不同 key,出现短期粘性失效。采用“读新回退旧、写新同时写旧(兼容窗口)”可平滑过渡,不影响在线请求。
-
-### D8: context 元数据注入演进 — 一次性替换 vs 兼容桥接
-
-**选择**: 兼容桥接分阶段替换
-
-**理由**: 当前大量逻辑直接读取 `ctxkey.IsMaxTokensOneHaikuRequest`、`ctxkey.ThinkingEnabled`、`ctxkey.PrefetchedStickyAccountID` 等旧键。若一次性替换为 `RequestMetadata`,存在行为回归风险。先“新结构体注入 + 旧键保留写入/回退读取”,待全链路切换后再移除旧键,风险最低。
-
-### D9: 兼容开关与下线门禁 — 一次切换 vs 分阶段收敛
-
-**选择**: 分阶段开关 + 指标门禁
-
-**理由**: 会话哈希与 context 键演进都涉及滚动发布期间的新旧版本共存。一次性关闭旧路径会放大回滚风险。采用显式开关并绑定观测阈值,可实现“可回退发布”:
-- `session_hash_read_old_fallback`(默认 `true`):新 key 未命中时是否回退读旧 key
-- `session_hash_dual_write_old`(默认 `true`):写入时是否同时写旧 key
-- `metadata_bridge_enabled`(默认 `true`):是否保留旧 `ctxkey.*` 兼容注入/读取桥接
-
-**下线顺序**:
-1. 先关闭 `session_hash_dual_write_old`(保留旧读回退)
-2. 观测稳定后再关闭 `session_hash_read_old_fallback`
-3. 最后关闭 `metadata_bridge_enabled`
-
-**门禁条件**: 旧 key 回退命中率连续 7 天 `< 0.1%` 且无兼容性告警,方可进入下一步下线。
-
-**回滚策略**: 任一步骤出现粘性失配或行为回归,立即重新开启对应开关并回退到上一步。
-
-## Risks / Trade-offs
-
-**[Risk] gjson 按需提取可能遗漏需要处理的字段** → 通过代码审查确认所有需提取的字段列表,并在修改处保留 fallback 到全量 Unmarshal 的路径
-
-**[Risk] DNS 缓存可能导致 IP 变更延迟感知** → TTL 设为 30 秒,足够短以跟随 DNS 变更;安全校验失败时立即清除对应缓存条目
-
-**[Risk] atomic.Pointer 替换全局日志后 Reconfigure 时的短暂不一致** → Reconfigure 本身就是低频操作(管理员触发),Store 是原子操作,不一致窗口为纳秒级
-
-**[Risk] CONCURRENTLY 创建索引在高写入负载下可能耗时较长** → 在低峰期执行迁移;索引创建不阻塞读写,仅消耗额外 I/O
-
-**[Risk] sjson 修改嵌套字段时路径语法与 gjson 不完全一致** → 仅对顶层字段使用 sjson(如 `model`、`max_output_tokens`),嵌套修改仍走全量 Marshal
-
-**[Risk] singleflight 在余额缓存中可能导致一个慢查询阻塞同 key 所有请求** → 对 singleflight 调用设置独立的 context 超时(3 秒),超时后放弃等待直接回源
-
-**[Risk] 运行时镜像移除 curl 后健康检查失效** → 将 Docker `HEALTHCHECK` 命令改为 `wget -q -O -` 或显式保留 curl(二选一并在任务中固定)
-
-**[Risk] 并发槽位 requestID 改为纯原子递增会跨实例碰撞** → 使用“进程随机前缀 + 原子递增”组合,保证跨实例/重启场景下仍具备足够唯一性
-
-**[Risk] 兼容路径过早下线导致滚动升级抖动** → 兼容开关默认开启,按“关旧写→关旧读→关桥接”顺序执行,并用命中率阈值做门禁
-
-## Migration Plan
-
-1. **Phase 1 — 无风险纯代码优化**(无需迁移)
- - P2/P3 级别的代码优化(日志无锁化、对象池化、正则预编译、Debug 日志守卫等)
- - Dockerfile 多阶段构建
- - 与兼容桥接无关的改动可直接合并;涉及兼容开关默认值调整的改动需灰度
-
-2. **Phase 2 — 热路径优化**(需要充分测试)
- - WS 消息 gjson 解析、请求体缓存回写、会话哈希改 xxhash
- - DNS 查询缓存、IP 规则预编译
- - Google 认证中间件对齐
- - 会话哈希/metadata 兼容开关默认保持开启,按门禁分阶段下线旧路径
- - 需要完整的集成测试覆盖后合并
-
-3. **Phase 3 — 数据库索引**(需要低峰期执行)
- - 先完成 migration runner 非事务迁移能力(或明确独立执行机制)
- - 在 `backend/migrations` 新增 `*_notx.sql` 迁移文件,使用 `CREATE INDEX CONCURRENTLY` 在线创建
- - `*_notx.sql` 使用 `IF NOT EXISTS`/`IF EXISTS` 保证幂等,不与事务迁移语句混用
- - 视需要在 Ent schema 对齐普通索引定义(不依赖 Ent 自动迁移落地)
- - 验证查询计划(`EXPLAIN ANALYZE`)确认索引被使用
-
-**回滚策略**: 所有变更按 Phase 分批提交为独立 commit,任一 Phase 出现问题可独立回滚。索引回滚通过 `DROP INDEX CONCURRENTLY` 执行。
-
-## Open Questions
-
-- Q1: `openai_ws_forwarder.go` 中 WS 消息修改场景(需要全量 Unmarshal 的情况)的具体触发条件和频率,以评估 gjson-only 路径的覆盖率。
diff --git a/openspec/changes/backend-performance-optimization/proposal.md b/openspec/changes/backend-performance-optimization/proposal.md
deleted file mode 100644
index 39874fee8..000000000
--- a/openspec/changes/backend-performance-optimization/proposal.md
+++ /dev/null
@@ -1,92 +0,0 @@
-## Why
-
-后端代码经全量性能审计发现 30+ 个已确认的性能问题,覆盖网关热路径、数据库查询、Redis 缓存、中间件开销、日志系统等关键模块。部分问题(如每请求 DNS 查询、WS 消息全量反序列化、缺失复合索引)直接影响请求延迟和吞吐量,在高并发场景下会成为瓶颈。需要系统性修复以提升整体性能。
-
-## What Changes
-
-### P0 — 关键热路径优化
-
-- **[P0-1] WS 消息反序列化优化**:`openai_ws_forwarder.go:1890` 每条 WS 消息 `json.Unmarshal` 到 `map[string]any`,改用 `gjson.GetBytes` 按需提取只读字段,仅在需要修改 payload 时才全量解析
-- **[P0-2] HTTP 请求体解析缓存回写**:`openai_gateway_service.go:3615-3629` `getOpenAIRequestBodyMap` 首次解析后缺少 `c.Set` 回写 gin context 缓存,导致同一请求内可能多次 `json.Unmarshal`
-- **[P0-3] Google 认证中间件订阅验证对齐**:`backend/internal/server/middleware/api_key_auth_google.go:79-90` 仍使用旧的 4 次同步调用(`ValidateSubscription`、`CheckAndActivateWindow`、`CheckAndResetWindows`、`CheckUsageLimits`),需更新为 `ValidateAndCheckLimits` 合并 + 异步维护模式
-- **[P0-4] CSP nonce 全局中间件优化**:`security_headers.go` 的 CSP nonce 生成(`crypto/rand` 系统调用)作为全局中间件对 API 路由(`/v1/*`、`/v1beta/*`)无意义执行,需限制为仅前端路由
-- **[P0-5] 数据库复合索引补充**:`accounts` 表调度热路径缺少 `(platform, priority) WHERE deleted_at IS NULL AND status='active' AND schedulable=true` 复合部分索引;`user_subscriptions` 缺少 `(user_id, status, expires_at) WHERE deleted_at IS NULL` 复合部分索引;`usage_logs` 缺少 `(group_id, created_at)` 复合索引
-- **[P0-6] Dockerfile 运行时镜像精简**:当前已是多阶段构建,但运行时镜像仍包含 `curl` 且基础镜像版本偏旧(`alpine:3.20`),需在不破坏健康检查的前提下精简依赖并升级基线
-
-### P1 — 高优先级优化
-
-- **[P1-1] 请求体预分配读取统一化**:`gateway_handler.go:115` 等多处使用 `io.ReadAll` 无预分配,而 `openai_gateway_handler.go` 已有 `readRequestBodyWithPrealloc` 优化方案,需提升到公共层统一使用
-- **[P1-2] 双重 Redis 粘性会话查询消除**:`openai_gateway_service.go` 中 `SelectAccountWithLoadAwareness` 和 `selectBySessionHash` 对同一 key 执行两次 `GetSessionAccountID` Redis 查询
-- **[P1-3] 并发查询批量化**:`backend/internal/service/concurrency_service.go:323-335` `GetAccountConcurrencyBatch` 名为 Batch 但实现为串行 N 次 Redis GET,需在 `ConcurrencyCache` 增加批量接口并在 `backend/internal/repository/concurrency_cache.go` 用 Pipeline 实现
-- **[P1-4] 会话哈希算法降级(含兼容过渡)**:`openai_gateway_service.go:858` 使用 SHA-256 做会话映射,项目已引入 xxhash,对非密码学场景改用 `xxhash.Sum64String` 可提速 10-20 倍;需提供“新 hash 优先读取 + 旧 SHA-256 回退读取 + 兼容期双写”机制,避免升级瞬间粘性会话失配
-- **[P1-5] IP 规则匹配预编译**:`ip/ip.go:105-126` `MatchesPattern` 每次都重新 `net.ParseCIDR`/`net.ParseIP`,需在 API Key 创建时预编译为 `[]*net.IPNet`
-- **[P1-6] 余额缓存防击穿保护**:`billing_cache.go:86` 热点用户缓存失效瞬间并发穿透到数据库,需添加 `singleflight` 合并回源
-- **[P1-7] ResponseHeaders 预编译**:`backend/internal/util/responseheaders/responseheaders.go:44` `FilterHeaders` 每请求重建 allowed map(~20 条目),需在 service 初始化时预构建
-- **[P1-8] accounts 表多余查询消除**:`account_repo.go:1386-1426` `loadTempUnschedStates` 仅取 2 列却对 accounts 表做第二次完整查询,可在首次 ORM 查询时一并 Select
-- **[P1-9] 仪表盘 SQL 查询合并**:`usage_log_repo.go:500-587` `fillDashboardUsageStatsFromUsageLogs` 4 次独立 SQL 扫描 usage_logs,可合并为 1-2 个 CTE 查询
-- **[P1-10] 每请求 DNS 查询缓存**:`httpclient/pool.go:155-165` `validatedTransport.RoundTrip` 对每个 HTTP 请求执行 `ValidateResolvedIP` 完整 DNS 查询,需添加带 TTL 的已验证主机缓存
-
-### P2 — 中优先级优化
-
-- **[P2-1] slog Handler 临时 logger 消除**:`slog_handler.go:51` `h.logger.With(fields...)` 每次创建临时 logger 实例(2 次堆分配),改为直接传 fields 调用对应级别方法
-- **[P2-2] os.Getenv 缓存**:`gateway_service.go:130,135` `debugModelRoutingEnabled`/`debugClaudeMimicEnabled` 每次调用 `os.Getenv`(单请求路径可能触发 16 次),改为 `atomic.Bool` 初始化时读取
-- **[P2-3] 全局日志无锁化**:`logger.go:171-186` `L()`/`S()` 每次获取 `mu.RLock` 全局读锁,改为 `atomic.Pointer[zap.Logger]`
-- **[P2-4] opsCaptureWriter 对象池化**:`ops_error_logger.go:308-336` 每请求堆分配 `opsCaptureWriter`(含 `bytes.Buffer`),改用 `sync.Pool` 复用
-- **[P2-5] generateRequestID 轻量化(跨实例唯一)**:`concurrency_service.go:47-53` 内部 slot ID 使用 `crypto/rand`,改为“进程随机前缀 + 原子递增计数器”,在保留跨实例唯一性的同时降低热路径开销
-- **[P2-6] body 二次 Unmarshal 消除**:`gateway_helper.go:38-45` `SetClaudeCodeClientContext` 对已解析的 body 再做一次完整 `json.Unmarshal`,复用首次解析结果
-- **[P2-7] 请求体增量 patch**:`openai_gateway_service.go:1543` `bodyModified` 时 `json.Marshal` 全量重序列化,改用 `sjson.SetBytes`/`sjson.DeleteBytes` 做增量修改
-- **[P2-8] context.WithValue 合并(兼容桥接)**:`gateway_handler.go:144-266` Messages() 中 5+ 次 `context.WithValue` 链式调用,合并为单个请求属性结构体一次注入;兼容期保留旧 key 注入与读取回退,避免现有 service/handler 读取点行为回归
-- **[P2-9] WS pool ping 并行化**:`openai_ws_pool.go:635-647` 后台 ping sweep 串行执行,改为有限并发并行 ping
-- **[P2-10] WS pool 后台 worker WaitGroup 跟踪**:`openai_ws_pool.go:606-612` `startBackgroundWorkers` 缺少 `sync.WaitGroup`,关闭时无法等待 goroutine 退出
-- **[P2-11] Debug 日志参数延迟求值**:`backend/internal/pkg/tlsfingerprint/dialer.go:270` `fmt.Sprintf("0x%04x", ...)` 在 `slog.Debug` 参数中提前求值,改用 `slog.Enabled` 守卫或直接传整数
-- **[P2-12] RedactText 正则预编译**:`backend/internal/util/logredact/redact.go:88-92` 每次调用编译 3 个正则,对无 extraKeys 的默认路径预编译缓存
-- **[P2-13] enqueueCacheWrite panic-recover 改 atomic.Bool**:`billing_cache_service.go:125-145` 用 panic-recover 检测已关闭 channel,改为 `atomic.Bool` 标记 + 前置检查
-- **[P2-14] 软删除 deleted_at 单列索引优化**:所有软删除表上 `deleted_at` 单列索引对 99% NULL 值无效,改为在业务复合索引上添加 `WHERE deleted_at IS NULL` 部分索引条件
-
-### P3 — 低优先级优化
-
-- **[P3-1] ToHTTP 轻量拷贝优化(保留语义)**:`backend/internal/pkg/errors/http.go:19` 避免 `Clone` 整体对象开销,改为按需构造 `Status` 并仅在 `Metadata != nil` 时做 map 深拷贝,保持对外语义不变
-- **[P3-2] failover_loop log.Printf**:`backend/internal/handler/failover_loop.go:81` 使用标准库 `log.Printf` 而非结构化日志
-- **[P3-3] UpdateSortOrders 逐条 UPDATE**:`group_repo.go:570` N 个分组 N 条 UPDATE,可用 CASE WHEN 批量化
-- **[P3-4] cleanupAccountLocked evicted 切片无容量**:`openai_ws_pool.go:992` `make([]*openAIWSConn, 0)` 无初始容量
-- **[P3-5] conn.touch() 降频**:`openai_ws_pool.go:408` 每条 WS 消息触发 `atomic.Store + time.Now()`,可增加 1 秒内去重
-- **[P3-6] Wire cleanup 并行化**:`wire_gen.go` 24 个清理步骤串行执行(10 秒超时),互不依赖的步骤可并行
-
-### 七轮审核修复记录
-
-- **第 1 轮(结构审核)**:OpenSpec 语法通过,但发现若干条目路径不精确,已统一修正为仓库真实路径(`internal/pkg`、`internal/util`、`internal/handler`)。
-- **第 2 轮(存在性复核)**:对关键条目逐项源码核对,确认问题真实存在(如 `getOpenAIRequestBodyMap` 未回写缓存、Google 中间件同步窗口维护、`GetAccountConcurrencyBatch` 串行查询、`ToHTTP` 多余 Clone、`SecurityHeaders` 全路由 nonce)。
-- **第 3 轮(方案最优性审核)**:将批量并发查询方案收敛为“接口下沉到 repository 层 Pipeline 实现”;避免仅在 service 层做伪批量,确保改动方向可落地且收益稳定。
-- **第 4 轮(兼容性修复)**:补齐 `CONCURRENTLY` 非事务迁移、会话哈希双读双写、context 兼容桥接、requestID 跨实例唯一。
-- **第 5 轮(兼容门禁复审)**:补齐兼容开关命名、默认值、下线顺序与阈值门禁,避免“一次性切换”。
-- **第 6 轮(二次确认)**:再次核对源码,确认事务迁移包裹、SHA-256 会话哈希、旧 `ctxkey.*` 读取点均真实存在。
-- **第 7 轮(最优性再评估)**:将方案收敛为“开关可回退 + 指标门禁 + notx 幂等 + 索引移除观察期”的最优落地路径。
-
-### 兼容性修复补充
-
-- **向前兼容迁移机制**:`CREATE INDEX CONCURRENTLY` 相关迁移需通过“非事务迁移路径”执行,避免被当前 migration runner 的事务包装直接失败。
-- **粘性会话键兼容**:会话哈希算法切换采用双读双写过渡窗口,确保滚动发布期间新旧版本共存可用。
-- **上下文键兼容**:`RequestMetadata` 引入后保留旧 `ctxkey.*` 键的写入与读取兜底,分阶段下线。
-- **并发槽位 ID 兼容**:`requestID` 新格式保持字符串形态与 Redis member 语义兼容,不修改 key 前缀与数据结构。
-- **灰度与回滚开关**:会话哈希双写、旧 key 回退读取、context 兼容桥接均需配置开关控制,支持“关旧写 → 关旧读 → 关桥接”的可回退发布路径(每步满足命中率门禁后再推进)。
-
-## Capabilities
-
-### New Capabilities
-- `hotpath-optimization`: 网关热路径性能优化(WS 消息解析、请求体缓存、会话哈希、DNS 缓存、预分配读取、context 合并)
-- `middleware-optimization`: 中间件性能优化(CSP nonce 路由限制、Google 订阅验证对齐、opsCaptureWriter 池化)
-- `database-optimization`: 数据库查询与索引优化(复合部分索引、多余查询消除、SQL 合并、批量化更新)
-- `cache-optimization`: Redis 缓存与并发控制优化(防击穿保护、批量查询、粘性会话去重、requestID 轻量化)
-- `logging-optimization`: 日志系统性能优化(slog handler、全局日志无锁化、正则预编译、debug 日志延迟求值)
-- `build-optimization`: 构建产物优化(Dockerfile 多阶段构建、编译参数优化)
-
-### Modified Capabilities
-
-
-## Impact
-
-- **代码影响范围**:约 30 个文件,涵盖 handler、service、repository、middleware、pkg、ent/schema、Dockerfile 层
-- **API 行为**:无 API 接口变更,仅内部实现优化
-- **数据库**:新增 3-5 个复合部分索引(需数据库迁移),优化已有索引策略
-- **Redis**:无 key 格式变更,仅优化查询模式和缓存策略
-- **风险评估**:所有优化为内部实现变更,不影响外部接口;索引变更使用 `CREATE INDEX CONCURRENTLY` 不阻塞读写;Dockerfile 变更需验证镜像运行正确性
diff --git a/openspec/changes/backend-performance-optimization/review-rounds.md b/openspec/changes/backend-performance-optimization/review-rounds.md
deleted file mode 100644
index 9c85e671b..000000000
--- a/openspec/changes/backend-performance-optimization/review-rounds.md
+++ /dev/null
@@ -1,69 +0,0 @@
-## backend-performance-optimization 多轮审核记录
-
-### 第 1 轮:提案结构审核
-
-- 结果:`openspec validate backend-performance-optimization --strict` 通过。
-- 发现问题:
- - 多处路径表述不精确(`internal/pkg`、`internal/util`、`internal/handler` 混用)。
- - 部分场景描述有事实偏差(Google 中间件“3 次调用”实际为 4 次)。
- - 构建规范遗漏健康检查依赖约束(移除 `curl` 后可能导致 healthcheck 失效)。
-- 修复动作:已在 `proposal.md`、`tasks.md`、`specs/*` 统一修正。
-
-### 第 2 轮:问题存在性二次确认
-
-抽样复核核心条目,均确认“问题真实存在”:
-
-- `backend/internal/service/openai_gateway_service.go`:`getOpenAIRequestBodyMap` 未回写 `OpenAIParsedRequestBodyKey`。
-- `backend/internal/server/middleware/api_key_auth_google.go`:仍为同步窗口维护路径,未与 `api_key_auth.go` 对齐。
-- `backend/internal/server/middleware/security_headers.go`:全路由执行 CSP nonce 生成。
-- `backend/internal/service/concurrency_service.go`:`GetAccountConcurrencyBatch` 串行调用。
-- `backend/internal/pkg/errors/http.go`:`ToHTTP` 存在多余 `Clone`。
-- `backend/internal/handler/failover_loop.go`:仍使用 `log.Printf`。
-
-### 第 3 轮:修复方案最优性复审
-
-- 调整前:`GetAccountConcurrencyBatch` 仅在 service 层“改批量”容易变成伪批量。
-- 调整后:明确为“接口下沉到 repository 层,Redis Pipeline 实现,service 层委托调用”,减少重复实现并确保真实收益。
-- 调整前:索引策略偏向 Ent 自动迁移,和项目现有 SQL migration runner 不一致。
-- 调整后:改为“SQL migration + CONCURRENTLY 优先,Ent schema 对齐可选”,与现有迁移机制一致、风险更低。
-- 调整前:Docker 运行层精简未覆盖 healthcheck 依赖。
-- 调整后:明确 healthcheck 使用 BusyBox `wget` 或等效方案,避免引入 `curl`。
-
-### 最终结论
-
-- 二次确认结论:本提案核心性能问题存在性成立。
-- 方案最优性结论:已将关键次优点修正为更符合当前仓库实现与发布流程的方案。
-
-### 第 4 轮:向前兼容性专项修复
-
-- 修复 `CREATE INDEX CONCURRENTLY` 与事务迁移冲突:补充“非事务迁移”能力与 `*_notx.sql` 约束。
-- 修复会话哈希切换兼容性:新增“双读双写 + 兼容窗口”策略,避免滚动发布粘性失配。
-- 修复 requestID 方案:改为“进程随机前缀 + 原子递增”,避免多实例碰撞。
-- 修复 context 合并风险:明确兼容期保留旧 `ctxkey.*` 注入与读取回退。
-- 修复 `ToHTTP` 优化语义风险:改为“按需轻量拷贝 + metadata 深拷贝保留”,不改变外部语义。
-
-### 第 5 轮:向前兼容门禁复审(本次追加)
-
-- 新发现 1:提案虽提到“灰度与回滚开关”,但 `design/spec/tasks` 未形成可执行门禁(缺少开关名、默认值、下线顺序、阈值)。
-- 新发现 2:`database-optimization` 对 `*_notx.sql` 仅描述“非事务”,缺少幂等约束与 tx/notx 语义隔离,重复执行与回滚风险仍在。
-- 新发现 3:`deleted_at` 单列索引移除缺少观察期与回滚门禁,存在误删后查询退化风险。
-- 修复动作:已在 `design.md`、`specs/hotpath-optimization/spec.md`、`specs/database-optimization/spec.md`、`tasks.md` 补齐上述约束。
-
-### 第 6 轮:问题存在性二次确认(本次追加)
-
-- 迁移事务冲突确认:`migrations_runner.go:188-210` 仍按文件统一 `BeginTx` 包裹执行,`CONCURRENTLY` 场景确实会冲突。
-- 会话哈希兼容风险确认:`openai_gateway_service.go:858-859` 与 `openai_ws_forwarder.go:544-545` 仍是 SHA-256,会在滚动发布中造成新旧 key 不一致风险。
-- context 兼容风险确认:仓库内仍有大量 `ctxkey.*` 读取点(如 `middleware/logger.go`、`gemini_v1beta_handler.go` 等),一次性移除旧键会产生行为回归风险。
-- 结论:第 5 轮新增问题均真实存在,不是“文档臆测”。
-
-### 第 7 轮:方案最优性再复审(本次追加)
-
-- 评估结果 1:会话哈希与 metadata 兼容采用“开关 + 指标门禁 + 顺序下线(关旧写→关旧读→关桥接)”优于“一次切换”,回滚路径最短。
-- 评估结果 2:`*_notx.sql` 增加 `IF NOT EXISTS/IF EXISTS` 幂等约束,优于仅靠执行流程控制,能覆盖重放/灾备演练场景。
-- 评估结果 3:`deleted_at` 索引移除增加“7 天观测门禁 + 可回滚恢复语句”,比“确认覆盖后立即删除”更稳健。
-- 结论:本轮修复后的方案在滚动升级、回滚可用性、重复执行容错方面为当前仓库约束下的最优解。
-
-### 最新结论
-
-- 二次确认:本次新增兼容性问题均已确认存在。
-- 最优性结论:修复方案已收敛为“可灰度、可回滚、可观测、可重放”的执行路径,优于原始提案描述。
diff --git a/openspec/changes/backend-performance-optimization/specs/build-optimization/spec.md b/openspec/changes/backend-performance-optimization/specs/build-optimization/spec.md
deleted file mode 100644
index ad97fc170..000000000
--- a/openspec/changes/backend-performance-optimization/specs/build-optimization/spec.md
+++ /dev/null
@@ -1,65 +0,0 @@
-## ADDED Requirements
-
-### Requirement: Dockerfile 多阶段构建
-系统 SHALL 将 Dockerfile 改为多阶段构建:第一阶段使用 `golang:1.25.7-alpine` 编译,第二阶段使用 `alpine:3.21` 作为运行时镜像,仅包含编译产物、`ca-certificates` 和 `tzdata`;运行时健康检查 SHALL 使用 BusyBox `wget` 或等效方案,不依赖 `curl`。
-
-#### Scenario: 构建 Docker 镜像
-- **WHEN** 执行 `docker build`
-- **THEN** 最终镜像不包含 Go 工具链、源代码和依赖缓存,体积不超过 30MB
-
----
-
-### Requirement: 编译参数优化
-系统 SHALL 在 Dockerfile 和 Makefile 的构建命令中使用以下编译参数:
-- `CGO_ENABLED=0`:确保纯静态链接
-- `-ldflags="-s -w"`:剥离符号表和 DWARF 调试信息
-- `-trimpath`:移除编译路径信息
-
-#### Scenario: 编译 Go binary
-- **WHEN** 通过 Dockerfile 或 Makefile 编译后端程序
-- **THEN** 编译命令包含 `CGO_ENABLED=0`、`-ldflags="-s -w"` 和 `-trimpath`,生成的 binary 为纯静态链接且不含调试信息
-
----
-
-### Requirement: Wire cleanup 并行化
-系统 SHALL 将 `provideCleanup` 中互不依赖的清理步骤并行执行,仅对有依赖关系的步骤(如 Redis/Ent 须最后关闭)保持顺序。
-
-#### Scenario: 优雅停机
-- **WHEN** 系统收到停机信号
-- **THEN** 互不依赖的业务服务(如各 OAuth 服务、各定时清理服务、各 Token 刷新服务)并行关闭,基础设施服务(Redis、Ent)在所有业务服务关闭后顺序关闭
-
----
-
-### Requirement: WS pool 后台 worker 生命周期管理
-系统 SHALL 为 `openAIWSConnPool` 的后台 goroutine(ping worker、cleanup worker)添加 `sync.WaitGroup` 跟踪,`Close()` 时等待所有 goroutine 实际退出后再返回。
-
-#### Scenario: WS pool 关闭
-- **WHEN** 系统关闭 WS 连接池
-- **THEN** `Close()` 关闭 `workerStopCh` 后通过 `WaitGroup.Wait()` 等待 ping worker 和 cleanup worker 退出,确保不存在 goroutine 泄漏
-
----
-
-### Requirement: WS pool ping 并行化
-系统 SHALL 将后台 ping sweep 从串行改为有限并发(如 `errgroup` 限制并发度为 10),避免 N 个 idle 连接的 ping 耗时线性增长。
-
-#### Scenario: 后台 ping sweep
-- **WHEN** 后台 ping worker 触发 sweep
-- **THEN** 系统并发 ping 所有候选 idle 连接(并发度上限 10),总耗时上界从 `N × 单次 ping 超时` 降为 `ceil(N/10) × 单次 ping 超时`
-
----
-
-### Requirement: ToHTTP 轻量拷贝优化
-系统 SHALL 在 `errors.ToHTTP` 中用“按需轻量拷贝”替代 `Clone` 整体对象;当 `Metadata` 非空时仍 SHALL 做 map 深拷贝,保持返回语义与并发安全。
-
-#### Scenario: HTTP 错误响应
-- **WHEN** 系统将内部错误转换为 HTTP 响应
-- **THEN** `ToHTTP` 返回与旧实现等价的 `Status` 数据(含 metadata 深拷贝语义),但减少不必要对象复制
-
----
-
-### Requirement: failover_loop 结构化日志
-系统 SHALL 将 `backend/internal/handler/failover_loop.go` 中的 `log.Printf` 调用替换为 `zap.Logger` 的结构化日志方法。
-
-#### Scenario: failover 重试日志
-- **WHEN** 网关执行 failover 重试
-- **THEN** 系统使用 `zap.Logger.Warn()` 记录结构化日志(含 account_id、status_code、retry_count 等字段),而非 `log.Printf` 的文本格式
diff --git a/openspec/changes/backend-performance-optimization/specs/cache-optimization/spec.md b/openspec/changes/backend-performance-optimization/specs/cache-optimization/spec.md
deleted file mode 100644
index df6f79678..000000000
--- a/openspec/changes/backend-performance-optimization/specs/cache-optimization/spec.md
+++ /dev/null
@@ -1,60 +0,0 @@
-## ADDED Requirements
-
-### Requirement: 并发查询 Pipeline 批量化
-系统 SHALL 将 `GetAccountConcurrencyBatch` 从串行 N 次 Redis GET 改为 Redis Pipeline 批量查询,单次 RTT 获取所有账号的并发数。
-
-#### Scenario: 批量查询账号并发数
-- **WHEN** 调度器需要查询 N 个账号的当前并发数
-- **THEN** 系统通过单次 Redis Pipeline 发送 N 个 EVAL 命令并批量接收结果,而非 N 次独立 Redis 往返
-
----
-
-### Requirement: 余额缓存防击穿保护
-系统 SHALL 在余额缓存回源查询中使用 `singleflight.Group` 合并并发请求,同一 `userID` 的缓存穿透只执行一次数据库查询。
-
-#### Scenario: 高并发缓存过期
-- **WHEN** 同一用户的余额缓存过期,多个并发请求同时穿透
-- **THEN** 仅第一个请求执行数据库查询,其余请求等待并共享结果
-
-#### Scenario: singleflight 超时保护
-- **WHEN** 数据库查询耗时超过 3 秒
-- **THEN** 等待中的请求超时放弃 singleflight 等待,各自独立回源(防止一个慢查询阻塞所有请求)
-
----
-
-### Requirement: IP 规则预编译缓存
-系统 SHALL 在 API Key 加载/缓存时预编译 IP 白名单/黑名单规则为 `[]*net.IPNet` 和 `[]net.IP`,认证时 SHALL 使用预编译结果进行匹配,不再每次调用 `net.ParseCIDR`/`net.ParseIP`。
-
-#### Scenario: API Key 认证中的 IP 检查
-- **WHEN** 请求通过 API Key 认证且该 Key 配置了 IP 限制规则
-- **THEN** 系统使用预编译的 `*net.IPNet` 执行 `Contains()` 检查,不执行字符串解析
-
-#### Scenario: API Key 规则变更
-- **WHEN** 管理员修改 API Key 的 IP 限制规则
-- **THEN** 系统重新编译该 Key 的 IP 规则并更新缓存
-
----
-
-### Requirement: generateRequestID 轻量化
-系统 SHALL 将并发控制的内部 `generateRequestID` 从每次调用 `crypto/rand` 改为“进程随机前缀 + 原子计数器”,生成格式为 `-`,在降低开销的同时保持跨实例唯一性。
-
-#### Scenario: 获取新 slot 时生成 requestID
-- **WHEN** `AcquireAccountSlot` 或 `AcquireUserSlot` 需要生成 requestID
-- **THEN** 系统使用 `atomic.Uint64.Add(1)` 生成递增序号并拼接进程级前缀,不在每次请求中调用 `crypto/rand`
-
-#### Scenario: 多实例部署
-- **WHEN** 多个网关实例同时生成 requestID
-- **THEN** 不同实例通过各自前缀隔离,避免 Redis 槽位 member 冲突
-
----
-
-### Requirement: enqueueCacheWrite 安全关闭检查
-系统 SHALL 在 `BillingCacheService` 中使用 `atomic.Bool` 标记服务是否已停止,`enqueueCacheWrite` 在发送前检查该标记,替代当前的 `panic-recover` 模式。
-
-#### Scenario: 服务运行中入队
-- **WHEN** 服务运行中调用 `enqueueCacheWrite`
-- **THEN** 系统检查 `stopped.Load() == false` 后正常入队,无 defer/recover 开销
-
-#### Scenario: 服务已停止时入队
-- **WHEN** 服务已停止后调用 `enqueueCacheWrite`
-- **THEN** 系统通过 `stopped.Load() == true` 快速返回 false,不触发 panic
diff --git a/openspec/changes/backend-performance-optimization/specs/database-optimization/spec.md b/openspec/changes/backend-performance-optimization/specs/database-optimization/spec.md
deleted file mode 100644
index 496d10e12..000000000
--- a/openspec/changes/backend-performance-optimization/specs/database-optimization/spec.md
+++ /dev/null
@@ -1,102 +0,0 @@
-## ADDED Requirements
-
-### Requirement: CONCURRENTLY 索引迁移需走非事务执行路径
-系统 SHALL 为 `CREATE INDEX CONCURRENTLY` / `DROP INDEX CONCURRENTLY` 提供非事务执行路径(如 `*_notx.sql` 迁移或等效机制),避免被默认事务迁移包装导致执行失败。
-
-#### Scenario: 执行并发索引迁移
-- **WHEN** 运行包含 `CREATE INDEX CONCURRENTLY` 的索引迁移
-- **THEN** 迁移在非事务模式执行并成功完成,不触发“cannot run inside a transaction block”错误
-
-#### Scenario: 事务迁移向前兼容
-- **WHEN** 运行不包含 `CONCURRENTLY` 的历史事务迁移文件
-- **THEN** 系统继续按原事务模式执行,不改变既有迁移语义
-
----
-
-### Requirement: 非事务迁移文件必须幂等且语义隔离
-系统 SHALL 对 `*_notx.sql` 执行幂等约束与语义隔离:创建索引使用 `CREATE INDEX CONCURRENTLY IF NOT EXISTS`,删除索引使用 `DROP INDEX CONCURRENTLY IF EXISTS`,并且同一迁移文件不得混入需要事务保护的 DDL/DML。
-
-#### Scenario: 重复执行非事务索引迁移
-- **WHEN** 运维重复执行同一 `*_notx.sql`(例如灾备演练或重放)
-- **THEN** 迁移因 `IF NOT EXISTS`/`IF EXISTS` 具备幂等性,不因对象已存在/不存在而失败
-
-#### Scenario: 校验迁移语义隔离
-- **WHEN** 提交新的 `*_notx.sql` 迁移
-- **THEN** 文件仅包含 `CONCURRENTLY` 相关语句,不混入事务语义 SQL,避免执行器行为不确定
-
----
-
-### Requirement: accounts 表调度复合部分索引
-系统 SHALL 在 `accounts` 表上创建复合部分索引以覆盖调度热路径查询:
-- `(platform, priority) WHERE deleted_at IS NULL AND status = 'active' AND schedulable = true`
-- `(priority, status) WHERE deleted_at IS NULL AND schedulable = true`(无平台过滤场景)
-
-#### Scenario: 按平台查询可调度账号
-- **WHEN** 调度器调用 `ListSchedulableByPlatform` 查询特定平台的可调度账号
-- **THEN** PostgreSQL 使用 `idx_accounts_schedulable_hot` 部分索引执行 Index Scan,而非 Seq Scan 或低效 Bitmap Scan
-
-#### Scenario: 全平台查询可调度账号
-- **WHEN** 调度器调用 `ListSchedulable` 查询所有平台的可调度账号
-- **THEN** PostgreSQL 使用 `idx_accounts_active_schedulable` 部分索引
-
----
-
-### Requirement: user_subscriptions 表复合部分索引
-系统 SHALL 在 `user_subscriptions` 表上创建复合部分索引:`(user_id, status, expires_at) WHERE deleted_at IS NULL`
-
-#### Scenario: 查询用户活跃订阅
-- **WHEN** 系统查询 `WHERE user_id = ? AND status = 'active' AND deleted_at IS NULL`
-- **THEN** PostgreSQL 使用复合部分索引执行 Index Scan
-
----
-
-### Requirement: usage_logs 表分组维度复合索引
-系统 SHALL 在 `usage_logs` 表上创建复合索引:`(group_id, created_at) WHERE group_id IS NOT NULL`
-
-#### Scenario: 按分组查询时间范围用量
-- **WHEN** 仪表盘按分组维度查询用量统计
-- **THEN** PostgreSQL 使用 `(group_id, created_at)` 复合索引,而非 `group_id` 单列索引 + Filter
-
----
-
-### Requirement: loadTempUnschedStates 多余查询消除
-系统 SHALL 先完成 Ent schema 与现有数据库列对齐(补齐 `temp_unschedulable_until`、`temp_unschedulable_reason` 字段定义),再在首次 Ent ORM 查询 accounts 时一并 Select 这两个字段,消除 `loadTempUnschedStates` 对 accounts 表的第二次查询。
-
-#### Scenario: 批量加载账号信息
-- **WHEN** `accountsToService` 或 `GetByIDs` 加载账号列表
-- **THEN** 系统通过单次 ORM 查询获取所有需要的字段(含 temp_unschedulable 相关),不再执行 `loadTempUnschedStates` 的额外 SQL 查询
-
----
-
-### Requirement: 仪表盘 SQL 查询合并
-系统 SHALL 将 `fillDashboardUsageStatsFromUsageLogs` 中的 4 次独立 SQL 查询合并为 1-2 个 CTE 查询,减少数据库往返。
-
-#### Scenario: 仪表盘用量统计查询
-- **WHEN** 系统调用 `fillDashboardUsageStatsFromUsageLogs` 获取仪表盘数据
-- **THEN** 系统通过单个 CTE 查询同时获取总体统计、今日统计、今日活跃用户数和小时活跃用户数,而非 4 次独立查询
-
----
-
-### Requirement: deleted_at 单列索引替换为业务部分索引
-系统 SHALL 评估并清理软删除表上无效的 `deleted_at` 单列索引,将其替换为业务查询复合索引中的 `WHERE deleted_at IS NULL` 部分索引条件。
-
-#### Scenario: accounts 表 deleted_at 索引优化
-- **WHEN** accounts 表已有业务复合部分索引(含 `WHERE deleted_at IS NULL` 条件)
-- **THEN** 原 `deleted_at` 单列索引可安全移除,减少写入时的索引维护开销
-
-#### Scenario: 移除前观察门禁
-- **WHEN** 计划删除任意表 `deleted_at` 单列索引
-- **THEN** 系统先完成至少 7 天慢 SQL/查询计划观测,确认无关键查询依赖该单列索引后再执行删除
-
-#### Scenario: 删除后回滚
-- **WHEN** 删除 `deleted_at` 单列索引后出现查询退化
-- **THEN** 系统可通过 `CREATE INDEX CONCURRENTLY IF NOT EXISTS` 在低峰期快速恢复索引
-
----
-
-### Requirement: UpdateSortOrders 批量化
-系统 SHALL 将 `UpdateSortOrders` 从逐条 UPDATE 优化为单条批量 UPDATE SQL(使用 `CASE WHEN` 或 `unnest` 方式)。
-
-#### Scenario: 批量更新分组排序
-- **WHEN** 管理员调整 N 个分组的排序顺序
-- **THEN** 系统通过单条 SQL 完成所有排序更新,而非 N 次独立 UPDATE
diff --git a/openspec/changes/backend-performance-optimization/specs/hotpath-optimization/spec.md b/openspec/changes/backend-performance-optimization/specs/hotpath-optimization/spec.md
deleted file mode 100644
index 131cb98dd..000000000
--- a/openspec/changes/backend-performance-optimization/specs/hotpath-optimization/spec.md
+++ /dev/null
@@ -1,141 +0,0 @@
-## ADDED Requirements
-
-### Requirement: WS 消息按需解析
-系统 SHALL 对 WebSocket 客户端消息使用 `gjson.GetBytes` 按需提取只读字段(`type`、`model`、`prompt_cache_key`、`previous_response_id` 等),而非每条消息都 `json.Unmarshal` 到 `map[string]any`。仅在需要修改 payload 字段时才退回到全量反序列化。
-
-#### Scenario: 只读消息(不修改 payload)
-- **WHEN** WebSocket 收到客户端消息且不需要修改任何字段
-- **THEN** 系统仅通过 `gjson.GetBytes` 提取所需字段,不执行 `json.Unmarshal`,零额外堆分配
-
-#### Scenario: 需要修改 payload 的消息(如 model 字段映射)
-- **WHEN** WebSocket 收到客户端消息且需要修改 `model` 字段或其他字段
-- **THEN** 系统使用 `sjson.SetBytes` 做增量修改,或在多字段修改时退回到 `json.Unmarshal` + `json.Marshal`
-
----
-
-### Requirement: HTTP 请求体解析结果缓存回写
-系统 SHALL 在 `getOpenAIRequestBodyMap` 首次解析请求体后,将结果通过 `c.Set(OpenAIParsedRequestBodyKey, reqBody)` 回写到 gin context 缓存,确保同一请求内的后续调用可命中缓存。
-
-#### Scenario: 首次解析请求体
-- **WHEN** `getOpenAIRequestBodyMap` 被调用且 gin context 中无缓存
-- **THEN** 系统执行 `json.Unmarshal` 解析,并将结果 `c.Set` 写入 context 缓存后返回
-
-#### Scenario: 后续调用命中缓存
-- **WHEN** 同一请求中第二次调用 `getOpenAIRequestBodyMap`
-- **THEN** 系统直接从 gin context 缓存返回,不执行 `json.Unmarshal`
-
----
-
-### Requirement: 双重粘性会话查询消除
-系统 SHALL 在 `SelectAccountWithLoadAwareness` 入口处查询 `GetSessionAccountID` 后,将结果通过 `OpenAIAccountScheduleRequest` 传递给下游调度器,下游 SHALL 优先使用已传入的结果而非再次查询 Redis。
-
-#### Scenario: 入口已获取粘性会话 accountID
-- **WHEN** `SelectAccountWithLoadAwareness` 入口查询到 `stickyAccountID`
-- **THEN** 该 ID 通过请求参数传递到 `selectBySessionHash`/`tryStickySessionHit`,下游不再重复查询 Redis
-
-#### Scenario: 入口未查到粘性会话
-- **WHEN** 入口处 `GetSessionAccountID` 返回空
-- **THEN** 下游调度器按正常负载均衡流程选择账号,不执行额外 Redis 查询
-
----
-
-### Requirement: 会话哈希使用非密码学算法
-系统 SHALL 对用于会话粘性映射的哈希使用 `xxhash.Sum64String` 替代 `sha256.Sum256`。仅保留密码学场景(API Key 哈希、幂等键)使用 SHA-256。为保证滚动发布兼容性,系统 SHALL 实现“新 key 优先读取 + 旧 key 回退读取 + 兼容期双写”。
-
-#### Scenario: 会话哈希生成
-- **WHEN** 需要为 sessionID 生成粘性会话哈希
-- **THEN** 系统使用 `xxhash.Sum64String(sessionID)` 并转为十六进制字符串,不使用 `crypto/sha256`
-
-#### Scenario: 滚动升级兼容读取
-- **WHEN** 新 hash key 未命中粘性会话,但旧版本仍可能写入 SHA-256 key
-- **THEN** 系统回退读取旧 SHA-256 key,避免升级窗口内粘性会话失配
-
-#### Scenario: 兼容期写入
-- **WHEN** 系统绑定或刷新粘性会话
-- **THEN** 系统同时写入新 hash key 与旧 SHA-256 key(旧 key 使用较短 TTL),兼容期结束后下线旧 key 写入
-
----
-
-### Requirement: 会话哈希与 metadata 兼容开关门禁
-系统 SHALL 为兼容路径提供可回退特性开关(`session_hash_read_old_fallback`、`session_hash_dual_write_old`、`metadata_bridge_enabled`),默认值均为 `true`。系统 SHALL 按“先关旧写、后关旧读、最后关 metadata 桥接”的顺序下线;每一步下线前 SHALL 满足门禁条件:旧路径命中率连续 7 天 `< 0.1%` 且无兼容性告警。
-
-#### Scenario: 兼容开关回滚
-- **WHEN** 灰度期间出现粘性会话失配或旧读取点异常
-- **THEN** 运维可立即重新开启对应兼容开关,系统恢复旧路径读取/写入能力
-
-#### Scenario: 关闭旧写门禁
-- **WHEN** `session_hash_dual_write_old` 准备关闭
-- **THEN** 系统先确认旧 key 回退命中率连续 7 天 `< 0.1%`,关闭后继续保留 `session_hash_read_old_fallback=true`
-
-#### Scenario: 关闭 metadata 兼容桥接门禁
-- **WHEN** `metadata_bridge_enabled` 准备关闭
-- **THEN** 系统确认旧 `ctxkey.*` 回退读取命中率连续 7 天 `< 0.1%` 且无错误回归,再执行下线
-
----
-
-### Requirement: 每请求 DNS 查询缓存
-系统 SHALL 在 `validatedTransport.RoundTrip` 中对已通过 `ValidateResolvedIP` 验证的主机名缓存验证结果(TTL 30 秒),避免每请求都执行完整 DNS 查询。
-
-#### Scenario: 已缓存的主机名
-- **WHEN** 请求目标主机在 30 秒内已通过 IP 验证
-- **THEN** 系统直接放行请求,不执行 DNS 查询
-
-#### Scenario: 缓存过期或未知主机
-- **WHEN** 请求目标主机不在缓存中或缓存已过期
-- **THEN** 系统执行 `ValidateResolvedIP` 完整 DNS 查询,验证通过后写入缓存
-
-#### Scenario: 验证失败清除缓存
-- **WHEN** 某主机名此前缓存为通过,但新一次 DNS 查询返回内网 IP
-- **THEN** 系统拒绝请求并清除该主机的缓存条目
-
----
-
-### Requirement: 请求体预分配读取统一化
-系统 SHALL 将 `readRequestBodyWithPrealloc` 提升为公共函数,所有 Handler 的请求体读取 SHALL 使用该函数替代 `io.ReadAll`,根据 `Content-Length` 预分配 buffer 容量。
-
-#### Scenario: 已知 Content-Length 的请求
-- **WHEN** 请求带有 `Content-Length` 头
-- **THEN** 系统以 `Content-Length` 值预分配 buffer(上限 `openAIRequestBodyReadMaxInitCap`),一次性读取无需扩容
-
-#### Scenario: 未知 Content-Length 的请求
-- **WHEN** 请求无 `Content-Length` 头(如 chunked transfer)
-- **THEN** 系统以 `openAIRequestBodyReadInitCap`(512 字节)为初始容量,按需扩容
-
----
-
-### Requirement: context.WithValue 调用合并
-系统 SHALL 在 `Messages()` 方法中将多个请求属性(`IsMaxTokensOneHaikuRequest`、`ThinkingEnabled`、`PrefetchedStickyAccountID`、`PrefetchedStickyGroupID`、`SingleAccountRetry` 等)合并为一个 `RequestMetadata` 结构体,通过单次 `context.WithValue` 注入。兼容期 SHALL 保留旧 `ctxkey.*` 键写入与读取回退。
-
-#### Scenario: Messages 请求属性注入
-- **WHEN** `Messages()` 方法需要向 context 注入多个请求属性
-- **THEN** 系统构建一个 `RequestMetadata` 结构体并通过一次 `context.WithValue` 注入,context 链深度增加 1 而非 5+
-
-#### Scenario: 旧读取点兼容
-- **WHEN** 下游代码仍按旧 `ctxkey.*` 读取请求属性
-- **THEN** 系统可通过兼容注入或读取回退获取正确值,不出现行为回归
-
----
-
-### Requirement: 请求体增量 patch
-系统 SHALL 在 `bodyModified` 场景中优先使用 `sjson.SetBytes`/`sjson.DeleteBytes` 对原始 `[]byte` 做增量修改,而非 `json.Marshal(map[string]any)` 全量重序列化。仅在多字段复杂修改时退回到全量序列化。
-
-#### Scenario: 单字段删除
-- **WHEN** 仅需删除 `max_output_tokens` 字段
-- **THEN** 系统使用 `sjson.DeleteBytes(body, "max_output_tokens")` 直接操作原始 bytes
-
-#### Scenario: 单字段修改
-- **WHEN** 仅需修改 `model` 字段值
-- **THEN** 系统使用 `sjson.SetBytes(body, "model", newModel)` 直接操作原始 bytes
-
-#### Scenario: 多字段复杂修改
-- **WHEN** 需要修改 3 个以上字段或涉及嵌套结构修改
-- **THEN** 系统退回到 `json.Unmarshal` + 修改 + `json.Marshal` 的全量路径
-
----
-
-### Requirement: body 二次 Unmarshal 消除
-系统 SHALL 在 `SetClaudeCodeClientContext` 中复用上游已解析的请求体结果(通过 gin context 传递),而非对同一 `body` 再做一次 `json.Unmarshal`。
-
-#### Scenario: Claude Code 客户端请求验证
-- **WHEN** `SetClaudeCodeClientContext` 需要验证请求体内容
-- **THEN** 系统从 gin context 中读取已解析的结果,或使用 `gjson` 按需提取验证所需字段,不执行完整 `json.Unmarshal`
diff --git a/openspec/changes/backend-performance-optimization/specs/logging-optimization/spec.md b/openspec/changes/backend-performance-optimization/specs/logging-optimization/spec.md
deleted file mode 100644
index 0d169ef3d..000000000
--- a/openspec/changes/backend-performance-optimization/specs/logging-optimization/spec.md
+++ /dev/null
@@ -1,56 +0,0 @@
-## ADDED Requirements
-
-### Requirement: slog Handler 直接传递 fields
-系统 SHALL 在 `slogZapHandler.Handle` 方法中直接调用 `h.logger.Info(msg, fields...)` / `h.logger.Error(msg, fields...)` 等对应级别方法传递字段,而非通过 `h.logger.With(fields...)` 创建临时 logger 实例。
-
-#### Scenario: slog 日志记录
-- **WHEN** 通过 slog API 记录一条日志
-- **THEN** `slogZapHandler.Handle` 直接将 fields 传给 zap logger 的对应级别方法,不创建中间 logger 对象(消除 2 次堆分配)
-
----
-
-### Requirement: 全局日志无锁化
-系统 SHALL 将 `logger.L()` 和 `logger.S()` 的内部存储从 `sync.RWMutex` 保护的全局变量改为 `atomic.Pointer[zap.Logger]` / `atomic.Pointer[zap.SugaredLogger]`,实现无锁读取。`Reconfigure` 时通过 `Store` 原子替换。
-
-#### Scenario: 高并发日志获取
-- **WHEN** 多个 goroutine 并发调用 `logger.L()` 获取 logger 实例
-- **THEN** 每次调用仅执行一次 `atomic.Pointer.Load()`(无锁),不执行 `mu.RLock()/mu.RUnlock()`
-
-#### Scenario: 日志热重载
-- **WHEN** 管理员触发日志配置热重载
-- **THEN** 系统通过 `atomic.Pointer.Store()` 原子替换 logger 实例,后续 `L()` 调用立即获取新 logger
-
----
-
-### Requirement: os.Getenv 初始化缓存
-系统 SHALL 将 `debugModelRoutingEnabled` 和 `debugClaudeMimicEnabled` 的环境变量读取改为在 `GatewayService` 初始化时读取一次,存储为 `atomic.Bool` 字段。
-
-#### Scenario: 网关请求中检查 debug 开关
-- **WHEN** 网关请求处理中需要检查 debug 模式是否启用
-- **THEN** 系统通过 `atomic.Bool.Load()` 读取缓存值,不调用 `os.Getenv` + `strings.ToLower` + `strings.TrimSpace`
-
----
-
-### Requirement: RedactText 正则预编译
-系统 SHALL 对 `RedactText` 中无 `extraKeys` 的默认调用路径预编译 3 个正则表达式(在 `init()` 或包初始化时),对有 `extraKeys` 的调用路径使用 `sync.Map` 按 key 组合缓存已编译正则。
-
-#### Scenario: 默认路径(无 extraKeys)
-- **WHEN** `RedactText(input)` 不传 extraKeys
-- **THEN** 系统使用预编译的全局正则实例,不执行 `regexp.MustCompile`
-
-#### Scenario: 自定义路径(有 extraKeys)
-- **WHEN** `RedactText(input, "custom_key")` 传入 extraKeys
-- **THEN** 系统以 extraKeys 排序后的哈希为 key 查找缓存,命中则复用,未命中则编译后缓存
-
----
-
-### Requirement: Debug 日志参数延迟求值
-系统 SHALL 在 TLS fingerprint dialer 的 `slog.Debug` 调用中,对 `fmt.Sprintf` 格式化操作使用 `slog.Default().Enabled(ctx, slog.LevelDebug)` 前置检查,或直接传递整数值替代格式化字符串。
-
-#### Scenario: 生产环境(Debug 级别关闭)
-- **WHEN** 日志级别高于 Debug(如 Info/Warn)
-- **THEN** 系统不执行 `fmt.Sprintf("0x%04x", ...)` 格式化,零额外分配
-
-#### Scenario: 调试环境(Debug 级别开启)
-- **WHEN** 日志级别为 Debug
-- **THEN** 系统正常执行格式化并输出完整 debug 信息
diff --git a/openspec/changes/backend-performance-optimization/specs/middleware-optimization/spec.md b/openspec/changes/backend-performance-optimization/specs/middleware-optimization/spec.md
deleted file mode 100644
index 36cbf9496..000000000
--- a/openspec/changes/backend-performance-optimization/specs/middleware-optimization/spec.md
+++ /dev/null
@@ -1,47 +0,0 @@
-## ADDED Requirements
-
-### Requirement: CSP nonce 仅对前端路由生效
-系统 SHALL 将 `SecurityHeaders` 中间件的 CSP nonce 生成逻辑限制为仅对前端路由(返回 HTML 的路由)执行。API 路由(`/v1/*`、`/v1beta/*`、`/antigravity/*`、`/sora/*`、`/responses`)SHALL 跳过 CSP nonce 生成,仅设置基础安全头(`X-Content-Type-Options`、`X-Frame-Options`、`Referrer-Policy`)。
-
-#### Scenario: API 路由请求
-- **WHEN** 请求路径以 `/v1/`、`/v1beta/`、`/antigravity/`、`/sora/`、`/responses` 开头
-- **THEN** 系统设置基础安全头但跳过 CSP nonce 生成,不调用 `crypto/rand`
-
-#### Scenario: 前端路由请求
-- **WHEN** 请求路径为前端页面路由(如 `/`、`/admin`、`/settings` 等返回 HTML 的路由)
-- **THEN** 系统正常生成 CSP nonce 并设置 `Content-Security-Policy` 头
-
----
-
-### Requirement: Google 认证中间件订阅验证对齐
-`api_key_auth_google.go` SHALL 使用与 `api_key_auth.go` 相同的合并验证模式:调用 `ValidateAndCheckLimits`(纯内存操作)进行订阅验证和限额检查,将窗口维护操作(`CheckAndActivateWindow`、`CheckAndResetWindows`)改为异步执行。
-
-#### Scenario: Google 格式 API Key 认证
-- **WHEN** 请求通过 Google 格式 API Key 认证且用户有活跃订阅
-- **THEN** 系统调用 `ValidateAndCheckLimits` 进行合并验证(纯内存操作),不再同步执行 4 次独立调用
-
-#### Scenario: 订阅需要窗口维护
-- **WHEN** `ValidateAndCheckLimits` 返回 `needsMaintenance=true`
-- **THEN** 系统异步调用 `DoWindowMaintenance`,不阻塞请求处理
-
----
-
-### Requirement: opsCaptureWriter 对象池化
-系统 SHALL 使用 `sync.Pool` 复用 `opsCaptureWriter` 实例,避免每请求堆分配。
-
-#### Scenario: 请求进入 OpsErrorLoggerMiddleware
-- **WHEN** 新请求进入错误日志中间件
-- **THEN** 系统从 `sync.Pool` 获取 `opsCaptureWriter`(命中时零分配),设置 `ResponseWriter` 和 `limit` 后使用
-
-#### Scenario: 请求结束
-- **WHEN** 请求处理完毕
-- **THEN** 系统 `Reset()` opsCaptureWriter 的 `bytes.Buffer` 并归还到 `sync.Pool`
-
----
-
-### Requirement: ResponseHeaders 预编译
-系统 SHALL 在 service 初始化时预构建 `compiledHeaderFilter`(包含合并后的 `allowed` set 和 `forceRemove` set),`FilterHeaders`/`WriteFilteredHeaders` 在运行时直接使用预编译结果,不每次重建 map。
-
-#### Scenario: 代理响应头过滤
-- **WHEN** 网关将上游响应转发给客户端
-- **THEN** 系统使用预编译的 `compiledHeaderFilter` 过滤响应头,不分配新的 `allowed` / `forceRemove` map
diff --git a/openspec/changes/backend-performance-optimization/tasks.md b/openspec/changes/backend-performance-optimization/tasks.md
deleted file mode 100644
index 8b20b8e73..000000000
--- a/openspec/changes/backend-performance-optimization/tasks.md
+++ /dev/null
@@ -1,76 +0,0 @@
-- [x] 0.1 `backend/internal/repository/migrations_runner.go`: 增加“非事务迁移”执行能力(如按文件后缀 `_notx.sql` 分支执行,不包事务),并补充对应测试,确保 `CREATE INDEX CONCURRENTLY` 可被安全执行
-- [x] 0.2 `backend/migrations/`: 约定并文档化 `*_notx.sql` 命名规范、回滚策略与执行顺序,禁止在同一文件混用 tx/notx 语义
-- [x] 0.3 `backend/internal/repository/migrations_runner_test.go`: 增加 `*_notx.sql` 幂等测试与语义校验测试(重复执行不报错、混用语义时阻断)
-
-## 1. Phase 1 — 日志系统优化(logging-optimization)
-
-- [x] 1.1 `backend/internal/pkg/logger/slog_handler.go`: 将 `Handle` 方法中的 `h.logger.With(fields...)` 改为直接调用 `h.logger.Info(msg, fields...)` / `Error(msg, fields...)` 等对应级别方法,消除每条日志的临时 logger 分配
-- [x] 1.2 `backend/internal/pkg/logger/logger.go`: 将 `global` 和 `sugar` 全局变量从 `sync.RWMutex` 保护改为 `atomic.Pointer[zap.Logger]` / `atomic.Pointer[zap.SugaredLogger]`,同步修改 `L()`、`S()`、`Reconfigure()` 和 `sinkCore.Write` 中的 `currentSink` 为 `atomic.Value`
-- [x] 1.3 `backend/internal/service/gateway_service.go`: 在 `GatewayService` 构造函数中读取 `SUB2API_DEBUG_MODEL_ROUTING` 和 `SUB2API_DEBUG_CLAUDE_MIMIC` 环境变量,存储为 `atomic.Bool` 字段,替换 `debugModelRoutingEnabled()` / `debugClaudeMimicEnabled()` 中的 `os.Getenv` 调用
-- [x] 1.4 `backend/internal/util/logredact/redact.go`: 在包初始化时对默认 key 列表预编译 3 个全局正则;对有 `extraKeys` 的调用路径,以排序后 key 组合为 cache key 使用 `sync.Map` 缓存已编译正则
-- [x] 1.5 `backend/internal/pkg/tlsfingerprint/dialer.go`: 对所有 `slog.Debug` 调用中的 `fmt.Sprintf("0x%04x", ...)` 参数,改为直接传整数值 `spec.TLSVersMax`,或添加 `slog.Default().Enabled(ctx, slog.LevelDebug)` 前置检查
-- [x] 1.6 `backend/internal/handler/failover_loop.go`: 将所有 `log.Printf` 调用替换为 `zap.Logger.Warn()` 结构化日志,接受 context 传入的 logger 实例
-
-## 2. Phase 1 — 中间件优化(middleware-optimization)
-
-- [x] 2.1 `backend/internal/server/middleware/security_headers.go`: 在 `SecurityHeaders` 中间件中添加 API 路由前缀检查(`/v1/`、`/v1beta/`、`/antigravity/`、`/sora/`、`/responses`),命中时跳过 CSP nonce 生成,仅设置基础安全头
-- [x] 2.2 `backend/internal/server/middleware/api_key_auth_google.go`: 将同步 4 次调用(`ValidateSubscription` + `CheckAndActivateWindow` + `CheckAndResetWindows` + `CheckUsageLimits`)替换为 `ValidateAndCheckLimits` 合并调用 + `needsMaintenance` 时异步 `DoWindowMaintenance`,与 `api_key_auth.go` 对齐
-- [x] 2.3 `backend/internal/handler/ops_error_logger.go`: 创建 `var opsCaptureWriterPool = sync.Pool{...}`,在 `OpsErrorLoggerMiddleware` 中从 pool 获取 `opsCaptureWriter`,请求结束后 `Reset()` buffer 并归还 pool
-- [x] 2.4 `backend/internal/util/responseheaders/responseheaders.go`: 新增 `CompileHeaderFilter(cfg)` 函数返回 `*compiledHeaderFilter`(预构建 `allowed` 和 `forceRemove` map),在 service 初始化时调用;修改 `FilterHeaders` / `WriteFilteredHeaders` 接受预编译结果
-
-## 3. Phase 1 — 缓存与并发优化(cache-optimization)
-
-- [x] 3.1 `backend/internal/service/concurrency_service.go` + `backend/internal/repository/concurrency_cache.go`: 在 `ConcurrencyCache` 增加批量查询接口(如 `GetAccountConcurrencyBatch`),由 repository 层使用 Redis Pipeline 实现,service 层改为单次委托调用
-- [x] 3.2 `backend/internal/repository/billing_cache.go` / `backend/internal/service/billing_service.go`: 在余额缓存回源路径中引入 `singleflight.Group`,以 `userID` 为 key 合并并发穿透;为 singleflight 调用设置独立 3 秒 context 超时
-- [x] 3.3 `backend/internal/service/concurrency_service.go`: 将 `generateRequestID` 改为“进程随机前缀 + 原子计数器(base36)”,避免跨实例碰撞且减少 `crypto/rand` 热路径开销
-- [x] 3.4 `backend/internal/service/billing_cache_service.go`: 在 `BillingCacheService` 中添加 `stopped atomic.Bool` 字段,`Stop()` 时设置为 true;`enqueueCacheWrite` 中先检查 `s.stopped.Load()` 再入队,移除 `defer func() { recover() }` 模式
-- [x] 3.5 `backend/internal/pkg/ip/ip.go` + `backend/internal/service/api_key_service.go`: 新增 `CompiledIPRules` 结构体(含 `[]*net.IPNet` 和 `[]net.IP`)和 `CompileIPRules(patterns []string)` 函数;在 API Key 加载/缓存时预编译;修改 `CheckIPRestriction` 使用预编译规则
-
-## 4. Phase 2 — 网关热路径优化(hotpath-optimization)
-
-- [x] 4.1 `openai_ws_forwarder.go`: 将第 1890 行的 `json.Unmarshal(trimmed, &payload)` 改为 `gjson.GetBytes` 按需提取 `type`、`model`、`prompt_cache_key`、`previous_response_id` 等字段;仅在需要修改 payload 时退回 Unmarshal
-- [x] 4.2 `openai_gateway_service.go`: 在 `getOpenAIRequestBodyMap` 中首次 `json.Unmarshal` 成功后添加 `c.Set(OpenAIParsedRequestBodyKey, reqBody)` 回写 gin context 缓存
-- [x] 4.3 `openai_gateway_service.go`: 在 `SelectAccountWithLoadAwareness` 入口处查询 `GetSessionAccountID` 后,将 `stickyAccountID` 写入 `OpenAIAccountScheduleRequest` 结构体新增字段;在 `selectBySessionHash` / `tryStickySessionHit` 中优先使用已传入的值,非空时跳过 Redis 查询
-- [x] 4.4 `openai_gateway_service.go` + `openai_ws_forwarder.go`: 将会话哈希切换到 xxhash,并实现兼容期“双读双写”策略:读新 key 未命中回退读旧 SHA key;绑定时同时刷新新旧 key(旧 key 带短 TTL)
-- [x] 4.5 `httpclient/pool.go`: 在 `validatedTransport` 中新增 `validatedHosts sync.Map`(key: string, value: time.Time),`RoundTrip` 中先检查缓存(30 秒 TTL),命中则跳过 DNS 查询;未命中或过期时执行 `ValidateResolvedIP` 并写入缓存
-- [x] 4.6 `gateway_handler.go` + `gemini_v1beta_handler.go` + `sora_gateway_handler.go`: 将 `readRequestBodyWithPrealloc` 从 `openai_gateway_handler.go` 提取到公共包(如 `pkg/httputil/body.go`),替换这三个文件中的 `io.ReadAll(c.Request.Body)` 调用
-- [x] 4.7 `gateway_handler.go` + `service/*`: 定义 `RequestMetadata` 结构体并单次注入;兼容期保留旧 `ctxkey.*` 注入,读取侧优先读新结构体、回退读旧 key
-- [x] 4.8 `openai_gateway_service.go`: 在 `bodyModified` 路径中,对单字段删除场景使用 `sjson.DeleteBytes`,对单字段修改场景使用 `sjson.SetBytes`;仅在多字段复杂修改时保留 `json.Marshal` 全量路径
-- [x] 4.9 `gateway_helper.go`: 修改 `SetClaudeCodeClientContext` 接受已解析的请求结构或从 gin context 中读取缓存结果,替换内部的 `json.Unmarshal(body, &bodyMap)` 调用
-- [x] 4.10 `backend/internal/config/*` + `config.example.yaml`: 增加兼容开关配置项并设默认值(`session_hash_read_old_fallback=true`、`session_hash_dual_write_old=true`、`metadata_bridge_enabled=true`)
-- [x] 4.11 `openai_gateway_service.go` + `gateway_handler.go`: 增加兼容路径观测指标(旧 key 回退命中率、旧 ctxkey 回退命中率),作为下线门禁依据
-
-## 5. Phase 2 — 数据库查询优化(database-optimization,代码层)
-
-- [x] 5.1 `backend/ent/schema/account.go`: 先补齐 `temp_unschedulable_until`、`temp_unschedulable_reason` 字段定义(与现有 DB 列对齐,不改列类型)
-- [x] 5.2 `backend/internal/repository/account_repo.go`: 在 `accountsToService` 和 `GetByIDs` 中复用首次查询字段,移除 `loadTempUnschedStates` 二次查询
-- [x] 5.3 `usage_log_repo.go`: 将 `fillDashboardUsageStatsFromUsageLogs` 中 4 次独立 SQL 合并为 1-2 个 CTE 查询(total + today 统计合并,today active + hourly active 合并)
-- [x] 5.4 `group_repo.go`: 将 `UpdateSortOrders` 从 `for range` 逐条 `UpdateOneID` 改为单条 SQL `UPDATE groups SET sort_order = CASE id WHEN $1 THEN $2 ... END WHERE id = ANY($N)`
-
-## 6. Phase 3 — 数据库索引(database-optimization,Schema 层)
-
-- [x] 6.1 `backend/migrations/*_notx.sql`:新增索引迁移 SQL,使用 `CREATE INDEX CONCURRENTLY IF NOT EXISTS` 创建 `accounts(platform, priority)` 与 `accounts(priority, status)` 的业务部分索引
-- [x] 6.2 `backend/migrations/*_notx.sql`:新增索引迁移 SQL,使用 `CREATE INDEX CONCURRENTLY IF NOT EXISTS` 创建 `user_subscriptions(user_id, status, expires_at) WHERE deleted_at IS NULL`
-- [x] 6.3 `backend/migrations/*_notx.sql`:新增索引迁移 SQL,使用 `CREATE INDEX CONCURRENTLY IF NOT EXISTS` 创建 `usage_logs(group_id, created_at) WHERE group_id IS NOT NULL`
-- [x] 6.4 `backend/ent/schema/*.go`:按需同步普通索引声明(文档一致性目的),不依赖 Ent 自动迁移落地部分索引
-- [ ] 6.5 评估并移除 `accounts`、`users`、`api_keys`、`groups`、`user_subscriptions`、`proxies` 表上无效的 `deleted_at` 单列索引(确认已被业务复合部分索引覆盖后)
-- [x] 6.6 在 `backend/` 目录运行 `go generate ./ent && go generate ./cmd/server` 重新生成代码,验证编译通过
-- [ ] 6.7 在测试/预发布环境执行 `EXPLAIN ANALYZE` 验证新索引被调度热路径查询使用
-- [ ] 6.8 `backend/migrations/*_notx.sql`:对 planned 删除索引使用 `DROP INDEX CONCURRENTLY IF EXISTS`,并仅在完成 7 天慢 SQL/查询计划观测后执行
-
-## 7. Phase 1 — 构建与基础设施优化(build-optimization)
-
-- [x] 7.1 `Dockerfile`: 改为多阶段构建 — Stage 1 `golang:1.25.7-alpine` 编译(含 `CGO_ENABLED=0 -ldflags="-s -w" -trimpath`),Stage 2 `alpine:3.21` 运行时(`ca-certificates` + `tzdata` + binary + resources),并将 healthcheck 命令改为 BusyBox `wget`(避免引入 `curl`)
-- [x] 7.2 `backend/Makefile`: 更新 `build` 目标添加 `CGO_ENABLED=0 -ldflags="-s -w -X main.Version=$(VERSION)" -trimpath`;新增 `generate` 目标(`go generate ./ent && go generate ./cmd/server`)
-- [x] 7.3 `openai_ws_pool.go`: 在 `openAIWSConnPool` 中添加 `workerWg sync.WaitGroup`,`startBackgroundWorkers` 中 `wg.Add(2)` + goroutine 内 `defer wg.Done()`,`Close()` 中 `close(workerStopCh)` 后 `wg.Wait()`
-- [x] 7.4 `openai_ws_pool.go`: 将 `runBackgroundPingSweep` 改为使用 `errgroup.Group` 并设置并发度上限(`SetLimit(10)`),并发 ping 所有候选 idle 连接
-- [x] 7.5 `backend/internal/pkg/errors/http.go`: 用“按需轻量拷贝”替代 `Clone(appErr)`(仅在 `Metadata != nil` 时拷贝 map),保持 `ToHTTP` 返回语义与线程安全不变
-- [x] 7.6 `backend/cmd/server/wire.go`:将 `provideCleanup` 中互不依赖的清理步骤分组并行执行(使用 `sync.WaitGroup`),基础设施步骤(Redis、Ent)保持最后顺序执行;随后重新生成 `backend/cmd/server/wire_gen.go`
-
-## 8. 验证与收尾
-
-- [x] 8.1 运行完整单元测试套件 `go test ./...` 确保所有修改不破坏现有功能
-- [ ] 8.2 运行集成测试(特别是 `internal/integration/` 下的 E2E 测试)验证网关热路径修改
-- [ ] 8.3 构建 Docker 镜像并验证体积 < 30MB、启动正常、API 功能正常
-- [x] 8.4 使用 `go test -bench=.` 对关键路径(WS 消息解析、会话哈希、日志写入)做基准测试对比
-- [ ] 8.5 在预发布环境执行 `EXPLAIN ANALYZE` 验证所有新索引的查询计划
diff --git a/openspec/specs/frontend-routing/spec.md b/openspec/specs/frontend-routing/spec.md
deleted file mode 100644
index b8db533a5..000000000
--- a/openspec/specs/frontend-routing/spec.md
+++ /dev/null
@@ -1,28 +0,0 @@
-# frontend-routing Specification
-
-## Purpose
-TBD - created by archiving change add-chunk-load-error-recovery. Update Purpose after archive.
-## Requirements
-### Requirement: Chunk Load Error Recovery
-
-The frontend application SHALL automatically recover from chunk loading failures caused by deployment updates. When a dynamically imported module fails to load, the router SHALL detect the error and attempt to reload the page to fetch the latest resources.
-
-#### Scenario: Dynamic import fails due to stale cache
-- **WHEN** a user navigates to a lazily-loaded route
-- **AND** the browser has cached an outdated `index.html` referencing old chunk files
-- **AND** the server returns 404 for the requested chunk
-- **THEN** the router detects the chunk load error
-- **AND** automatically reloads the page to fetch the latest version
-
-#### Scenario: Reload cooldown prevents infinite loop
-- **WHEN** a chunk load error triggers an automatic page reload
-- **AND** the reload occurs within 10 seconds of a previous reload attempt
-- **THEN** the router SHALL NOT trigger another reload
-- **AND** SHALL log an error message suggesting the user clear their browser cache
-
-#### Scenario: Successful recovery after reload
-- **WHEN** the page reloads due to a chunk load error
-- **AND** the browser fetches the latest `index.html` and chunk files
-- **THEN** the user can successfully navigate to the intended route
-- **AND** the application functions normally
-
diff --git a/openspec/specs/openai-oauth-performance/spec.md b/openspec/specs/openai-oauth-performance/spec.md
deleted file mode 100644
index 6b201137e..000000000
--- a/openspec/specs/openai-oauth-performance/spec.md
+++ /dev/null
@@ -1,47 +0,0 @@
-# openai-oauth-performance Specification
-
-## Purpose
-定义 OpenAI OAuth `/v1/responses` 主链路的高性能与高稳定性要求,确保在保持协议兼容的前提下,持续降低网关附加延迟与尾延迟(P95/P99)、控制错误率,并提供可量化、可灰度、可回滚的性能治理能力。
-## Requirements
-### Requirement: OpenAI OAuth 链路性能目标可量化
-系统 MUST 为 OpenAI OAuth `/v1/responses` 链路定义并维护可量化的性能目标,至少覆盖网关附加延迟、TTFT、P95/P99、错误率与资源开销基线,并以统一口径输出对比结果。
-
-#### Scenario: 发布前具备性能基线与目标对比
-- **WHEN** 团队发起 OpenAI OAuth 性能优化发布评审
-- **THEN** 评审材料 MUST 包含优化前后同口径压测结果与目标达成情况
-
-### Requirement: 请求热路径避免重复解析与不必要拷贝
-系统 SHALL 在 OpenAI OAuth 请求处理热路径中避免对同一请求体进行重复解析与不必要数据拷贝,保证常态请求不引入额外的可避免 CPU/内存开销。
-
-#### Scenario: 常态请求路径不发生多次完整解析
-- **WHEN** 网关处理一个合法的 OpenAI OAuth 非异常请求
-- **THEN** 热路径实现 SHALL 不重复执行可避免的全量 JSON 解析与大对象拷贝
-
-### Requirement: 并发控制快速路径最小化额外存储往返
-系统 SHALL 对并发控制采用快速路径策略:在可直接获得并发槽位时,不执行不必要的等待队列写入,并最小化常态请求的额外 Redis 往返。
-
-#### Scenario: 可立即获得槽位时跳过等待队列写入
-- **WHEN** 请求到达且用户与账号并发槽位均可立即获取
-- **THEN** 系统 SHALL 直接进入上游转发路径而不执行等待队列计数写入
-
-### Requirement: 流式转发热路径降低逐行处理成本
-系统 MUST 优化 OpenAI OAuth SSE 流式转发热路径,降低逐行处理中的高频字符串与 JSON 操作成本,并保持与 OpenAI Responses 流式协议兼容。
-
-#### Scenario: 流式协议兼容且处理开销降低
-- **WHEN** 客户端发起 OpenAI OAuth 流式请求并持续接收事件
-- **THEN** 系统 SHALL 保持事件语义兼容,同时逐行处理不应依赖可替代的高开销通用解析手段
-
-### Requirement: Token 竞争路径控制尾延迟放大
-系统 SHALL 在 OpenAI OAuth token 获取的锁竞争场景中采用低抖动等待策略,避免固定大步长等待导致的尾延迟放大。
-
-#### Scenario: 锁竞争下请求不出现固定等待台阶
-- **WHEN** 多个并发请求同时命中同一 OAuth 账号的 token 刷新竞争
-- **THEN** 请求等待策略 SHALL 使用短周期可回退机制,并避免固定长等待造成显著延迟台阶
-
-### Requirement: 优化发布必须具备可灰度与可回滚保障
-系统 MUST 为 OpenAI OAuth 性能优化提供灰度发布、关键指标监控与回滚策略,确保在异常时可快速恢复到稳定状态。
-
-#### Scenario: 灰度阶段触发阈值时可快速回滚
-- **WHEN** 灰度期间关键指标(错误率或 P99)超出预设阈值
-- **THEN** 运行策略 MUST 支持按批次或按开关回滚,并恢复至优化前稳定行为
-
diff --git a/openspec/specs/openai-ws-v2-performance/spec.md b/openspec/specs/openai-ws-v2-performance/spec.md
deleted file mode 100644
index b31bb4346..000000000
--- a/openspec/specs/openai-ws-v2-performance/spec.md
+++ /dev/null
@@ -1,179 +0,0 @@
-# openai-ws-v2-performance Specification
-
-## Purpose
-TBD - created by archiving change ç. Update Purpose after archive.
-## Requirements
-### Requirement: WSv2 转发热路径必须避免重复序列化与重复字段解析
-系统 MUST 在单次 WSv2 请求处理过程中避免可消除的 payload 重复序列化、重复字符串解析与重复大对象拷贝。
-
-#### Scenario: 单请求仅进行必要序列化
-- **WHEN** 网关处理一次合法的 OpenAI WSv2 请求
-- **THEN** payload 编码与字段提取 SHALL 采用单次快照策略
-- **AND** 系统 MUST NOT 在同一请求中重复执行可避免的全量 JSON 编码
-
-### Requirement: WSv2 日志必须受预算与采样控制
-系统 MUST 对 WSv2 热路径日志与 payload 统计执行预算控制,避免日志计算放大主流程开销。
-
-#### Scenario: 大 payload 场景日志成本受控
-- **WHEN** 请求包含大型 `tools` 或 `input` 字段
-- **THEN** 系统 SHALL 使用采样和截断策略记录诊断信息
-- **AND** 系统 MUST NOT 对所有字段每次都执行高开销序列化统计
-
-### Requirement: WS 事件循环必须最小化字节与字符串往返转换
-系统 SHALL 在 WS 事件处理循环中优先使用字节路径,降低 `[]byte <-> string` 的频繁转换成本。
-
-#### Scenario: 高频 token 事件下保持低分配
-- **WHEN** 流式请求持续输出高频 token 事件
-- **THEN** 事件处理路径 MUST 使用字节优先处理与选择性解析
-- **AND** 在不影响协议语义前提下 MUST 减少每事件的临时对象分配
-
-### Requirement: 连接池获取路径必须使用低复杂度连接选择策略
-系统 MUST 为账号连接池提供低复杂度连接选择机制,避免在每次 `Acquire` 上执行全量排序。
-
-#### Scenario: 账号连接数增加时获取开销受控
-- **WHEN** 同一账号连接池中连接数量上升
-- **THEN** `Acquire` 延迟 SHALL 维持稳定并接近 O(1)/O(log n) 复杂度
-- **AND** `preferred_conn_id` 命中时 MUST 走快速路径
-
-### Requirement: 代理建连必须复用 HTTP 传输资源
-系统 MUST 复用代理建连使用的 HTTP client/transport,避免按请求重复创建传输对象。
-
-#### Scenario: 同代理地址连续建连
-- **WHEN** 同一 `proxyURL` 在短时间内多次用于 WS 建连
-- **THEN** 系统 SHALL 复用同一传输资源池
-- **AND** 握手延迟与建连 CPU 开销 MUST 低于未复用基线
-
-### Requirement: WS 重试策略必须具备分类、退避与熔断能力
-系统 MUST 将 WS 失败分为可重试与不可重试两类,并对可重试路径应用退避与抖动策略。
-
-#### Scenario: 策略类失败快速回退
-- **WHEN** 上游返回策略违规类关闭状态(例如 `1008`)
-- **THEN** 系统 MUST 在一次失败后快速回退到 HTTP
-- **AND** 系统 MUST NOT 连续进行多次无效 WS 重试
-
-#### Scenario: 可重试失败执行退避
-- **WHEN** 发生可重试的瞬时错误(如网络抖动、上游 5xx)
-- **THEN** 系统 SHALL 使用指数退避并附加 jitter 控制重试节奏
-- **AND** 重试次数与等待时长 MUST 受配置上限约束
-
-### Requirement: 预热与扩容策略必须防抖并避免建连风暴
-系统 SHALL 对连接预热和扩容触发执行防抖控制,避免瞬时负载波动触发过量后台建连。
-
-#### Scenario: 高频 Acquire 下预热触发受控
-- **WHEN** 同账号在短窗口内出现大量 Acquire 调用
-- **THEN** 系统 MUST 保证同一账号预热线程/任务数量有界
-- **AND** 预热触发 MUST 受 cooldown 与失败率门限控制
-
-### Requirement: WSv2 性能优化不得改变“默认开启”产品策略
-系统 MUST 在性能优化实施后保持 OpenAI Responses WebSocket 的默认开启策略不变,不得通过性能提案将默认行为回退为关闭。
-
-#### Scenario: 配置默认值保持开启
-- **WHEN** 系统加载默认网关配置
-- **THEN** `gateway.openai_ws.enabled` MUST 保持为 `true`
-- **AND** 性能优化开关 MUST 只影响实现细节,不改变 WS 默认启用语义
-
-### Requirement: WSv2 性能优化发布必须满足量化验收与回滚保障
-系统 MUST 在 WSv2 性能优化上线前后提供统一口径基线对比,并具备阈值触发回滚能力。
-
-#### Scenario: 发布验收材料完整
-- **WHEN** 团队评审 WSv2 性能优化发布
-- **THEN** 材料 MUST 包含 `TTFT`、`P95/P99`、`CPU`、`allocs/op`、`retry_attempts`、`fallback_rate` 的前后对比
-
-### Requirement: WSv2 性能优化必须达到明确阈值
-系统 MUST 基于统一压测口径达到本提案定义的性能阈值,未达标不得全量发布。
-
-#### Scenario: 延迟与资源阈值达标
-- **WHEN** 在统一基线环境完成 WSv2 优化回归压测
-- **THEN** 网关附加延迟 `P95` MUST 至少降低 25%
-- **AND** 网关附加延迟 `P99` MUST 至少降低 20%
-- **AND** 热路径 `allocs/op` MUST 至少降低 30%
-- **AND** 热路径 `B/op` MUST 至少降低 25%
-
-#### Scenario: 重试与连接复用阈值达标
-- **WHEN** 在统一基线环境完成失败注入与稳态压测
-- **THEN** 单请求平均 `retry_attempts` MUST 小于等于 1.2
-- **AND** `retry_exhausted` 比例 MUST 小于等于 0.5%
-- **AND** 连接池复用率 MUST 大于等于 75%
-
-#### Scenario: 指标越界可快速回滚
-- **WHEN** 灰度阶段关键指标超出预设阈值
-- **THEN** 系统 MUST 支持按开关快速回滚到稳定路径
-- **AND** 回滚后行为 MUST 与回滚前基线兼容
-
-### Requirement: WSv2 error 事件后的连接必须不可复用
-系统 MUST 在收到上游 `type=error` 事件后将当前连接标记为损坏,避免回池复用。
-
-#### Scenario: error 事件触发统一损坏标记
-- **WHEN** 上游返回 `error` 事件
-- **THEN** 系统 MUST 执行连接损坏标记
-- **AND** 不得因“是否可回退”分支差异而漏标记
-
-### Requirement: WSv2 写上游超时必须继承父 context
-系统 MUST 在写上游 WS 时继承调用方父 `context`,避免客户端已断开时仍长时间阻塞。
-
-#### Scenario: 父 context 已取消
-- **WHEN** 父 `context` 已取消
-- **THEN** 写上游操作 MUST 立即感知取消并返回
-- **AND** MUST NOT 阻塞到默认写超时
-
-### Requirement: 连接池必须具备后台 ping 与后台清理
-系统 MUST 在 `Acquire` 之外提供后台连接维护能力。
-
-#### Scenario: 空闲连接后台心跳
-- **WHEN** 连接处于空闲状态
-- **THEN** 系统 SHALL 按周期对空闲连接执行 ping
-- **AND** ping 失败连接 MUST 被回收
-
-#### Scenario: 长时间无请求账号
-- **WHEN** 某账号长时间无新请求
-- **THEN** 系统 SHALL 仍执行后台清理
-- **AND** 过期/无效连接 MUST 被回收
-
-### Requirement: 连接 I/O 必须支持并发一读一写
-系统 MUST 避免将 WS 读写串行化到同一把锁上。
-
-#### Scenario: 读阻塞期间执行写/Ping
-- **WHEN** 读路径处于阻塞等待
-- **THEN** 写路径 SHOULD 仍可独立推进
-- **AND** 不得因单锁竞争导致心跳/写入长时间饥饿
-
-### Requirement: ingress WS 客户端断连后应继续 drain 上游
-系统 MUST 在 ingress WS 模式下对客户端断连采用“继续 drain 到 terminal”的策略。
-
-#### Scenario: 客户端中途断开
-- **WHEN** 向客户端写事件返回断连错误
-- **THEN** 系统 SHALL 继续读取上游直到 terminal
-- **AND** 连接不得因该断连被立即标记损坏
-
-### Requirement: 状态存储 Redis 操作必须有独立短超时
-系统 MUST 为 WS 状态存储的 Redis 操作设置独立短超时,避免长上下文阻塞。
-
-#### Scenario: Redis 网络异常
-- **WHEN** Redis 操作发生网络抖动/分区
-- **THEN** `set/get/delete` MUST 在短超时内返回
-- **AND** 不得无限依赖上层长连接 context
-
-### Requirement: 协议决策必须对未知认证类型显式回退 HTTP
-系统 MUST 在未知 OpenAI 认证类型下显式回退 HTTP。
-
-#### Scenario: 非 OAuth 且非 API Key 账号
-- **WHEN** 账号认证类型不在已知集合内
-- **THEN** 协议决策 MUST 返回 HTTP
-- **AND** MUST NOT 进入 WS 子开关判定路径
-
-### Requirement: WS 消息读取上限必须受控
-系统 MUST 对 ingress 与上游 WS 客户端统一设置合理读取上限,降低异常大包内存风险。
-
-#### Scenario: 默认读取上限
-- **WHEN** 系统创建 ingress/上游 WS 连接
-- **THEN** 读取上限 MUST 为受控值(16MB)
-- **AND** ingress 与上游配置 MUST 保持一致
-
-### Requirement: 粘连绑定失败必须可观测
-系统 MUST 对 `BindResponseAccount` 失败记录警告日志。
-
-#### Scenario: 粘连绑定异常
-- **WHEN** 状态存储返回绑定错误
-- **THEN** 系统 MUST 记录 `warn` 级别日志
-- **AND** 日志 MUST 包含 group/account/response 标识
-
diff --git a/openspec/specs/schedule-account/spec.md b/openspec/specs/schedule-account/spec.md
deleted file mode 100644
index c70620102..000000000
--- a/openspec/specs/schedule-account/spec.md
+++ /dev/null
@@ -1,11 +0,0 @@
-# schedule-account Specification
-
-## Purpose
-TBD - 由归档变更 refactor-sticky-session-hit-lookup 创建;后续归档后补充本规范的目的说明。
-## Requirements
-### Requirement: Sticky-session 命中复用可调度账号列表
-调度器 SHALL 从请求已加载的“可调度账号列表”中解析 sticky-session 账号选择,并且当 sticky 账号已存在于该列表时,SHALL NOT 额外发起按账号 ID 查询数据库的请求。
-
-#### Scenario: Sticky session 命中且不额外查询数据库
-- **WHEN** 调度请求包含 sticky session,且该 sticky session 指向的账号存在于可调度账号列表中
-- **THEN** 调度器复用该内存中的账号数据,并且不再按账号 ID 查询数据库
diff --git a/openspec/specs/sora-account-apikey/spec.md b/openspec/specs/sora-account-apikey/spec.md
deleted file mode 100644
index e322009b8..000000000
--- a/openspec/specs/sora-account-apikey/spec.md
+++ /dev/null
@@ -1,82 +0,0 @@
-## ADDED Requirements
-
-### Requirement: Sora 平台支持 API Key 账号类型
-系统 SHALL 为 Sora 平台新增 "API Key / 上游透传" 账号类型,取消现有 OAuth 硬编码限制。
-
-#### Scenario: 前端创建 Sora API Key 账号
-- **WHEN** 管理员在账号创建对话框中选择 Sora 平台
-- **THEN** 系统 SHALL 显示两个账号类别选项卡:"OAuth 认证"和"API Key / 上游透传"
-- **AND** 选择"API Key / 上游透传"时 SHALL 显示 `Base URL`(必填)和 `API Key`(必填)表单字段
-- **AND** 提交时 `form.type` SHALL 设置为 `'apikey'`
-
-#### Scenario: Base URL 字段校验
-- **WHEN** 管理员创建或编辑 `platform=sora, type=apikey` 账号
-- **THEN** `base_url` SHALL 为必填
-- **AND** `base_url` SHALL 以 `http://` 或 `https://` 开头
-- **AND** 不满足校验时 SHALL 拒绝保存并提示明确错误
-
-#### Scenario: 取消 Sora OAuth 硬编码
-- **WHEN** 用户选择 Sora 平台
-- **THEN** 系统 SHALL 不再强制设置 `form.type = 'oauth'`
-- **AND** SHALL 允许用户选择 OAuth 或 API Key 类型
-
-### Requirement: Sora API Key 账号编辑
-系统 SHALL 支持编辑 Sora API Key 类型账号的 `base_url` 和 `api_key`。
-
-#### Scenario: 编辑 Sora API Key 账号
-- **WHEN** 管理员编辑一个 `platform=sora, type=apikey` 的账号
-- **THEN** 编辑界面 SHALL 显示 `Base URL` 和 `API Key` 可编辑字段
-- **AND** 保存时 SHALL 更新 `credentials` 中的 `base_url` 和 `api_key`
-
-### Requirement: Sora API Key 账号连通性测试
-系统 SHALL 支持 Sora API Key 账号的连通性测试。
-
-#### Scenario: 测试连通性成功
-- **WHEN** 管理员点击"测试连接"
-- **AND** 上游 `base_url` 可达且 `api_key` 有效
-- **THEN** 系统 SHALL 发送轻量级请求到上游验证连通性
-- **AND** 返回测试成功结果
-
-#### Scenario: 测试连通性失败
-- **WHEN** 上游不可达或认证失败
-- **THEN** 系统 SHALL 返回明确的错误信息(如"连接超时"、"认证失败")
-
-### Requirement: Sora apikey 账号 HTTP 透传
-系统 SHALL 对 `type=apikey` 的 Sora 账号执行 HTTP 透传,而非 SDK 直连。
-
-#### Scenario: apikey 账号走 HTTP 透传
-- **WHEN** `SoraGatewayService.Forward()` 检测到 `account.Type == "apikey"` 且 `account.GetBaseURL() != ""`
-- **THEN** 系统 SHALL 调用 `forwardToUpstream()` 方法
-- **AND** SHALL 不使用 `SoraSDKClient` 直连
-
-#### Scenario: HTTP 透传请求构造
-- **WHEN** 系统执行 `forwardToUpstream()`
-- **THEN** 系统 SHALL 构造 HTTP POST 请求到规范化拼接的 `{base_url}/sora/v1/chat/completions`
-- **AND** Header SHALL 包含 `Authorization: Bearer ` 和 `Content-Type: application/json`
-- **AND** 请求体 SHALL 原样透传客户端请求体
-
-#### Scenario: 流式响应透传
-- **WHEN** 上游返回流式 SSE 响应
-- **THEN** 系统 SHALL 逐字节透传 SSE 流到客户端
-- **AND** SHALL 不缓存完整响应
-
-#### Scenario: 非流式响应透传
-- **WHEN** 上游返回非流式 JSON 响应
-- **THEN** 系统 SHALL 读取完整响应后原样返回客户端
-
-#### Scenario: 上游错误触发失败转移
-- **WHEN** 上游返回 401/403/429/5xx 错误
-- **THEN** 系统 SHALL 复用现有的 `UpstreamFailoverError` 机制触发账号切换
-
-### Requirement: sub2api 二级桥接
-系统 SHALL 通过 API Key 账号类型天然支持 sub2api 级联部署。
-
-#### Scenario: 分站通过 API Key 连接总站
-- **WHEN** 分站创建 Sora API Key 账号,`base_url` 指向总站地址
-- **THEN** 分站的 Sora 请求 SHALL 通过 HTTP 透传到总站的 `/sora/v1/chat/completions`
-- **AND** 总站 SHALL 使用自己的 OAuth 账号连接 OpenAI
-
-#### Scenario: 级联中的存储独立性
-- **WHEN** 分站收到总站返回的生成结果
-- **THEN** 分站 SHALL 根据自己的 S3 配置决定是否存储
-- **AND** 存储行为与总站无关(完全独立)
diff --git a/openspec/specs/sora-client-ui/spec.md b/openspec/specs/sora-client-ui/spec.md
deleted file mode 100644
index bd3466f55..000000000
--- a/openspec/specs/sora-client-ui/spec.md
+++ /dev/null
@@ -1,305 +0,0 @@
-## ADDED Requirements
-
-### Requirement: Sora 客户端路由与菜单
-系统 SHALL 在前端新增 Sora 客户端页面,可通过侧边栏菜单访问。菜单项的显示须与现有侧边栏风格一致,并遵循条件显示、简单模式、双菜单同步等现有模式。
-
-#### Scenario: 路由注册
-- **WHEN** 前端路由初始化
-- **THEN** 系统 SHALL 注册 `/sora` 路由,加载 `SoraView.vue` 页面
-- **AND** 路由 meta SHALL 设置 `requiresAuth: true, requiresAdmin: false`
-
-#### Scenario: 侧边栏菜单项(条件显示)
-- **WHEN** 用户登录后查看侧边栏
-- **AND** 公共设置 `sora_client_enabled` 为 true(后端根据是否存在活跃 Sora 账号自动推断)
-- **THEN** 侧边栏 SHALL 显示"Sora"菜单项
-- **AND** 菜单项 SHALL 使用 Heroicons 线性风格的 Sparkles 图标(与现有侧边栏图标统一为 stroke 风格,`h-5 w-5`)
-- **AND** 点击后 SHALL 跳转到 `/sora` 页面
-
-#### Scenario: 菜单项在管理员未启用 Sora 时隐藏
-- **WHEN** 公共设置 `sora_client_enabled` 为 false(无活跃 Sora 账号)
-- **THEN** 侧边栏 SHALL 不显示"Sora"菜单项
-- **AND** 用户直接访问 `/sora` 时 SHALL 显示功能未启用提示页
-
-#### Scenario: 菜单位置与双菜单同步
-- **WHEN** Sora 菜单项显示
-- **THEN** 对于普通用户(`userNavItems`),Sora SHALL 位于"Dashboard"之后、"API 密钥"之前
-- **AND** 对于管理员"我的账户"区域(`personalNavItems`),Sora SHALL 位于"API 密钥"之后、"使用记录"之前
-- **AND** 两个菜单列表 SHALL 同步添加(确保管理员和普通用户均可访问)
-
-#### Scenario: 简单模式隐藏
-- **WHEN** 系统处于简单模式(`isSimpleMode = true`)
-- **THEN** Sora 菜单项 SHALL 隐藏(`hideInSimpleMode: true`)
-
-### Requirement: Sora 客户端页面内导航
-系统 SHALL 在 Sora 客户端页面顶部显示页面内导航栏,仅包含 Tab 切换和配额信息。Sora 页面嵌入在全局侧边栏布局内,不独立展示 Logo 或用户头像(这些已由全局侧边栏提供)。
-
-#### Scenario: 页面内导航栏显示
-- **WHEN** 用户进入 Sora 客户端页面
-- **THEN** 页面顶部 SHALL 显示页面内导航栏,包含"生成"/"作品库" Tab 切换
-- **AND** 右侧 SHALL 显示配额进度条(如 "2.1GB / 5GB")
-- **AND** 导航栏 SHALL 不包含 Logo 和用户头像(避免与全局侧边栏重复)
-- **AND** Sora 页面 SHALL 保留在全局侧边栏布局内渲染(用户可通过侧边栏随时切换到其他页面)
-
-#### Scenario: Tab 切换
-- **WHEN** 用户点击"生成"或"作品库" Tab
-- **THEN** 页面 SHALL 切换到对应视图,不刷新页面
-
-### Requirement: 生成页面 - 底部创作栏
-系统 SHALL 在生成页底部固定显示创作栏,用于输入提示词和配置生成参数。
-
-#### Scenario: 提示词输入
-- **WHEN** 用户在创作栏输入提示词
-- **THEN** 输入框 SHALL 支持多行文本,自动扩展高度
-- **AND** 支持 Ctrl/Cmd + Enter 快捷键触发生成
-
-#### Scenario: 模型选择
-- **WHEN** 用户点击模型选择器
-- **THEN** 系统 SHALL 从 `GET /api/v1/sora/models` 获取可用模型列表
-- **AND** 下拉菜单 SHALL 按视频模型和图片模型分组显示
-
-#### Scenario: 视频参数配置
-- **WHEN** 用户选择视频模型
-- **THEN** 创作栏 SHALL 显示方向选择(横屏/竖屏/方形)和时长选择(10s/15s/25s)
-
-#### Scenario: 图片模型隐藏视频参数
-- **WHEN** 用户选择图片模型(如 gpt-image-1)
-- **THEN** 创作栏 SHALL 隐藏方向选择和时长选择
-
-#### Scenario: 参考图上传
-- **WHEN** 用户点击图片上传按钮
-- **THEN** 系统 SHALL 允许上传参考图片,作为生成输入的 `image_url`
-
-### Requirement: 生成页面 - 发起生成
-系统 SHALL 通过底部创作栏的"生成"按钮发起 Sora 生成请求。
-
-#### Scenario: 发起视频生成
-- **WHEN** 用户填写提示词并点击"生成"按钮
-- **AND** 当前选择视频模型
-- **THEN** 系统 SHALL 发送 `POST /api/v1/sora/generate`,包含 `prompt`、`model`、`media_type=video`、方向、时长参数
-- **AND** 页面 SHALL 显示新的进度卡片(pending 状态)
-
-#### Scenario: 发起图片生成
-- **WHEN** 用户填写提示词并选择图片模型后点击"生成"
-- **THEN** 系统 SHALL 发送生成请求,`media_type=image`
-- **AND** 页面 SHALL 显示新的进度卡片
-
-#### Scenario: 配额不足预防与提示
-- **WHEN** 用户配额使用率超过 90%
-- **THEN** 配额进度条 SHALL 变为黄色警告色,提示"存储空间即将用完"
-- **AND** 配额使用率达 100% 时,"生成"按钮 SHALL 禁用并显示 tooltip "存储配额已满"
-
-#### Scenario: 配额不足错误引导
-- **WHEN** 生成请求返回 HTTP 429(配额不足)
-- **THEN** 页面 SHALL 弹出配额不足对话框,包含:
- - 当前配额使用详情(已用 / 总配额)
- - 引导文案"您可以在作品库中删除不需要的作品来释放存储空间"
- - "前往作品库"按钮(直接跳转到作品库页面)
-
-### Requirement: 生成页面 - 进度展示
-系统 SHALL 在生成页中间区域实时展示当前生成任务的进度状态。
-
-#### Scenario: 排队中状态
-- **WHEN** 生成记录 `status = 'pending'`
-- **THEN** 进度卡片 SHALL 显示"排队中"状态、灰色状态指示、提示词摘要(前 50 字)
-- **AND** SHALL 显示"取消"按钮
-
-#### Scenario: 生成中状态
-- **WHEN** 生成记录 `status = 'generating'`
-- **THEN** 进度卡片 SHALL 显示"生成中"动画、提示词预览
-- **AND** SHALL 显示已等待时长(如"已等待 3:42")和预估剩余时间(如"预计剩余 8 分钟")
-- **AND** SHALL 显示"取消"按钮
-- **AND** 超过 20 分钟未完成时 SHALL 显示"生成时间异常,建议取消重试"
-
-#### Scenario: 生成完成 - 自动保存成功
-- **WHEN** 生成记录 `status = 'completed'` 且 `storage_type = 's3'`
-- **THEN** 进度卡片 SHALL 显示生成结果预览(视频播放器或图片缩略图)
-- **AND** SHALL 显示 "✓ 已保存到云端" 状态标识
-- **AND** SHALL 提供"📥 本地下载"按钮
-- **AND** 作品自动出现在作品库中
-
-#### Scenario: 生成完成 - 降级本地存储
-- **WHEN** 生成记录 `status = 'completed'` 且 `storage_type = 'local'`
-- **THEN** 进度卡片 SHALL 显示 "✓ 已保存到本地" 状态标识
-- **AND** SHALL 提供"📥 本地下载"按钮
-
-#### Scenario: 生成完成 - 无存储(upstream)
-- **WHEN** 生成记录 `status = 'completed'` 且 `storage_type = 'upstream'`
-- **THEN** 进度卡片 SHALL 显示"📥 本地下载"按钮
-- **AND** SHALL 显示 15 分钟过期倒计时进度条(基于 `completed_at` 计算)
-- **AND** 若 S3 当前可用,SHALL 显示可点击的"☁️ 保存到存储"按钮
-- **AND** 若 S3 不可用,"☁️ 保存到存储"按钮 SHALL 禁用并 tooltip "管理员未开通云存储"
-- **AND** 倒计时结束后 SHALL 禁用所有按钮并显示"链接已过期"
-
-#### Scenario: 生成失败状态
-- **WHEN** 生成记录 `status = 'failed'`
-- **THEN** 进度卡片 SHALL 显示分类错误信息:
- - 上游服务错误 → "服务暂时不可用,建议稍后重试"
- - 内容审核不通过 → "提示词包含不支持的内容,请修改后重试"
- - 超时 → "生成超时,建议降低分辨率或时长后重试"
-- **AND** SHALL 提供"重试"按钮(一键以相同参数重新发起)
-- **AND** SHALL 提供"编辑后重试"按钮(将参数回填到创作栏)
-- **AND** SHALL 提供"删除"按钮
-
-#### Scenario: 任务取消状态
-- **WHEN** 生成记录 `status = 'cancelled'`
-- **THEN** 进度卡片 SHALL 显示"已取消"灰色状态
-- **AND** SHALL 提供"重新生成"和"删除"按钮
-
-### Requirement: 生成页面 - 多任务管理与状态恢复
-系统 SHALL 支持多个并发生成任务的展示和页面刷新后的状态恢复。
-
-#### Scenario: 多任务并发展示
-- **WHEN** 用户有多个进行中或刚完成的生成任务
-- **THEN** 生成页中间区域 SHALL 以时间线方式纵向排列所有任务卡片,最新在最上方
-- **AND** 底部创作栏 SHALL 显示当前活跃任务数(如"正在生成 2/3")
-- **AND** 超过并发上限(3 个)时,"生成"按钮 SHALL 禁用并提示"请等待当前任务完成"
-
-#### Scenario: 页面刷新后恢复任务
-- **WHEN** 用户刷新页面或重新进入 Sora 客户端
-- **THEN** 系统 SHALL 调用 `GET /api/v1/sora/generations?status=pending,generating` 获取进行中任务
-- **AND** SHALL 自动恢复所有进度卡片的显示
-- **AND** SHALL 继续对进行中任务执行轮询
-
-#### Scenario: 前端轮询策略
-- **WHEN** 存在 pending 或 generating 状态的任务
-- **THEN** 前端 SHALL 按递减频率轮询 `GET /api/v1/sora/generations/:id`:
- - 0-2 分钟:每 3 秒
- - 2-10 分钟:每 10 秒
- - 10-30 分钟:每 30 秒
-- **AND** 每次轮询结果 SHALL 更新卡片显示
-- **AND** 卡片上 SHALL 显示"最后更新:N 秒前"以确认数据实时性
-
-#### Scenario: 浏览器通知
-- **WHEN** 生成任务完成或失败
-- **AND** 浏览器标签页不在前台
-- **THEN** 系统 SHALL 通过 Notification API 发送桌面通知
-- **AND** 标签页 title SHALL 闪烁提示(如"(1) ✓ 生成完成 - Sora")
-
-### Requirement: 生成页面 - 无存储提醒
-系统 SHALL 在未配置存储时显示醒目提示。
-
-#### Scenario: 无存储警告
-- **WHEN** 用户进入生成页
-- **AND** S3 和本地存储均未配置
-- **THEN** 创作栏 SHALL 显示警告标签"存储未配置,生成后请立即下载"
-
-#### Scenario: S3 可用时自动保存(正常模式)
-- **WHEN** 管理员已开通 S3 存储
-- **AND** 用户存储配额未超限
-- **THEN** 生成完成后系统 SHALL 自动上传到 S3
-- **AND** 卡片 SHALL 显示"✓ 已保存到云端"
-
-#### Scenario: S3 不可用时的降级提示
-- **WHEN** 管理员未开通 S3 存储(`sora_s3_enabled = false`)
-- **THEN** 生成完成后卡片 SHALL 仅显示"📥 本地下载"按钮
-- **AND** "☁️ 保存到存储"按钮 SHALL 禁用并 tooltip "管理员未开通云存储"
-
-#### Scenario: 手动保存到存储(仅 upstream 记录)
-- **WHEN** 生成记录 `storage_type = 'upstream'` 且 S3 当前可用
-- **THEN** "☁️ 保存到存储"按钮 SHALL 可点击
-- **AND** 点击后 SHALL 调用 `POST /api/v1/sora/generations/:id/save`
-- **AND** 上传过程中按钮 SHALL 显示 loading 状态
-- **AND** 上传成功后按钮 SHALL 变为"✓ 已保存"
-- **AND** 上传失败 SHALL 显示错误信息并允许重试
-
-#### Scenario: 无存储生成完成自动提示下载
-- **WHEN** 生成完成且 `storage_type = 'upstream'`
-- **THEN** 客户端 SHALL 弹出提醒弹窗"文件仅临时保存,请在 15 分钟内下载"
-- **AND** SHALL 显示 15 分钟倒计时
-
-#### Scenario: 离开页面未下载警告
-- **WHEN** 存在 `storage_type = 'upstream'` 且未过期的完成记录
-- **AND** 用户尝试离开或关闭页面
-- **THEN** 系统 SHALL 触发 `beforeunload` 事件警告"您有未下载的生成结果,离开后可能丢失"
-
-### Requirement: 作品库页面 - 网格展示
-系统 SHALL 在作品库页面以网格布局展示用户的历史生成作品。
-
-#### Scenario: 作品网格显示
-- **WHEN** 用户切换到"作品库" Tab
-- **THEN** 系统 SHALL 从 `GET /api/v1/sora/generations?storage_type=s3,local` 获取已保存记录
-- **AND** SHALL 以响应式网格展示作品卡片(桌面 4 列、平板 3 列、移动端 1-2 列)
-- **AND** `storage_type = 'upstream'` 或 `'none'` 的记录 SHALL 不在作品库中显示
-- **AND** S3 作品的 URL SHALL 由后端每次请求时动态生成(避免预签名过期)
-
-#### Scenario: 作品卡片信息
-- **WHEN** 作品卡片渲染
-- **THEN** 每张卡片 SHALL 显示:缩略图/视频预览、类型角标(VIDEO/IMAGE)、模型名称、生成时间
-- **AND** 视频卡片 SHALL 显示播放按钮和时长标签
-
-#### Scenario: 卡片 hover 操作
-- **WHEN** 用户 hover 作品卡片
-- **THEN** SHALL 显示"📥 本地下载"和"🗑 删除"操作按钮
-- **AND** 缩略图 SHALL 轻微放大效果(scale 1.05,transition 0.2s)
-
-### Requirement: 作品库页面 - 筛选
-系统 SHALL 支持按类型筛选作品。
-
-#### Scenario: 全部/视频/图片筛选
-- **WHEN** 用户点击筛选按钮(全部/视频/图片)
-- **THEN** 作品网格 SHALL 只显示对应类型的记录
-- **AND** SHALL 更新显示作品数量
-
-#### Scenario: 空状态
-- **WHEN** 筛选结果为空或用户无任何生成记录
-- **THEN** 页面 SHALL 显示空状态引导(图标 + "暂无作品" + "开始创作"按钮)
-
-### Requirement: 作品详情与操作
-系统 SHALL 支持查看作品详情和执行下载、删除操作。
-
-#### Scenario: 查看作品详情
-- **WHEN** 用户点击作品卡片
-- **THEN** 系统 SHALL 弹出预览弹窗,显示完整的媒体内容、提示词、模型信息、生成时间
-
-#### Scenario: 本地下载作品
-- **WHEN** 用户点击"本地下载"按钮
-- **THEN** 系统 SHALL 触发浏览器下载对应媒体文件
-
-#### Scenario: 保存作品到存储
-- **WHEN** 用户点击"保存到存储"按钮
-- **AND** 管理员已开通 S3 存储
-- **THEN** 系统 SHALL 将媒体文件上传到 S3
-- **AND** 更新生成记录的 `storage_type`、`s3_object_keys`
-- **AND** 累加用户存储配额
-
-#### Scenario: 删除作品
-- **WHEN** 用户点击删除按钮
-- **THEN** 系统 SHALL 弹出确认对话框
-- **AND** 确认后调用 `DELETE /api/v1/sora/generations/:id`
-- **AND** 删除成功后 SHALL 从网格中移除卡片并更新配额显示
-
-### Requirement: 暗色主题设计
-系统 SHALL 采用参考 Sora 官方客户端的暗色主题设计。
-
-#### Scenario: 暗色主题样式
-- **WHEN** 用户访问 Sora 客户端页面
-- **THEN** 页面背景 SHALL 为深黑色(`#0D0D0D`)
-- **AND** 文字 SHALL 为白色/浅灰色
-- **AND** 卡片和输入框 SHALL 使用多层次灰色(`#1A1A1A`、`#242424`、`#2A2A2A`)
-- **AND** 导航栏 SHALL 有毛玻璃效果(`backdrop-filter: blur`)
-
-### Requirement: 响应式布局
-系统 SHALL 支持不同屏幕尺寸下的自适应布局。
-
-#### Scenario: 桌面端布局
-- **WHEN** 屏幕宽度 > 1200px
-- **THEN** 作品网格 SHALL 显示 4 列
-
-#### Scenario: 平板端布局
-- **WHEN** 屏幕宽度 900px - 1200px
-- **THEN** 作品网格 SHALL 调整为 3 列
-
-#### Scenario: 移动端布局
-- **WHEN** 屏幕宽度 < 600px
-- **THEN** 作品网格 SHALL 调整为 1-2 列
-
-### Requirement: 国际化支持
-系统 SHALL 为 Sora 客户端所有文案提供中英文国际化支持。
-
-#### Scenario: 中文环境
-- **WHEN** 系统语言设置为中文
-- **THEN** 所有 Sora 客户端文案 SHALL 显示中文
-
-#### Scenario: 英文环境
-- **WHEN** 系统语言设置为英文
-- **THEN** 所有 Sora 客户端文案 SHALL 显示英文
diff --git a/openspec/specs/sora-generation-gateway/spec.md b/openspec/specs/sora-generation-gateway/spec.md
deleted file mode 100644
index b6574ab28..000000000
--- a/openspec/specs/sora-generation-gateway/spec.md
+++ /dev/null
@@ -1,129 +0,0 @@
-## MODIFIED Requirements
-
-### Requirement: Sora 生成网关入口
-系统 SHALL 提供 `POST /v1/chat/completions` 作为 Sora 生成入口(仅限 `platform=sora` 分组)。
-
-#### Scenario: Sora 分组调用 /v1/chat/completions
-- **WHEN** 请求的 API Key 分组平台为 `sora`
-- **AND** 请求体包含 `model` 与 `messages`
-- **THEN** 网关按 Sora 规则处理并返回流式或非流式结果
-- **AND** 若生成需要流式,网关 SHALL 强制 `stream=true` 或返回明确提示
-
-#### Scenario: Sora 专用路由调用 /sora/v1/chat/completions
-- **WHEN** 客户端请求 `POST /sora/v1/chat/completions`
-- **THEN** 网关 SHALL 强制使用 `platform=sora` 的调度与生成逻辑
-
-#### Scenario: 非流式请求策略
-- **WHEN** 客户端请求 `stream=false`
-- **THEN** 网关 SHALL 选择"强制流式并聚合"或"返回明确错误",并在文档中一致说明
-- **AND** 默认策略 SHALL 为"强制流式并聚合"
-
-#### Scenario: 非 Sora 分组调用 /v1/chat/completions
-- **WHEN** 请求的 API Key 分组平台不为 `sora`
-- **THEN** 网关 SHALL 返回 4xx 并提示不支持该平台
-
-#### Scenario: API Key 直接调用不存储不记录
-- **WHEN** 请求通过 `/sora/v1/chat/completions`(API Key 直接调用路径)
-- **THEN** 网关 SHALL 不将媒体文件上传到 S3
-- **AND** SHALL 不执行本地磁盘媒体落盘
-- **AND** SHALL 不写入 `sora_generations` 表
-- **AND** SHALL 不检查存储配额
-- **AND** SHALL 直接返回上游 URL(保持现有行为)
-
-### Requirement: Sora 调度与失败切换
-系统 SHALL 对 Sora 账号执行调度、并发控制、失败切换,与 OpenAI 调度一致。
-
-#### Scenario: 账号可用时成功调度
-- **WHEN** 至少存在一个可调度的 Sora 账号
-- **THEN** 选择优先级最高且最近未使用的账号,并在完成后刷新 LRU
-
-#### Scenario: 上游失败触发切换
-- **WHEN** 上游返回 401/403/429/5xx
-- **THEN** 网关 SHALL 切换账号并重试,直到达到最大切换次数
-
-#### Scenario: apikey 类型账号调度到 HTTP 透传
-- **WHEN** 调度选中的 Sora 账号 `type = 'apikey'` 且 `base_url` 非空
-- **THEN** 网关 SHALL 调用 `forwardToUpstream()` 执行 HTTP 透传
-- **AND** SHALL 不使用 `SoraSDKClient` 直连
-
-## ADDED Requirements
-
-### Requirement: Sora 客户端生成入口
-系统 SHALL 提供 `POST /api/v1/sora/generate` 作为客户端 UI 专用生成入口。
-
-#### Scenario: 客户端 UI 调用生成
-- **WHEN** 用户通过 Sora 客户端 UI 发起生成请求
-- **THEN** 系统 SHALL 接受请求并内部调用现有 `SoraGatewayService.Forward()` 完成生成
-- **AND** 在上层包装存储/记录/配额逻辑
-
-#### Scenario: 客户端生成流程(异步)
-- **WHEN** `POST /api/v1/sora/generate` 收到请求
-- **THEN** 系统 SHALL 按以下顺序执行:
- 1. 检查存储配额(有效配额 > 0 时)
- 2. 检查用户当前 pending+generating 任务数不超过 3
- 3. 创建 `sora_generations` 记录(status=pending)
- 4. **立即返回** `{ generation_id, status: "pending" }` 给前端
- 5. 后台异步:内部调用 `SoraGatewayService.Forward()` 获取上游媒体 URL(不在该步骤落盘)
- 6. 后台异步:自动上传媒体到 S3(若可用),否则降级到本地/上游 URL
- 7. 后台异步:更新生成记录(status、media_url、storage_type、file_size 等)
- 8. 后台异步:累加存储配额(仅 S3/本地存储时)
-
-#### Scenario: 前端轮询生成状态
-- **WHEN** 前端需要获取生成任务最新状态
-- **THEN** 系统 SHALL 通过 `GET /api/v1/sora/generations/:id` 返回完整记录
-- **AND** 前端 SHALL 按递减频率轮询(3s → 10s → 30s)
-
-#### Scenario: 并发生成上限
-- **WHEN** 用户 pending+generating 状态的任务已达 3 个
-- **THEN** 系统 SHALL 返回 HTTP 429 + "请等待当前任务完成后再发起新任务"
-
-### Requirement: Sora 可用模型列表 API
-系统 SHALL 提供 `GET /api/v1/sora/models` 供客户端 UI 获取可用模型。
-
-#### Scenario: 获取可用 Sora 模型
-- **WHEN** 用户请求 `GET /api/v1/sora/models`
-- **THEN** 系统 SHALL 返回系统内置的 Sora 模型列表
-- **AND** 每个模型 SHALL 包含 `id`、`name`、`media_type`(video/image)、`description`
-
-### Requirement: 手动保存到存储
-系统 SHALL 提供 `POST /api/v1/sora/generations/:id/save` 供用户将未自动保存的作品手动上传到 S3。
-
-#### Scenario: 手动保存 upstream 记录到 S3
-- **WHEN** 用户请求 `POST /api/v1/sora/generations/:id/save`
-- **AND** 该记录 `storage_type = 'upstream'` 且 `media_url` 未过期
-- **AND** S3 存储当前可用
-- **THEN** 系统 SHALL 从 `media_url` 下载媒体并上传到 S3
-- **AND** 更新记录 `storage_type = 's3'`、`s3_object_keys`、`file_size_bytes`
-- **AND** 累加用户存储配额
-
-#### Scenario: 手动保存时 URL 已过期
-- **WHEN** 上游 URL 已过期(下载返回 403/404)
-- **THEN** 系统 SHALL 返回 HTTP 410 + "媒体链接已过期,无法保存"
-
-#### Scenario: 手动保存时 S3 不可用
-- **WHEN** S3 存储未启用或配置不完整
-- **THEN** 系统 SHALL 返回 HTTP 503 + "云存储未配置,请联系管理员"
-
-### Requirement: 取消生成任务
-系统 SHALL 提供 `POST /api/v1/sora/generations/:id/cancel` 供用户取消进行中的生成任务。
-
-#### Scenario: 取消 pending/generating 状态的任务
-- **WHEN** 用户请求 `POST /api/v1/sora/generations/:id/cancel`
-- **AND** 该记录 `status` 为 `pending` 或 `generating`
-- **THEN** 系统 SHALL 将记录状态更新为 `cancelled`
-- **AND** SHALL 不累加任何存储配额
-- **AND** 若上游任务已提交,后续返回的结果 SHALL 被忽略
-
-#### Scenario: 取消非活跃状态的任务
-- **WHEN** 该记录 `status` 为 `completed`、`failed` 或 `cancelled`
-- **THEN** 系统 SHALL 返回 HTTP 409 + "任务已结束,无法取消"
-
-### Requirement: 存储状态查询
-系统 SHALL 提供 `GET /api/v1/sora/storage-status` 供前端查询当前存储可用性。
-
-#### Scenario: 查询存储状态
-- **WHEN** 用户请求 `GET /api/v1/sora/storage-status`
-- **THEN** 系统 SHALL 返回 `{ s3_enabled, s3_healthy, local_enabled }`
-- **AND** `s3_enabled` 表示管理员是否启用 S3
-- **AND** `s3_healthy` 表示 S3 客户端是否初始化成功
-- **AND** `local_enabled` 表示本地存储是否可用
diff --git a/openspec/specs/sora-generation-history/spec.md b/openspec/specs/sora-generation-history/spec.md
deleted file mode 100644
index 5a36554c1..000000000
--- a/openspec/specs/sora-generation-history/spec.md
+++ /dev/null
@@ -1,138 +0,0 @@
-## ADDED Requirements
-
-### Requirement: 生成记录数据模型
-系统 SHALL 新建 `sora_generations` 表存储每次 Sora 客户端 UI 生成的元数据。
-
-#### Scenario: 数据库表创建
-- **WHEN** 数据库迁移执行
-- **THEN** 系统 SHALL 创建 `sora_generations` 表,包含以下字段:
- - `id` (BIGSERIAL PRIMARY KEY)
- - `user_id` (BIGINT NOT NULL, FK → users.id ON DELETE CASCADE)
- - `api_key_id` (BIGINT, 可空)
- - `model` (VARCHAR(64) NOT NULL)
- - `prompt` (TEXT NOT NULL DEFAULT '')
- - `media_type` (VARCHAR(16) NOT NULL DEFAULT 'video')
- - `status` (VARCHAR(16) NOT NULL DEFAULT 'pending')
- - `media_url` (TEXT NOT NULL DEFAULT '')
- - `media_urls` (JSONB, 多图 URL 数组)
- - `file_size_bytes` (BIGINT NOT NULL DEFAULT 0)
- - `storage_type` (VARCHAR(16) NOT NULL DEFAULT 'none')
- - `s3_object_keys` (JSONB, S3 object key 数组)
- - `upstream_task_id` (VARCHAR(128) NOT NULL DEFAULT '')
- - `error_message` (TEXT NOT NULL DEFAULT '')
- - `created_at` (TIMESTAMPTZ NOT NULL DEFAULT NOW())
- - `completed_at` (TIMESTAMPTZ)
-- **AND** SHALL 创建 `(user_id, created_at DESC)` 普通索引(非唯一)
-- **AND** SHALL 创建 `(user_id, status)` 索引
-
-### Requirement: 创建生成记录
-系统 SHALL 在客户端 UI 发起生成时创建记录,并在生成过程中更新状态。
-
-#### Scenario: 发起生成时创建 pending 记录
-- **WHEN** 用户通过 `POST /api/v1/sora/generate` 发起生成
-- **THEN** 系统 SHALL 在 `sora_generations` 中创建一条 `status = 'pending'` 的记录
-- **AND** 记录 SHALL 包含 `user_id`、`model`、`prompt`、`media_type`
-
-#### Scenario: 上游开始处理时更新为 generating
-- **WHEN** 上游开始处理生成任务
-- **THEN** 系统 SHALL 更新记录 `status = 'generating'`
-- **AND** 记录 `upstream_task_id`
-
-#### Scenario: 生成成功时更新为 completed
-- **WHEN** 生成完成且媒体文件存储成功
-- **THEN** 系统 SHALL 更新记录 `status = 'completed'`
-- **AND** 更新 `media_url`、`media_urls`、`file_size_bytes`、`storage_type`、`s3_object_keys`、`completed_at`
-
-#### Scenario: 生成失败时更新为 failed
-- **WHEN** 生成过程中发生错误
-- **THEN** 系统 SHALL 更新记录 `status = 'failed'`
-- **AND** 记录 `error_message`
-
-#### Scenario: 用户取消生成
-- **WHEN** 用户通过 `POST /api/v1/sora/generations/:id/cancel` 取消任务
-- **AND** 记录状态为 `pending` 或 `generating`
-- **THEN** 系统 SHALL 更新记录 `status = 'cancelled'`
-- **AND** SHALL 不累加配额
-
-#### Scenario: 手动保存到存储后更新
-- **WHEN** 用户对 `storage_type = 'upstream'` 的记录手动触发保存
-- **AND** S3 上传成功
-- **THEN** 系统 SHALL 更新 `storage_type = 's3'`、`s3_object_keys`、`file_size_bytes`
-- **AND** 累加存储配额
-
-### Requirement: 查询生成历史列表
-系统 SHALL 提供分页查询用户生成历史的 API。
-
-#### Scenario: 获取用户生成历史
-- **WHEN** 用户请求 `GET /api/v1/sora/generations`
-- **THEN** 系统 SHALL 返回当前用户的生成记录列表,按 `created_at DESC` 排序
-- **AND** 支持分页参数 `page`(默认 1)和 `page_size`(默认 20,最大 100)
-
-#### Scenario: 按媒体类型筛选
-- **WHEN** 请求携带 `media_type=video` 或 `media_type=image`
-- **THEN** 系统 SHALL 只返回对应类型的记录
-
-#### Scenario: 按状态筛选
-- **WHEN** 请求携带 `status=completed`
-- **THEN** 系统 SHALL 只返回对应状态的记录
-
-#### Scenario: 按存储类型筛选(作品库专用)
-- **WHEN** 请求携带 `storage_type=s3,local`
-- **THEN** 系统 SHALL 返回已持久化存储(S3 或本地)的记录
-- **AND** 作品库页面默认 SHALL 使用 `storage_type=s3,local` 筛选,展示所有已保存的作品
-- **AND** `storage_type='upstream'` 和 `'none'` 的记录 SHALL 不在作品库中显示
-
-#### Scenario: 预签名 URL 动态生成
-- **WHEN** 返回 `storage_type = 's3'` 的记录列表
-- **AND** 未配置 CDN URL
-- **THEN** 系统 SHALL 为每条记录动态生成新的 S3 预签名 URL(24 小时有效)
-- **AND** 前端 SHALL 不缓存媒体 URL
-
-#### Scenario: 恢复进行中的任务
-- **WHEN** 请求携带 `status=pending,generating`
-- **THEN** 系统 SHALL 返回用户所有进行中的生成任务
-- **AND** 前端页面加载时 SHALL 调用此接口恢复任务进度显示
-
-### Requirement: 查询生成详情
-系统 SHALL 提供查询单条生成记录详情的 API。
-
-#### Scenario: 获取生成详情
-- **WHEN** 用户请求 `GET /api/v1/sora/generations/:id`
-- **AND** 该记录属于当前用户
-- **THEN** 系统 SHALL 返回完整的生成记录详情
-
-#### Scenario: 访问他人记录返回 404
-- **WHEN** 用户请求的生成记录不属于当前用户
-- **THEN** 系统 SHALL 返回 HTTP 404
-
-### Requirement: 删除生成记录
-系统 SHALL 提供删除生成记录的 API,并联动清理存储文件和配额。
-
-#### Scenario: 删除单条记录
-- **WHEN** 用户请求 `DELETE /api/v1/sora/generations/:id`
-- **AND** 该记录属于当前用户
-- **THEN** 系统 SHALL 删除数据库记录
-- **AND** 若 `storage_type = 's3'`,SHALL 删除 S3 文件
-- **AND** 若 `storage_type = 'local'`,SHALL 删除本地文件
-- **AND** SHALL 释放对应的存储配额
-
-#### Scenario: 删除不存在的记录
-- **WHEN** 记录不存在或不属于当前用户
-- **THEN** 系统 SHALL 返回 HTTP 404
-
-### Requirement: 无存储模式下保留生成历史
-系统 SHALL 在无存储可用时仍记录生成元数据。
-
-#### Scenario: 无存储时记录元数据
-- **WHEN** S3 和本地存储均不可用
-- **AND** 客户端 UI 生成完成
-- **THEN** 系统 SHALL 创建生成记录,`storage_type = 'upstream'`
-- **AND** `media_url` 为上游临时 URL
-- **AND** 系统 SHALL 不累加存储配额
-
-#### Scenario: 过期 URL 标记与倒计时
-- **WHEN** 生成记录的 `storage_type = 'upstream'`
-- **THEN** 客户端 SHALL 显示 15 分钟倒计时进度条(基于 `completed_at` 计算剩余时间)
-- **AND** 剩余 5 分钟时 SHALL 通过浏览器通知提醒用户
-- **AND** 剩余 2 分钟时卡片边框 SHALL 变为红色警告态
-- **AND** 超过 15 分钟后 SHALL 显示"链接已过期,作品无法恢复",禁用下载和保存按钮
diff --git a/openspec/specs/sora-s3-media-storage/spec.md b/openspec/specs/sora-s3-media-storage/spec.md
deleted file mode 100644
index 6d226c62c..000000000
--- a/openspec/specs/sora-s3-media-storage/spec.md
+++ /dev/null
@@ -1,104 +0,0 @@
-## ADDED Requirements
-
-### Requirement: S3 媒体存储服务初始化
-系统 SHALL 在启动时从系统设置(Settings 表)读取 Sora S3 配置,使用 `aws-sdk-go-v2` 初始化 S3 客户端。
-
-#### Scenario: Sora S3 已启用且配置完整
-- **WHEN** 系统启动或 S3 配置变更
-- **AND** Settings 中 `sora_s3_enabled = true` 且必填字段(endpoint、bucket、access_key_id、secret_access_key)均已配置
-- **THEN** 系统 SHALL 使用 `aws-sdk-go-v2` 初始化 S3 客户端
-- **AND** 系统 SHALL 缓存 S3 客户端实例,标记 S3 存储为可用
-
-#### Scenario: Sora S3 未启用或配置不完整
-- **WHEN** 系统启动或 S3 配置变更
-- **AND** `sora_s3_enabled = false` 或缺少必填配置
-- **THEN** 系统 SHALL 标记 S3 存储为不可用
-- **AND** 客户端 UI 调用路径 SHALL 降级为本地存储或即生即下载模式
-
-### Requirement: 媒体文件上传到 S3
-系统 SHALL 将 Sora 客户端 UI 生成的媒体文件流式上传到 S3 兼容存储。
-
-#### Scenario: 视频文件上传成功
-- **WHEN** Sora 客户端 UI 调用路径生成完成,返回上游媒体 URL
-- **AND** S3 存储可用
-- **THEN** 系统 SHALL 使用流式管道(`io.Pipe`)从上游 URL 下载并同时上传到 S3
-- **AND** S3 object key 格式 SHALL 为 `sora/{user_id}/{YYYY/MM/DD}/{uuid}.{ext}`
-- **AND** 上传完成后 SHALL 返回 S3 访问 URL(签名 URL 或 CDN URL)
-- **AND** 系统 SHALL 记录 `s3_object_keys` 数组到生成记录中(视频为单元素数组)
-
-#### Scenario: 图片文件上传成功
-- **WHEN** Sora 客户端 UI 生成图片完成
-- **AND** S3 存储可用
-- **THEN** 系统 SHALL 使用与视频相同的上传流程将图片上传到 S3
-- **AND** 支持多图场景(`media_urls` 数组中每个 URL 都上传)
-
-#### Scenario: S3 上传失败降级
-- **WHEN** S3 上传过程中发生错误(网络超时、权限错误等)
-- **THEN** 系统 SHALL 降级到本地磁盘存储(复用现有 `SoraMediaStorage`)
-- **AND** 若本地存储也失败,SHALL 降级为返回上游临时 URL
-- **AND** 生成记录的 `storage_type` SHALL 反映实际存储位置
-
-#### Scenario: 大文件流式上传避免内存溢出
-- **WHEN** 上游媒体文件大于 50MB
-- **THEN** 系统 SHALL 使用流式管道上传,不将完整文件缓存到内存
-- **AND** 内存峰值 SHALL 不超过 16MB 缓冲区
-
-### Requirement: S3 文件删除
-系统 SHALL 在用户删除生成记录时同步删除 S3 中对应的文件。
-
-#### Scenario: 删除 S3 文件(含多图)
-- **WHEN** 用户通过作品库删除一条生成记录
-- **AND** 该记录的 `storage_type = 's3'` 且 `s3_object_keys` 非空
-- **THEN** 系统 SHALL 遍历 `s3_object_keys` 数组,逐一调用 S3 DeleteObject 删除所有文件
-- **AND** 释放对应的存储配额(`sora_storage_used_bytes` 减去 `file_size_bytes`)
-
-#### Scenario: S3 删除失败不阻塞记录删除
-- **WHEN** S3 DeleteObject 调用失败(部分或全部)
-- **THEN** 系统 SHALL 仍然删除数据库中的生成记录
-- **AND** 系统 SHALL 记录告警日志,包含失败的 `s3_object_keys` 以便后续清理
-
-### Requirement: 三层降级链
-系统 SHALL 支持 S3 → 本地磁盘 → 上游临时 URL 的三层存储降级。
-
-#### Scenario: S3 可用时优先使用 S3
-- **WHEN** 客户端 UI 生成完成
-- **AND** S3 存储可用
-- **THEN** 系统 SHALL 使用 S3 存储,`storage_type = 's3'`
-
-#### Scenario: S3 不可用时降级到本地
-- **WHEN** 客户端 UI 生成完成
-- **AND** S3 存储不可用但本地存储启用
-- **THEN** 系统 SHALL 使用本地存储,`storage_type = 'local'`
-
-#### Scenario: 均不可用时透传上游 URL
-- **WHEN** 客户端 UI 生成完成
-- **AND** S3 和本地存储均不可用
-- **THEN** 系统 SHALL 直接返回上游临时 URL,`storage_type = 'upstream'`
-- **AND** 客户端 SHALL 显示即时下载提示
-
-### Requirement: S3 访问 URL 生成策略
-系统 SHALL 为 S3 中的媒体文件按配置生成可访问 URL(CDN 优先,预签名兜底)。
-
-#### Scenario: 配置 CDN URL 时返回 CDN 地址
-- **WHEN** 系统设置中配置了 `sora_s3_cdn_url`
-- **THEN** 系统 SHALL 返回基于 `sora_s3_cdn_url + object_key` 的访问地址
-- **AND** SHALL 不额外生成预签名 URL
-
-#### Scenario: 未配置 CDN URL 时生成预签名 URL
-- **WHEN** 系统未配置 `sora_s3_cdn_url`
-- **THEN** 系统 SHALL 生成 S3 预签名 URL,有效期 SHALL 为 24 小时
-- **AND** URL SHALL 支持直接在浏览器中播放/查看
-
-### Requirement: 预签名 URL 动态刷新
-系统 SHALL 在返回 S3 媒体记录时动态生成访问 URL,避免预签名过期导致作品库碎图。
-
-#### Scenario: 列表 API 动态生成 URL
-- **WHEN** `GET /api/v1/sora/generations` 返回 `storage_type = 's3'` 的记录
-- **AND** 未配置 CDN URL
-- **THEN** 后端 SHALL 为每条记录的 `s3_object_keys` 动态生成新的预签名 URL 填充到 `media_url` / `media_urls`
-- **AND** 前端 SHALL 不缓存这些 URL
-
-#### Scenario: 详情 API 动态生成 URL
-- **WHEN** `GET /api/v1/sora/generations/:id` 返回 `storage_type = 's3'` 的记录
-- **THEN** 后端 SHALL 动态生成预签名 URL
-- **AND** 批量签名性能 SHALL 不影响列表加载速度(使用并发签名或缓存短期 URL)
diff --git a/openspec/specs/sora-s3-settings/spec.md b/openspec/specs/sora-s3-settings/spec.md
deleted file mode 100644
index da9aea93f..000000000
--- a/openspec/specs/sora-s3-settings/spec.md
+++ /dev/null
@@ -1,39 +0,0 @@
-## ADDED Requirements
-
-### Requirement: Sora S3 存储配置
-系统 SHALL 在系统设置中提供独立的 Sora S3 存储配置,使用 `aws-sdk-go-v2` 直连 S3 兼容存储,不依赖现有数据管理的 gRPC 代理。
-
-#### Scenario: 系统设置新增 Sora S3 配置项
-- **WHEN** 管理员访问系统设置页面
-- **THEN** 页面 SHALL 显示"Sora S3 存储配置"区域
-- **AND** 包含以下配置项:
- - 启用开关(`sora_s3_enabled`)
- - S3 端点(`sora_s3_endpoint`)
- - 区域(`sora_s3_region`)
- - 存储桶(`sora_s3_bucket`)
- - 访问密钥 ID(`sora_s3_access_key_id`)
- - 访问密钥(`sora_s3_secret_access_key`,加密存储,界面显示为密码框)
- - 对象键前缀(`sora_s3_prefix`,可选)
- - 强制路径模式(`sora_s3_force_path_style`,可选)
- - CDN 域名(`sora_s3_cdn_url`,可选)
-
-#### Scenario: 保存 Sora S3 配置
-- **WHEN** 管理员填写 S3 配置并点击保存
-- **THEN** 系统 SHALL 将配置保存到 Settings 表
-- **AND** `sora_s3_secret_access_key` SHALL 加密存储
-- **AND** Sora S3 Storage Service SHALL 刷新缓存的 S3 客户端配置
-
-#### Scenario: 测试 S3 连接
-- **WHEN** 管理员点击"测试连接"按钮
-- **THEN** 系统 SHALL 使用当前表单中的配置创建临时 S3 客户端
-- **AND** 执行 `HeadBucket` 或 `PutObject` + `DeleteObject` 测试连通性
-- **AND** 返回测试结果(成功/失败 + 错误信息)
-
-#### Scenario: 禁用 Sora S3 存储
-- **WHEN** 管理员关闭 `sora_s3_enabled` 开关
-- **THEN** Sora 客户端 UI 的生成结果 SHALL 降级到本地存储或上游 URL 透传
-
-#### Scenario: S3 配置不完整
-- **WHEN** `sora_s3_enabled = true` 但缺少必填字段(endpoint/bucket/access_key_id/secret_access_key)
-- **THEN** 系统 SHALL 视为 S3 存储不可用
-- **AND** SHALL 在日志中记录配置不完整的警告
diff --git a/openspec/specs/sora-user-storage-quota/spec.md b/openspec/specs/sora-user-storage-quota/spec.md
deleted file mode 100644
index ae899c87a..000000000
--- a/openspec/specs/sora-user-storage-quota/spec.md
+++ /dev/null
@@ -1,91 +0,0 @@
-## ADDED Requirements
-
-### Requirement: 用户存储配额字段
-系统 SHALL 在 `users` 表新增 Sora 存储配额字段,用于追踪每个用户的配额和用量。
-
-#### Scenario: 用户表新增配额字段
-- **WHEN** 数据库迁移执行
-- **THEN** `users` 表 SHALL 新增 `sora_storage_quota_bytes BIGINT NOT NULL DEFAULT 0` 字段(0 表示使用系统默认)
-- **AND** `users` 表 SHALL 新增 `sora_storage_used_bytes BIGINT NOT NULL DEFAULT 0` 字段
-
-### Requirement: 系统默认配额设置
-系统 SHALL 提供全局默认 Sora 存储配额设置,管理员可在系统设置中配置。
-
-#### Scenario: 管理员设置全局默认配额
-- **WHEN** 管理员在系统设置中设置 `sora_default_storage_quota_bytes`
-- **THEN** 系统 SHALL 将该值保存到 Settings 表
-- **AND** 所有未单独设置配额的用户 SHALL 使用该默认值
-
-#### Scenario: 未设置全局默认配额
-- **WHEN** `sora_default_storage_quota_bytes` 未设置或为 0
-- **THEN** 系统 SHALL 不限制用户存储空间(即无配额限制)
-
-### Requirement: 配额优先级判断
-系统 SHALL 按用户级 → 分组级 → 系统默认的优先级计算有效配额。
-
-#### Scenario: 用户级配额优先
-- **WHEN** 用户 `sora_storage_quota_bytes > 0`
-- **THEN** 有效配额 SHALL 为用户级配额值
-
-#### Scenario: 分组级配额次优先
-- **WHEN** 用户 `sora_storage_quota_bytes = 0`(未单独设置)
-- **AND** 用户所属分组 `sora_storage_quota_bytes > 0`
-- **THEN** 有效配额 SHALL 为分组级配额值
-
-#### Scenario: 系统默认配额兜底
-- **WHEN** 用户和分组的配额均未设置(均为 0)
-- **THEN** 有效配额 SHALL 为 `settings.sora_default_storage_quota_bytes`
-
-### Requirement: 生成前配额检查
-系统 SHALL 在客户端 UI 调用路径发起生成前检查存储配额。
-
-#### Scenario: 配额充足允许生成
-- **WHEN** 用户发起 Sora 客户端生成请求
-- **AND** `sora_storage_used_bytes < 有效配额`
-- **THEN** 系统 SHALL 允许生成请求继续
-
-#### Scenario: 配额不足拒绝生成
-- **WHEN** 用户发起 Sora 客户端生成请求
-- **AND** `sora_storage_used_bytes >= 有效配额`
-- **AND** 有效配额 > 0
-- **THEN** 系统 SHALL 返回 HTTP 429 错误
-- **AND** 响应 SHALL 包含 `{ quota_bytes, used_bytes, message: "存储配额已满,请删除不需要的作品释放空间" }`
-- **AND** 响应 SHALL 包含 `guide: "delete_works"` 字段,前端据此显示引导对话框
-
-#### Scenario: 无配额限制时不检查
-- **WHEN** 有效配额 = 0(系统默认也未设置)
-- **THEN** 系统 SHALL 跳过配额检查,允许生成
-
-### Requirement: 配额原子更新
-系统 SHALL 使用原子操作更新用户已用存储空间,防止并发超额。
-
-#### Scenario: 生成完成后累加用量
-- **WHEN** 媒体文件上传到 S3/本地存储成功
-- **THEN** 系统 SHALL 在计算出 `effective_quota` 后执行原子 SQL:`UPDATE users SET sora_storage_used_bytes = sora_storage_used_bytes + :file_size WHERE id = :id AND (:effective_quota = 0 OR sora_storage_used_bytes + :file_size <= :effective_quota)`
-- **AND** 若原子更新失败(超额),系统 SHALL 删除已上传的文件并返回配额错误
-
-#### Scenario: 删除作品后释放配额
-- **WHEN** 用户删除一条生成记录
-- **AND** 该记录 `file_size_bytes > 0`
-- **THEN** 系统 SHALL 执行 `UPDATE users SET sora_storage_used_bytes = sora_storage_used_bytes - file_size WHERE id = ?`
-- **AND** `sora_storage_used_bytes` SHALL 不低于 0
-
-### Requirement: 配额查询 API
-系统 SHALL 提供配额查询接口,用户可查看当前用量和剩余空间。
-
-#### Scenario: 查询用户 Sora 配额
-- **WHEN** 用户请求 `GET /api/v1/sora/quota`
-- **THEN** 系统 SHALL 返回 `{ quota_bytes, used_bytes, available_bytes, quota_source }`
-- **AND** `quota_source` SHALL 标明配额来源("user" / "group" / "system" / "unlimited")
-
-### Requirement: 管理员配额管理
-管理员 SHALL 可以在用户管理和分组管理中设置 Sora 存储配额。
-
-#### Scenario: 管理员设置单个用户配额
-- **WHEN** 管理员在用户编辑页面设置 Sora 存储配额
-- **THEN** 系统 SHALL 更新 `users.sora_storage_quota_bytes`
-
-#### Scenario: 管理员设置分组配额
-- **WHEN** 管理员在分组管理中设置 Sora 存储配额
-- **THEN** 系统 SHALL 更新 `groups.sora_storage_quota_bytes` 字段
-- **AND** 该分组下所有未单独设置配额的用户 SHALL 使用分组配额
diff --git a/openspec/specs/timing-wheel/spec.md b/openspec/specs/timing-wheel/spec.md
deleted file mode 100644
index 56280af9d..000000000
--- a/openspec/specs/timing-wheel/spec.md
+++ /dev/null
@@ -1,44 +0,0 @@
-# timing-wheel Specification
-
-## Purpose
-定义应用内 TimingWheel 定时调度能力的行为边界与可验证场景,覆盖一次性任务、周期任务与取消等核心能力。
-## Requirements
-### Requirement: 支持一次性延时任务调度
-系统 SHALL 允许通过 TimingWheel 调度一次性任务,使其在指定延迟后执行。
-
-#### Scenario: 调度一次性任务
-- **WHEN** 调用方提交一个任务并设置延迟时间
-- **THEN** 任务在延迟到期后执行一次
-
-### Requirement: 支持周期任务调度
-系统 SHALL 允许通过 TimingWheel 调度周期任务,使其按固定间隔重复执行。
-
-#### Scenario: 调度周期任务
-- **WHEN** 调用方提交一个周期任务并设置执行间隔
-- **THEN** 任务按该间隔重复执行
-
-### Requirement: 支持取消已调度任务
-系统 SHALL 允许取消已调度的任务,避免其在未来触发执行。
-
-#### Scenario: 取消任务
-- **WHEN** 调用方取消一个已调度的任务
-- **THEN** 该任务后续不会再执行
-
-### Requirement: TimingWheel Initialization Error Handling
-
-当 TimingWheel 初始化失败时,`NewTimingWheelService()` SHALL 返回 error 而不是触发 panic。函数签名 MUST 为 `(*TimingWheelService, error)`,以便调用方能够感知初始化失败并按“启动失败”路径处理。
-
-#### Scenario: TimingWheel 初始化失败时不触发 panic
-- **WHEN** 底层 `collection.NewTimingWheel()` 返回 error
-- **THEN** `NewTimingWheelService()` 返回 `nil` 和包装后的 error(例如使用 `%w` 包装)
-- **AND** 不发生 panic(进程不应因该错误直接崩溃)
-
-#### Scenario: TimingWheel 初始化成功
-- **WHEN** 底层 `collection.NewTimingWheel()` 初始化成功
-- **THEN** `NewTimingWheelService()` 返回有效的 `*TimingWheelService` 和 `nil` error
-
-#### Scenario: 初始化失败导致应用启动失败并退出(非 0)
-- **WHEN** `initializeApplication(...)` 调用 TimingWheel 的 provider/constructor 并收到 error
-- **THEN** `initializeApplication(...)` 将该 error 返回给调用方
-- **AND** `backend/cmd/server/main.go` 记录 fatal 日志并以非 0 状态码退出进程
-
diff --git a/openspec/specs/usage-request-type/spec.md b/openspec/specs/usage-request-type/spec.md
deleted file mode 100644
index 987d666a3..000000000
--- a/openspec/specs/usage-request-type/spec.md
+++ /dev/null
@@ -1,97 +0,0 @@
-# usage-request-type Specification
-
-## Purpose
-TBD - created by archiving change add-usage-request-type-enum. Update Purpose after archive.
-## Requirements
-### Requirement: 系统必须以 request_type 作为使用记录类型的主事实源
-系统 MUST 在 `usage_logs` 中持久化 `request_type` 枚举字段,并将其作为类型展示与筛选的主事实源。
-
-#### Scenario: 新增记录写入 request_type
-- **WHEN** 网关记录一条新的 usage 日志
-- **THEN** 系统 MUST 写入有效的 `request_type` 枚举值
-- **AND** 枚举值 MUST 在约束集合内(`unknown/sync/stream/ws_v2`)
-
-#### Scenario: 读取优先 request_type
-- **WHEN** 系统读取 usage 日志用于 API 返回
-- **THEN** 系统 MUST 优先使用 `request_type` 作为类型来源
-
-### Requirement: 系统必须保持与旧字段兼容
-系统 MUST 在迁移期保持 `stream` 与 `openai_ws_mode` 的向后兼容能力。
-
-#### Scenario: 旧字段仍保留
-- **WHEN** 新版本后端返回 usage 记录
-- **THEN** 响应 MUST 继续包含 `stream` 与 `openai_ws_mode`
-
-#### Scenario: request_type 缺失时回退
-- **WHEN** 历史记录 `request_type` 为 `unknown` 或不可用
-- **THEN** 系统 MUST 按旧字段推导类型
-- **AND** 推导规则 MUST 与既有展示口径一致
-
-#### Scenario: 响应字段保持兼容一致
-- **WHEN** 系统返回一条包含 `request_type` 的 usage 记录
-- **THEN** 响应中的 `stream` 与 `openai_ws_mode` MUST 与 `request_type` 保持一致映射
-- **AND** `request_type=ws_v2` MUST 对应 `openai_ws_mode=true`
-- **AND** `request_type=stream` MUST 对应 `openai_ws_mode=false && stream=true`
-- **AND** `request_type=sync` MUST 对应 `openai_ws_mode=false && stream=false`
-
-### Requirement: 系统必须支持 request_type 查询过滤并兼容 stream 参数
-系统 MUST 提供 `request_type` 过滤能力,并继续兼容历史 `stream` 参数。
-
-#### Scenario: 使用 request_type 过滤列表
-- **WHEN** 客户端请求携带 `request_type`
-- **THEN** 系统 MUST 按 `request_type` 执行过滤
-
-#### Scenario: request_type 参数非法值
-- **WHEN** 客户端请求携带非法 `request_type`(不在 `unknown/sync/stream/ws_v2` 中)
-- **THEN** 系统 MUST 返回 `400 Bad Request`
-- **AND** 错误信息 MUST 提示可接受枚举值
-
-#### Scenario: 旧客户端使用 stream 过滤
-- **WHEN** 客户端仅携带 `stream`
-- **THEN** 系统 MUST 保持历史过滤行为不变
-
-#### Scenario: 同时携带 request_type 与 stream
-- **WHEN** 请求同时携带 `request_type` 与 `stream`
-- **THEN** 系统 MUST 优先按 `request_type` 过滤
-
-#### Scenario: request_type 过滤覆盖所有 usage 入口
-- **WHEN** 客户端访问 usage 列表/统计/趋势/模型/清理任务入口并携带 `request_type`
-- **THEN** 系统 MUST 在对应入口应用一致的 `request_type` 过滤语义
-
-### Requirement: 历史数据迁移必须可在线执行且不破坏旧逻辑
-系统 MUST 提供可在线迁移的回填方案,使历史数据具备 `request_type`,且迁移前后展示口径一致。
-
-#### Scenario: 历史回填映射
-- **WHEN** 执行历史数据回填
-- **THEN** `openai_ws_mode=true` MUST 映射为 `ws_v2`
-- **AND** `openai_ws_mode=false && stream=true` MUST 映射为 `stream`
-- **AND** 其他情况 MUST 映射为 `sync`
-
-#### Scenario: 分批回填
-- **WHEN** 数据量较大
-- **THEN** 回填 MUST 支持分批执行以降低锁与性能风险
-
-### Requirement: 前端必须在新旧后端间保持显示一致
-前端 MUST 支持 `request_type` 优先展示,并在老后端响应中自动回退旧字段推导。
-
-#### Scenario: 新后端响应
-- **WHEN** 响应包含 `request_type`
-- **THEN** 前端 MUST 使用 `request_type` 渲染类型标签与样式
-
-#### Scenario: 老后端响应
-- **WHEN** 响应不包含 `request_type`
-- **THEN** 前端 MUST 使用旧字段推导类型
-- **AND** 渲染结果 MUST 与升级前一致
-
-### Requirement: 升级与回滚必须可独立进行
-系统 MUST 支持数据库、后端、前端分阶段升级与独立回滚,不要求一次性切换。
-
-#### Scenario: 后端先回滚
-- **WHEN** 新数据库已上线但后端回滚到旧版本
-- **THEN** 系统 MUST 继续可用
-- **AND** 旧字段语义 MUST 保持不变
-
-#### Scenario: 前端先升级
-- **WHEN** 前端升级但后端尚未返回 `request_type`
-- **THEN** 前端 MUST 通过回退逻辑保持功能正常
-
From 8b7604b9d733a5ce2005b231adec49bc6570d9c2 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Mon, 2 Mar 2026 18:57:13 +0800
Subject: [PATCH 03/13] =?UTF-8?q?fix(repository):=20=E4=BF=AE=E6=AD=A3?=
=?UTF-8?q?=E4=BD=99=E9=A2=9D=E7=BC=93=E5=AD=98=E7=BC=BA=E5=A4=B1=E9=94=AE?=
=?UTF-8?q?=E6=89=A3=E5=87=8F=E6=B5=8B=E8=AF=95=E6=96=AD=E8=A8=80?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.../internal/repository/billing_cache_integration_test.go | 7 ++++---
1 file changed, 4 insertions(+), 3 deletions(-)
diff --git a/backend/internal/repository/billing_cache_integration_test.go b/backend/internal/repository/billing_cache_integration_test.go
index 6a5983af7..dffe0b6e4 100644
--- a/backend/internal/repository/billing_cache_integration_test.go
+++ b/backend/internal/repository/billing_cache_integration_test.go
@@ -31,14 +31,15 @@ func (s *BillingCacheSuite) TestUserBalance() {
},
},
{
- name: "deduct_on_nonexistent_is_noop",
+ name: "deduct_on_nonexistent_returns_not_found",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(1)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
- require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
+ err := cache.DeductUserBalance(ctx, userID, 1)
+ require.ErrorIs(s.T(), err, service.ErrBalanceCacheNotFound, "DeductUserBalance on non-existent key should return ErrBalanceCacheNotFound")
- _, err := rdb.Get(ctx, balanceKey).Result()
+ _, err = rdb.Get(ctx, balanceKey).Result()
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
},
},
From 3a6c1afbb2fbef2c97f344d8950f7974e5915a95 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Mon, 2 Mar 2026 19:23:22 +0800
Subject: [PATCH 04/13] =?UTF-8?q?fix(openai-ws):=20store=5Fdisabled=20?=
=?UTF-8?q?=E6=A8=A1=E5=BC=8F=E4=B8=8B=20function=5Fcall=5Foutput=20?=
=?UTF-8?q?=E7=BC=BA=E5=A4=B1=20previous=5Fresponse=5Fid=20=E5=AF=BC?=
=?UTF-8?q?=E8=87=B4=20tool=5Foutput=5Fnot=5Ffound=20=E4=B8=8D=E5=8F=AF?=
=?UTF-8?q?=E6=81=A2=E5=A4=8D?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
1. 预防性检测:在 sendAndRelay 发送前拦截 store_disabled + function_call_output + 无 previous_response_id 的必然失败组合,提前返回可恢复错误(wroteDownstream=false),避免上游先写数据再报错
2. 恢复逻辑修复:recoverIngressPrevResponseNotFound 中 previous_response_id 已缺失时跳过 drop 步骤,直接进入 setInputSequence 重放,修复旧代码因 removed=false 提前退出的问题
3. 新增 20 个单元测试覆盖条件矩阵、payload 提取、恢复链路、回归验证及边界条件
Co-Authored-By: Claude Opus 4.6
---
.../internal/service/openai_ws_forwarder.go | 39 +-
...ws_forwarder_proactive_tool_output_test.go | 738 ++++++++++++++++++
2 files changed, 765 insertions(+), 12 deletions(-)
create mode 100644 backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go
diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go
index 4de395f3a..ab14e23ed 100644
--- a/backend/internal/service/openai_ws_forwarder.go
+++ b/backend/internal/service/openai_ws_forwarder.go
@@ -1776,6 +1776,19 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account)
turnFunctionCallOutputCallIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)
turnHasFunctionCallOutput := len(turnFunctionCallOutputCallIDs) > 0
+
+ // 预防性检测:store_disabled 模式下 function_call_output 必须携带 previous_response_id,
+ // 否则上游会先发送部分事件(wroteDownstream=true)再报错 tool_output_not_found,导致无法恢复。
+ // 提前返回可恢复错误,由外层 recoverIngressPrevResponseNotFound 执行 context replay 重试。
+ if turnStoreDisabled && turnHasFunctionCallOutput && strings.TrimSpace(turnPreviousResponseID) == "" {
+ return nil, wrapOpenAIWSIngressTurnErrorWithPartial(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"),
+ false,
+ nil,
+ )
+ }
+
turnPendingFunctionCallIDSet := make(map[string]struct{}, 4)
eventCount := 0
tokenEventCount := 0
@@ -2363,21 +2376,23 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen),
)
turnPrevRecoveryTried = true
- updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
- if dropErr != nil || !removed {
- reason := "not_removed"
+ updatedPayload := currentPayload
+ if currentPreviousResponseID != "" {
+ dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload)
if dropErr != nil {
- reason = "drop_error"
+ logOpenAIWSModeInfo(
+ "ingress_ws_tool_output_not_found_recovery_skip account_id=%d turn=%d conn_id=%s reason=drop_error",
+ account.ID,
+ turn,
+ truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
+ )
+ return false
+ }
+ if removed {
+ updatedPayload = dropped
}
- logOpenAIWSModeInfo(
- "ingress_ws_tool_output_not_found_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s",
- account.ID,
- turn,
- truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen),
- normalizeOpenAIWSLogValue(reason),
- )
- return false
}
+ // previous_response_id 已不存在或已移除,继续执行 setOpenAIWSPayloadInputSequence
updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
updatedPayload,
currentTurnReplayInput,
diff --git a/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go b/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go
new file mode 100644
index 000000000..ea9efb515
--- /dev/null
+++ b/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go
@@ -0,0 +1,738 @@
+package service
+
+import (
+ "encoding/json"
+ "errors"
+ "strings"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/tidwall/gjson"
+)
+
+// ---------------------------------------------------------------------------
+// 改动 1 测试:预防性检测 — store_disabled + function_call_output + 无 previous_response_id
+// 在 sendAndRelay 中提前返回可恢复错误,避免上游先写数据再报错导致无法恢复
+// ---------------------------------------------------------------------------
+
+// TestProactiveToolOutputNotFound_ErrorShape 验证预防性检测生成的错误格式正确:
+// stage = tool_output_not_found、wroteDownstream = false、可被恢复函数识别。
+func TestProactiveToolOutputNotFound_ErrorShape(t *testing.T) {
+ t.Parallel()
+
+ err := wrapOpenAIWSIngressTurnErrorWithPartial(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"),
+ false,
+ nil,
+ )
+
+ // 1. 应被识别为 tool_output_not_found
+ require.True(t, isOpenAIWSIngressToolOutputNotFound(err),
+ "预防性检测错误应被 isOpenAIWSIngressToolOutputNotFound 识别")
+
+ // 2. wroteDownstream 应为 false(尚未发送任何数据)
+ require.False(t, openAIWSIngressTurnWroteDownstream(err),
+ "预防性检测在发送前触发,wroteDownstream 必须为 false")
+
+ // 3. 应被 classifyOpenAIWSIngressTurnAbortReason 归类为 ToolOutput
+ reason, expected := classifyOpenAIWSIngressTurnAbortReason(err)
+ require.Equal(t, openAIWSIngressTurnAbortReasonToolOutput, reason)
+ require.True(t, expected, "tool_output_not_found 是预期错误")
+
+ // 4. 处置方式应为 ContinueTurn(允许恢复重试)
+ disposition := openAIWSIngressTurnAbortDispositionForReason(reason)
+ require.Equal(t, openAIWSIngressTurnAbortDispositionContinueTurn, disposition,
+ "tool_output_not_found 应为 ContinueTurn 处置,允许恢复重试")
+
+ // 5. 部分结果应为 nil(因为尚未产生任何上游响应)
+ // partialResult=nil 时 OpenAIWSIngressTurnPartialResult 返回 (nil, false)
+ partial, ok := OpenAIWSIngressTurnPartialResult(err)
+ require.False(t, ok, "预防性检测传入 partialResult=nil,ok 应为 false")
+ require.Nil(t, partial, "预防性检测不应有部分结果")
+}
+
+// TestProactiveToolOutputNotFound_NotPreviousResponseNotFound
+// 验证预防性检测生成的错误不会被误识别为 previous_response_not_found。
+func TestProactiveToolOutputNotFound_NotPreviousResponseNotFound(t *testing.T) {
+ t.Parallel()
+
+ err := wrapOpenAIWSIngressTurnErrorWithPartial(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("proactive tool_output_not_found"),
+ false,
+ nil,
+ )
+
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(err),
+ "tool_output_not_found 不应被误识别为 previous_response_not_found")
+}
+
+// TestProactiveDetection_ConditionMatrix 验证预防性检测的三个触发条件的所有组合。
+// 只有 (storeDisabled=true, hasFunctionCallOutput=true, previousResponseID="") 同时满足时才触发。
+func TestProactiveDetection_ConditionMatrix(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ storeDisabled bool
+ hasFunctionCallOutput bool
+ previousResponseID string
+ shouldTrigger bool
+ }{
+ {
+ name: "all_conditions_met_should_trigger",
+ storeDisabled: true,
+ hasFunctionCallOutput: true,
+ previousResponseID: "",
+ shouldTrigger: true,
+ },
+ {
+ name: "whitespace_only_previous_response_id_should_trigger",
+ storeDisabled: true,
+ hasFunctionCallOutput: true,
+ previousResponseID: " ",
+ shouldTrigger: true,
+ },
+ {
+ name: "store_enabled_should_not_trigger",
+ storeDisabled: false,
+ hasFunctionCallOutput: true,
+ previousResponseID: "",
+ shouldTrigger: false,
+ },
+ {
+ name: "no_function_call_output_should_not_trigger",
+ storeDisabled: true,
+ hasFunctionCallOutput: false,
+ previousResponseID: "",
+ shouldTrigger: false,
+ },
+ {
+ name: "has_previous_response_id_should_not_trigger",
+ storeDisabled: true,
+ hasFunctionCallOutput: true,
+ previousResponseID: "resp_abc",
+ shouldTrigger: false,
+ },
+ {
+ name: "all_false_should_not_trigger",
+ storeDisabled: false,
+ hasFunctionCallOutput: false,
+ previousResponseID: "resp_abc",
+ shouldTrigger: false,
+ },
+ {
+ name: "store_disabled_no_fco_has_prev_should_not_trigger",
+ storeDisabled: true,
+ hasFunctionCallOutput: false,
+ previousResponseID: "resp_abc",
+ shouldTrigger: false,
+ },
+ {
+ name: "store_enabled_has_fco_has_prev_should_not_trigger",
+ storeDisabled: false,
+ hasFunctionCallOutput: true,
+ previousResponseID: "resp_abc",
+ shouldTrigger: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ // 模拟 sendAndRelay 中的检测逻辑
+ triggered := tt.storeDisabled && tt.hasFunctionCallOutput && strings.TrimSpace(tt.previousResponseID) == ""
+ require.Equal(t, tt.shouldTrigger, triggered)
+ })
+ }
+}
+
+// TestProactiveDetection_PayloadExtraction 验证从真实 payload 中提取的条件参数
+// 与预防性检测逻辑的配合。确保 payload 解析结果能正确触发或跳过检测。
+func TestProactiveDetection_PayloadExtraction(t *testing.T) {
+ t.Parallel()
+
+ tests := []struct {
+ name string
+ payload string
+ wantHasFCO bool
+ wantPrevID string
+ shouldTrigger bool // 假设 storeDisabled=true
+ }{
+ {
+ name: "fco_without_previous_response_id",
+ payload: `{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`,
+ wantHasFCO: true,
+ wantPrevID: "",
+ shouldTrigger: true,
+ },
+ {
+ name: "fco_with_previous_response_id",
+ payload: `{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`,
+ wantHasFCO: true,
+ wantPrevID: "resp_1",
+ shouldTrigger: false,
+ },
+ {
+ name: "no_fco_without_previous_response_id",
+ payload: `{"type":"response.create","input":[{"type":"input_text","text":"hello"}]}`,
+ wantHasFCO: false,
+ wantPrevID: "",
+ shouldTrigger: false,
+ },
+ {
+ name: "multiple_fco_without_previous_response_id",
+ payload: `{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"r1"},{"type":"function_call_output","call_id":"call_2","output":"r2"}]}`,
+ wantHasFCO: true,
+ wantPrevID: "",
+ shouldTrigger: true,
+ },
+ {
+ name: "empty_input_without_previous_response_id",
+ payload: `{"type":"response.create","input":[]}`,
+ wantHasFCO: false,
+ wantPrevID: "",
+ shouldTrigger: false,
+ },
+ {
+ name: "no_input_field",
+ payload: `{"type":"response.create","model":"gpt-5.1"}`,
+ wantHasFCO: false,
+ wantPrevID: "",
+ shouldTrigger: false,
+ },
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ t.Parallel()
+ payload := []byte(tt.payload)
+
+ callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)
+ hasFCO := len(callIDs) > 0
+ prevID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+
+ require.Equal(t, tt.wantHasFCO, hasFCO, "hasFunctionCallOutput 不匹配")
+ require.Equal(t, tt.wantPrevID, prevID, "previousResponseID 不匹配")
+
+ // 模拟 storeDisabled=true 时的检测
+ triggered := true && hasFCO && strings.TrimSpace(prevID) == ""
+ require.Equal(t, tt.shouldTrigger, triggered, "检测触发结果不匹配")
+ })
+ }
+}
+
+// TestProactiveDetection_RecoveryChainIntegration 验证预防性检测错误能被
+// recoverIngressPrevResponseNotFound 的恢复逻辑识别并处理。
+// 这是两个改动之间的集成测试。
+func TestProactiveDetection_RecoveryChainIntegration(t *testing.T) {
+ t.Parallel()
+
+ // 模拟预防性检测返回的错误
+ proactiveErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"),
+ false,
+ nil,
+ )
+
+ // 1. 恢复函数入口检测:isToolOutputMissing 应为 true
+ require.True(t, isOpenAIWSIngressToolOutputNotFound(proactiveErr),
+ "预防性检测错误应通过 isToolOutputMissing 检查")
+
+ // 2. isPrevNotFound 应为 false(确保不走 previous_response_not_found 分支)
+ require.False(t, isOpenAIWSIngressPreviousResponseNotFound(proactiveErr),
+ "预防性检测错误不应走 previous_response_not_found 分支")
+
+ // 3. wroteDownstream=false 确保可以安全重试
+ require.False(t, openAIWSIngressTurnWroteDownstream(proactiveErr),
+ "预防性检测在发送前触发,wroteDownstream 必须为 false")
+}
+
+// ---------------------------------------------------------------------------
+// 改动 2 测试:恢复逻辑修复 — previous_response_id 已缺失时跳过 drop 步骤
+// ---------------------------------------------------------------------------
+
+// TestToolOutputRecovery_PreviousResponseIDAlreadyEmpty 验证核心修复:
+// 当 payload 中本就不存在 previous_response_id 时,跳过 drop 步骤,
+// 直接进入 setOpenAIWSPayloadInputSequence。
+func TestToolOutputRecovery_PreviousResponseIDAlreadyEmpty(t *testing.T) {
+ t.Parallel()
+
+ // 场景:预防性检测触发后,payload 中不存在 previous_response_id
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ require.Empty(t, currentPreviousResponseID, "payload 中不应存在 previous_response_id")
+
+ // 修复后的逻辑:currentPreviousResponseID 为空时跳过 drop
+ // 直接使用原始 payload 作为 updatedPayload
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ // 此分支不应执行
+ t.Fatal("不应进入 drop 分支")
+ }
+
+ // 验证跳过 drop 后,payload 仍然有效且可以继续处理
+ require.True(t, gjson.ValidBytes(updatedPayload), "payload 应为有效 JSON")
+ require.True(t, gjson.GetBytes(updatedPayload, `input.#(type=="function_call_output")`).Exists(),
+ "function_call_output 应保持不变")
+
+ // 模拟 setOpenAIWSPayloadInputSequence:使用 replay input 替换
+ replayInput := []json.RawMessage{
+ json.RawMessage(`{"type":"function_call_output","call_id":"call_1","output":"ok"}`),
+ }
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ replayInput,
+ true,
+ )
+ require.NoError(t, setInputErr, "setOpenAIWSPayloadInputSequence 应成功")
+ require.True(t, gjson.ValidBytes(updatedWithInput), "更新后的 payload 应为有效 JSON")
+ require.True(t, gjson.GetBytes(updatedWithInput, `input.#(type=="function_call_output")`).Exists(),
+ "replay input 中的 function_call_output 应存在")
+}
+
+// TestToolOutputRecovery_PreviousResponseIDExists 验证修复不影响原有逻辑:
+// 当 payload 中存在 previous_response_id 时,仍然执行 drop。
+func TestToolOutputRecovery_PreviousResponseIDExists(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ require.NotEmpty(t, currentPreviousResponseID, "payload 中应存在 previous_response_id")
+
+ // 修复后的逻辑:currentPreviousResponseID 不为空时执行 drop
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErr, "drop 操作不应出错")
+ require.True(t, removed, "previous_response_id 应被成功移除")
+ updatedPayload = dropped
+ }
+
+ // 验证 drop 后 previous_response_id 已移除
+ require.False(t, gjson.GetBytes(updatedPayload, "previous_response_id").Exists(),
+ "previous_response_id 应被移除")
+
+ // 验证 function_call_output 仍然存在
+ require.True(t, gjson.GetBytes(updatedPayload, `input.#(type=="function_call_output")`).Exists(),
+ "function_call_output 应保持不变")
+
+ // 模拟 setOpenAIWSPayloadInputSequence
+ replayInput := []json.RawMessage{
+ json.RawMessage(`{"type":"function_call_output","call_id":"call_1","output":"ok"}`),
+ }
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ replayInput,
+ true,
+ )
+ require.NoError(t, setInputErr, "setOpenAIWSPayloadInputSequence 应成功")
+ require.True(t, gjson.ValidBytes(updatedWithInput), "更新后的 payload 应为有效 JSON")
+}
+
+// TestToolOutputRecovery_DropError 验证当 drop 操作出错时返回 false。
+func TestToolOutputRecovery_DropError(t *testing.T) {
+ t.Parallel()
+
+ // 使用非法 JSON 模拟 drop 错误
+ currentPreviousResponseID := "resp_stale"
+
+ if currentPreviousResponseID != "" {
+ // 使用含有 previous_response_id 但格式异常的 payload 触发 drop 错误
+ badPayload := []byte(`not-valid-json`)
+ _, _, dropErr := dropPreviousResponseIDFromRawPayload(badPayload)
+ // 无论 drop 是否报错,验证错误分支的行为
+ if dropErr != nil {
+ // drop 出错时应返回 false(跳过恢复)
+ t.Log("drop 出错场景:恢复应返回 false")
+ }
+ }
+}
+
+// TestToolOutputRecovery_DropNotRemoved 验证当 drop 操作返回 removed=false 时的行为。
+// 修复后:如果 currentPreviousResponseID 不为空但 drop 未能移除(removed=false),
+// updatedPayload 仍使用原始 payload(不更新),继续执行后续流程。
+func TestToolOutputRecovery_DropNotRemoved(t *testing.T) {
+ t.Parallel()
+
+ // 构造一个含有 previous_response_id 的 payload
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ require.NotEmpty(t, currentPreviousResponseID)
+
+ // 使用自定义 delete 函数模拟 removed=false 的场景
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ // 正常 drop 应该成功
+ dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErr)
+ if removed {
+ updatedPayload = dropped
+ }
+ // 无论 removed 与否,后续逻辑都应继续执行(不再提前 return false)
+ }
+
+ // 验证可以继续执行 setOpenAIWSPayloadInputSequence
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ nil,
+ false, // fullInputExists=false 时直接返回原 payload
+ )
+ require.NoError(t, setInputErr)
+ require.Equal(t, string(updatedPayload), string(updatedWithInput))
+}
+
+// TestToolOutputRecovery_WithDeleteFn_Removed_False 使用 WithDeleteFn 接口模拟
+// drop 函数返回 removed=false 的边界情况。
+func TestToolOutputRecovery_WithDeleteFn_Removed_False(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{"type":"response.create","previous_response_id":"resp_1","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+ currentPreviousResponseID := "resp_1"
+
+ // 使用 noop delete 函数模拟 removed=false
+ noopDelete := func(data []byte, _ string) ([]byte, error) {
+ return data, nil // 不修改 payload,sjson 会返回原样
+ }
+
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ dropped, removed, dropErr := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, noopDelete)
+ require.NoError(t, dropErr)
+ // noop delete 不移除字段,但 previous_response_id 仍存在
+ // removed 取决于 payload 比较
+ if removed {
+ updatedPayload = dropped
+ }
+ }
+
+ // 无论 removed 与否,都应继续(不提前 return false)
+ require.True(t, gjson.ValidBytes(updatedPayload))
+}
+
+// ---------------------------------------------------------------------------
+// 端到端场景测试:预防性检测 → 恢复逻辑 完整链路
+// ---------------------------------------------------------------------------
+
+// TestEndToEnd_ProactiveDetection_RecoveryWithEmptyPreviousResponseID
+// 完整模拟:store_disabled 模式下,客户端发送 function_call_output 但
+// 未携带 previous_response_id → 预防性检测触发 → 恢复逻辑跳过 drop → 重放。
+func TestEndToEnd_ProactiveDetection_RecoveryWithEmptyPreviousResponseID(t *testing.T) {
+ t.Parallel()
+
+ // Step 1: 构造客户端 payload(无 previous_response_id)
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[
+ {"type":"function_call_output","call_id":"call_abc","output":"{\"result\":\"ok\"}"}
+ ]
+ }`)
+
+ // Step 2: 提取条件参数(模拟 sendAndRelay 中的变量赋值)
+ turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ turnFunctionCallOutputCallIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)
+ turnHasFunctionCallOutput := len(turnFunctionCallOutputCallIDs) > 0
+ turnStoreDisabled := true // 模拟 store_disabled 模式
+
+ require.Empty(t, turnPreviousResponseID)
+ require.True(t, turnHasFunctionCallOutput)
+ require.Equal(t, []string{"call_abc"}, turnFunctionCallOutputCallIDs)
+
+ // Step 3: 预防性检测触发
+ shouldTrigger := turnStoreDisabled && turnHasFunctionCallOutput && strings.TrimSpace(turnPreviousResponseID) == ""
+ require.True(t, shouldTrigger, "预防性检测应触发")
+
+ // Step 4: 构造预防性检测错误
+ proactiveErr := wrapOpenAIWSIngressTurnErrorWithPartial(
+ openAIWSIngressStageToolOutputNotFound,
+ errors.New("proactive tool_output_not_found: function_call_output without previous_response_id in store_disabled mode"),
+ false,
+ nil,
+ )
+
+ // Step 5: 验证恢复入口条件
+ isToolOutputMissing := isOpenAIWSIngressToolOutputNotFound(proactiveErr)
+ require.True(t, isToolOutputMissing, "恢复入口应识别 tool_output_not_found")
+
+ // Step 6: 模拟恢复逻辑(改动 2 的核心路径)
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ require.Empty(t, currentPreviousResponseID, "payload 中无 previous_response_id")
+
+ // 改动 2:跳过 drop 步骤
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ t.Fatal("不应进入 drop 分支")
+ }
+
+ // Step 7: 执行 setOpenAIWSPayloadInputSequence
+ replayInput := []json.RawMessage{
+ json.RawMessage(`{"type":"function_call_output","call_id":"call_abc","output":"{\"result\":\"ok\"}"}`),
+ }
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ replayInput,
+ true,
+ )
+ require.NoError(t, setInputErr, "setOpenAIWSPayloadInputSequence 应成功")
+ require.True(t, gjson.ValidBytes(updatedWithInput), "结果应为有效 JSON")
+
+ // Step 8: 验证最终 payload
+ require.False(t, gjson.GetBytes(updatedWithInput, "previous_response_id").Exists(),
+ "最终 payload 不应含 previous_response_id")
+ require.True(t, gjson.GetBytes(updatedWithInput, `input.#(type=="function_call_output")`).Exists(),
+ "最终 payload 应含 function_call_output")
+ require.Equal(t, "call_abc",
+ gjson.GetBytes(updatedWithInput, `input.#(type=="function_call_output").call_id`).String(),
+ "call_id 应保持不变")
+}
+
+// TestEndToEnd_ProactiveDetection_StoreEnabledBypasses 对照测试:
+// store 未禁用时,即使 function_call_output 无 previous_response_id,也不触发预防性检测。
+func TestEndToEnd_ProactiveDetection_StoreEnabledBypasses(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ turnHasFunctionCallOutput := len(openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)) > 0
+ turnStoreDisabled := false // store 未禁用
+
+ shouldTrigger := turnStoreDisabled && turnHasFunctionCallOutput && strings.TrimSpace(turnPreviousResponseID) == ""
+ require.False(t, shouldTrigger, "store 未禁用时不应触发预防性检测")
+}
+
+// TestEndToEnd_ProactiveDetection_WithPreviousResponseIDBypasses 对照测试:
+// store_disabled 但 payload 含有 previous_response_id 时,不触发预防性检测。
+func TestEndToEnd_ProactiveDetection_WithPreviousResponseIDBypasses(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "previous_response_id":"resp_valid",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ turnHasFunctionCallOutput := len(openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)) > 0
+ turnStoreDisabled := true
+
+ shouldTrigger := turnStoreDisabled && turnHasFunctionCallOutput && strings.TrimSpace(turnPreviousResponseID) == ""
+ require.False(t, shouldTrigger, "有 previous_response_id 时不应触发预防性检测")
+}
+
+// TestEndToEnd_NormalTextInput_NeverTriggersProactive 对照测试:
+// 普通文本输入(无 function_call_output)在任何模式下都不触发预防性检测。
+func TestEndToEnd_NormalTextInput_NeverTriggersProactive(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"input_text","text":"hello world"}]
+ }`)
+
+ turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ turnHasFunctionCallOutput := len(openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)) > 0
+ turnStoreDisabled := true
+
+ shouldTrigger := turnStoreDisabled && turnHasFunctionCallOutput && strings.TrimSpace(turnPreviousResponseID) == ""
+ require.False(t, shouldTrigger, "普通文本输入不应触发预防性检测")
+}
+
+// ---------------------------------------------------------------------------
+// 改动 2 与原有 drop 路径的兼容性测试
+// ---------------------------------------------------------------------------
+
+// TestToolOutputRecovery_OriginalFlow_WithPreviousResponseID_StillWorks
+// 验证改动 2 不破坏原有的 tool_output_not_found 恢复流程:
+// 当 previous_response_id 存在时,仍然执行 drop + setInputSequence。
+func TestToolOutputRecovery_OriginalFlow_WithPreviousResponseID_StillWorks(t *testing.T) {
+ t.Parallel()
+
+ // 原有场景:用户按 ESC 取消 function_call 后重新发送消息,
+ // payload 含有过时的 previous_response_id
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "previous_response_id":"resp_stale",
+ "input":[
+ {"type":"function_call_output","call_id":"call_canceled","output":"canceled"},
+ {"type":"input_text","text":"new message"}
+ ]
+ }`)
+
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ require.Equal(t, "resp_stale", currentPreviousResponseID)
+
+ // 改动后的逻辑
+ updatedPayload := payload
+ if currentPreviousResponseID != "" {
+ dropped, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErr)
+ require.True(t, removed)
+ updatedPayload = dropped
+ }
+
+ // 验证 drop 成功
+ require.False(t, gjson.GetBytes(updatedPayload, "previous_response_id").Exists())
+
+ // 验证 input 保持不变
+ inputCount := gjson.GetBytes(updatedPayload, "input.#").Int()
+ require.Equal(t, int64(2), inputCount, "input 数组应保持不变")
+
+ // setOpenAIWSPayloadInputSequence 使用 replay input
+ replayInput := []json.RawMessage{
+ json.RawMessage(`{"type":"input_text","text":"new message"}`),
+ }
+ updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence(
+ updatedPayload,
+ replayInput,
+ true,
+ )
+ require.NoError(t, setInputErr)
+ require.True(t, gjson.ValidBytes(updatedWithInput))
+ require.Equal(t, int64(1), gjson.GetBytes(updatedWithInput, "input.#").Int(),
+ "replay input 应替换原始 input")
+}
+
+// TestToolOutputRecovery_SetInputSequence_NoReplayInput 验证当没有 replay input 时,
+// setOpenAIWSPayloadInputSequence 直接返回原 payload。
+func TestToolOutputRecovery_SetInputSequence_NoReplayInput(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+
+ // fullInputExists=false 时直接返回原 payload
+ result, err := setOpenAIWSPayloadInputSequence(payload, nil, false)
+ require.NoError(t, err)
+ require.Equal(t, string(payload), string(result))
+}
+
+// TestToolOutputRecovery_SetInputSequence_EmptyReplayInput 验证 replay input 为空数组时
+// 仍然能正确设置(替换为空数组)。
+func TestToolOutputRecovery_SetInputSequence_EmptyReplayInput(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`)
+
+ result, err := setOpenAIWSPayloadInputSequence(payload, []json.RawMessage{}, true)
+ require.NoError(t, err)
+ require.True(t, gjson.ValidBytes(result))
+ require.Equal(t, int64(0), gjson.GetBytes(result, "input.#").Int(),
+ "空 replay input 应替换为空数组")
+}
+
+// ---------------------------------------------------------------------------
+// 边界条件与防御性测试
+// ---------------------------------------------------------------------------
+
+// TestProactiveDetection_EmptyPayload 验证空 payload 的行为。
+func TestProactiveDetection_EmptyPayload(t *testing.T) {
+ t.Parallel()
+
+ emptyPayloads := [][]byte{
+ nil,
+ {},
+ []byte(""),
+ }
+
+ for i, payload := range emptyPayloads {
+ callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)
+ require.Empty(t, callIDs, "空 payload[%d] 不应提取到 call_id", i)
+
+ prevID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ require.Empty(t, prevID, "空 payload[%d] 不应有 previous_response_id", i)
+ }
+}
+
+// TestProactiveDetection_MalformedJSON 验证非法 JSON 不会导致 panic。
+func TestProactiveDetection_MalformedJSON(t *testing.T) {
+ t.Parallel()
+
+ badPayloads := [][]byte{
+ []byte(`{invalid json`),
+ []byte(`{"type":"response.create","input":"not_array"}`),
+ []byte(`{"type":"response.create","input":123}`),
+ }
+
+ for i, payload := range badPayloads {
+ // 不应 panic
+ callIDs := openAIWSExtractFunctionCallOutputCallIDsFromPayload(payload)
+ _ = callIDs
+ prevID := openAIWSPayloadStringFromRaw(payload, "previous_response_id")
+ _ = prevID
+
+ // 模拟检测逻辑不应 panic
+ hasFCO := len(callIDs) > 0
+ triggered := true && hasFCO && strings.TrimSpace(prevID) == ""
+ _ = triggered
+ t.Logf("badPayload[%d]: hasFCO=%v, triggered=%v", i, hasFCO, triggered)
+ }
+}
+
+// TestToolOutputRecovery_OldCodeWouldFail_Regression 回归测试:
+// 验证旧代码在 previous_response_id 为空时会因 removed=false 而失败。
+// 新代码跳过 drop 步骤后应成功。
+func TestToolOutputRecovery_OldCodeWouldFail_Regression(t *testing.T) {
+ t.Parallel()
+
+ payload := []byte(`{
+ "type":"response.create",
+ "model":"gpt-5.1",
+ "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]
+ }`)
+
+ // 旧代码行为:直接调用 dropPreviousResponseIDFromRawPayload
+ _, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErr)
+ require.False(t, removed, "payload 中无 previous_response_id 时 drop 返回 removed=false")
+
+ // 旧代码:!removed → return false(恢复失败)
+ oldCodeResult := !(dropErr != nil || !removed) // 旧条件:dropErr != nil || !removed
+ require.False(t, oldCodeResult, "旧代码在此场景会失败(return false)")
+
+ // 新代码行为:先检查 currentPreviousResponseID,为空时跳过 drop
+ currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id"))
+ updatedPayload := payload
+ newCodeSkippedDrop := false
+ if currentPreviousResponseID != "" {
+ dropped, removedNew, dropErrNew := dropPreviousResponseIDFromRawPayload(payload)
+ require.NoError(t, dropErrNew)
+ if removedNew {
+ updatedPayload = dropped
+ }
+ } else {
+ newCodeSkippedDrop = true
+ }
+
+ require.True(t, newCodeSkippedDrop, "新代码应跳过 drop 步骤")
+ require.Equal(t, string(payload), string(updatedPayload), "payload 应保持不变")
+
+ // 新代码继续执行 setOpenAIWSPayloadInputSequence
+ replayInput := []json.RawMessage{
+ json.RawMessage(`{"type":"function_call_output","call_id":"call_1","output":"ok"}`),
+ }
+ result, setErr := setOpenAIWSPayloadInputSequence(updatedPayload, replayInput, true)
+ require.NoError(t, setErr, "新代码应能继续执行 setInputSequence")
+ require.True(t, gjson.ValidBytes(result))
+}
From 748474546c152b4eaa46a594d9ae5457caae5ae0 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Mon, 2 Mar 2026 19:41:01 +0800
Subject: [PATCH 05/13] =?UTF-8?q?fix(openai-ws):=20=E4=BF=AE=E5=A4=8D=20pr?=
=?UTF-8?q?oactive=20tool=20output=20=E5=9B=9E=E5=BD=92=E6=B5=8B=E8=AF=95?=
=?UTF-8?q?=20lint=20=E9=97=AE=E9=A2=98?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
...ws_forwarder_proactive_tool_output_test.go | 26 +++++++++----------
1 file changed, 13 insertions(+), 13 deletions(-)
diff --git a/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go b/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go
index ea9efb515..161915662 100644
--- a/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go
+++ b/backend/internal/service/openai_ws_forwarder_proactive_tool_output_test.go
@@ -74,66 +74,66 @@ func TestProactiveDetection_ConditionMatrix(t *testing.T) {
t.Parallel()
tests := []struct {
- name string
- storeDisabled bool
+ name string
+ storeDisabled bool
hasFunctionCallOutput bool
- previousResponseID string
- shouldTrigger bool
+ previousResponseID string
+ shouldTrigger bool
}{
{
name: "all_conditions_met_should_trigger",
storeDisabled: true,
hasFunctionCallOutput: true,
- previousResponseID: "",
+ previousResponseID: "",
shouldTrigger: true,
},
{
name: "whitespace_only_previous_response_id_should_trigger",
storeDisabled: true,
hasFunctionCallOutput: true,
- previousResponseID: " ",
+ previousResponseID: " ",
shouldTrigger: true,
},
{
name: "store_enabled_should_not_trigger",
storeDisabled: false,
hasFunctionCallOutput: true,
- previousResponseID: "",
+ previousResponseID: "",
shouldTrigger: false,
},
{
name: "no_function_call_output_should_not_trigger",
storeDisabled: true,
hasFunctionCallOutput: false,
- previousResponseID: "",
+ previousResponseID: "",
shouldTrigger: false,
},
{
name: "has_previous_response_id_should_not_trigger",
storeDisabled: true,
hasFunctionCallOutput: true,
- previousResponseID: "resp_abc",
+ previousResponseID: "resp_abc",
shouldTrigger: false,
},
{
name: "all_false_should_not_trigger",
storeDisabled: false,
hasFunctionCallOutput: false,
- previousResponseID: "resp_abc",
+ previousResponseID: "resp_abc",
shouldTrigger: false,
},
{
name: "store_disabled_no_fco_has_prev_should_not_trigger",
storeDisabled: true,
hasFunctionCallOutput: false,
- previousResponseID: "resp_abc",
+ previousResponseID: "resp_abc",
shouldTrigger: false,
},
{
name: "store_enabled_has_fco_has_prev_should_not_trigger",
storeDisabled: false,
hasFunctionCallOutput: true,
- previousResponseID: "resp_abc",
+ previousResponseID: "resp_abc",
shouldTrigger: false,
},
}
@@ -708,7 +708,7 @@ func TestToolOutputRecovery_OldCodeWouldFail_Regression(t *testing.T) {
require.False(t, removed, "payload 中无 previous_response_id 时 drop 返回 removed=false")
// 旧代码:!removed → return false(恢复失败)
- oldCodeResult := !(dropErr != nil || !removed) // 旧条件:dropErr != nil || !removed
+ oldCodeResult := dropErr == nil && removed // 等价于:!(dropErr != nil || !removed)
require.False(t, oldCodeResult, "旧代码在此场景会失败(return false)")
// 新代码行为:先检查 currentPreviousResponseID,为空时跳过 drop
From b4e525a14a622759c59b784fefbb9ecd87088b55 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Mon, 2 Mar 2026 19:49:36 +0800
Subject: [PATCH 06/13] =?UTF-8?q?chore(ci):=20=E5=9B=9E=E9=80=80=20workflo?=
=?UTF-8?q?w=20=E9=85=8D=E7=BD=AE=E4=B8=8E=20main=20=E4=BF=9D=E6=8C=81?=
=?UTF-8?q?=E4=B8=80=E8=87=B4?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
.github/workflows/backend-ci.yml | 10 +++++-----
.github/workflows/release.yml | 22 +++++++++++-----------
.github/workflows/security-scan.yml | 14 +++++---------
3 files changed, 21 insertions(+), 25 deletions(-)
diff --git a/.github/workflows/backend-ci.yml b/.github/workflows/backend-ci.yml
index 4fd22aff6..d21d06841 100644
--- a/.github/workflows/backend-ci.yml
+++ b/.github/workflows/backend-ci.yml
@@ -11,8 +11,8 @@ jobs:
test:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
- - uses: actions/setup-go@v5
+ - uses: actions/checkout@v6
+ - uses: actions/setup-go@v6
with:
go-version-file: backend/go.mod
check-latest: false
@@ -30,8 +30,8 @@ jobs:
golangci-lint:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
- - uses: actions/setup-go@v5
+ - uses: actions/checkout@v6
+ - uses: actions/setup-go@v6
with:
go-version-file: backend/go.mod
check-latest: false
@@ -43,5 +43,5 @@ jobs:
uses: golangci/golangci-lint-action@v9
with:
version: v2.7
- args: --timeout=5m
+ args: --timeout=30m
working-directory: backend
\ No newline at end of file
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 50bb73e0c..a1c6aa233 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -31,7 +31,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
- name: Update VERSION file
run: |
@@ -45,7 +45,7 @@ jobs:
echo "Updated VERSION file to: $VERSION"
- name: Upload VERSION artifact
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v7
with:
name: version-file
path: backend/cmd/server/VERSION
@@ -55,7 +55,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
- name: Setup pnpm
uses: pnpm/action-setup@v4
@@ -63,7 +63,7 @@ jobs:
version: 9
- name: Setup Node.js
- uses: actions/setup-node@v4
+ uses: actions/setup-node@v6
with:
node-version: '20'
cache: 'pnpm'
@@ -78,7 +78,7 @@ jobs:
working-directory: frontend
- name: Upload frontend artifact
- uses: actions/upload-artifact@v4
+ uses: actions/upload-artifact@v7
with:
name: frontend-dist
path: backend/internal/web/dist/
@@ -89,25 +89,25 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
- uses: actions/checkout@v4
+ uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ github.event.inputs.tag || github.ref }}
- name: Download VERSION artifact
- uses: actions/download-artifact@v4
+ uses: actions/download-artifact@v8
with:
name: version-file
path: backend/cmd/server/
- name: Download frontend artifact
- uses: actions/download-artifact@v4
+ uses: actions/download-artifact@v8
with:
name: frontend-dist
path: backend/internal/web/dist/
- name: Setup Go
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v6
with:
go-version-file: backend/go.mod
check-latest: false
@@ -173,7 +173,7 @@ jobs:
run: echo "owner=$(echo '${{ github.repository_owner }}' | tr '[:upper:]' '[:lower:]')" >> $GITHUB_OUTPUT
- name: Run GoReleaser
- uses: goreleaser/goreleaser-action@v6
+ uses: goreleaser/goreleaser-action@v7
with:
version: '~> v2'
args: release --clean --skip=validate ${{ env.SIMPLE_RELEASE == 'true' && '--config=.goreleaser.simple.yaml' || '' }}
@@ -188,7 +188,7 @@ jobs:
# Update DockerHub description
- name: Update DockerHub description
if: ${{ env.SIMPLE_RELEASE != 'true' && env.DOCKERHUB_USERNAME != '' }}
- uses: peter-evans/dockerhub-description@v4
+ uses: peter-evans/dockerhub-description@v5
env:
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
with:
diff --git a/.github/workflows/security-scan.yml b/.github/workflows/security-scan.yml
index fd0c7a411..db9225095 100644
--- a/.github/workflows/security-scan.yml
+++ b/.github/workflows/security-scan.yml
@@ -12,10 +12,11 @@ permissions:
jobs:
backend-security:
runs-on: ubuntu-latest
+ timeout-minutes: 15
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Set up Go
- uses: actions/setup-go@v5
+ uses: actions/setup-go@v6
with:
go-version-file: backend/go.mod
check-latest: false
@@ -28,22 +29,17 @@ jobs:
run: |
go install golang.org/x/vuln/cmd/govulncheck@latest
govulncheck ./...
- - name: Run gosec
- working-directory: backend
- run: |
- go install github.com/securego/gosec/v2/cmd/gosec@latest
- gosec -conf .gosec.json -severity high -confidence high ./...
frontend-security:
runs-on: ubuntu-latest
steps:
- - uses: actions/checkout@v4
+ - uses: actions/checkout@v6
- name: Set up pnpm
uses: pnpm/action-setup@v4
with:
version: 9
- name: Set up Node.js
- uses: actions/setup-node@v4
+ uses: actions/setup-node@v6
with:
node-version: '20'
cache: 'pnpm'
From 4b8aeb71d7c9f74a6e1184186fca015e2de8edf1 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Mon, 2 Mar 2026 20:26:58 +0800
Subject: [PATCH 07/13] fix(admin): restore admin apikey group management and
route compatibility
---
backend/cmd/server/wire_gen.go | 7 +-
.../admin/admin_basic_handlers_test.go | 2 +-
.../handler/admin/admin_service_stub_test.go | 56 +++--
.../internal/handler/admin/apikey_handler.go | 63 ++++++
.../handler/admin/apikey_handler_test.go | 202 ++++++++++++++++++
.../handler/admin/dashboard_handler.go | 14 ++
.../internal/handler/admin/redeem_handler.go | 94 +++++++-
backend/internal/handler/handler.go | 1 +
backend/internal/handler/wire.go | 3 +
backend/internal/repository/api_key_repo.go | 3 +-
backend/internal/repository/user_repo.go | 10 +
backend/internal/server/api_contract_test.go | 2 +-
backend/internal/server/routes/admin.go | 38 ++++
backend/internal/service/admin_service.go | 98 +++++++++
backend/internal/service/redeem_service.go | 27 +++
backend/internal/service/user_service.go | 2 +
16 files changed, 596 insertions(+), 26 deletions(-)
create mode 100644 backend/internal/handler/admin/apikey_handler.go
create mode 100644 backend/internal/handler/admin/apikey_handler_test.go
diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go
index 0880df68e..9e95bd054 100644
--- a/backend/cmd/server/wire_gen.go
+++ b/backend/cmd/server/wire_gen.go
@@ -103,7 +103,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
proxyRepository := repository.NewProxyRepository(client, db)
proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig)
proxyLatencyCache := repository.NewProxyLatencyCache(redisClient)
- adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator)
+ adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, soraAccountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client)
concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig)
concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig)
adminUserHandler := admin.NewUserHandler(adminService, concurrencyService)
@@ -147,7 +147,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService)
antigravityOAuthHandler := admin.NewAntigravityOAuthHandler(antigravityOAuthService)
proxyHandler := admin.NewProxyHandler(adminService)
- adminRedeemHandler := admin.NewRedeemHandler(adminService)
+ adminRedeemHandler := admin.NewRedeemHandler(adminService, redeemService)
promoHandler := admin.NewPromoHandler(promoService)
opsRepository := repository.NewOpsRepository(db)
pricingRemoteClient := repository.ProvidePricingRemoteClient(configConfig)
@@ -192,7 +192,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient)
errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache)
errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService)
- adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler)
+ adminAPIKeyHandler := admin.NewAdminAPIKeyHandler(adminService)
+ adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, adminAPIKeyHandler)
usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig)
gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig)
diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go
index aeb4097f1..4de10d3e2 100644
--- a/backend/internal/handler/admin/admin_basic_handlers_test.go
+++ b/backend/internal/handler/admin/admin_basic_handlers_test.go
@@ -19,7 +19,7 @@ func setupAdminRouter() (*gin.Engine, *stubAdminService) {
userHandler := NewUserHandler(adminSvc, nil)
groupHandler := NewGroupHandler(adminSvc)
proxyHandler := NewProxyHandler(adminSvc)
- redeemHandler := NewRedeemHandler(adminSvc)
+ redeemHandler := NewRedeemHandler(adminSvc, nil)
router.GET("/api/v1/admin/users", userHandler.List)
router.GET("/api/v1/admin/users/:id", userHandler.GetByID)
diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go
index edbc9b856..a84988617 100644
--- a/backend/internal/handler/admin/admin_service_stub_test.go
+++ b/backend/internal/handler/admin/admin_service_stub_test.go
@@ -10,22 +10,23 @@ import (
)
type stubAdminService struct {
- users []service.User
- apiKeys []service.APIKey
- groups []service.Group
- accounts []service.Account
- proxies []service.Proxy
- proxyCounts []service.ProxyWithAccountCount
- redeems []service.RedeemCode
- createdAccounts []*service.CreateAccountInput
- createdProxies []*service.CreateProxyInput
- updatedProxyIDs []int64
- updatedProxies []*service.UpdateProxyInput
- testedProxyIDs []int64
- createAccountErr error
- updateAccountErr error
- checkMixedErr error
- lastMixedCheck struct {
+ users []service.User
+ apiKeys []service.APIKey
+ groups []service.Group
+ accounts []service.Account
+ proxies []service.Proxy
+ proxyCounts []service.ProxyWithAccountCount
+ redeems []service.RedeemCode
+ createdAccounts []*service.CreateAccountInput
+ createdProxies []*service.CreateProxyInput
+ updatedProxyIDs []int64
+ updatedProxies []*service.UpdateProxyInput
+ testedProxyIDs []int64
+ createAccountErr error
+ updateAccountErr error
+ bulkUpdateAccountErr error
+ checkMixedErr error
+ lastMixedCheck struct {
accountID int64
platform string
groupIDs []int64
@@ -236,10 +237,13 @@ func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64,
}
func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) {
+ if s.bulkUpdateAccountErr != nil {
+ return nil, s.bulkUpdateAccountErr
+ }
s.mu.Lock()
s.lastBulkUpdateInput = input
s.mu.Unlock()
- return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil
+ return &service.BulkUpdateAccountsResult{Success: len(input.AccountIDs), Failed: 0, SuccessIDs: input.AccountIDs}, nil
}
func (s *stubAdminService) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
@@ -407,5 +411,23 @@ func (s *stubAdminService) UpdateGroupSortOrders(ctx context.Context, updates []
return nil
}
+func (s *stubAdminService) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
+ for i := range s.apiKeys {
+ if s.apiKeys[i].ID == keyID {
+ k := s.apiKeys[i]
+ if groupID != nil {
+ if *groupID == 0 {
+ k.GroupID = nil
+ } else {
+ gid := *groupID
+ k.GroupID = &gid
+ }
+ }
+ return &service.AdminUpdateAPIKeyGroupIDResult{APIKey: &k}, nil
+ }
+ }
+ return nil, service.ErrAPIKeyNotFound
+}
+
// Ensure stub implements interface.
var _ service.AdminService = (*stubAdminService)(nil)
diff --git a/backend/internal/handler/admin/apikey_handler.go b/backend/internal/handler/admin/apikey_handler.go
new file mode 100644
index 000000000..8dd245a43
--- /dev/null
+++ b/backend/internal/handler/admin/apikey_handler.go
@@ -0,0 +1,63 @@
+package admin
+
+import (
+ "strconv"
+
+ "github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/response"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+
+ "github.com/gin-gonic/gin"
+)
+
+// AdminAPIKeyHandler handles admin API key management
+type AdminAPIKeyHandler struct {
+ adminService service.AdminService
+}
+
+// NewAdminAPIKeyHandler creates a new admin API key handler
+func NewAdminAPIKeyHandler(adminService service.AdminService) *AdminAPIKeyHandler {
+ return &AdminAPIKeyHandler{
+ adminService: adminService,
+ }
+}
+
+// AdminUpdateAPIKeyGroupRequest represents the request to update an API key's group
+type AdminUpdateAPIKeyGroupRequest struct {
+ GroupID *int64 `json:"group_id"` // nil=不修改, 0=解绑, >0=绑定到目标分组
+}
+
+// UpdateGroup handles updating an API key's group binding
+// PUT /api/v1/admin/api-keys/:id
+func (h *AdminAPIKeyHandler) UpdateGroup(c *gin.Context) {
+ keyID, err := strconv.ParseInt(c.Param("id"), 10, 64)
+ if err != nil {
+ response.BadRequest(c, "Invalid API key ID")
+ return
+ }
+
+ var req AdminUpdateAPIKeyGroupRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+
+ result, err := h.adminService.AdminUpdateAPIKeyGroupID(c.Request.Context(), keyID, req.GroupID)
+ if err != nil {
+ response.ErrorFrom(c, err)
+ return
+ }
+
+ resp := struct {
+ APIKey *dto.APIKey `json:"api_key"`
+ AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
+ GrantedGroupID *int64 `json:"granted_group_id,omitempty"`
+ GrantedGroupName string `json:"granted_group_name,omitempty"`
+ }{
+ APIKey: dto.APIKeyFromService(result.APIKey),
+ AutoGrantedGroupAccess: result.AutoGrantedGroupAccess,
+ GrantedGroupID: result.GrantedGroupID,
+ GrantedGroupName: result.GrantedGroupName,
+ }
+ response.Success(c, resp)
+}
diff --git a/backend/internal/handler/admin/apikey_handler_test.go b/backend/internal/handler/admin/apikey_handler_test.go
new file mode 100644
index 000000000..bf128b18a
--- /dev/null
+++ b/backend/internal/handler/admin/apikey_handler_test.go
@@ -0,0 +1,202 @@
+package admin
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "errors"
+ "net/http"
+ "net/http/httptest"
+ "testing"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/gin-gonic/gin"
+ "github.com/stretchr/testify/require"
+)
+
+func setupAPIKeyHandler(adminSvc service.AdminService) *gin.Engine {
+ gin.SetMode(gin.TestMode)
+ router := gin.New()
+ h := NewAdminAPIKeyHandler(adminSvc)
+ router.PUT("/api/v1/admin/api-keys/:id", h.UpdateGroup)
+ return router
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_InvalidID(t *testing.T) {
+ router := setupAPIKeyHandler(newStubAdminService())
+ body := `{"group_id": 2}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/abc", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "Invalid API key ID")
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_InvalidJSON(t *testing.T) {
+ router := setupAPIKeyHandler(newStubAdminService())
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{bad json`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "Invalid request")
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_KeyNotFound(t *testing.T) {
+ router := setupAPIKeyHandler(newStubAdminService())
+ body := `{"group_id": 2}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/999", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ // ErrAPIKeyNotFound maps to 404
+ require.Equal(t, http.StatusNotFound, rec.Code)
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_BindGroup(t *testing.T) {
+ router := setupAPIKeyHandler(newStubAdminService())
+ body := `{"group_id": 2}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data json.RawMessage `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+
+ var data struct {
+ APIKey struct {
+ ID int64 `json:"id"`
+ GroupID *int64 `json:"group_id"`
+ } `json:"api_key"`
+ AutoGrantedGroupAccess bool `json:"auto_granted_group_access"`
+ }
+ require.NoError(t, json.Unmarshal(resp.Data, &data))
+ require.Equal(t, int64(10), data.APIKey.ID)
+ require.NotNil(t, data.APIKey.GroupID)
+ require.Equal(t, int64(2), *data.APIKey.GroupID)
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_Unbind(t *testing.T) {
+ svc := newStubAdminService()
+ gid := int64(2)
+ svc.apiKeys[0].GroupID = &gid
+ router := setupAPIKeyHandler(svc)
+ body := `{"group_id": 0}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var resp struct {
+ Data struct {
+ APIKey struct {
+ GroupID *int64 `json:"group_id"`
+ } `json:"api_key"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Nil(t, resp.Data.APIKey.GroupID)
+}
+
+func TestAdminAPIKeyHandler_UpdateGroup_ServiceError(t *testing.T) {
+ svc := &failingUpdateGroupService{
+ stubAdminService: newStubAdminService(),
+ err: errors.New("internal failure"),
+ }
+ router := setupAPIKeyHandler(svc)
+ body := `{"group_id": 2}`
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(body))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusInternalServerError, rec.Code)
+}
+
+// H2: empty body → group_id is nil → no-op, returns original key
+func TestAdminAPIKeyHandler_UpdateGroup_EmptyBody_NoChange(t *testing.T) {
+ router := setupAPIKeyHandler(newStubAdminService())
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusOK, rec.Code)
+
+ var resp struct {
+ Code int `json:"code"`
+ Data struct {
+ APIKey struct {
+ ID int64 `json:"id"`
+ } `json:"api_key"`
+ } `json:"data"`
+ }
+ require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
+ require.Equal(t, 0, resp.Code)
+ require.Equal(t, int64(10), resp.Data.APIKey.ID)
+}
+
+// M2: service returns GROUP_NOT_ACTIVE → handler maps to 400
+func TestAdminAPIKeyHandler_UpdateGroup_GroupNotActive(t *testing.T) {
+ svc := &failingUpdateGroupService{
+ stubAdminService: newStubAdminService(),
+ err: infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active"),
+ }
+ router := setupAPIKeyHandler(svc)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": 5}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "GROUP_NOT_ACTIVE")
+}
+
+// M2: service returns INVALID_GROUP_ID → handler maps to 400
+func TestAdminAPIKeyHandler_UpdateGroup_NegativeGroupID(t *testing.T) {
+ svc := &failingUpdateGroupService{
+ stubAdminService: newStubAdminService(),
+ err: infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative"),
+ }
+ router := setupAPIKeyHandler(svc)
+
+ rec := httptest.NewRecorder()
+ req := httptest.NewRequest(http.MethodPut, "/api/v1/admin/api-keys/10", bytes.NewBufferString(`{"group_id": -5}`))
+ req.Header.Set("Content-Type", "application/json")
+ router.ServeHTTP(rec, req)
+
+ require.Equal(t, http.StatusBadRequest, rec.Code)
+ require.Contains(t, rec.Body.String(), "INVALID_GROUP_ID")
+}
+
+// failingUpdateGroupService overrides AdminUpdateAPIKeyGroupID to return an error.
+type failingUpdateGroupService struct {
+ *stubAdminService
+ err error
+}
+
+func (f *failingUpdateGroupService) AdminUpdateAPIKeyGroupID(_ context.Context, _ int64, _ *int64) (*service.AdminUpdateAPIKeyGroupIDResult, error) {
+ return nil, f.err
+}
diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go
index 7e3185926..ea2572fa4 100644
--- a/backend/internal/handler/admin/dashboard_handler.go
+++ b/backend/internal/handler/admin/dashboard_handler.go
@@ -333,6 +333,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) {
})
}
+// GetGroupStats handles getting group usage statistics.
+// GET /api/v1/admin/dashboard/groups
+//
+// NOTE: Group-level aggregation pipeline is not available in this branch baseline.
+// Keep this endpoint for API compatibility and return an empty dataset.
+func (h *DashboardHandler) GetGroupStats(c *gin.Context) {
+ startTime, endTime := parseTimeRange(c)
+ response.Success(c, gin.H{
+ "groups": []any{},
+ "start_date": startTime.Format("2006-01-02"),
+ "end_date": endTime.Add(-24 * time.Hour).Format("2006-01-02"),
+ })
+}
+
// GetAPIKeyUsageTrend handles getting API key usage trend data
// GET /api/v1/admin/dashboard/api-keys-trend
// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), limit (default 5)
diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go
index 7073061d5..0a932ee98 100644
--- a/backend/internal/handler/admin/redeem_handler.go
+++ b/backend/internal/handler/admin/redeem_handler.go
@@ -4,11 +4,13 @@ import (
"bytes"
"context"
"encoding/csv"
+ "errors"
"fmt"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/handler/dto"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -17,13 +19,15 @@ import (
// RedeemHandler handles admin redeem code management
type RedeemHandler struct {
- adminService service.AdminService
+ adminService service.AdminService
+ redeemService *service.RedeemService
}
// NewRedeemHandler creates a new admin redeem handler
-func NewRedeemHandler(adminService service.AdminService) *RedeemHandler {
+func NewRedeemHandler(adminService service.AdminService, redeemService *service.RedeemService) *RedeemHandler {
return &RedeemHandler{
- adminService: adminService,
+ adminService: adminService,
+ redeemService: redeemService,
}
}
@@ -36,6 +40,15 @@ type GenerateRedeemCodesRequest struct {
ValidityDays int `json:"validity_days" binding:"omitempty,max=36500"` // 订阅类型使用,默认30天,最大100年
}
+// CreateAndRedeemCodeRequest represents creating a fixed code and redeeming it for a target user.
+type CreateAndRedeemCodeRequest struct {
+ Code string `json:"code" binding:"required,min=3,max=128"`
+ Type string `json:"type" binding:"required,oneof=balance concurrency subscription invitation"`
+ Value float64 `json:"value" binding:"required,gt=0"`
+ UserID int64 `json:"user_id" binding:"required,gt=0"`
+ Notes string `json:"notes"`
+}
+
// List handles listing all redeem codes with pagination
// GET /api/v1/admin/redeem-codes
func (h *RedeemHandler) List(c *gin.Context) {
@@ -109,6 +122,81 @@ func (h *RedeemHandler) Generate(c *gin.Context) {
})
}
+// CreateAndRedeem creates a fixed redeem code and redeems it for a target user in one step.
+// POST /api/v1/admin/redeem-codes/create-and-redeem
+func (h *RedeemHandler) CreateAndRedeem(c *gin.Context) {
+ if h.redeemService == nil {
+ response.InternalError(c, "redeem service not configured")
+ return
+ }
+
+ var req CreateAndRedeemCodeRequest
+ if err := c.ShouldBindJSON(&req); err != nil {
+ response.BadRequest(c, "Invalid request: "+err.Error())
+ return
+ }
+ req.Code = strings.TrimSpace(req.Code)
+
+ executeAdminIdempotentJSON(c, "admin.redeem_codes.create_and_redeem", req, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) {
+ existing, err := h.redeemService.GetByCode(ctx, req.Code)
+ if err == nil {
+ return h.resolveCreateAndRedeemExisting(ctx, existing, req.UserID)
+ }
+ if !errors.Is(err, service.ErrRedeemCodeNotFound) {
+ return nil, err
+ }
+
+ createErr := h.redeemService.CreateCode(ctx, &service.RedeemCode{
+ Code: req.Code,
+ Type: req.Type,
+ Value: req.Value,
+ Status: service.StatusUnused,
+ Notes: req.Notes,
+ })
+ if createErr != nil {
+ // Unique code race: if code now exists, use idempotent semantics by used_by.
+ existingAfterCreateErr, getErr := h.redeemService.GetByCode(ctx, req.Code)
+ if getErr == nil {
+ return h.resolveCreateAndRedeemExisting(ctx, existingAfterCreateErr, req.UserID)
+ }
+ return nil, createErr
+ }
+
+ redeemed, redeemErr := h.redeemService.Redeem(ctx, req.UserID, req.Code)
+ if redeemErr != nil {
+ return nil, redeemErr
+ }
+ return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
+ })
+}
+
+func (h *RedeemHandler) resolveCreateAndRedeemExisting(ctx context.Context, existing *service.RedeemCode, userID int64) (any, error) {
+ if existing == nil {
+ return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code conflict")
+ }
+
+ // If previous run created the code but crashed before redeem, redeem it now.
+ if existing.CanUse() {
+ redeemed, err := h.redeemService.Redeem(ctx, userID, existing.Code)
+ if err == nil {
+ return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(redeemed)}, nil
+ }
+ if !errors.Is(err, service.ErrRedeemCodeUsed) {
+ return nil, err
+ }
+ latest, getErr := h.redeemService.GetByCode(ctx, existing.Code)
+ if getErr == nil {
+ existing = latest
+ }
+ }
+
+ if existing.UsedBy != nil && *existing.UsedBy == userID {
+ return gin.H{"redeem_code": dto.RedeemCodeFromServiceAdmin(existing)}, nil
+ }
+
+ return nil, infraerrors.Conflict("REDEEM_CODE_CONFLICT", "redeem code already used by another user")
+}
+
// Delete handles deleting a redeem code
// DELETE /api/v1/admin/redeem-codes/:id
func (h *RedeemHandler) Delete(c *gin.Context) {
diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go
index bbf4be4b7..1e1247fc8 100644
--- a/backend/internal/handler/handler.go
+++ b/backend/internal/handler/handler.go
@@ -26,6 +26,7 @@ type AdminHandlers struct {
Usage *admin.UsageHandler
UserAttribute *admin.UserAttributeHandler
ErrorPassthrough *admin.ErrorPassthroughHandler
+ APIKey *admin.AdminAPIKeyHandler
}
// Handlers contains all HTTP handlers
diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go
index f1a21119b..76f5a9796 100644
--- a/backend/internal/handler/wire.go
+++ b/backend/internal/handler/wire.go
@@ -29,6 +29,7 @@ func ProvideAdminHandlers(
usageHandler *admin.UsageHandler,
userAttributeHandler *admin.UserAttributeHandler,
errorPassthroughHandler *admin.ErrorPassthroughHandler,
+ apiKeyHandler *admin.AdminAPIKeyHandler,
) *AdminHandlers {
return &AdminHandlers{
Dashboard: dashboardHandler,
@@ -51,6 +52,7 @@ func ProvideAdminHandlers(
Usage: usageHandler,
UserAttribute: userAttributeHandler,
ErrorPassthrough: errorPassthroughHandler,
+ APIKey: apiKeyHandler,
}
}
@@ -138,6 +140,7 @@ var ProviderSet = wire.NewSet(
admin.NewUsageHandler,
admin.NewUserAttributeHandler,
admin.NewErrorPassthroughHandler,
+ admin.NewAdminAPIKeyHandler,
// AdminHandlers and Handlers constructors
ProvideAdminHandlers,
diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go
index a9faf388f..b9ce60a57 100644
--- a/backend/internal/repository/api_key_repo.go
+++ b/backend/internal/repository/api_key_repo.go
@@ -171,8 +171,9 @@ func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) erro
// 则会更新已删除的记录。
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
+ client := clientFromContext(ctx, r.client)
now := time.Now()
- builder := r.client.APIKey.Update().
+ builder := client.APIKey.Update().
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
SetName(key.Name).
SetStatus(key.Status).
diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go
index bc00e64df..05b689689 100644
--- a/backend/internal/repository/user_repo.go
+++ b/backend/internal/repository/user_repo.go
@@ -429,6 +429,16 @@ func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool,
return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx)
}
+func (r *userRepository) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
+ client := clientFromContext(ctx, r.client)
+ return client.UserAllowedGroup.Create().
+ SetUserID(userID).
+ SetGroupID(groupID).
+ OnConflictColumns(userallowedgroup.FieldUserID, userallowedgroup.FieldGroupID).
+ DoNothing().
+ Exec(ctx)
+}
+
func (r *userRepository) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
// 仅操作 user_allowed_groups 联接表,legacy users.allowed_groups 列已弃用。
affected, err := r.client.UserAllowedGroup.Delete().
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index c98086e0d..d2cd2a5ee 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -619,7 +619,7 @@ func newContractDeps(t *testing.T) *contractDeps {
settingRepo := newStubSettingRepo()
settingService := service.NewSettingService(settingRepo, cfg)
- adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil)
+ adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, nil, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil)
authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil)
apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService)
usageHandler := handler.NewUsageHandler(usageService, apiKeyService)
diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go
index 2c92c3d41..a901b6fc5 100644
--- a/backend/internal/server/routes/admin.go
+++ b/backend/internal/server/routes/admin.go
@@ -26,6 +26,9 @@ func RegisterAdminRoutes(
// 分组管理
registerGroupRoutes(admin, h)
+ // API Key 管理
+ registerAdminAPIKeyRoutes(admin, h)
+
// 账号管理
registerAccountRoutes(admin, h)
@@ -55,6 +58,9 @@ func RegisterAdminRoutes(
// 系统设置
registerSettingsRoutes(admin, h)
+ // 数据管理
+ registerDataManagementRoutes(admin, h)
+
// 运维监控(Ops)
registerOpsRoutes(admin, h)
@@ -171,6 +177,7 @@ func registerDashboardRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
dashboard.GET("/realtime", h.Admin.Dashboard.GetRealtimeMetrics)
dashboard.GET("/trend", h.Admin.Dashboard.GetUsageTrend)
dashboard.GET("/models", h.Admin.Dashboard.GetModelStats)
+ dashboard.GET("/groups", h.Admin.Dashboard.GetGroupStats)
dashboard.GET("/api-keys-trend", h.Admin.Dashboard.GetAPIKeyUsageTrend)
dashboard.GET("/users-trend", h.Admin.Dashboard.GetUserUsageTrend)
dashboard.POST("/users-usage", h.Admin.Dashboard.GetBatchUsersUsage)
@@ -213,6 +220,13 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
+func registerAdminAPIKeyRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ apiKeys := admin.Group("/api-keys")
+ {
+ apiKeys.PUT("/:id", h.Admin.APIKey.UpdateGroup)
+ }
+}
+
func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
accounts := admin.Group("/accounts")
{
@@ -338,6 +352,7 @@ func registerRedeemCodeRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
codes.GET("/stats", h.Admin.Redeem.GetStats)
codes.GET("/export", h.Admin.Redeem.Export)
codes.GET("/:id", h.Admin.Redeem.GetByID)
+ codes.POST("/create-and-redeem", h.Admin.Redeem.CreateAndRedeem)
codes.POST("/generate", h.Admin.Redeem.Generate)
codes.DELETE("/:id", h.Admin.Redeem.Delete)
codes.POST("/batch-delete", h.Admin.Redeem.BatchDelete)
@@ -395,6 +410,29 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
}
}
+func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
+ dataManagement := admin.Group("/data-management")
+ {
+ dataManagement.GET("/agent/health", h.Admin.DataManagement.GetAgentHealth)
+ dataManagement.GET("/config", h.Admin.DataManagement.GetConfig)
+ dataManagement.PUT("/config", h.Admin.DataManagement.UpdateConfig)
+ dataManagement.GET("/sources/:source_type/profiles", h.Admin.DataManagement.ListSourceProfiles)
+ dataManagement.POST("/sources/:source_type/profiles", h.Admin.DataManagement.CreateSourceProfile)
+ dataManagement.PUT("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.UpdateSourceProfile)
+ dataManagement.DELETE("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.DeleteSourceProfile)
+ dataManagement.POST("/sources/:source_type/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveSourceProfile)
+ dataManagement.POST("/s3/test", h.Admin.DataManagement.TestS3)
+ dataManagement.GET("/s3/profiles", h.Admin.DataManagement.ListS3Profiles)
+ dataManagement.POST("/s3/profiles", h.Admin.DataManagement.CreateS3Profile)
+ dataManagement.PUT("/s3/profiles/:profile_id", h.Admin.DataManagement.UpdateS3Profile)
+ dataManagement.DELETE("/s3/profiles/:profile_id", h.Admin.DataManagement.DeleteS3Profile)
+ dataManagement.POST("/s3/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveS3Profile)
+ dataManagement.POST("/backups", h.Admin.DataManagement.CreateBackupJob)
+ dataManagement.GET("/backups", h.Admin.DataManagement.ListBackupJobs)
+ dataManagement.GET("/backups/:job_id", h.Admin.DataManagement.GetBackupJob)
+ }
+}
+
func registerSystemRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
system := admin.Group("/system")
{
diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go
index 76890206a..1fa22f1d1 100644
--- a/backend/internal/service/admin_service.go
+++ b/backend/internal/service/admin_service.go
@@ -9,6 +9,7 @@ import (
"strings"
"time"
+ dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
@@ -42,6 +43,7 @@ type AdminService interface {
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error
+ AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error)
@@ -244,6 +246,14 @@ type BulkUpdateAccountResult struct {
Error string `json:"error,omitempty"`
}
+// AdminUpdateAPIKeyGroupIDResult is the result of AdminUpdateAPIKeyGroupID.
+type AdminUpdateAPIKeyGroupIDResult struct {
+ APIKey *APIKey
+ AutoGrantedGroupAccess bool
+ GrantedGroupID *int64
+ GrantedGroupName string
+}
+
// BulkUpdateAccountsResult is the aggregated response for bulk updates.
type BulkUpdateAccountsResult struct {
Success int `json:"success"`
@@ -408,6 +418,7 @@ type adminServiceImpl struct {
proxyProber ProxyExitInfoProber
proxyLatencyCache ProxyLatencyCache
authCacheInvalidator APIKeyAuthCacheInvalidator
+ entClient *dbent.Client // 用于开启数据库事务
}
type userGroupRateBatchReader interface {
@@ -432,6 +443,7 @@ func NewAdminService(
proxyProber ProxyExitInfoProber,
proxyLatencyCache ProxyLatencyCache,
authCacheInvalidator APIKeyAuthCacheInvalidator,
+ entClient *dbent.Client,
) AdminService {
return &adminServiceImpl{
userRepo: userRepo,
@@ -446,6 +458,7 @@ func NewAdminService(
proxyProber: proxyProber,
proxyLatencyCache: proxyLatencyCache,
authCacheInvalidator: authCacheInvalidator,
+ entClient: entClient,
}
}
@@ -1187,6 +1200,91 @@ func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []
return s.groupRepo.UpdateSortOrders(ctx, updates)
}
+// AdminUpdateAPIKeyGroupID allows admins to update API key-group binding.
+// groupID: nil=unchanged, 0=unbind, >0=bind to target group.
+func (s *adminServiceImpl) AdminUpdateAPIKeyGroupID(ctx context.Context, keyID int64, groupID *int64) (*AdminUpdateAPIKeyGroupIDResult, error) {
+ apiKey, err := s.apiKeyRepo.GetByID(ctx, keyID)
+ if err != nil {
+ return nil, err
+ }
+
+ if groupID == nil {
+ return &AdminUpdateAPIKeyGroupIDResult{APIKey: apiKey}, nil
+ }
+ if *groupID < 0 {
+ return nil, infraerrors.BadRequest("INVALID_GROUP_ID", "group_id must be non-negative")
+ }
+
+ result := &AdminUpdateAPIKeyGroupIDResult{}
+ if *groupID == 0 {
+ apiKey.GroupID = nil
+ apiKey.Group = nil
+ } else {
+ group, err := s.groupRepo.GetByID(ctx, *groupID)
+ if err != nil {
+ return nil, err
+ }
+ if group.Status != StatusActive {
+ return nil, infraerrors.BadRequest("GROUP_NOT_ACTIVE", "target group is not active")
+ }
+ if group.IsSubscriptionType() {
+ return nil, infraerrors.BadRequest("SUBSCRIPTION_GROUP_NOT_ALLOWED", "subscription groups must be managed through the subscription workflow")
+ }
+
+ gid := *groupID
+ apiKey.GroupID = &gid
+ apiKey.Group = group
+
+ if group.IsExclusive {
+ opCtx := ctx
+ var tx *dbent.Tx
+ if s.entClient == nil {
+ logger.LegacyPrintf("service.admin", "Warning: entClient is nil, skipping transaction protection for exclusive group binding")
+ } else {
+ var txErr error
+ tx, txErr = s.entClient.Tx(ctx)
+ if txErr != nil {
+ return nil, fmt.Errorf("begin transaction: %w", txErr)
+ }
+ defer func() { _ = tx.Rollback() }()
+ opCtx = dbent.NewTxContext(ctx, tx)
+ }
+
+ if addErr := s.userRepo.AddGroupToAllowedGroups(opCtx, apiKey.UserID, gid); addErr != nil {
+ return nil, fmt.Errorf("add group to user allowed groups: %w", addErr)
+ }
+ if err := s.apiKeyRepo.Update(opCtx, apiKey); err != nil {
+ return nil, fmt.Errorf("update api key: %w", err)
+ }
+ if tx != nil {
+ if err := tx.Commit(); err != nil {
+ return nil, fmt.Errorf("commit transaction: %w", err)
+ }
+ }
+
+ result.AutoGrantedGroupAccess = true
+ result.GrantedGroupID = &gid
+ result.GrantedGroupName = group.Name
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
+ }
+
+ result.APIKey = apiKey
+ return result, nil
+ }
+ }
+
+ if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
+ return nil, fmt.Errorf("update api key: %w", err)
+ }
+ if s.authCacheInvalidator != nil {
+ s.authCacheInvalidator.InvalidateAuthCacheByKey(ctx, apiKey.Key)
+ }
+
+ result.APIKey = apiKey
+ return result, nil
+}
+
// Account management implementations
func (s *adminServiceImpl) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64) ([]Account, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
diff --git a/backend/internal/service/redeem_service.go b/backend/internal/service/redeem_service.go
index 49792438d..225477d52 100644
--- a/backend/internal/service/redeem_service.go
+++ b/backend/internal/service/redeem_service.go
@@ -175,6 +175,33 @@ func (s *RedeemService) GenerateCodes(ctx context.Context, req GenerateCodesRequ
return codes, nil
}
+// CreateCode creates a redeem code with caller-provided code value.
+// It is primarily used by admin integrations that require an external order ID
+// to be mapped to a deterministic redeem code.
+func (s *RedeemService) CreateCode(ctx context.Context, code *RedeemCode) error {
+ if code == nil {
+ return errors.New("redeem code is required")
+ }
+ code.Code = strings.TrimSpace(code.Code)
+ if code.Code == "" {
+ return errors.New("code is required")
+ }
+ if code.Type == "" {
+ code.Type = RedeemTypeBalance
+ }
+ if code.Type != RedeemTypeInvitation && code.Value <= 0 {
+ return errors.New("value must be greater than 0")
+ }
+ if code.Status == "" {
+ code.Status = StatusUnused
+ }
+
+ if err := s.redeemRepo.Create(ctx, code); err != nil {
+ return fmt.Errorf("create redeem code: %w", err)
+ }
+ return nil
+}
+
// checkRedeemRateLimit 检查用户兑换错误次数是否超限
func (s *RedeemService) checkRedeemRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go
index 510e734e6..b5553935e 100644
--- a/backend/internal/service/user_service.go
+++ b/backend/internal/service/user_service.go
@@ -40,6 +40,8 @@ type UserRepository interface {
UpdateConcurrency(ctx context.Context, id int64, amount int) error
ExistsByEmail(ctx context.Context, email string) (bool, error)
RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error)
+ // AddGroupToAllowedGroups 将指定分组增量添加到用户的 allowed_groups(幂等,冲突忽略)
+ AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error
// TOTP 双因素认证
UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error
From 10fe2375e04d39251f92b757949f96cee3c1982a Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Mon, 2 Mar 2026 20:31:34 +0800
Subject: [PATCH 08/13] chore(openai-gateway): tag passthrough hit log as debug
---
backend/internal/service/openai_gateway_service.go | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go
index 11d363510..c32e1581a 100644
--- a/backend/internal/service/openai_gateway_service.go
+++ b/backend/internal/service/openai_gateway_service.go
@@ -2223,7 +2223,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
}
logger.LegacyPrintf("service.openai_gateway",
- "[OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
+ "[DEBUG] [OpenAI 自动透传] 命中自动透传分支: account=%d name=%s type=%s model=%s stream=%v",
account.ID,
account.Name,
account.Type,
From c876a3ed897d0e619304b46d02a9e96a7e529318 Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Mon, 2 Mar 2026 20:44:49 +0800
Subject: [PATCH 09/13] fix(test): complete UserRepository stubs for
AddGroupToAllowedGroups
---
backend/internal/handler/sora_client_handler_test.go | 3 +++
backend/internal/server/api_contract_test.go | 4 ++++
backend/internal/server/middleware/admin_auth_test.go | 4 ++++
backend/internal/service/admin_service_delete_test.go | 4 ++++
backend/internal/service/sora_generation_service_test.go | 3 +++
backend/internal/service/user_service_test.go | 1 +
6 files changed, 19 insertions(+)
diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go
index 523b016c7..dbb3667a7 100644
--- a/backend/internal/handler/sora_client_handler_test.go
+++ b/backend/internal/handler/sora_client_handler_test.go
@@ -942,6 +942,9 @@ func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, e
func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
+func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go
index d2cd2a5ee..a9a9bbdd1 100644
--- a/backend/internal/server/api_contract_test.go
+++ b/backend/internal/server/api_contract_test.go
@@ -779,6 +779,10 @@ func (r *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
return 0, errors.New("not implemented")
}
+func (r *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
+ return errors.New("not implemented")
+}
+
func (r *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
return errors.New("not implemented")
}
diff --git a/backend/internal/server/middleware/admin_auth_test.go b/backend/internal/server/middleware/admin_auth_test.go
index 7b6d4ce8d..7640ab2ae 100644
--- a/backend/internal/server/middleware/admin_auth_test.go
+++ b/backend/internal/server/middleware/admin_auth_test.go
@@ -181,6 +181,10 @@ func (s *stubUserRepo) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
+func (s *stubUserRepo) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
+ panic("unexpected AddGroupToAllowedGroups call")
+}
+
func (s *stubUserRepo) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/service/admin_service_delete_test.go b/backend/internal/service/admin_service_delete_test.go
index 60fa3d774..bb906df53 100644
--- a/backend/internal/service/admin_service_delete_test.go
+++ b/backend/internal/service/admin_service_delete_test.go
@@ -93,6 +93,10 @@ func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID
panic("unexpected RemoveGroupFromAllowedGroups call")
}
+func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
+ panic("unexpected AddGroupToAllowedGroups call")
+}
+
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
diff --git a/backend/internal/service/sora_generation_service_test.go b/backend/internal/service/sora_generation_service_test.go
index 820945f02..a5c0c890d 100644
--- a/backend/internal/service/sora_generation_service_test.go
+++ b/backend/internal/service/sora_generation_service_test.go
@@ -162,6 +162,9 @@ func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, err
func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
+func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go
index 7c3c984f9..57eb3274a 100644
--- a/backend/internal/service/user_service_test.go
+++ b/backend/internal/service/user_service_test.go
@@ -45,6 +45,7 @@ func (m *mockUserRepo) ExistsByEmail(context.Context, string) (bool, error) { re
func (m *mockUserRepo) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
return 0, nil
}
+func (m *mockUserRepo) AddGroupToAllowedGroups(context.Context, int64, int64) error { return nil }
func (m *mockUserRepo) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (m *mockUserRepo) EnableTotp(context.Context, int64) error { return nil }
func (m *mockUserRepo) DisableTotp(context.Context, int64) error { return nil }
From 0ee164a6b8c1bf0d5dd425edf44eb32f56b10dab Mon Sep 17 00:00:00 2001
From: QTom
Date: Sat, 28 Feb 2026 20:55:31 +0800
Subject: [PATCH 10/13] =?UTF-8?q?test(sora):=20=E8=A1=A5=E5=85=85=E6=B5=8B?=
=?UTF-8?q?=E8=AF=95=20stub=20=E4=B8=AD=E7=BC=BA=E5=A4=B1=E7=9A=84=20AddGr?=
=?UTF-8?q?oupToAllowedGroups=20=E6=96=B9=E6=B3=95?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
feat/admin-apikey-group-update 分支给 UserRepository 接口新增了
AddGroupToAllowedGroups 方法,需要在测试 stub 中补充实现以通过编译。
- sora_client_handler_test.go: stubUserRepoForHandler
- sora_generation_service_test.go: stubUserRepoForQuota
(cherry picked from commit ddd6331c2f4285a69aaf6da8d7ebdfe82365ba3b)
---
backend/internal/handler/sora_client_handler_test.go | 3 +++
backend/internal/service/sora_generation_service_test.go | 3 +++
2 files changed, 6 insertions(+)
diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go
index dbb3667a7..4ebbf627e 100644
--- a/backend/internal/handler/sora_client_handler_test.go
+++ b/backend/internal/handler/sora_client_handler_test.go
@@ -948,6 +948,9 @@ func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64,
func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
+func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
// ==================== NewSoraClientHandler ====================
diff --git a/backend/internal/service/sora_generation_service_test.go b/backend/internal/service/sora_generation_service_test.go
index a5c0c890d..3b4b49e37 100644
--- a/backend/internal/service/sora_generation_service_test.go
+++ b/backend/internal/service/sora_generation_service_test.go
@@ -168,6 +168,9 @@ func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, i
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
+func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
+ return nil
+}
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
From 229287983dca53446b9966a939078b53c8a3d76f Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Mon, 2 Mar 2026 20:45:45 +0800
Subject: [PATCH 11/13] fix(test): remove duplicated AddGroupToAllowedGroups
stubs
---
backend/internal/handler/sora_client_handler_test.go | 3 ---
backend/internal/service/sora_generation_service_test.go | 3 ---
2 files changed, 6 deletions(-)
diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go
index 4ebbf627e..dbb3667a7 100644
--- a/backend/internal/handler/sora_client_handler_test.go
+++ b/backend/internal/handler/sora_client_handler_test.go
@@ -948,9 +948,6 @@ func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64,
func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil }
-func (r *stubUserRepoForHandler) AddGroupToAllowedGroups(context.Context, int64, int64) error {
- return nil
-}
// ==================== NewSoraClientHandler ====================
diff --git a/backend/internal/service/sora_generation_service_test.go b/backend/internal/service/sora_generation_service_test.go
index 3b4b49e37..a5c0c890d 100644
--- a/backend/internal/service/sora_generation_service_test.go
+++ b/backend/internal/service/sora_generation_service_test.go
@@ -168,9 +168,6 @@ func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, i
func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil }
func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil }
func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil }
-func (r *stubUserRepoForQuota) AddGroupToAllowedGroups(context.Context, int64, int64) error {
- return nil
-}
// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ====================
From 41ea2a1463ed6dc7f452fd303631a7d51fe4894e Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Tue, 3 Mar 2026 09:05:22 +0800
Subject: [PATCH 12/13] fix(pr716): restore files removed from main baseline
sync
---
backend/internal/repository/rpm_cache.go | 141 ++++++
backend/internal/service/account_rpm_test.go | 120 +++++
.../service/admin_service_apikey_test.go | 420 ++++++++++++++++++
backend/internal/service/rpm_cache.go | 17 +
.../service/setting_service_update_test.go | 182 ++++++++
frontend/src/api/admin/apiKeys.ts | 33 ++
.../charts/GroupDistributionChart.vue | 152 +++++++
7 files changed, 1065 insertions(+)
create mode 100644 backend/internal/repository/rpm_cache.go
create mode 100644 backend/internal/service/account_rpm_test.go
create mode 100644 backend/internal/service/admin_service_apikey_test.go
create mode 100644 backend/internal/service/rpm_cache.go
create mode 100644 backend/internal/service/setting_service_update_test.go
create mode 100644 frontend/src/api/admin/apiKeys.ts
create mode 100644 frontend/src/components/charts/GroupDistributionChart.vue
diff --git a/backend/internal/repository/rpm_cache.go b/backend/internal/repository/rpm_cache.go
new file mode 100644
index 000000000..4d73ec4b8
--- /dev/null
+++ b/backend/internal/repository/rpm_cache.go
@@ -0,0 +1,141 @@
+package repository
+
+import (
+ "context"
+ "errors"
+ "fmt"
+ "strconv"
+ "time"
+
+ "github.com/Wei-Shaw/sub2api/internal/service"
+ "github.com/redis/go-redis/v9"
+)
+
+// RPM 计数器缓存常量定义
+//
+// 设计说明:
+// 使用 Redis 简单计数器跟踪每个账号每分钟的请求数:
+// - Key: rpm:{accountID}:{minuteTimestamp}
+// - Value: 当前分钟内的请求计数
+// - TTL: 120 秒(覆盖当前分钟 + 一定冗余)
+//
+// 使用 TxPipeline(MULTI/EXEC)执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster。
+// 通过 rdb.Time() 获取服务端时间,避免多实例时钟不同步。
+//
+// 设计决策:
+// - TxPipeline vs Pipeline:Pipeline 仅合并发送但不保证原子,TxPipeline 使用 MULTI/EXEC 事务保证原子执行。
+// - rdb.Time() 单独调用:Pipeline/TxPipeline 中无法引用前一命令的结果,因此 TIME 必须单独调用(2 RTT)。
+// Lua 脚本可以做到 1 RTT,但在 Redis Cluster 中动态拼接 key 存在 CROSSSLOT 风险,选择安全性优先。
+const (
+ // RPM 计数器键前缀
+ // 格式: rpm:{accountID}:{minuteTimestamp}
+ rpmKeyPrefix = "rpm:"
+
+ // RPM 计数器 TTL(120 秒,覆盖当前分钟窗口 + 冗余)
+ rpmKeyTTL = 120 * time.Second
+)
+
+// RPMCacheImpl RPM 计数器缓存 Redis 实现
+type RPMCacheImpl struct {
+ rdb *redis.Client
+}
+
+// NewRPMCache 创建 RPM 计数器缓存
+func NewRPMCache(rdb *redis.Client) service.RPMCache {
+ return &RPMCacheImpl{rdb: rdb}
+}
+
+// currentMinuteKey 获取当前分钟的完整 Redis key
+// 使用 rdb.Time() 获取 Redis 服务端时间,避免多实例时钟偏差
+func (c *RPMCacheImpl) currentMinuteKey(ctx context.Context, accountID int64) (string, error) {
+ serverTime, err := c.rdb.Time(ctx).Result()
+ if err != nil {
+ return "", fmt.Errorf("redis TIME: %w", err)
+ }
+ minuteTS := serverTime.Unix() / 60
+ return fmt.Sprintf("%s%d:%d", rpmKeyPrefix, accountID, minuteTS), nil
+}
+
+// currentMinuteSuffix 获取当前分钟时间戳后缀(供批量操作使用)
+// 使用 rdb.Time() 获取 Redis 服务端时间
+func (c *RPMCacheImpl) currentMinuteSuffix(ctx context.Context) (string, error) {
+ serverTime, err := c.rdb.Time(ctx).Result()
+ if err != nil {
+ return "", fmt.Errorf("redis TIME: %w", err)
+ }
+ minuteTS := serverTime.Unix() / 60
+ return strconv.FormatInt(minuteTS, 10), nil
+}
+
+// IncrementRPM 原子递增并返回当前分钟的计数
+// 使用 TxPipeline (MULTI/EXEC) 执行 INCR + EXPIRE,保证原子性且兼容 Redis Cluster
+func (c *RPMCacheImpl) IncrementRPM(ctx context.Context, accountID int64) (int, error) {
+ key, err := c.currentMinuteKey(ctx, accountID)
+ if err != nil {
+ return 0, fmt.Errorf("rpm increment: %w", err)
+ }
+
+ // 使用 TxPipeline (MULTI/EXEC) 保证 INCR + EXPIRE 原子执行
+ // EXPIRE 幂等,每次都设置不影响正确性
+ pipe := c.rdb.TxPipeline()
+ incrCmd := pipe.Incr(ctx, key)
+ pipe.Expire(ctx, key, rpmKeyTTL)
+
+ if _, err := pipe.Exec(ctx); err != nil {
+ return 0, fmt.Errorf("rpm increment: %w", err)
+ }
+
+ return int(incrCmd.Val()), nil
+}
+
+// GetRPM 获取当前分钟的 RPM 计数
+func (c *RPMCacheImpl) GetRPM(ctx context.Context, accountID int64) (int, error) {
+ key, err := c.currentMinuteKey(ctx, accountID)
+ if err != nil {
+ return 0, fmt.Errorf("rpm get: %w", err)
+ }
+
+ val, err := c.rdb.Get(ctx, key).Int()
+ if errors.Is(err, redis.Nil) {
+ return 0, nil // 当前分钟无记录
+ }
+ if err != nil {
+ return 0, fmt.Errorf("rpm get: %w", err)
+ }
+ return val, nil
+}
+
+// GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline)
+func (c *RPMCacheImpl) GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
+ if len(accountIDs) == 0 {
+ return map[int64]int{}, nil
+ }
+
+ // 获取当前分钟后缀
+ minuteSuffix, err := c.currentMinuteSuffix(ctx)
+ if err != nil {
+ return nil, fmt.Errorf("rpm batch get: %w", err)
+ }
+
+ // 使用 Pipeline 批量 GET
+ pipe := c.rdb.Pipeline()
+ cmds := make(map[int64]*redis.StringCmd, len(accountIDs))
+ for _, id := range accountIDs {
+ key := fmt.Sprintf("%s%d:%s", rpmKeyPrefix, id, minuteSuffix)
+ cmds[id] = pipe.Get(ctx, key)
+ }
+
+ if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
+ return nil, fmt.Errorf("rpm batch get: %w", err)
+ }
+
+ result := make(map[int64]int, len(accountIDs))
+ for id, cmd := range cmds {
+ if val, err := cmd.Int(); err == nil {
+ result[id] = val
+ } else {
+ result[id] = 0
+ }
+ }
+ return result, nil
+}
diff --git a/backend/internal/service/account_rpm_test.go b/backend/internal/service/account_rpm_test.go
new file mode 100644
index 000000000..9d91f3e0c
--- /dev/null
+++ b/backend/internal/service/account_rpm_test.go
@@ -0,0 +1,120 @@
+package service
+
+import (
+ "encoding/json"
+ "testing"
+)
+
+func TestGetBaseRPM(t *testing.T) {
+ tests := []struct {
+ name string
+ extra map[string]any
+ expected int
+ }{
+ {"nil extra", nil, 0},
+ {"no key", map[string]any{}, 0},
+ {"zero", map[string]any{"base_rpm": 0}, 0},
+ {"int value", map[string]any{"base_rpm": 15}, 15},
+ {"float value", map[string]any{"base_rpm": 15.0}, 15},
+ {"string value", map[string]any{"base_rpm": "15"}, 15},
+ {"negative value", map[string]any{"base_rpm": -5}, 0},
+ {"int64 value", map[string]any{"base_rpm": int64(20)}, 20},
+ {"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &Account{Extra: tt.extra}
+ if got := a.GetBaseRPM(); got != tt.expected {
+ t.Errorf("GetBaseRPM() = %d, want %d", got, tt.expected)
+ }
+ })
+ }
+}
+
+func TestGetRPMStrategy(t *testing.T) {
+ tests := []struct {
+ name string
+ extra map[string]any
+ expected string
+ }{
+ {"nil extra", nil, "tiered"},
+ {"no key", map[string]any{}, "tiered"},
+ {"tiered", map[string]any{"rpm_strategy": "tiered"}, "tiered"},
+ {"sticky_exempt", map[string]any{"rpm_strategy": "sticky_exempt"}, "sticky_exempt"},
+ {"invalid", map[string]any{"rpm_strategy": "foobar"}, "tiered"},
+ {"empty string fallback", map[string]any{"rpm_strategy": ""}, "tiered"},
+ {"numeric value fallback", map[string]any{"rpm_strategy": 123}, "tiered"},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &Account{Extra: tt.extra}
+ if got := a.GetRPMStrategy(); got != tt.expected {
+ t.Errorf("GetRPMStrategy() = %q, want %q", got, tt.expected)
+ }
+ })
+ }
+}
+
+func TestCheckRPMSchedulability(t *testing.T) {
+ tests := []struct {
+ name string
+ extra map[string]any
+ currentRPM int
+ expected WindowCostSchedulability
+ }{
+ {"disabled", map[string]any{}, 100, WindowCostSchedulable},
+ {"green zone", map[string]any{"base_rpm": 15}, 10, WindowCostSchedulable},
+ {"yellow zone tiered", map[string]any{"base_rpm": 15}, 15, WindowCostStickyOnly},
+ {"red zone tiered", map[string]any{"base_rpm": 15}, 18, WindowCostNotSchedulable},
+ {"sticky_exempt at limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 15, WindowCostStickyOnly},
+ {"sticky_exempt over limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 100, WindowCostStickyOnly},
+ {"custom buffer", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 14, WindowCostStickyOnly},
+ {"custom buffer red", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 15, WindowCostNotSchedulable},
+ {"base_rpm=1 green", map[string]any{"base_rpm": 1}, 0, WindowCostSchedulable},
+ {"base_rpm=1 yellow (at limit)", map[string]any{"base_rpm": 1}, 1, WindowCostStickyOnly},
+ {"base_rpm=1 red (at limit+buffer)", map[string]any{"base_rpm": 1}, 2, WindowCostNotSchedulable},
+ {"negative currentRPM", map[string]any{"base_rpm": 15}, -1, WindowCostSchedulable},
+ {"base_rpm negative disabled", map[string]any{"base_rpm": -5}, 10, WindowCostSchedulable},
+ {"very high currentRPM", map[string]any{"base_rpm": 10}, 9999, WindowCostNotSchedulable},
+ {"sticky_exempt very high currentRPM", map[string]any{"base_rpm": 10, "rpm_strategy": "sticky_exempt"}, 9999, WindowCostStickyOnly},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &Account{Extra: tt.extra}
+ if got := a.CheckRPMSchedulability(tt.currentRPM); got != tt.expected {
+ t.Errorf("CheckRPMSchedulability(%d) = %d, want %d", tt.currentRPM, got, tt.expected)
+ }
+ })
+ }
+}
+
+func TestGetRPMStickyBuffer(t *testing.T) {
+ tests := []struct {
+ name string
+ extra map[string]any
+ expected int
+ }{
+ {"nil extra", nil, 0},
+ {"no keys", map[string]any{}, 0},
+ {"base_rpm=0", map[string]any{"base_rpm": 0}, 0},
+ {"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1},
+ {"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1},
+ {"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1},
+ {"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2},
+ {"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3},
+ {"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20},
+ {"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5},
+ {"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2},
+ {"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2},
+ {"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7},
+ {"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2},
+ }
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ a := &Account{Extra: tt.extra}
+ if got := a.GetRPMStickyBuffer(); got != tt.expected {
+ t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected)
+ }
+ })
+ }
+}
diff --git a/backend/internal/service/admin_service_apikey_test.go b/backend/internal/service/admin_service_apikey_test.go
new file mode 100644
index 000000000..9210a7862
--- /dev/null
+++ b/backend/internal/service/admin_service_apikey_test.go
@@ -0,0 +1,420 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "errors"
+ "testing"
+ "time"
+
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
+ "github.com/stretchr/testify/require"
+)
+
+// ---------------------------------------------------------------------------
+// Stubs
+// ---------------------------------------------------------------------------
+
+// userRepoStubForGroupUpdate implements UserRepository for AdminUpdateAPIKeyGroupID tests.
+type userRepoStubForGroupUpdate struct {
+ addGroupErr error
+ addGroupCalled bool
+ addedUserID int64
+ addedGroupID int64
+}
+
+func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, userID int64, groupID int64) error {
+ s.addGroupCalled = true
+ s.addedUserID = userID
+ s.addedGroupID = groupID
+ return s.addGroupErr
+}
+
+func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
+func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
+
+// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
+type apiKeyRepoStubForGroupUpdate struct {
+ key *APIKey
+ getErr error
+ updateErr error
+ updated *APIKey // captures what was passed to Update
+}
+
+func (s *apiKeyRepoStubForGroupUpdate) GetByID(_ context.Context, _ int64) (*APIKey, error) {
+ if s.getErr != nil {
+ return nil, s.getErr
+ }
+ clone := *s.key
+ return &clone, nil
+}
+func (s *apiKeyRepoStubForGroupUpdate) Update(_ context.Context, key *APIKey) error {
+ if s.updateErr != nil {
+ return s.updateErr
+ }
+ clone := *key
+ s.updated = &clone
+ return nil
+}
+
+// Unused methods – panic on unexpected call.
+func (s *apiKeyRepoStubForGroupUpdate) Create(context.Context, *APIKey) error { panic("unexpected") }
+func (s *apiKeyRepoStubForGroupUpdate) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) GetByKey(context.Context, string) (*APIKey, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
+func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) CountByUserID(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) ExistsByKey(context.Context, string) (bool, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) ListKeysByUserID(context.Context, int64) ([]string, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) ListKeysByGroupID(context.Context, int64) ([]string, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
+ panic("unexpected")
+}
+func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error {
+ panic("unexpected")
+}
+
+// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
+type groupRepoStubForGroupUpdate struct {
+ group *Group
+ getErr error
+ lastGetByIDArg int64
+}
+
+func (s *groupRepoStubForGroupUpdate) GetByID(_ context.Context, id int64) (*Group, error) {
+ s.lastGetByIDArg = id
+ if s.getErr != nil {
+ return nil, s.getErr
+ }
+ clone := *s.group
+ return &clone, nil
+}
+
+// Unused methods – panic on unexpected call.
+func (s *groupRepoStubForGroupUpdate) Create(context.Context, *Group) error { panic("unexpected") }
+func (s *groupRepoStubForGroupUpdate) GetByIDLite(context.Context, int64) (*Group, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) Update(context.Context, *Group) error { panic("unexpected") }
+func (s *groupRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
+func (s *groupRepoStubForGroupUpdate) DeleteCascade(context.Context, int64) ([]int64, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) ListActive(context.Context) ([]Group, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, string) ([]Group, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) BindAccountsToGroup(context.Context, int64, []int64) error {
+ panic("unexpected")
+}
+func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
+ panic("unexpected")
+}
+
+// ---------------------------------------------------------------------------
+// Tests
+// ---------------------------------------------------------------------------
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_KeyNotFound(t *testing.T) {
+ repo := &apiKeyRepoStubForGroupUpdate{getErr: ErrAPIKeyNotFound}
+ svc := &adminServiceImpl{apiKeyRepo: repo}
+
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 999, int64Ptr(1))
+ require.ErrorIs(t, err, ErrAPIKeyNotFound)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_NilGroupID_NoOp(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5)}
+ repo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ svc := &adminServiceImpl{apiKeyRepo: repo}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, nil)
+ require.NoError(t, err)
+ require.Equal(t, int64(1), got.APIKey.ID)
+ // Update should NOT have been called (updated stays nil)
+ require.Nil(t, repo.updated)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5), Group: &Group{ID: 5, Name: "Old"}}
+ repo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: repo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
+ require.NoError(t, err)
+ require.Nil(t, got.APIKey.GroupID, "group_id should be nil after unbind")
+ require.Nil(t, got.APIKey.Group, "group object should be nil after unbind")
+ require.NotNil(t, repo.updated, "Update should have been called")
+ require.Nil(t, repo.updated.GroupID)
+ require.Equal(t, []string{"sk-test"}, cache.keys, "cache should be invalidated")
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_BindActiveGroup(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ require.Equal(t, int64(10), *got.APIKey.GroupID)
+ require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
+ require.Equal(t, []string{"sk-test"}, cache.keys)
+ // M3: verify correct group ID was passed to repo
+ require.Equal(t, int64(10), groupRepo.lastGetByIDArg)
+ // C1 fix: verify Group object is populated
+ require.NotNil(t, got.APIKey.Group)
+ require.Equal(t, "Pro", got.APIKey.Group.Name)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_SameGroup_Idempotent(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Pro"}}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ require.Equal(t, int64(10), *got.APIKey.GroupID)
+ // Update is still called (current impl doesn't short-circuit on same group)
+ require.NotNil(t, apiKeyRepo.updated)
+ require.Equal(t, []string{"sk-test"}, cache.keys)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotFound(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test"}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{getErr: ErrGroupNotFound}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
+
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(99))
+ require.ErrorIs(t, err, ErrGroupNotFound)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotActive(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test"}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 5, Status: StatusDisabled}}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
+
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(5))
+ require.Error(t, err)
+ require.Equal(t, "GROUP_NOT_ACTIVE", infraerrors.Reason(err))
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_UpdateFails(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(3)}
+ repo := &apiKeyRepoStubForGroupUpdate{key: existing, updateErr: errors.New("db write error")}
+ svc := &adminServiceImpl{apiKeyRepo: repo}
+
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "update api key")
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_NegativeGroupID(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test"}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo}
+
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(-5))
+ require.Error(t, err)
+ require.Equal(t, "INVALID_GROUP_ID", infraerrors.Reason(err))
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_PointerIsolation(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
+
+ inputGID := int64(10)
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, &inputGID)
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ // Mutating the input pointer must NOT affect the stored value
+ inputGID = 999
+ require.Equal(t, int64(10), *got.APIKey.GroupID)
+ require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_NilCacheInvalidator(t *testing.T) {
+ existing := &APIKey{ID: 1, Key: "sk-test"}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 7, Status: StatusActive}}
+ // authCacheInvalidator is nil – should not panic
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(7))
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ require.Equal(t, int64(7), *got.APIKey.GroupID)
+}
+
+// ---------------------------------------------------------------------------
+// Tests: AllowedGroup auto-sync
+// ---------------------------------------------------------------------------
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AddsAllowedGroup(t *testing.T) {
+ existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
+ userRepo := &userRepoStubForGroupUpdate{}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ require.Equal(t, int64(10), *got.APIKey.GroupID)
+ // 验证 AddGroupToAllowedGroups 被调用,且参数正确
+ require.True(t, userRepo.addGroupCalled)
+ require.Equal(t, int64(42), userRepo.addedUserID)
+ require.Equal(t, int64(10), userRepo.addedGroupID)
+ // 验证 result 标记了自动授权
+ require.True(t, got.AutoGrantedGroupAccess)
+ require.NotNil(t, got.GrantedGroupID)
+ require.Equal(t, int64(10), *got.GrantedGroupID)
+ require.Equal(t, "Exclusive", got.GrantedGroupName)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupUpdate(t *testing.T) {
+ existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Public", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeStandard}}
+ userRepo := &userRepoStubForGroupUpdate{}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.NoError(t, err)
+ require.NotNil(t, got.APIKey.GroupID)
+ // 非专属分组不触发 AddGroupToAllowedGroups
+ require.False(t, userRepo.addGroupCalled)
+ require.False(t, got.AutoGrantedGroupAccess)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
+ existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
+ userRepo := &userRepoStubForGroupUpdate{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
+
+ // 订阅类型分组应被阻止绑定
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.Error(t, err)
+ require.Equal(t, "SUBSCRIPTION_GROUP_NOT_ALLOWED", infraerrors.Reason(err))
+ require.False(t, userRepo.addGroupCalled)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AllowedGroupAddFails_ReturnsError(t *testing.T) {
+ existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
+ userRepo := &userRepoStubForGroupUpdate{addGroupErr: errors.New("db error")}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
+
+ // 严格模式:AddGroupToAllowedGroups 失败时,整体操作报错
+ _, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
+ require.Error(t, err)
+ require.Contains(t, err.Error(), "add group to user allowed groups")
+ require.True(t, userRepo.addGroupCalled)
+ // apiKey 不应被更新
+ require.Nil(t, apiKeyRepo.updated)
+}
+
+func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind_NoAllowedGroupUpdate(t *testing.T) {
+ existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Exclusive"}}
+ apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
+ userRepo := &userRepoStubForGroupUpdate{}
+ cache := &authCacheInvalidatorStub{}
+ svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, userRepo: userRepo, authCacheInvalidator: cache}
+
+ got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
+ require.NoError(t, err)
+ require.Nil(t, got.APIKey.GroupID)
+ // 解绑时不修改 allowed_groups
+ require.False(t, userRepo.addGroupCalled)
+ require.False(t, got.AutoGrantedGroupAccess)
+}
diff --git a/backend/internal/service/rpm_cache.go b/backend/internal/service/rpm_cache.go
new file mode 100644
index 000000000..070362190
--- /dev/null
+++ b/backend/internal/service/rpm_cache.go
@@ -0,0 +1,17 @@
+package service
+
+import "context"
+
+// RPMCache RPM 计数器缓存接口
+// 用于 Anthropic OAuth/SetupToken 账号的每分钟请求数限制
+type RPMCache interface {
+ // IncrementRPM 原子递增并返回当前分钟的计数
+ // 使用 Redis 服务器时间确定 minute key,避免多实例时钟偏差
+ IncrementRPM(ctx context.Context, accountID int64) (count int, err error)
+
+ // GetRPM 获取当前分钟的 RPM 计数
+ GetRPM(ctx context.Context, accountID int64) (count int, err error)
+
+ // GetRPMBatch 批量获取多个账号的 RPM 计数(使用 Pipeline)
+ GetRPMBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
+}
diff --git a/backend/internal/service/setting_service_update_test.go b/backend/internal/service/setting_service_update_test.go
new file mode 100644
index 000000000..ec64511f2
--- /dev/null
+++ b/backend/internal/service/setting_service_update_test.go
@@ -0,0 +1,182 @@
+//go:build unit
+
+package service
+
+import (
+ "context"
+ "encoding/json"
+ "testing"
+
+ "github.com/Wei-Shaw/sub2api/internal/config"
+ infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
+ "github.com/stretchr/testify/require"
+)
+
+type settingUpdateRepoStub struct {
+ updates map[string]string
+}
+
+func (s *settingUpdateRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
+ panic("unexpected Get call")
+}
+
+func (s *settingUpdateRepoStub) GetValue(ctx context.Context, key string) (string, error) {
+ panic("unexpected GetValue call")
+}
+
+func (s *settingUpdateRepoStub) Set(ctx context.Context, key, value string) error {
+ panic("unexpected Set call")
+}
+
+func (s *settingUpdateRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
+ panic("unexpected GetMultiple call")
+}
+
+func (s *settingUpdateRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
+ s.updates = make(map[string]string, len(settings))
+ for k, v := range settings {
+ s.updates[k] = v
+ }
+ return nil
+}
+
+func (s *settingUpdateRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
+ panic("unexpected GetAll call")
+}
+
+func (s *settingUpdateRepoStub) Delete(ctx context.Context, key string) error {
+ panic("unexpected Delete call")
+}
+
+type defaultSubGroupReaderStub struct {
+ byID map[int64]*Group
+ errBy map[int64]error
+ calls []int64
+}
+
+func (s *defaultSubGroupReaderStub) GetByID(ctx context.Context, id int64) (*Group, error) {
+ s.calls = append(s.calls, id)
+ if err, ok := s.errBy[id]; ok {
+ return nil, err
+ }
+ if g, ok := s.byID[id]; ok {
+ return g, nil
+ }
+ return nil, ErrGroupNotFound
+}
+
+func TestSettingService_UpdateSettings_DefaultSubscriptions_ValidGroup(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ groupReader := &defaultSubGroupReaderStub{
+ byID: map[int64]*Group{
+ 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription},
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+ svc.SetDefaultSubscriptionGroupReader(groupReader)
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ DefaultSubscriptions: []DefaultSubscriptionSetting{
+ {GroupID: 11, ValidityDays: 30},
+ },
+ })
+ require.NoError(t, err)
+ require.Equal(t, []int64{11}, groupReader.calls)
+
+ raw, ok := repo.updates[SettingKeyDefaultSubscriptions]
+ require.True(t, ok)
+
+ var got []DefaultSubscriptionSetting
+ require.NoError(t, json.Unmarshal([]byte(raw), &got))
+ require.Equal(t, []DefaultSubscriptionSetting{
+ {GroupID: 11, ValidityDays: 30},
+ }, got)
+}
+
+func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNonSubscriptionGroup(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ groupReader := &defaultSubGroupReaderStub{
+ byID: map[int64]*Group{
+ 12: {ID: 12, SubscriptionType: SubscriptionTypeStandard},
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+ svc.SetDefaultSubscriptionGroupReader(groupReader)
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ DefaultSubscriptions: []DefaultSubscriptionSetting{
+ {GroupID: 12, ValidityDays: 7},
+ },
+ })
+ require.Error(t, err)
+ require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err))
+ require.Nil(t, repo.updates)
+}
+
+func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsNotFoundGroup(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ groupReader := &defaultSubGroupReaderStub{
+ errBy: map[int64]error{
+ 13: ErrGroupNotFound,
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+ svc.SetDefaultSubscriptionGroupReader(groupReader)
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ DefaultSubscriptions: []DefaultSubscriptionSetting{
+ {GroupID: 13, ValidityDays: 7},
+ },
+ })
+ require.Error(t, err)
+ require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_INVALID", infraerrors.Reason(err))
+ require.Equal(t, "13", infraerrors.FromError(err).Metadata["group_id"])
+ require.Nil(t, repo.updates)
+}
+
+func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroup(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ groupReader := &defaultSubGroupReaderStub{
+ byID: map[int64]*Group{
+ 11: {ID: 11, SubscriptionType: SubscriptionTypeSubscription},
+ },
+ }
+ svc := NewSettingService(repo, &config.Config{})
+ svc.SetDefaultSubscriptionGroupReader(groupReader)
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ DefaultSubscriptions: []DefaultSubscriptionSetting{
+ {GroupID: 11, ValidityDays: 30},
+ {GroupID: 11, ValidityDays: 60},
+ },
+ })
+ require.Error(t, err)
+ require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err))
+ require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"])
+ require.Nil(t, repo.updates)
+}
+
+func TestSettingService_UpdateSettings_DefaultSubscriptions_RejectsDuplicateGroupWithoutGroupReader(t *testing.T) {
+ repo := &settingUpdateRepoStub{}
+ svc := NewSettingService(repo, &config.Config{})
+
+ err := svc.UpdateSettings(context.Background(), &SystemSettings{
+ DefaultSubscriptions: []DefaultSubscriptionSetting{
+ {GroupID: 11, ValidityDays: 30},
+ {GroupID: 11, ValidityDays: 60},
+ },
+ })
+ require.Error(t, err)
+ require.Equal(t, "DEFAULT_SUBSCRIPTION_GROUP_DUPLICATE", infraerrors.Reason(err))
+ require.Equal(t, "11", infraerrors.FromError(err).Metadata["group_id"])
+ require.Nil(t, repo.updates)
+}
+
+func TestParseDefaultSubscriptions_NormalizesValues(t *testing.T) {
+ got := parseDefaultSubscriptions(`[{"group_id":11,"validity_days":30},{"group_id":11,"validity_days":60},{"group_id":0,"validity_days":10},{"group_id":12,"validity_days":99999}]`)
+ require.Equal(t, []DefaultSubscriptionSetting{
+ {GroupID: 11, ValidityDays: 30},
+ {GroupID: 11, ValidityDays: 60},
+ {GroupID: 12, ValidityDays: MaxValidityDays},
+ }, got)
+}
diff --git a/frontend/src/api/admin/apiKeys.ts b/frontend/src/api/admin/apiKeys.ts
new file mode 100644
index 000000000..79f6e1748
--- /dev/null
+++ b/frontend/src/api/admin/apiKeys.ts
@@ -0,0 +1,33 @@
+/**
+ * Admin API Keys API endpoints
+ * Handles API key management for administrators
+ */
+
+import { apiClient } from '../client'
+import type { ApiKey } from '@/types'
+
+export interface UpdateApiKeyGroupResult {
+ api_key: ApiKey
+ auto_granted_group_access: boolean
+ granted_group_id?: number
+ granted_group_name?: string
+}
+
+/**
+ * Update an API key's group binding
+ * @param id - API Key ID
+ * @param groupId - Group ID (0 to unbind, positive to bind, null/undefined to skip)
+ * @returns Updated API key with auto-grant info
+ */
+export async function updateApiKeyGroup(id: number, groupId: number | null): Promise {
+ const { data } = await apiClient.put(`/admin/api-keys/${id}`, {
+ group_id: groupId === null ? 0 : groupId
+ })
+ return data
+}
+
+export const apiKeysAPI = {
+ updateApiKeyGroup
+}
+
+export default apiKeysAPI
diff --git a/frontend/src/components/charts/GroupDistributionChart.vue b/frontend/src/components/charts/GroupDistributionChart.vue
new file mode 100644
index 000000000..d9231a630
--- /dev/null
+++ b/frontend/src/components/charts/GroupDistributionChart.vue
@@ -0,0 +1,152 @@
+
+
+
+ {{ t('admin.dashboard.groupDistribution') }}
+
+
+
+
+
+
+
+
+
+
+
+
+ | {{ t('admin.dashboard.group') }} |
+ {{ t('admin.dashboard.requests') }} |
+ {{ t('admin.dashboard.tokens') }} |
+ {{ t('admin.dashboard.actual') }} |
+ {{ t('admin.dashboard.standard') }} |
+
+
+
+
+ |
+ {{ group.group_name || t('admin.dashboard.noGroup') }}
+ |
+
+ {{ formatNumber(group.requests) }}
+ |
+
+ {{ formatTokens(group.total_tokens) }}
+ |
+
+ ${{ formatCost(group.actual_cost) }}
+ |
+
+ ${{ formatCost(group.cost) }}
+ |
+
+
+
+
+
+
+ {{ t('admin.dashboard.noDataAvailable') }}
+
+
+
+
+
From 337acd85935322f557d055d3dbf3f1ca8968b6df Mon Sep 17 00:00:00 2001
From: yangjianbo
Date: Tue, 3 Mar 2026 09:12:58 +0800
Subject: [PATCH 13/13] fix(service): restore account RPM helpers from main
---
backend/internal/service/account.go | 74 +++++++++++++++++++++++++++++
1 file changed, 74 insertions(+)
diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go
index 90c5026dc..b159a972e 100644
--- a/backend/internal/service/account.go
+++ b/backend/internal/service/account.go
@@ -1140,6 +1140,80 @@ func (a *Account) GetSessionIdleTimeoutMinutes() int {
return 5
}
+// GetBaseRPM 获取基础 RPM 限制
+// 返回 0 表示未启用(负数视为无效配置,按 0 处理)
+func (a *Account) GetBaseRPM() int {
+ if a.Extra == nil {
+ return 0
+ }
+ if v, ok := a.Extra["base_rpm"]; ok {
+ val := parseExtraInt(v)
+ if val > 0 {
+ return val
+ }
+ }
+ return 0
+}
+
+// GetRPMStrategy 获取 RPM 策略
+// "tiered" = 三区模型(默认), "sticky_exempt" = 粘性豁免
+func (a *Account) GetRPMStrategy() string {
+ if a.Extra == nil {
+ return "tiered"
+ }
+ if v, ok := a.Extra["rpm_strategy"]; ok {
+ if s, ok := v.(string); ok && s == "sticky_exempt" {
+ return "sticky_exempt"
+ }
+ }
+ return "tiered"
+}
+
+// GetRPMStickyBuffer 获取 RPM 粘性缓冲数量
+// tiered 模式下的黄区大小,默认为 base_rpm 的 20%(至少 1)
+func (a *Account) GetRPMStickyBuffer() int {
+ if a.Extra == nil {
+ return 0
+ }
+ if v, ok := a.Extra["rpm_sticky_buffer"]; ok {
+ val := parseExtraInt(v)
+ if val > 0 {
+ return val
+ }
+ }
+ base := a.GetBaseRPM()
+ buffer := base / 5
+ if buffer < 1 && base > 0 {
+ buffer = 1
+ }
+ return buffer
+}
+
+// CheckRPMSchedulability 根据当前 RPM 计数检查调度状态
+// 复用 WindowCostSchedulability 三态:Schedulable / StickyOnly / NotSchedulable
+func (a *Account) CheckRPMSchedulability(currentRPM int) WindowCostSchedulability {
+ baseRPM := a.GetBaseRPM()
+ if baseRPM <= 0 {
+ return WindowCostSchedulable
+ }
+
+ if currentRPM < baseRPM {
+ return WindowCostSchedulable
+ }
+
+ strategy := a.GetRPMStrategy()
+ if strategy == "sticky_exempt" {
+ return WindowCostStickyOnly // 粘性豁免无红区
+ }
+
+ // tiered: 黄区 + 红区
+ buffer := a.GetRPMStickyBuffer()
+ if currentRPM < baseRPM+buffer {
+ return WindowCostStickyOnly
+ }
+ return WindowCostNotSchedulable
+}
+
// CheckWindowCostSchedulability 根据当前窗口费用检查调度状态
// - 费用 < 阈值: WindowCostSchedulable(可正常调度)
// - 费用 >= 阈值 且 < 阈值+预留: WindowCostStickyOnly(仅粘性会话)