diff --git a/truffile/infer/prompts.py b/truffile/infer/prompts.py index 5d8c04b..56daf62 100644 --- a/truffile/infer/prompts.py +++ b/truffile/infer/prompts.py @@ -51,7 +51,12 @@ def get_tag_for_tool(tool: Tool) -> dict: "stop_after_first": not allow_parallel, }, } - req.cfg.response_format.format = ResponseFormat.STRUCTURAL_TAG + try: + fmt = ResponseFormat.STRUCTURAL_TAG + except AttributeError: + # older proto or server; fall back to prompt-only tool guidance. + return + req.cfg.response_format.format = fmt req.cfg.response_format.schema = json.dumps(structural_tag, indent=0) @@ -66,7 +71,6 @@ def get_tag_for_tool(tool: Tool) -> dict: "content": {"type": "json_schema", "json_schema": tool.input_schema}, "end": end, } - structural_tag = { "type": "structural_tag", "format": { @@ -87,5 +91,10 @@ def get_tag_for_tool(tool: Tool) -> dict: ], }, } - req.cfg.response_format.format = ResponseFormat.STRUCTURAL_TAG + try: + fmt = ResponseFormat.STRUCTURAL_TAG + except AttributeError: + # older proto or server; fall back to prompt-only tool guidance. + return + req.cfg.response_format.format = fmt req.cfg.response_format.schema = json.dumps(structural_tag, indent=0) diff --git a/truffile/infer/proxy.py b/truffile/infer/proxy.py index e003af2..2d921d3 100644 --- a/truffile/infer/proxy.py +++ b/truffile/infer/proxy.py @@ -8,6 +8,7 @@ import threading import time import uuid +import os from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from typing import Any, Dict, List, Optional, Tuple @@ -219,6 +220,17 @@ def _usage_to_openai(usage: Any) -> Dict[str, int]: } +def _set_structural_tag(req: IRequest, structural_tag: Dict[str, Any]) -> bool: + try: + fmt = ResponseFormat.STRUCTURAL_TAG + except AttributeError: + # older proto or server; fall back to prompt-only constraints. + return False + req.cfg.response_format.format = fmt + req.cfg.response_format.schema = json.dumps(structural_tag, indent=0) + return True + + class _StreamFilter: def __init__(self, hide_cot: bool = False) -> None: self._buffer = "" @@ -370,8 +382,7 @@ def build_request(self, payload: Dict[str, Any]) -> Tuple[IRequest, Model, bool, ], }, } - req.cfg.response_format.format = ResponseFormat.STRUCTURAL_TAG - req.cfg.response_format.schema = json.dumps(structural_tag, indent=0) + _set_structural_tag(req, structural_tag) else: req.cfg.response_format.format = ResponseFormat.JSON req.cfg.response_format.schema = json.dumps(schema) @@ -464,6 +475,7 @@ def do_POST(self) -> None: try: payload = self._read_body() except Exception as e: + print(f"\tError reading request body: {e}") self._send_json(400, {"error": {"message": str(e), "type": "invalid_request_error"}}) return @@ -472,6 +484,7 @@ def do_POST(self) -> None: try: req, model, is_reasoner, _tools, stream = proxy.build_request(payload) except Exception as e: + print(f"\tError building request: {e}") self._send_json(400, {"error": {"message": str(e), "type": "invalid_request_error"}}) return @@ -502,9 +515,13 @@ def do_POST(self) -> None: raw_content = "" last_finish = None filter_state = _StreamFilter(hide_cot=is_reasoner) - + log_output = os.getenv("TRUFFLE_PROXY_LOG_STREAM_OUTPUT", "0") == "1" + if log_output: + print("Streaming output:") for ir in proxy.run_stream(req): raw_content += ir.content + if log_output: + print(ir.content, end="", flush=True) if ir.HasField("finish_reason") and ir.finish_reason != FinishReason.FINISH_UNSPECIFIED: last_finish = ir.finish_reason visible = filter_state.feed(ir.content) @@ -554,6 +571,7 @@ def do_POST(self) -> None: { "id": f"call_{i+1}", "type": "function", + "index": i, "function": {"name": name, "arguments": args}, } )