Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions scripts/test_oai_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def test_basic(client: OpenAI, model: str) -> None:
top_p=0.9,
)
msg = resp.choices[0].message
reasoning = getattr(msg, "reasoning_content", None)
if reasoning:
print("reasoning_content:", reasoning[:200], "..." if len(reasoning) > 200 else "")
print("content:", msg.content)


Expand Down Expand Up @@ -85,10 +88,18 @@ def test_stream(client: OpenAI, model: str) -> None:
stream=True,
)
parts: List[str] = []
reasoning_parts: List[str] = []
for chunk in stream:
delta = chunk.choices[0].delta
if delta and delta.content:
parts.append(delta.content)
if delta:
if delta.content:
parts.append(delta.content)
reasoning = getattr(delta, "reasoning_content", None)
if reasoning:
reasoning_parts.append(reasoning)
if reasoning_parts:
full_reasoning = "".join(reasoning_parts)
print("reasoning_content:", full_reasoning[:200], "..." if len(full_reasoning) > 200 else "")
print("content:", "".join(parts))


Expand Down
113 changes: 88 additions & 25 deletions truffile/infer/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,75 +232,126 @@ def _set_structural_tag(req: IRequest, structural_tag: Dict[str, Any]) -> bool:


class _StreamFilter:
"""Streaming filter that separates visible content, reasoning, and toolcall tags.

``feed()`` and ``finalize()`` return a ``(visible, reasoning)`` tuple so
that callers can emit ``delta.reasoning_content`` alongside ``delta.content``
in OpenAI-compatible SSE chunks (matching DeepSeek / OpenAI convention).
"""

def __init__(self, hide_cot: bool = False) -> None:
self._buffer = ""
self._mode = "normal" # normal | think | toolcall
self._max_tag = max(len("<think>"), len("</think>"), len("<toolcall>"), len("</toolcall>"))
self._hide_cot = hide_cot
self._passed_cot = not hide_cot
def finalize(self) -> str:
if self._mode != "normal":
self._cot_open_stripped = False # whether we've consumed the opening <think> in phase 1

def finalize(self) -> Tuple[str, str]:
"""Flush remaining buffer. Returns ``(visible, reasoning)``."""
if not self._passed_cot:
# Stream ended mid-thinking (never saw </think>).
reasoning = self._buffer
self._buffer = ""
return ""
return "", reasoning
if self._mode == "think":
reasoning = self._buffer
self._buffer = ""
return "", reasoning
if self._mode == "toolcall":
self._buffer = ""
return "", ""
tail = self._buffer
self._buffer = ""
return tail
def feed(self, chunk: str) -> str:
return tail, ""

def feed(self, chunk: str) -> Tuple[str, str]:
"""Process *chunk* and return ``(visible, reasoning)`` text."""
if not chunk:
return ""
return "", ""
buf = self._buffer + chunk
reasoning_parts: List[str] = []

# Phase 1: skip initial CoT block for reasoner models, capturing it.
if not self._passed_cot:
# Strip opening <think> tag once, before emitting any reasoning.
if not self._cot_open_stripped:
tag_pos = buf.find("<think>")
if tag_pos != -1:
buf = buf[tag_pos + len("<think>"):]
buf = buf.lstrip("\n") # drop leading newline after <think>
self._cot_open_stripped = True
else:
# Haven't seen the full opening tag yet — could be split.
# Keep buffering without emitting anything as reasoning.
keep = len("<think>") - 1
self._buffer = buf[-keep:] if keep > 0 else ""
return "", ""

end = buf.find("</think>")
if end == -1:
# Keep only enough to detect a split closing tag.
keep = len("</think>") - 1
# Everything except the safety buffer is reasoning.
if len(buf) > keep:
text = buf[:-keep] if keep > 0 else buf
reasoning_parts.append(text)
self._buffer = buf[-keep:] if keep > 0 else ""
return ""
buf = buf[end + len("</think>") :]
return "", "".join(reasoning_parts)
# Capture everything before </think> as reasoning.
cot_text = buf[:end]
if cot_text:
reasoning_parts.append(cot_text)
buf = buf[end + len("</think>"):]
self._passed_cot = True

# Phase 2: state-machine pass over visible / think / toolcall segments.
out_parts: List[str] = []
while buf:
if self._mode == "think":
end = buf.find("</think>")
if end == -1:
self._buffer = buf[-(self._max_tag - 1):]
return "".join(out_parts)
buf = buf[end + len("</think>") :]
keep = self._max_tag - 1
if len(buf) > keep:
reasoning_parts.append(buf[:-keep])
self._buffer = buf[-keep:]
return "".join(out_parts), "".join(reasoning_parts)
reasoning_parts.append(buf[:end])
buf = buf[end + len("</think>"):]
self._mode = "normal"
continue

if self._mode == "toolcall":
end = buf.find("</toolcall>")
if end == -1:
self._buffer = buf[-(self._max_tag - 1):]
return "".join(out_parts)
buf = buf[end + len("</toolcall>") :]
return "".join(out_parts), "".join(reasoning_parts)
buf = buf[end + len("</toolcall>"):]
self._mode = "normal"
continue

next_think = buf.find("<think>")
next_tool = buf.find("<toolcall>")
if next_think == -1 and next_tool == -1:
if len(buf) >= self._max_tag:
out_parts.append(buf[: -(self._max_tag - 1)])
self._buffer = buf[-(self._max_tag - 1) :]
out_parts.append(buf[:-(self._max_tag - 1)])
self._buffer = buf[-(self._max_tag - 1):]
else:
self._buffer = buf
return "".join(out_parts)
return "".join(out_parts), "".join(reasoning_parts)

if next_think == -1 or (next_tool != -1 and next_tool < next_think):
if next_tool > 0:
out_parts.append(buf[:next_tool])
buf = buf[next_tool + len("<toolcall>") :]
buf = buf[next_tool + len("<toolcall>"):]
self._mode = "toolcall"
continue

if next_think > 0:
out_parts.append(buf[:next_think])
buf = buf[next_think + len("<think>") :]
buf = buf[next_think + len("<think>"):]
self._mode = "think"

self._buffer = ""
return "".join(out_parts)
return "".join(out_parts), "".join(reasoning_parts)


class OpenAIProxy:
Expand Down Expand Up @@ -524,8 +575,13 @@ def do_POST(self) -> None:
print(ir.content, end="", flush=True)
if ir.HasField("finish_reason") and ir.finish_reason != FinishReason.FINISH_UNSPECIFIED:
last_finish = ir.finish_reason
visible = filter_state.feed(ir.content)
visible, reasoning = filter_state.feed(ir.content)
delta: Dict[str, Any] = {}
if visible:
delta["content"] = visible
if reasoning:
delta["reasoning_content"] = reasoning
if delta:
if not self._send_sse(
{
"id": stream_id,
Expand All @@ -535,15 +591,20 @@ def do_POST(self) -> None:
"choices": [
{
"index": 0,
"delta": {"content": visible},
"delta": delta,
"finish_reason": None,
}
],
}
):
return
tail = filter_state.finalize()
tail, tail_reasoning = filter_state.finalize()
tail_delta: Dict[str, Any] = {}
if tail:
tail_delta["content"] = tail
if tail_reasoning:
tail_delta["reasoning_content"] = tail_reasoning
if tail_delta:
if not self._send_sse(
{
"id": stream_id,
Expand All @@ -553,7 +614,7 @@ def do_POST(self) -> None:
"choices": [
{
"index": 0,
"delta": {"content": tail},
"delta": tail_delta,
"finish_reason": None,
}
],
Expand Down Expand Up @@ -633,6 +694,8 @@ def do_POST(self) -> None:
)

msg: Dict[str, Any] = {"role": "assistant", "content": message}
if cot:
msg["reasoning_content"] = cot
if openai_tool_calls:
msg["tool_calls"] = openai_tool_calls
if not message:
Expand Down