diff --git a/scripts/test_oai_proxy.py b/scripts/test_oai_proxy.py
index 8981b7a..dcdf6a3 100644
--- a/scripts/test_oai_proxy.py
+++ b/scripts/test_oai_proxy.py
@@ -25,6 +25,9 @@ def test_basic(client: OpenAI, model: str) -> None:
top_p=0.9,
)
msg = resp.choices[0].message
+ reasoning = getattr(msg, "reasoning_content", None)
+ if reasoning:
+ print("reasoning_content:", reasoning[:200], "..." if len(reasoning) > 200 else "")
print("content:", msg.content)
@@ -85,10 +88,18 @@ def test_stream(client: OpenAI, model: str) -> None:
stream=True,
)
parts: List[str] = []
+ reasoning_parts: List[str] = []
for chunk in stream:
delta = chunk.choices[0].delta
- if delta and delta.content:
- parts.append(delta.content)
+ if delta:
+ if delta.content:
+ parts.append(delta.content)
+ reasoning = getattr(delta, "reasoning_content", None)
+ if reasoning:
+ reasoning_parts.append(reasoning)
+ if reasoning_parts:
+ full_reasoning = "".join(reasoning_parts)
+ print("reasoning_content:", full_reasoning[:200], "..." if len(full_reasoning) > 200 else "")
print("content:", "".join(parts))
diff --git a/truffile/infer/proxy.py b/truffile/infer/proxy.py
index 2d921d3..262d1f1 100644
--- a/truffile/infer/proxy.py
+++ b/truffile/infer/proxy.py
@@ -232,48 +232,99 @@ def _set_structural_tag(req: IRequest, structural_tag: Dict[str, Any]) -> bool:
class _StreamFilter:
+ """Streaming filter that separates visible content, reasoning, and toolcall tags.
+
+ ``feed()`` and ``finalize()`` return a ``(visible, reasoning)`` tuple so
+ that callers can emit ``delta.reasoning_content`` alongside ``delta.content``
+ in OpenAI-compatible SSE chunks (matching DeepSeek / OpenAI convention).
+ """
+
def __init__(self, hide_cot: bool = False) -> None:
self._buffer = ""
self._mode = "normal" # normal | think | toolcall
self._max_tag = max(len(""), len(""), len(""), len(""))
- self._hide_cot = hide_cot
self._passed_cot = not hide_cot
- def finalize(self) -> str:
- if self._mode != "normal":
+ self._cot_open_stripped = False # whether we've consumed the opening in phase 1
+
+ def finalize(self) -> Tuple[str, str]:
+ """Flush remaining buffer. Returns ``(visible, reasoning)``."""
+ if not self._passed_cot:
+ # Stream ended mid-thinking (never saw ).
+ reasoning = self._buffer
self._buffer = ""
- return ""
+ return "", reasoning
+ if self._mode == "think":
+ reasoning = self._buffer
+ self._buffer = ""
+ return "", reasoning
+ if self._mode == "toolcall":
+ self._buffer = ""
+ return "", ""
tail = self._buffer
self._buffer = ""
- return tail
- def feed(self, chunk: str) -> str:
+ return tail, ""
+
+ def feed(self, chunk: str) -> Tuple[str, str]:
+ """Process *chunk* and return ``(visible, reasoning)`` text."""
if not chunk:
- return ""
+ return "", ""
buf = self._buffer + chunk
+ reasoning_parts: List[str] = []
+
+ # Phase 1: skip initial CoT block for reasoner models, capturing it.
if not self._passed_cot:
+ # Strip opening tag once, before emitting any reasoning.
+ if not self._cot_open_stripped:
+ tag_pos = buf.find("")
+ if tag_pos != -1:
+ buf = buf[tag_pos + len(""):]
+ buf = buf.lstrip("\n") # drop leading newline after
+ self._cot_open_stripped = True
+ else:
+ # Haven't seen the full opening tag yet — could be split.
+ # Keep buffering without emitting anything as reasoning.
+ keep = len("") - 1
+ self._buffer = buf[-keep:] if keep > 0 else ""
+ return "", ""
+
end = buf.find("")
if end == -1:
- # Keep only enough to detect a split closing tag.
keep = len("") - 1
+ # Everything except the safety buffer is reasoning.
+ if len(buf) > keep:
+ text = buf[:-keep] if keep > 0 else buf
+ reasoning_parts.append(text)
self._buffer = buf[-keep:] if keep > 0 else ""
- return ""
- buf = buf[end + len("") :]
+ return "", "".join(reasoning_parts)
+ # Capture everything before as reasoning.
+ cot_text = buf[:end]
+ if cot_text:
+ reasoning_parts.append(cot_text)
+ buf = buf[end + len(""):]
self._passed_cot = True
+
+ # Phase 2: state-machine pass over visible / think / toolcall segments.
out_parts: List[str] = []
while buf:
if self._mode == "think":
end = buf.find("")
if end == -1:
- self._buffer = buf[-(self._max_tag - 1):]
- return "".join(out_parts)
- buf = buf[end + len("") :]
+ keep = self._max_tag - 1
+ if len(buf) > keep:
+ reasoning_parts.append(buf[:-keep])
+ self._buffer = buf[-keep:]
+ return "".join(out_parts), "".join(reasoning_parts)
+ reasoning_parts.append(buf[:end])
+ buf = buf[end + len(""):]
self._mode = "normal"
continue
+
if self._mode == "toolcall":
end = buf.find("")
if end == -1:
self._buffer = buf[-(self._max_tag - 1):]
- return "".join(out_parts)
- buf = buf[end + len("") :]
+ return "".join(out_parts), "".join(reasoning_parts)
+ buf = buf[end + len(""):]
self._mode = "normal"
continue
@@ -281,26 +332,26 @@ def feed(self, chunk: str) -> str:
next_tool = buf.find("")
if next_think == -1 and next_tool == -1:
if len(buf) >= self._max_tag:
- out_parts.append(buf[: -(self._max_tag - 1)])
- self._buffer = buf[-(self._max_tag - 1) :]
+ out_parts.append(buf[:-(self._max_tag - 1)])
+ self._buffer = buf[-(self._max_tag - 1):]
else:
self._buffer = buf
- return "".join(out_parts)
+ return "".join(out_parts), "".join(reasoning_parts)
if next_think == -1 or (next_tool != -1 and next_tool < next_think):
if next_tool > 0:
out_parts.append(buf[:next_tool])
- buf = buf[next_tool + len("") :]
+ buf = buf[next_tool + len(""):]
self._mode = "toolcall"
continue
if next_think > 0:
out_parts.append(buf[:next_think])
- buf = buf[next_think + len("") :]
+ buf = buf[next_think + len(""):]
self._mode = "think"
self._buffer = ""
- return "".join(out_parts)
+ return "".join(out_parts), "".join(reasoning_parts)
class OpenAIProxy:
@@ -524,8 +575,13 @@ def do_POST(self) -> None:
print(ir.content, end="", flush=True)
if ir.HasField("finish_reason") and ir.finish_reason != FinishReason.FINISH_UNSPECIFIED:
last_finish = ir.finish_reason
- visible = filter_state.feed(ir.content)
+ visible, reasoning = filter_state.feed(ir.content)
+ delta: Dict[str, Any] = {}
if visible:
+ delta["content"] = visible
+ if reasoning:
+ delta["reasoning_content"] = reasoning
+ if delta:
if not self._send_sse(
{
"id": stream_id,
@@ -535,15 +591,20 @@ def do_POST(self) -> None:
"choices": [
{
"index": 0,
- "delta": {"content": visible},
+ "delta": delta,
"finish_reason": None,
}
],
}
):
return
- tail = filter_state.finalize()
+ tail, tail_reasoning = filter_state.finalize()
+ tail_delta: Dict[str, Any] = {}
if tail:
+ tail_delta["content"] = tail
+ if tail_reasoning:
+ tail_delta["reasoning_content"] = tail_reasoning
+ if tail_delta:
if not self._send_sse(
{
"id": stream_id,
@@ -553,7 +614,7 @@ def do_POST(self) -> None:
"choices": [
{
"index": 0,
- "delta": {"content": tail},
+ "delta": tail_delta,
"finish_reason": None,
}
],
@@ -633,6 +694,8 @@ def do_POST(self) -> None:
)
msg: Dict[str, Any] = {"role": "assistant", "content": message}
+ if cot:
+ msg["reasoning_content"] = cot
if openai_tool_calls:
msg["tool_calls"] = openai_tool_calls
if not message: