diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go new file mode 100644 index 000000000..30e2c6b9c --- /dev/null +++ b/backend/internal/handler/openai_chat_completions.go @@ -0,0 +1,530 @@ +package handler + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "encoding/json" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// ChatCompletions handles OpenAI Chat Completions API compatibility. +// POST /v1/chat/completions +func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { + body, err := io.ReadAll(c.Request.Body) + if err != nil { + if maxErr, ok := extractMaxBytesError(err); ok { + h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) + return + } + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body") + return + } + if len(body) == 0 { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty") + return + } + + // Preserve original chat-completions request for upstream passthrough when needed. + c.Set(service.OpenAIChatCompletionsBodyKey, body) + + var chatReq map[string]any + if err := json.Unmarshal(body, &chatReq); err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") + return + } + + includeUsage := false + if streamOptions, ok := chatReq["stream_options"].(map[string]any); ok { + if v, ok := streamOptions["include_usage"].(bool); ok { + includeUsage = v + } + } + c.Set(service.OpenAIChatCompletionsIncludeUsageKey, includeUsage) + + converted, err := service.ConvertChatCompletionsToResponses(chatReq) + if err != nil { + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", err.Error()) + return + } + + convertedBody, err := json.Marshal(converted) + if err != nil { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request") + return + } + + stream, _ := converted["stream"].(bool) + model, _ := converted["model"].(string) + writer := newChatCompletionsResponseWriter(c.Writer, stream, includeUsage, model) + c.Writer = writer + c.Request.Body = io.NopCloser(bytes.NewReader(convertedBody)) + c.Request.ContentLength = int64(len(convertedBody)) + + h.Responses(c) + writer.Finalize() +} + +type chatCompletionsResponseWriter struct { + gin.ResponseWriter + stream bool + includeUsage bool + buffer bytes.Buffer + streamBuf bytes.Buffer + state *chatCompletionStreamState + corrector *service.CodexToolCorrector + finalized bool + passthrough bool +} + +type chatCompletionStreamState struct { + id string + model string + created int64 + sentRole bool + sawToolCall bool + sawText bool + toolCallIndex map[string]int + usage map[string]any +} + +func newChatCompletionsResponseWriter(w gin.ResponseWriter, stream bool, includeUsage bool, model string) *chatCompletionsResponseWriter { + return &chatCompletionsResponseWriter{ + ResponseWriter: w, + stream: stream, + includeUsage: includeUsage, + state: &chatCompletionStreamState{ + model: strings.TrimSpace(model), + toolCallIndex: make(map[string]int), + }, + corrector: service.NewCodexToolCorrector(), + } +} + +func (w *chatCompletionsResponseWriter) Write(data []byte) (int, error) { + if w.passthrough { + return w.ResponseWriter.Write(data) + } + if w.stream { + n, err := w.streamBuf.Write(data) + if err != nil { + return n, err + } + w.flushStreamBuffer() + return n, nil + } + + if w.finalized { + return len(data), nil + } + return w.buffer.Write(data) +} + +func (w *chatCompletionsResponseWriter) WriteString(s string) (int, error) { + return w.Write([]byte(s)) +} + +func (w *chatCompletionsResponseWriter) Finalize() { + if w.finalized { + return + } + w.finalized = true + if w.passthrough { + return + } + if w.stream { + return + } + + body := w.buffer.Bytes() + if len(body) == 0 { + return + } + + w.ResponseWriter.Header().Del("Content-Length") + + converted, err := service.ConvertResponsesToChatCompletion(body) + if err != nil { + _, _ = w.ResponseWriter.Write(body) + return + } + + corrected := converted + if correctedStr, ok := w.corrector.CorrectToolCallsInSSEData(string(converted)); ok { + corrected = []byte(correctedStr) + } + + _, _ = w.ResponseWriter.Write(corrected) +} + +func (w *chatCompletionsResponseWriter) SetPassthrough() { + w.passthrough = true +} + +func (w *chatCompletionsResponseWriter) flushStreamBuffer() { + for { + buf := w.streamBuf.Bytes() + idx := bytes.IndexByte(buf, '\n') + if idx == -1 { + return + } + lineBytes := w.streamBuf.Next(idx + 1) + line := strings.TrimRight(string(lineBytes), "\r\n") + w.handleStreamLine(line) + } +} + +func (w *chatCompletionsResponseWriter) handleStreamLine(line string) { + if line == "" { + return + } + if strings.HasPrefix(line, ":") { + _, _ = w.ResponseWriter.Write([]byte(line + "\n\n")) + return + } + if !strings.HasPrefix(line, "data:") { + return + } + + data := strings.TrimSpace(strings.TrimPrefix(line, "data:")) + for _, chunk := range w.convertResponseDataToChatChunks(data) { + if chunk == "" { + continue + } + if chunk == "[DONE]" { + _, _ = w.ResponseWriter.Write([]byte("data: [DONE]\n\n")) + continue + } + _, _ = w.ResponseWriter.Write([]byte("data: " + chunk + "\n\n")) + } +} + +func (w *chatCompletionsResponseWriter) convertResponseDataToChatChunks(data string) []string { + if data == "" { + return nil + } + if data == "[DONE]" { + return []string{"[DONE]"} + } + + var payload map[string]any + if err := json.Unmarshal([]byte(data), &payload); err != nil { + return []string{data} + } + + if _, ok := payload["error"]; ok { + return []string{data} + } + + eventType := strings.TrimSpace(getString(payload["type"])) + if eventType == "" { + return []string{data} + } + + w.state.applyMetadata(payload) + + switch eventType { + case "response.created": + return nil + case "response.output_text.delta": + delta := getString(payload["delta"]) + if delta == "" { + return nil + } + w.state.sawText = true + return []string{w.buildTextDeltaChunk(delta)} + case "response.output_text.done": + if w.state.sawText { + return nil + } + text := getString(payload["text"]) + if text == "" { + return nil + } + w.state.sawText = true + return []string{w.buildTextDeltaChunk(text)} + case "response.output_item.added", "response.output_item.delta": + if item, ok := payload["item"].(map[string]any); ok { + if callID, name, args, ok := extractToolCallFromItem(item); ok { + w.state.sawToolCall = true + return []string{w.buildToolCallChunk(callID, name, args)} + } + } + case "response.completed", "response.done": + if responseObj, ok := payload["response"].(map[string]any); ok { + w.state.applyResponseUsage(responseObj) + } + return []string{w.buildFinalChunk()} + } + + if strings.Contains(eventType, "tool_call") || strings.Contains(eventType, "function_call") { + callID := strings.TrimSpace(getString(payload["call_id"])) + if callID == "" { + callID = strings.TrimSpace(getString(payload["tool_call_id"])) + } + if callID == "" { + callID = strings.TrimSpace(getString(payload["id"])) + } + args := getString(payload["delta"]) + name := strings.TrimSpace(getString(payload["name"])) + if callID != "" && (args != "" || name != "") { + w.state.sawToolCall = true + return []string{w.buildToolCallChunk(callID, name, args)} + } + } + + return nil +} + +func (w *chatCompletionsResponseWriter) buildTextDeltaChunk(delta string) string { + w.state.ensureDefaults() + payload := map[string]any{ + "content": delta, + } + if !w.state.sentRole { + payload["role"] = "assistant" + w.state.sentRole = true + } + return w.buildChunk(payload, nil, nil) +} + +func (w *chatCompletionsResponseWriter) buildToolCallChunk(callID, name, args string) string { + w.state.ensureDefaults() + index := w.state.toolCallIndexFor(callID) + function := map[string]any{} + if name != "" { + function["name"] = name + } + if args != "" { + function["arguments"] = args + } + toolCall := map[string]any{ + "index": index, + "id": callID, + "type": "function", + "function": function, + } + + delta := map[string]any{ + "tool_calls": []any{toolCall}, + } + if !w.state.sentRole { + delta["role"] = "assistant" + w.state.sentRole = true + } + + return w.buildChunk(delta, nil, nil) +} + +func (w *chatCompletionsResponseWriter) buildFinalChunk() string { + w.state.ensureDefaults() + finishReason := "stop" + if w.state.sawToolCall { + finishReason = "tool_calls" + } + usage := map[string]any(nil) + if w.includeUsage && w.state.usage != nil { + usage = w.state.usage + } + return w.buildChunk(map[string]any{}, finishReason, usage) +} + +func (w *chatCompletionsResponseWriter) buildChunk(delta map[string]any, finishReason any, usage map[string]any) string { + w.state.ensureDefaults() + chunk := map[string]any{ + "id": w.state.id, + "object": "chat.completion.chunk", + "created": w.state.created, + "model": w.state.model, + "choices": []any{ + map[string]any{ + "index": 0, + "delta": delta, + "finish_reason": finishReason, + }, + }, + } + if usage != nil { + chunk["usage"] = usage + } + + data, _ := json.Marshal(chunk) + if corrected, ok := w.corrector.CorrectToolCallsInSSEData(string(data)); ok { + return corrected + } + return string(data) +} + +func (s *chatCompletionStreamState) ensureDefaults() { + if s.id == "" { + s.id = "chatcmpl-" + randomHexUnsafe(12) + } + if s.model == "" { + s.model = "unknown" + } + if s.created == 0 { + s.created = time.Now().Unix() + } +} + +func (s *chatCompletionStreamState) toolCallIndexFor(callID string) int { + if idx, ok := s.toolCallIndex[callID]; ok { + return idx + } + idx := len(s.toolCallIndex) + s.toolCallIndex[callID] = idx + return idx +} + +func (s *chatCompletionStreamState) applyMetadata(payload map[string]any) { + if responseObj, ok := payload["response"].(map[string]any); ok { + s.applyResponseMetadata(responseObj) + } + + if s.id == "" { + if id := strings.TrimSpace(getString(payload["response_id"])); id != "" { + s.id = id + } else if id := strings.TrimSpace(getString(payload["id"])); id != "" { + s.id = id + } + } + if s.model == "" { + if model := strings.TrimSpace(getString(payload["model"])); model != "" { + s.model = model + } + } + if s.created == 0 { + if created := getInt64(payload["created_at"]); created != 0 { + s.created = created + } else if created := getInt64(payload["created"]); created != 0 { + s.created = created + } + } +} + +func (s *chatCompletionStreamState) applyResponseMetadata(responseObj map[string]any) { + if s.id == "" { + if id := strings.TrimSpace(getString(responseObj["id"])); id != "" { + s.id = id + } + } + if s.model == "" { + if model := strings.TrimSpace(getString(responseObj["model"])); model != "" { + s.model = model + } + } + if s.created == 0 { + if created := getInt64(responseObj["created_at"]); created != 0 { + s.created = created + } + } +} + +func (s *chatCompletionStreamState) applyResponseUsage(responseObj map[string]any) { + usage, ok := responseObj["usage"].(map[string]any) + if !ok { + return + } + promptTokens := int(getNumber(usage["input_tokens"])) + completionTokens := int(getNumber(usage["output_tokens"])) + if promptTokens == 0 && completionTokens == 0 { + return + } + s.usage = map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": completionTokens, + "total_tokens": promptTokens + completionTokens, + } +} + +func extractToolCallFromItem(item map[string]any) (string, string, string, bool) { + itemType := strings.TrimSpace(getString(item["type"])) + if itemType != "tool_call" && itemType != "function_call" { + return "", "", "", false + } + callID := strings.TrimSpace(getString(item["call_id"])) + if callID == "" { + callID = strings.TrimSpace(getString(item["id"])) + } + name := strings.TrimSpace(getString(item["name"])) + args := getString(item["arguments"]) + if fn, ok := item["function"].(map[string]any); ok { + if name == "" { + name = strings.TrimSpace(getString(fn["name"])) + } + if args == "" { + args = getString(fn["arguments"]) + } + } + if callID == "" && name == "" && args == "" { + return "", "", "", false + } + if callID == "" { + callID = "call_" + randomHexUnsafe(6) + } + return callID, name, args, true +} + +func getString(value any) string { + switch v := value.(type) { + case string: + return v + case []byte: + return string(v) + case json.Number: + return v.String() + default: + return "" + } +} + +func getNumber(value any) float64 { + switch v := value.(type) { + case float64: + return v + case float32: + return float64(v) + case int: + return float64(v) + case int64: + return float64(v) + case json.Number: + f, _ := v.Float64() + return f + default: + return 0 + } +} + +func getInt64(value any) int64 { + switch v := value.(type) { + case int64: + return v + case int: + return int64(v) + case float64: + return int64(v) + case json.Number: + i, _ := v.Int64() + return i + default: + return 0 + } +} + +func randomHexUnsafe(byteLength int) string { + if byteLength <= 0 { + byteLength = 8 + } + buf := make([]byte, byteLength) + if _, err := rand.Read(buf); err != nil { + return "000000" + } + return hex.EncodeToString(buf) +} diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 930c8b9ee..613943287 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -1,8 +1,6 @@ package routes import ( - "net/http" - "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/handler" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -43,15 +41,8 @@ func RegisterGatewayRoutes( gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API gateway.POST("/responses", h.OpenAIGateway.Responses) - // 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。 - gateway.POST("/chat/completions", func(c *gin.Context) { - c.JSON(http.StatusBadRequest, gin.H{ - "error": gin.H{ - "type": "invalid_request_error", - "message": "Unsupported legacy protocol: /v1/chat/completions is not supported. Please use /v1/responses.", - }, - }) - }) + // OpenAI Chat Completions API + gateway.POST("/chat/completions", h.OpenAIGateway.ChatCompletions) } // Gemini 原生 API 兼容层(Gemini SDK/CLI 直连) @@ -69,6 +60,8 @@ func RegisterGatewayRoutes( // OpenAI Responses API(不带v1前缀的别名) r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) + // OpenAI Chat Completions API(不带v1前缀的别名) + r.POST("/chat/completions", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ChatCompletions) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels) diff --git a/backend/internal/service/openai_chat_completions.go b/backend/internal/service/openai_chat_completions.go new file mode 100644 index 000000000..c4c95ff22 --- /dev/null +++ b/backend/internal/service/openai_chat_completions.go @@ -0,0 +1,513 @@ +package service + +import ( + "encoding/json" + "errors" + "strings" + "time" +) + +// ConvertChatCompletionsToResponses converts an OpenAI Chat Completions request to a Responses request. +func ConvertChatCompletionsToResponses(req map[string]any) (map[string]any, error) { + if req == nil { + return nil, errors.New("request is nil") + } + + model := strings.TrimSpace(getString(req["model"])) + if model == "" { + return nil, errors.New("model is required") + } + + messagesRaw, ok := req["messages"] + if !ok { + return nil, errors.New("messages is required") + } + messages, ok := messagesRaw.([]any) + if !ok { + return nil, errors.New("messages must be an array") + } + + input, err := convertChatMessagesToResponsesInput(messages) + if err != nil { + return nil, err + } + + out := make(map[string]any, len(req)+1) + for key, value := range req { + switch key { + case "messages", "max_tokens", "max_completion_tokens", "stream_options", "functions", "function_call": + continue + default: + out[key] = value + } + } + + out["model"] = model + out["input"] = input + + if _, ok := out["max_output_tokens"]; !ok { + if v, ok := req["max_tokens"]; ok { + out["max_output_tokens"] = v + } else if v, ok := req["max_completion_tokens"]; ok { + out["max_output_tokens"] = v + } + } + + if _, ok := out["tools"]; !ok { + if functions, ok := req["functions"].([]any); ok && len(functions) > 0 { + tools := make([]any, 0, len(functions)) + for _, fn := range functions { + if fnMap, ok := fn.(map[string]any); ok { + tools = append(tools, map[string]any{ + "type": "function", + "function": fnMap, + }) + } + } + if len(tools) > 0 { + out["tools"] = tools + } + } + } + + if _, ok := out["tool_choice"]; !ok { + if functionCall, ok := req["function_call"]; ok { + out["tool_choice"] = functionCall + } + } + + return out, nil +} + +// ConvertResponsesToChatCompletion converts an OpenAI Responses response body to Chat Completions format. +func ConvertResponsesToChatCompletion(body []byte) ([]byte, error) { + var resp map[string]any + if err := json.Unmarshal(body, &resp); err != nil { + return nil, err + } + + id := strings.TrimSpace(getString(resp["id"])) + if id == "" { + id = "chatcmpl-" + safeRandomHex(12) + } + model := strings.TrimSpace(getString(resp["model"])) + + created := getInt64(resp["created_at"]) + if created == 0 { + created = getInt64(resp["created"]) + } + if created == 0 { + created = time.Now().Unix() + } + + text, toolCalls := extractResponseTextAndToolCalls(resp) + finishReason := "stop" + if len(toolCalls) > 0 { + finishReason = "tool_calls" + } + + message := map[string]any{ + "role": "assistant", + "content": text, + } + if len(toolCalls) > 0 { + message["tool_calls"] = toolCalls + } + + chatResp := map[string]any{ + "id": id, + "object": "chat.completion", + "created": created, + "model": model, + "choices": []any{ + map[string]any{ + "index": 0, + "message": message, + "finish_reason": finishReason, + }, + }, + } + + if usage := extractResponseUsage(resp); usage != nil { + chatResp["usage"] = usage + } + if fingerprint := strings.TrimSpace(getString(resp["system_fingerprint"])); fingerprint != "" { + chatResp["system_fingerprint"] = fingerprint + } + + return json.Marshal(chatResp) +} + +func convertChatMessagesToResponsesInput(messages []any) ([]any, error) { + input := make([]any, 0, len(messages)) + for _, msg := range messages { + msgMap, ok := msg.(map[string]any) + if !ok { + return nil, errors.New("message must be an object") + } + role := strings.TrimSpace(getString(msgMap["role"])) + if role == "" { + return nil, errors.New("message role is required") + } + + switch role { + case "tool": + callID := strings.TrimSpace(getString(msgMap["tool_call_id"])) + if callID == "" { + callID = strings.TrimSpace(getString(msgMap["id"])) + } + output := extractMessageContentText(msgMap["content"]) + input = append(input, map[string]any{ + "type": "function_call_output", + "call_id": callID, + "output": output, + }) + case "function": + callID := strings.TrimSpace(getString(msgMap["name"])) + output := extractMessageContentText(msgMap["content"]) + input = append(input, map[string]any{ + "type": "function_call_output", + "call_id": callID, + "output": output, + }) + default: + convertedContent := convertChatContent(msgMap["content"]) + toolCalls := []any(nil) + if role == "assistant" { + toolCalls = extractToolCallsFromMessage(msgMap) + } + skipAssistantMessage := role == "assistant" && len(toolCalls) > 0 && isEmptyContent(convertedContent) + if !skipAssistantMessage { + msgItem := map[string]any{ + "role": role, + "content": convertedContent, + } + if name := strings.TrimSpace(getString(msgMap["name"])); name != "" { + msgItem["name"] = name + } + input = append(input, msgItem) + } + if role == "assistant" && len(toolCalls) > 0 { + input = append(input, toolCalls...) + } + } + } + return input, nil +} + +func convertChatContent(content any) any { + switch v := content.(type) { + case nil: + return "" + case string: + return v + case []any: + converted := make([]any, 0, len(v)) + for _, part := range v { + partMap, ok := part.(map[string]any) + if !ok { + converted = append(converted, part) + continue + } + partType := strings.TrimSpace(getString(partMap["type"])) + switch partType { + case "text": + text := getString(partMap["text"]) + if text != "" { + converted = append(converted, map[string]any{ + "type": "input_text", + "text": text, + }) + continue + } + case "image_url": + imageURL := "" + if imageObj, ok := partMap["image_url"].(map[string]any); ok { + imageURL = getString(imageObj["url"]) + } else { + imageURL = getString(partMap["image_url"]) + } + if imageURL != "" { + converted = append(converted, map[string]any{ + "type": "input_image", + "image_url": imageURL, + }) + continue + } + case "input_text", "input_image": + converted = append(converted, partMap) + continue + } + converted = append(converted, partMap) + } + return converted + default: + return v + } +} + +func extractToolCallsFromMessage(msg map[string]any) []any { + var out []any + if toolCalls, ok := msg["tool_calls"].([]any); ok { + for _, call := range toolCalls { + callMap, ok := call.(map[string]any) + if !ok { + continue + } + callID := strings.TrimSpace(getString(callMap["id"])) + if callID == "" { + callID = strings.TrimSpace(getString(callMap["call_id"])) + } + name := "" + args := "" + if fn, ok := callMap["function"].(map[string]any); ok { + name = strings.TrimSpace(getString(fn["name"])) + args = getString(fn["arguments"]) + } + if name == "" && args == "" { + continue + } + item := map[string]any{ + "type": "tool_call", + } + if callID != "" { + item["call_id"] = callID + } + if name != "" { + item["name"] = name + } + if args != "" { + item["arguments"] = args + } + out = append(out, item) + } + } + + if fnCall, ok := msg["function_call"].(map[string]any); ok { + name := strings.TrimSpace(getString(fnCall["name"])) + args := getString(fnCall["arguments"]) + if name != "" || args != "" { + callID := strings.TrimSpace(getString(msg["tool_call_id"])) + if callID == "" { + callID = name + } + item := map[string]any{ + "type": "function_call", + } + if callID != "" { + item["call_id"] = callID + } + if name != "" { + item["name"] = name + } + if args != "" { + item["arguments"] = args + } + out = append(out, item) + } + } + + return out +} + +func extractMessageContentText(content any) string { + switch v := content.(type) { + case string: + return v + case []any: + parts := make([]string, 0, len(v)) + for _, part := range v { + partMap, ok := part.(map[string]any) + if !ok { + continue + } + partType := strings.TrimSpace(getString(partMap["type"])) + if partType == "" || partType == "text" || partType == "output_text" || partType == "input_text" { + text := getString(partMap["text"]) + if text != "" { + parts = append(parts, text) + } + } + } + return strings.Join(parts, "") + default: + return "" + } +} + +func isEmptyContent(content any) bool { + switch v := content.(type) { + case nil: + return true + case string: + return strings.TrimSpace(v) == "" + case []any: + return len(v) == 0 + default: + return false + } +} + +func extractResponseTextAndToolCalls(resp map[string]any) (string, []any) { + output, ok := resp["output"].([]any) + if !ok { + if text, ok := resp["output_text"].(string); ok { + return text, nil + } + return "", nil + } + + textParts := make([]string, 0) + toolCalls := make([]any, 0) + + for _, item := range output { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType := strings.TrimSpace(getString(itemMap["type"])) + + if itemType == "tool_call" || itemType == "function_call" { + if tc := responseItemToChatToolCall(itemMap); tc != nil { + toolCalls = append(toolCalls, tc) + } + continue + } + + content := itemMap["content"] + switch v := content.(type) { + case string: + if v != "" { + textParts = append(textParts, v) + } + case []any: + for _, part := range v { + partMap, ok := part.(map[string]any) + if !ok { + continue + } + partType := strings.TrimSpace(getString(partMap["type"])) + switch partType { + case "output_text", "text", "input_text": + text := getString(partMap["text"]) + if text != "" { + textParts = append(textParts, text) + } + case "tool_call", "function_call": + if tc := responseItemToChatToolCall(partMap); tc != nil { + toolCalls = append(toolCalls, tc) + } + } + } + } + } + + return strings.Join(textParts, ""), toolCalls +} + +func responseItemToChatToolCall(item map[string]any) map[string]any { + callID := strings.TrimSpace(getString(item["call_id"])) + if callID == "" { + callID = strings.TrimSpace(getString(item["id"])) + } + name := strings.TrimSpace(getString(item["name"])) + arguments := getString(item["arguments"]) + if fn, ok := item["function"].(map[string]any); ok { + if name == "" { + name = strings.TrimSpace(getString(fn["name"])) + } + if arguments == "" { + arguments = getString(fn["arguments"]) + } + } + + if name == "" && arguments == "" && callID == "" { + return nil + } + + if callID == "" { + callID = "call_" + safeRandomHex(6) + } + + return map[string]any{ + "id": callID, + "type": "function", + "function": map[string]any{ + "name": name, + "arguments": arguments, + }, + } +} + +func extractResponseUsage(resp map[string]any) map[string]any { + usage, ok := resp["usage"].(map[string]any) + if !ok { + return nil + } + promptTokens := int(getNumber(usage["input_tokens"])) + completionTokens := int(getNumber(usage["output_tokens"])) + if promptTokens == 0 && completionTokens == 0 { + return nil + } + + return map[string]any{ + "prompt_tokens": promptTokens, + "completion_tokens": completionTokens, + "total_tokens": promptTokens + completionTokens, + } +} + +func getString(value any) string { + switch v := value.(type) { + case string: + return v + case []byte: + return string(v) + case json.Number: + return v.String() + default: + return "" + } +} + +func getNumber(value any) float64 { + switch v := value.(type) { + case float64: + return v + case float32: + return float64(v) + case int: + return float64(v) + case int64: + return float64(v) + case json.Number: + f, _ := v.Float64() + return f + default: + return 0 + } +} + +func getInt64(value any) int64 { + switch v := value.(type) { + case int64: + return v + case int: + return int64(v) + case float64: + return int64(v) + case json.Number: + i, _ := v.Int64() + return i + default: + return 0 + } +} + +func safeRandomHex(byteLength int) string { + value, err := randomHexString(byteLength) + if err != nil || value == "" { + return "000000" + } + return value +} diff --git a/backend/internal/service/openai_chat_completions_forward.go b/backend/internal/service/openai_chat_completions_forward.go new file mode 100644 index 000000000..703f3af17 --- /dev/null +++ b/backend/internal/service/openai_chat_completions_forward.go @@ -0,0 +1,486 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net/http" + "strings" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" +) + +type chatStreamingResult struct { + usage *OpenAIUsage + firstTokenMs *int +} + +func (s *OpenAIGatewayService) forwardChatCompletions(ctx context.Context, c *gin.Context, account *Account, body []byte, includeUsage bool, startTime time.Time) (*OpenAIForwardResult, error) { + // Parse request body once (avoid multiple parse/serialize cycles) + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + + reqModel, _ := reqBody["model"].(string) + reqStream, _ := reqBody["stream"].(bool) + originalModel := reqModel + + bodyModified := false + mappedModel := account.GetMappedModel(reqModel) + if mappedModel != reqModel { + log.Printf("[OpenAI Chat] Model mapping applied: %s -> %s (account: %s)", reqModel, mappedModel, account.Name) + reqBody["model"] = mappedModel + bodyModified = true + } + + if reqStream && includeUsage { + streamOptions, _ := reqBody["stream_options"].(map[string]any) + if streamOptions == nil { + streamOptions = map[string]any{} + } + if _, ok := streamOptions["include_usage"]; !ok { + streamOptions["include_usage"] = true + reqBody["stream_options"] = streamOptions + bodyModified = true + } + } + + if bodyModified { + var err error + body, err = json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("serialize request body: %w", err) + } + } + + // Get access token + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + upstreamReq, err := s.buildChatCompletionsRequest(ctx, c, account, body, token) + if err != nil { + return nil, err + } + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + if c != nil { + c.Set(OpsUpstreamRequestBodyKey, string(body)) + } + + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + c.JSON(http.StatusBadGateway, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": "Upstream request failed", + }, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode >= 400 { + if s.shouldFailoverUpstreamError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + upstreamDetail := "" + if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody { + maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + if maxBytes <= 0 { + maxBytes = 2048 + } + upstreamDetail = truncateString(string(respBody), maxBytes) + } + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "failover", + Message: upstreamMsg, + Detail: upstreamDetail, + }) + + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + } + return s.handleErrorResponse(ctx, resp, c, account, body) + } + + var usage *OpenAIUsage + var firstTokenMs *int + if reqStream { + streamResult, err := s.handleChatCompletionsStreamingResponse(ctx, resp, c, account, startTime, originalModel, mappedModel) + if err != nil { + return nil, err + } + usage = streamResult.usage + firstTokenMs = streamResult.firstTokenMs + } else { + usage, err = s.handleChatCompletionsNonStreamingResponse(resp, c, originalModel, mappedModel) + if err != nil { + return nil, err + } + } + + if usage == nil { + usage = &OpenAIUsage{} + } + + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: *usage, + Model: originalModel, + Stream: reqStream, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +func (s *OpenAIGatewayService) buildChatCompletionsRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token string) (*http.Request, error) { + var targetURL string + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + targetURL = openaiChatAPIURL + } else { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return nil, err + } + targetURL = validatedURL + "/chat/completions" + } + + req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body)) + if err != nil { + return nil, err + } + + req.Header.Set("authorization", "Bearer "+token) + + for key, values := range c.Request.Header { + lowerKey := strings.ToLower(key) + if openaiChatAllowedHeaders[lowerKey] { + for _, v := range values { + req.Header.Add(key, v) + } + } + } + + customUA := account.GetOpenAIUserAgent() + if customUA != "" { + req.Header.Set("user-agent", customUA) + } + + if req.Header.Get("content-type") == "" { + req.Header.Set("content-type", "application/json") + } + + return req, nil +} + +func (s *OpenAIGatewayService) handleChatCompletionsStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*chatStreamingResult, error) { + if s.cfg != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + } + + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + if v := resp.Header.Get("x-request-id"); v != "" { + c.Header("x-request-id", v) + } + + w := c.Writer + flusher, ok := w.(http.Flusher) + if !ok { + return nil, errors.New("streaming not supported") + } + + usage := &OpenAIUsage{} + var firstTokenMs *int + + scanner := bufio.NewScanner(resp.Body) + maxLineSize := defaultMaxLineSize + if s.cfg != nil && s.cfg.Gateway.MaxLineSize > 0 { + maxLineSize = s.cfg.Gateway.MaxLineSize + } + scanner.Buffer(make([]byte, 64*1024), maxLineSize) + + type scanEvent struct { + line string + err error + } + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func() { + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }() + defer close(done) + + streamInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { + streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second + } + var intervalTicker *time.Ticker + if streamInterval > 0 { + intervalTicker = time.NewTicker(streamInterval) + defer intervalTicker.Stop() + } + var intervalCh <-chan time.Time + if intervalTicker != nil { + intervalCh = intervalTicker.C + } + + keepaliveInterval := time.Duration(0) + if s.cfg != nil && s.cfg.Gateway.StreamKeepaliveInterval > 0 { + keepaliveInterval = time.Duration(s.cfg.Gateway.StreamKeepaliveInterval) * time.Second + } + var keepaliveTicker *time.Ticker + if keepaliveInterval > 0 { + keepaliveTicker = time.NewTicker(keepaliveInterval) + defer keepaliveTicker.Stop() + } + var keepaliveCh <-chan time.Time + if keepaliveTicker != nil { + keepaliveCh = keepaliveTicker.C + } + lastDataAt := time.Now() + + errorEventSent := false + sendErrorEvent := func(reason string) { + if errorEventSent { + return + } + errorEventSent = true + _, _ = fmt.Fprintf(w, "event: error\ndata: {\"error\":\"%s\"}\n\n", reason) + flusher.Flush() + } + + needModelReplace := originalModel != mappedModel + + for { + select { + case ev, ok := <-events: + if !ok { + return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + } + if ev.err != nil { + if errors.Is(ev.err, bufio.ErrTooLong) { + log.Printf("SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) + sendErrorEvent("response_too_large") + return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err + } + sendErrorEvent("stream_read_error") + return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) + } + + line := ev.line + lastDataAt = time.Now() + + if openaiSSEDataRe.MatchString(line) { + data := openaiSSEDataRe.ReplaceAllString(line, "") + + if needModelReplace { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } + + if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { + line = "data: " + correctedData + } + + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + + if firstTokenMs == nil { + if event := parseChatStreamEvent(data); event != nil { + if chatChunkHasDelta(event) { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + applyChatUsageFromEvent(event, usage) + } + } else { + if event := parseChatStreamEvent(data); event != nil { + applyChatUsageFromEvent(event, usage) + } + } + } else { + if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { + sendErrorEvent("write_failed") + return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + } + + case <-intervalCh: + lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) + if time.Since(lastRead) < streamInterval { + continue + } + log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) + if s.rateLimitService != nil { + s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) + } + sendErrorEvent("stream_timeout") + return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + + case <-keepaliveCh: + if time.Since(lastDataAt) < keepaliveInterval { + continue + } + if _, err := fmt.Fprint(w, ":\n\n"); err != nil { + return &chatStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, err + } + flusher.Flush() + } + } +} + +func (s *OpenAIGatewayService) handleChatCompletionsNonStreamingResponse(resp *http.Response, c *gin.Context, originalModel, mappedModel string) (*OpenAIUsage, error) { + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + usage := &OpenAIUsage{} + var parsed map[string]any + if json.Unmarshal(body, &parsed) == nil { + if usageMap, ok := parsed["usage"].(map[string]any); ok { + applyChatUsageFromMap(usageMap, usage) + } + } + + if originalModel != mappedModel { + body = s.replaceModelInResponseBody(body, mappedModel, originalModel) + } + body = s.correctToolCallsInResponseBody(body) + + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + + contentType := "application/json" + if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { + if upstreamType := resp.Header.Get("Content-Type"); upstreamType != "" { + contentType = upstreamType + } + } + + c.Data(resp.StatusCode, contentType, body) + return usage, nil +} + +func parseChatStreamEvent(data string) map[string]any { + if data == "" || data == "[DONE]" { + return nil + } + var event map[string]any + if json.Unmarshal([]byte(data), &event) != nil { + return nil + } + return event +} + +func chatChunkHasDelta(event map[string]any) bool { + choices, ok := event["choices"].([]any) + if !ok { + return false + } + for _, choice := range choices { + choiceMap, ok := choice.(map[string]any) + if !ok { + continue + } + delta, ok := choiceMap["delta"].(map[string]any) + if !ok { + continue + } + if content, ok := delta["content"].(string); ok && strings.TrimSpace(content) != "" { + return true + } + if toolCalls, ok := delta["tool_calls"].([]any); ok && len(toolCalls) > 0 { + return true + } + if functionCall, ok := delta["function_call"].(map[string]any); ok && len(functionCall) > 0 { + return true + } + } + return false +} + +func applyChatUsageFromEvent(event map[string]any, usage *OpenAIUsage) { + if event == nil || usage == nil { + return + } + usageMap, ok := event["usage"].(map[string]any) + if !ok { + return + } + applyChatUsageFromMap(usageMap, usage) +} + +func applyChatUsageFromMap(usageMap map[string]any, usage *OpenAIUsage) { + if usageMap == nil || usage == nil { + return + } + promptTokens := int(getNumber(usageMap["prompt_tokens"])) + completionTokens := int(getNumber(usageMap["completion_tokens"])) + if promptTokens > 0 { + usage.InputTokens = promptTokens + } + if completionTokens > 0 { + usage.OutputTokens = completionTokens + } +} diff --git a/backend/internal/service/openai_chat_completions_test.go b/backend/internal/service/openai_chat_completions_test.go new file mode 100644 index 000000000..635bda236 --- /dev/null +++ b/backend/internal/service/openai_chat_completions_test.go @@ -0,0 +1,132 @@ +package service + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestConvertChatCompletionsToResponses(t *testing.T) { + req := map[string]any{ + "model": "gpt-4o", + "messages": []any{ + map[string]any{ + "role": "user", + "content": "hello", + }, + map[string]any{ + "role": "assistant", + "tool_calls": []any{ + map[string]any{ + "id": "call_1", + "type": "function", + "function": map[string]any{ + "name": "ping", + "arguments": "{}", + }, + }, + }, + }, + map[string]any{ + "role": "tool", + "tool_call_id": "call_1", + "content": "ok", + "response": "ignored", + "response_time": 1, + }, + }, + "functions": []any{ + map[string]any{ + "name": "ping", + "description": "ping tool", + "parameters": map[string]any{"type": "object"}, + }, + }, + "function_call": map[string]any{"name": "ping"}, + } + + converted, err := ConvertChatCompletionsToResponses(req) + require.NoError(t, err) + require.Equal(t, "gpt-4o", converted["model"]) + + input, ok := converted["input"].([]any) + require.True(t, ok) + require.Len(t, input, 3) + + toolCall := findInputItemByType(input, "tool_call") + require.NotNil(t, toolCall) + require.Equal(t, "call_1", toolCall["call_id"]) + + toolOutput := findInputItemByType(input, "function_call_output") + require.NotNil(t, toolOutput) + require.Equal(t, "call_1", toolOutput["call_id"]) + + tools, ok := converted["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 1) + + require.Equal(t, map[string]any{"name": "ping"}, converted["tool_choice"]) +} + +func TestConvertResponsesToChatCompletion(t *testing.T) { + resp := map[string]any{ + "id": "resp_123", + "model": "gpt-4o", + "created_at": 1700000000, + "output": []any{ + map[string]any{ + "type": "message", + "role": "assistant", + "content": []any{ + map[string]any{ + "type": "output_text", + "text": "hi", + }, + }, + }, + }, + "usage": map[string]any{ + "input_tokens": 2, + "output_tokens": 3, + }, + } + body, err := json.Marshal(resp) + require.NoError(t, err) + + converted, err := ConvertResponsesToChatCompletion(body) + require.NoError(t, err) + + var chat map[string]any + require.NoError(t, json.Unmarshal(converted, &chat)) + require.Equal(t, "chat.completion", chat["object"]) + + choices, ok := chat["choices"].([]any) + require.True(t, ok) + require.Len(t, choices, 1) + + choice, ok := choices[0].(map[string]any) + require.True(t, ok) + message, ok := choice["message"].(map[string]any) + require.True(t, ok) + require.Equal(t, "hi", message["content"]) + + usage, ok := chat["usage"].(map[string]any) + require.True(t, ok) + require.Equal(t, float64(2), usage["prompt_tokens"]) + require.Equal(t, float64(3), usage["completion_tokens"]) + require.Equal(t, float64(5), usage["total_tokens"]) +} + +func findInputItemByType(items []any, itemType string) map[string]any { + for _, item := range items { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + if itemMap["type"] == itemType { + return itemMap + } + } + return nil +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f26ce03f0..b47d2c4dd 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -11,6 +11,7 @@ import ( "fmt" "io" "net/http" + "regexp" "sort" "strconv" "strings" @@ -33,6 +34,7 @@ const ( chatgptCodexURL = "https://chatgpt.com/backend-api/codex/responses" // OpenAI Platform API for API Key accounts (fallback) openaiPlatformAPIURL = "https://api.openai.com/v1/responses" + openaiChatAPIURL = "https://api.openai.com/v1/chat/completions" openaiStickySessionTTL = time.Hour // 粘性会话TTL codexCLIUserAgent = "codex_cli_rs/0.98.0" // codex_cli_only 拒绝时单个请求头日志长度上限(字符) @@ -42,6 +44,16 @@ const ( OpenAIParsedRequestBodyKey = "openai_parsed_request_body" ) +// OpenAIChatCompletionsBodyKey stores the original chat-completions payload in gin.Context. +const OpenAIChatCompletionsBodyKey = "openai_chat_completions_body" + +// OpenAIChatCompletionsIncludeUsageKey stores stream_options.include_usage in gin.Context. +const OpenAIChatCompletionsIncludeUsageKey = "openai_chat_completions_include_usage" + +// openaiSSEDataRe matches SSE data lines with optional whitespace after colon. +// Some upstream APIs return non-standard "data:" without space (should be "data: "). +var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`) + // OpenAI allowed headers whitelist (for non-passthrough). var openaiAllowedHeaders = map[string]bool{ "accept-language": true, @@ -81,6 +93,19 @@ var codexCLIOnlyDebugHeaderWhitelist = []string{ "X-Real-IP", } +// OpenAI chat-completions allowed headers (extend responses whitelist). +var openaiChatAllowedHeaders = map[string]bool{ + "accept-language": true, + "content-type": true, + "conversation_id": true, + "user-agent": true, + "originator": true, + "session_id": true, + "openai-organization": true, + "openai-project": true, + "openai-beta": true, +} + // OpenAICodexUsageSnapshot represents Codex API usage limits from response headers type OpenAICodexUsageSnapshot struct { PrimaryUsedPercent *float64 `json:"primary_used_percent,omitempty"` @@ -1005,6 +1030,23 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco return nil, errors.New("codex_cli_only restriction: only codex official clients are allowed") } + if c != nil && account != nil && account.Type == AccountTypeAPIKey { + if raw, ok := c.Get(OpenAIChatCompletionsBodyKey); ok { + if rawBody, ok := raw.([]byte); ok && len(rawBody) > 0 { + includeUsage := false + if v, ok := c.Get(OpenAIChatCompletionsIncludeUsageKey); ok { + if flag, ok := v.(bool); ok { + includeUsage = flag + } + } + if passthroughWriter, ok := c.Writer.(interface{ SetPassthrough() }); ok { + passthroughWriter.SetPassthrough() + } + return s.forwardChatCompletions(ctx, c, account, rawBody, includeUsage, startTime) + } + } + } + originalBody := body reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) originalModel := reqModel