diff --git a/scripts/test_oai_proxy.py b/scripts/test_oai_proxy.py index 8981b7a..dcdf6a3 100644 --- a/scripts/test_oai_proxy.py +++ b/scripts/test_oai_proxy.py @@ -25,6 +25,9 @@ def test_basic(client: OpenAI, model: str) -> None: top_p=0.9, ) msg = resp.choices[0].message + reasoning = getattr(msg, "reasoning_content", None) + if reasoning: + print("reasoning_content:", reasoning[:200], "..." if len(reasoning) > 200 else "") print("content:", msg.content) @@ -85,10 +88,18 @@ def test_stream(client: OpenAI, model: str) -> None: stream=True, ) parts: List[str] = [] + reasoning_parts: List[str] = [] for chunk in stream: delta = chunk.choices[0].delta - if delta and delta.content: - parts.append(delta.content) + if delta: + if delta.content: + parts.append(delta.content) + reasoning = getattr(delta, "reasoning_content", None) + if reasoning: + reasoning_parts.append(reasoning) + if reasoning_parts: + full_reasoning = "".join(reasoning_parts) + print("reasoning_content:", full_reasoning[:200], "..." if len(full_reasoning) > 200 else "") print("content:", "".join(parts)) diff --git a/truffile/infer/proxy.py b/truffile/infer/proxy.py index 2d921d3..262d1f1 100644 --- a/truffile/infer/proxy.py +++ b/truffile/infer/proxy.py @@ -232,48 +232,99 @@ def _set_structural_tag(req: IRequest, structural_tag: Dict[str, Any]) -> bool: class _StreamFilter: + """Streaming filter that separates visible content, reasoning, and toolcall tags. + + ``feed()`` and ``finalize()`` return a ``(visible, reasoning)`` tuple so + that callers can emit ``delta.reasoning_content`` alongside ``delta.content`` + in OpenAI-compatible SSE chunks (matching DeepSeek / OpenAI convention). + """ + def __init__(self, hide_cot: bool = False) -> None: self._buffer = "" self._mode = "normal" # normal | think | toolcall self._max_tag = max(len(""), len(""), len(""), len("")) - self._hide_cot = hide_cot self._passed_cot = not hide_cot - def finalize(self) -> str: - if self._mode != "normal": + self._cot_open_stripped = False # whether we've consumed the opening in phase 1 + + def finalize(self) -> Tuple[str, str]: + """Flush remaining buffer. Returns ``(visible, reasoning)``.""" + if not self._passed_cot: + # Stream ended mid-thinking (never saw ). + reasoning = self._buffer self._buffer = "" - return "" + return "", reasoning + if self._mode == "think": + reasoning = self._buffer + self._buffer = "" + return "", reasoning + if self._mode == "toolcall": + self._buffer = "" + return "", "" tail = self._buffer self._buffer = "" - return tail - def feed(self, chunk: str) -> str: + return tail, "" + + def feed(self, chunk: str) -> Tuple[str, str]: + """Process *chunk* and return ``(visible, reasoning)`` text.""" if not chunk: - return "" + return "", "" buf = self._buffer + chunk + reasoning_parts: List[str] = [] + + # Phase 1: skip initial CoT block for reasoner models, capturing it. if not self._passed_cot: + # Strip opening tag once, before emitting any reasoning. + if not self._cot_open_stripped: + tag_pos = buf.find("") + if tag_pos != -1: + buf = buf[tag_pos + len(""):] + buf = buf.lstrip("\n") # drop leading newline after + self._cot_open_stripped = True + else: + # Haven't seen the full opening tag yet — could be split. + # Keep buffering without emitting anything as reasoning. + keep = len("") - 1 + self._buffer = buf[-keep:] if keep > 0 else "" + return "", "" + end = buf.find("") if end == -1: - # Keep only enough to detect a split closing tag. keep = len("") - 1 + # Everything except the safety buffer is reasoning. + if len(buf) > keep: + text = buf[:-keep] if keep > 0 else buf + reasoning_parts.append(text) self._buffer = buf[-keep:] if keep > 0 else "" - return "" - buf = buf[end + len("") :] + return "", "".join(reasoning_parts) + # Capture everything before as reasoning. + cot_text = buf[:end] + if cot_text: + reasoning_parts.append(cot_text) + buf = buf[end + len(""):] self._passed_cot = True + + # Phase 2: state-machine pass over visible / think / toolcall segments. out_parts: List[str] = [] while buf: if self._mode == "think": end = buf.find("") if end == -1: - self._buffer = buf[-(self._max_tag - 1):] - return "".join(out_parts) - buf = buf[end + len("") :] + keep = self._max_tag - 1 + if len(buf) > keep: + reasoning_parts.append(buf[:-keep]) + self._buffer = buf[-keep:] + return "".join(out_parts), "".join(reasoning_parts) + reasoning_parts.append(buf[:end]) + buf = buf[end + len(""):] self._mode = "normal" continue + if self._mode == "toolcall": end = buf.find("") if end == -1: self._buffer = buf[-(self._max_tag - 1):] - return "".join(out_parts) - buf = buf[end + len("") :] + return "".join(out_parts), "".join(reasoning_parts) + buf = buf[end + len(""):] self._mode = "normal" continue @@ -281,26 +332,26 @@ def feed(self, chunk: str) -> str: next_tool = buf.find("") if next_think == -1 and next_tool == -1: if len(buf) >= self._max_tag: - out_parts.append(buf[: -(self._max_tag - 1)]) - self._buffer = buf[-(self._max_tag - 1) :] + out_parts.append(buf[:-(self._max_tag - 1)]) + self._buffer = buf[-(self._max_tag - 1):] else: self._buffer = buf - return "".join(out_parts) + return "".join(out_parts), "".join(reasoning_parts) if next_think == -1 or (next_tool != -1 and next_tool < next_think): if next_tool > 0: out_parts.append(buf[:next_tool]) - buf = buf[next_tool + len("") :] + buf = buf[next_tool + len(""):] self._mode = "toolcall" continue if next_think > 0: out_parts.append(buf[:next_think]) - buf = buf[next_think + len("") :] + buf = buf[next_think + len(""):] self._mode = "think" self._buffer = "" - return "".join(out_parts) + return "".join(out_parts), "".join(reasoning_parts) class OpenAIProxy: @@ -524,8 +575,13 @@ def do_POST(self) -> None: print(ir.content, end="", flush=True) if ir.HasField("finish_reason") and ir.finish_reason != FinishReason.FINISH_UNSPECIFIED: last_finish = ir.finish_reason - visible = filter_state.feed(ir.content) + visible, reasoning = filter_state.feed(ir.content) + delta: Dict[str, Any] = {} if visible: + delta["content"] = visible + if reasoning: + delta["reasoning_content"] = reasoning + if delta: if not self._send_sse( { "id": stream_id, @@ -535,15 +591,20 @@ def do_POST(self) -> None: "choices": [ { "index": 0, - "delta": {"content": visible}, + "delta": delta, "finish_reason": None, } ], } ): return - tail = filter_state.finalize() + tail, tail_reasoning = filter_state.finalize() + tail_delta: Dict[str, Any] = {} if tail: + tail_delta["content"] = tail + if tail_reasoning: + tail_delta["reasoning_content"] = tail_reasoning + if tail_delta: if not self._send_sse( { "id": stream_id, @@ -553,7 +614,7 @@ def do_POST(self) -> None: "choices": [ { "index": 0, - "delta": {"content": tail}, + "delta": tail_delta, "finish_reason": None, } ], @@ -633,6 +694,8 @@ def do_POST(self) -> None: ) msg: Dict[str, Any] = {"role": "assistant", "content": message} + if cot: + msg["reasoning_content"] = cot if openai_tool_calls: msg["tool_calls"] = openai_tool_calls if not message: