diff --git a/pyproject.toml b/pyproject.toml
index 846b92e..7cb2321 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,6 +28,7 @@ dependencies = [
[project.scripts]
truffile = "truffile.cli:main"
+truffleinferproxy = "truffile.infer.proxy:main"
[project.optional-dependencies]
dev = [
diff --git a/scripts/test_oai_proxy.py b/scripts/test_oai_proxy.py
new file mode 100644
index 0000000..8981b7a
--- /dev/null
+++ b/scripts/test_oai_proxy.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+"""Smoke test for the local OpenAI-compatible proxy."""
+
+from __future__ import annotations
+
+import argparse
+import os
+from typing import Any, Dict, List
+try:
+ from openai import OpenAI
+except ImportError:
+ raise ImportError("Please install the 'openai' package to run this test script.")
+
+def _print_header(title: str) -> None:
+ print("\n" + "=" * 8 + f" {title} " + "=" * 8)
+
+
+def test_basic(client: OpenAI, model: str) -> None:
+ _print_header("basic")
+ resp = client.chat.completions.create(
+ model=model,
+ messages=[{"role": "user", "content": "Say hello in one sentence."}],
+ max_tokens=2048,
+ temperature=0.7,
+ top_p=0.9,
+ )
+ msg = resp.choices[0].message
+ print("content:", msg.content)
+
+
+def test_json_schema(client: OpenAI, model: str) -> None:
+ _print_header("json_schema")
+ schema: Dict[str, Any] = {
+ "type": "object",
+ "properties": {
+ "answer": {"type": "string"},
+ "confidence": {"type": "number"},
+ },
+ "required": ["answer", "confidence"],
+ }
+ resp = client.chat.completions.create(
+ model=model,
+ messages=[{"role": "user", "content": "What is 2+2? Respond as JSON."}],
+ response_format={"type": "json_schema", "json_schema": schema},
+ max_tokens=2048,
+ )
+ msg = resp.choices[0].message
+ print("content:", msg.content)
+
+
+def test_tools(client: OpenAI, model: str) -> None:
+ _print_header("tools")
+ tools: List[Dict[str, Any]] = [
+ {
+ "type": "function",
+ "function": {
+ "name": "get_time",
+ "description": "Return the current time in ISO-8601",
+ "parameters": {
+ "type": "object",
+ "properties": {"tz": {"type": "string"}},
+ "required": [],
+ },
+ },
+ }
+ ]
+ resp = client.chat.completions.create(
+ model=model,
+ messages=[{"role": "user", "content": "What time is it? Use the tool."}],
+ tools=tools,
+ tool_choice="auto",
+ max_tokens=2048,
+ )
+ msg = resp.choices[0].message
+ print("tool_calls:", msg.tool_calls)
+ print("content:", msg.content)
+
+
+def test_stream(client: OpenAI, model: str) -> None:
+ _print_header("stream")
+ stream = client.chat.completions.create(
+ model=model,
+ messages=[{"role": "user", "content": "Stream a short haiku."}],
+ max_tokens=2048,
+ stream=True,
+ )
+ parts: List[str] = []
+ for chunk in stream:
+ delta = chunk.choices[0].delta
+ if delta and delta.content:
+ parts.append(delta.content)
+ print("content:", "".join(parts))
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Smoke test for OpenAI proxy")
+ parser.add_argument("--base-url", default="http://127.0.0.1:8080/v1", help="Proxy base URL")
+ parser.add_argument("--model", default="auto", help="Model name or UUID")
+ parser.add_argument("--no-stream", action="store_true", help="Skip streaming test")
+ args = parser.parse_args()
+
+ api_key = os.getenv("OPENAI_API_KEY", "test")
+ client = OpenAI(base_url=args.base_url, api_key=api_key)
+
+ test_basic(client, args.model)
+ test_json_schema(client, args.model)
+ test_tools(client, args.model)
+ if not args.no_stream:
+ test_stream(client, args.model)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/truffile/infer/README.md b/truffile/infer/README.md
new file mode 100644
index 0000000..f267b20
--- /dev/null
+++ b/truffile/infer/README.md
@@ -0,0 +1,18 @@
+# Accessing Inference APIs on the Truffle
+
+The Truffle currently uses its own non-standard set of APIs for inference.
+
+Provided here is a proxy that both demonstrates the usage of these APIs and allows for easier compatibility with existing clients.
+
+This is experimental and may not be fully API compatible, but should serve as a good starting point for exploring the Truffle while core software improves.
+
+### Usage
+
+```bash
+truffleinferproxy --truffle truffle-5970 --host 127.0.0.1 --port 8080
+
+truffleinferproxy --help
+```
+
+
+
diff --git a/truffile/infer/__init__.py b/truffile/infer/__init__.py
new file mode 100644
index 0000000..d9b5eaa
--- /dev/null
+++ b/truffile/infer/__init__.py
@@ -0,0 +1 @@
+"""Standalone OpenAI-compatible proxy for Truffle gRPC inference."""
diff --git a/truffile/infer/common.py b/truffile/infer/common.py
new file mode 100644
index 0000000..53537dd
--- /dev/null
+++ b/truffile/infer/common.py
@@ -0,0 +1,7 @@
+from __future__ import annotations
+
+THINK_TAGS = ["", ""]
+
+
+def clean_response(response: str) -> str:
+ return response.strip().replace("�", "")
diff --git a/truffile/infer/prompts.py b/truffile/infer/prompts.py
new file mode 100644
index 0000000..5d8c04b
--- /dev/null
+++ b/truffile/infer/prompts.py
@@ -0,0 +1,91 @@
+from __future__ import annotations
+
+from typing import List, Tuple
+import json
+import re
+
+from truffle.infer.gencfg_pb2 import ResponseFormat
+
+from .common import THINK_TAGS
+from .tooling import Tool
+
+
+TOOL_TAGS = ["", ""]
+tool_tag_pattern = re.compile(f"{TOOL_TAGS[0]}(.*?){TOOL_TAGS[1]}", re.DOTALL)
+
+
+class AgentPromptBuilder:
+ def extract_tool_calls(self, response: str) -> Tuple[List[dict], str]:
+ tool_calls: List[dict] = []
+ matches = tool_tag_pattern.findall(response)
+ if not matches:
+ return tool_calls, response
+ for match in matches:
+ try:
+ tool_call = json.loads(match.strip())
+ tool_calls.append(tool_call)
+ except json.JSONDecodeError:
+ continue
+ clean_response = tool_tag_pattern.sub("", response).strip()
+ return tool_calls, clean_response
+
+
+def _build_tool_call_response_format_non_reasoning(
+ req, available_tools: List[Tool], allow_parallel: bool = False
+) -> None:
+ def get_tag_for_tool(tool: Tool) -> dict:
+ begin = f"{TOOL_TAGS[0]}\n" + '{"tool": ' + f'"{tool.name}", "args": '
+ end = "}" + f"{TOOL_TAGS[1]}\n"
+ return {
+ "begin": begin,
+ "content": {"type": "json_schema", "json_schema": tool.input_schema},
+ "end": end,
+ }
+
+ structural_tag = {
+ "type": "structural_tag",
+ "format": {
+ "type": "triggered_tags",
+ "triggers": [TOOL_TAGS[0]],
+ "tags": [get_tag_for_tool(tool) for tool in available_tools],
+ "stop_after_first": not allow_parallel,
+ },
+ }
+ req.cfg.response_format.format = ResponseFormat.STRUCTURAL_TAG
+ req.cfg.response_format.schema = json.dumps(structural_tag, indent=0)
+
+
+def _build_tool_call_response_format(
+ req, available_tools: List[Tool], allow_parallel: bool = False
+) -> None:
+ def get_tag_for_tool(tool: Tool) -> dict:
+ begin = f"{TOOL_TAGS[0]}\n" + '{"tool": ' + f'"{tool.name}", "args": '
+ end = "}" + f"{TOOL_TAGS[1]}\n"
+ return {
+ "begin": begin,
+ "content": {"type": "json_schema", "json_schema": tool.input_schema},
+ "end": end,
+ }
+
+ structural_tag = {
+ "type": "structural_tag",
+ "format": {
+ "type": "sequence",
+ "elements": [
+ {
+ "type": "tag",
+ "begin": "",
+ "content": {"type": "any_text"},
+ "end": THINK_TAGS[1],
+ },
+ {
+ "type": "triggered_tags",
+ "triggers": [TOOL_TAGS[0]],
+ "tags": [get_tag_for_tool(tool) for tool in available_tools],
+ "stop_after_first": not allow_parallel,
+ },
+ ],
+ },
+ }
+ req.cfg.response_format.format = ResponseFormat.STRUCTURAL_TAG
+ req.cfg.response_format.schema = json.dumps(structural_tag, indent=0)
diff --git a/truffile/infer/proxy.py b/truffile/infer/proxy.py
new file mode 100644
index 0000000..e003af2
--- /dev/null
+++ b/truffile/infer/proxy.py
@@ -0,0 +1,675 @@
+#!/usr/bin/env python3
+"""Minimal OpenAI-compatible /v1/chat/completions proxy for Truffle gRPC inference."""
+
+from __future__ import annotations
+
+import argparse
+import json
+import threading
+import time
+import uuid
+from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
+from typing import Any, Dict, List, Optional, Tuple
+
+import grpc
+
+from .common import THINK_TAGS, clean_response
+from .prompts import (
+ TOOL_TAGS,
+ AgentPromptBuilder,
+ _build_tool_call_response_format,
+ _build_tool_call_response_format_non_reasoning,
+)
+from .tooling import Tool
+
+from truffle.infer.convo.conversation_pb2 import Conversation, Message
+from truffle.infer.finishreason_pb2 import FinishReason
+from truffle.infer.gencfg_pb2 import ResponseFormat
+from truffle.infer.irequest_pb2 import IRequest
+from truffle.infer.infer_pb2_grpc import InferenceServiceStub
+from truffle.infer.model_pb2 import GetModelListRequest, Model
+
+
+_MODEL_LOCK = threading.Lock()
+_MODEL_CACHE: Dict[str, Model] = {}
+_MODEL_LIST: List[Model] = []
+
+
+def _now_ts() -> int:
+ return int(time.time())
+
+
+def _gen_id(prefix: str) -> str:
+ return f"{prefix}-{uuid.uuid4().hex}"
+
+
+def _load_models(stub: InferenceServiceStub) -> None:
+ global _MODEL_CACHE, _MODEL_LIST
+ model_list = stub.GetModelList(GetModelListRequest(use_filter=False))
+ models = [m for m in model_list.models if m.state == Model.MODEL_STATE_LOADED]
+ cache: Dict[str, Model] = {}
+ for m in models:
+ cache[m.uuid] = m
+ cache[m.name.lower()] = m
+ _MODEL_LIST = models
+ _MODEL_CACHE = cache
+
+
+def _get_models(stub: InferenceServiceStub) -> List[Model]:
+ with _MODEL_LOCK:
+ if not _MODEL_LIST:
+ _load_models(stub)
+ return list(_MODEL_LIST)
+
+
+def _resolve_model(stub: InferenceServiceStub, model_str: Optional[str]) -> Tuple[Model, bool]:
+ models = _get_models(stub)
+ model_key = (model_str or "").strip()
+ if model_key and model_key.lower() not in {"auto", "default"}:
+ with _MODEL_LOCK:
+ m = _MODEL_CACHE.get(model_key) or _MODEL_CACHE.get(model_key.lower())
+ if m is not None:
+ return m, bool(m.config.info.has_chain_of_thought)
+ for m in models:
+ if m.config.info.has_chain_of_thought:
+ return m, True
+ if not models:
+ raise RuntimeError("No loaded models available")
+ return models[0], bool(models[0].config.info.has_chain_of_thought)
+
+
+def _flatten_content(content: Any) -> str:
+ if content is None:
+ return ""
+ if isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ parts: List[str] = []
+ for p in content:
+ if isinstance(p, dict) and p.get("type") == "text":
+ parts.append(p.get("text") or "")
+ return "".join(parts)
+ return str(content)
+
+
+def _build_tool_list(tools_spec: List[Dict[str, Any]]) -> List[Tool]:
+ tools: List[Tool] = []
+ for t in tools_spec:
+ if t.get("type") != "function":
+ continue
+ fn = t.get("function", {})
+ name = fn.get("name")
+ if not name:
+ continue
+ tools.append(
+ Tool(
+ name=name,
+ description=fn.get("description") or "",
+ input_schema=fn.get("parameters") or {"type": "object"},
+ display_name=name,
+ )
+ )
+ return tools
+
+
+def _tool_system_prompt(tools: List[Tool]) -> str:
+ tool_desc = "\n".join([t.get_for_system_prompt() for t in tools])
+ return (
+ "You have access to the following tools:\n"
+ f"{tool_desc}\n"
+ f"When you decide to use a tool, respond with a JSON object enclosed by {TOOL_TAGS[0]} and {TOOL_TAGS[1]} tags in this format:\n"
+ f"{TOOL_TAGS[0]}\n{{\n \"tool\": \"\",\n \"args\": {{}}\n}}\n{TOOL_TAGS[1]}\n"
+ "Only use tools listed above, and ensure your JSON is valid."
+ )
+
+
+def _apply_tool_prompt(messages: List[Dict[str, Any]], prompt: str) -> None:
+ for msg in messages:
+ if msg.get("role") == "system":
+ content = _flatten_content(msg.get("content"))
+ msg["content"] = content + "\n\n" + prompt
+ return
+ messages.insert(0, {"role": "system", "content": prompt})
+
+
+def _serialize_tool_calls(tool_calls: List[Dict[str, Any]]) -> str:
+ chunks: List[str] = []
+ for tc in tool_calls:
+ if tc.get("type") != "function":
+ continue
+ fn = tc.get("function", {})
+ name = fn.get("name")
+ args_raw = fn.get("arguments")
+ args: Any
+ if isinstance(args_raw, str):
+ try:
+ args = json.loads(args_raw)
+ except json.JSONDecodeError:
+ args = {"_raw": args_raw}
+ else:
+ args = args_raw or {}
+ payload = {"tool": name, "args": args}
+ chunks.append(f"{TOOL_TAGS[0]}\n{json.dumps(payload)}\n{TOOL_TAGS[1]}")
+ return "\n".join(chunks)
+
+
+def _build_conversation(messages: List[Dict[str, Any]]) -> Conversation:
+ convo = Conversation()
+ tool_name_by_id: Dict[str, str] = {}
+ for msg in messages:
+ if msg.get("role") == "assistant":
+ for tc in msg.get("tool_calls", []) or []:
+ tc_id = tc.get("id")
+ fn = (tc.get("function") or {})
+ if tc_id and fn.get("name"):
+ tool_name_by_id[tc_id] = fn["name"]
+
+ for msg in messages:
+ role = msg.get("role")
+ content = _flatten_content(msg.get("content"))
+ if role == "assistant" and msg.get("tool_calls"):
+ tool_blob = _serialize_tool_calls(msg.get("tool_calls") or [])
+ content = (content + "\n" + tool_blob).strip()
+ elif role == "tool":
+ tool_name = msg.get("name") or tool_name_by_id.get(msg.get("tool_call_id"), "")
+ content = f" \"tool\" : \"{tool_name}\" \"output\": \"{content}\" "
+
+ if role == "system":
+ convo.messages.add(role=Message.ROLE_SYSTEM, content=content)
+ elif role == "user":
+ convo.messages.add(role=Message.ROLE_USER, content=content)
+ elif role == "assistant":
+ convo.messages.add(role=Message.ROLE_ASSISTANT, content=content)
+ elif role == "tool":
+ convo.messages.add(role=Message.ROLE_TOOL, content=content)
+
+ return convo
+
+
+def _safe_parse_cot(raw: str) -> Tuple[str, str]:
+ if THINK_TAGS[1] in raw:
+ pre, post = raw.split(THINK_TAGS[1], 1)
+ cot = pre.replace(THINK_TAGS[0], "").replace(THINK_TAGS[1], "").strip()
+ return cot, post
+ return "", raw
+
+
+def _map_finish_reason(fr: Optional[int]) -> Optional[str]:
+ if fr is None:
+ return None
+ if fr == FinishReason.FINISH_STOP:
+ return "stop"
+ if fr == FinishReason.FINISH_LENGTH:
+ return "length"
+ if fr == FinishReason.FINISH_TOOLCALLS:
+ return "tool_calls"
+ return "stop"
+
+
+def _usage_to_openai(usage: Any) -> Dict[str, int]:
+ tokens = getattr(usage, "tokens", None)
+ if tokens is None:
+ return {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
+ prompt = int(getattr(tokens, "prompt", 0))
+ completion = int(getattr(tokens, "completion", 0))
+ return {
+ "prompt_tokens": prompt,
+ "completion_tokens": completion,
+ "total_tokens": prompt + completion,
+ }
+
+
+class _StreamFilter:
+ def __init__(self, hide_cot: bool = False) -> None:
+ self._buffer = ""
+ self._mode = "normal" # normal | think | toolcall
+ self._max_tag = max(len(""), len(""), len(""), len(""))
+ self._hide_cot = hide_cot
+ self._passed_cot = not hide_cot
+ def finalize(self) -> str:
+ if self._mode != "normal":
+ self._buffer = ""
+ return ""
+ tail = self._buffer
+ self._buffer = ""
+ return tail
+ def feed(self, chunk: str) -> str:
+ if not chunk:
+ return ""
+ buf = self._buffer + chunk
+ if not self._passed_cot:
+ end = buf.find("")
+ if end == -1:
+ # Keep only enough to detect a split closing tag.
+ keep = len("") - 1
+ self._buffer = buf[-keep:] if keep > 0 else ""
+ return ""
+ buf = buf[end + len("") :]
+ self._passed_cot = True
+ out_parts: List[str] = []
+ while buf:
+ if self._mode == "think":
+ end = buf.find("")
+ if end == -1:
+ self._buffer = buf[-(self._max_tag - 1):]
+ return "".join(out_parts)
+ buf = buf[end + len("") :]
+ self._mode = "normal"
+ continue
+ if self._mode == "toolcall":
+ end = buf.find("")
+ if end == -1:
+ self._buffer = buf[-(self._max_tag - 1):]
+ return "".join(out_parts)
+ buf = buf[end + len("") :]
+ self._mode = "normal"
+ continue
+
+ next_think = buf.find("")
+ next_tool = buf.find("")
+ if next_think == -1 and next_tool == -1:
+ if len(buf) >= self._max_tag:
+ out_parts.append(buf[: -(self._max_tag - 1)])
+ self._buffer = buf[-(self._max_tag - 1) :]
+ else:
+ self._buffer = buf
+ return "".join(out_parts)
+
+ if next_think == -1 or (next_tool != -1 and next_tool < next_think):
+ if next_tool > 0:
+ out_parts.append(buf[:next_tool])
+ buf = buf[next_tool + len("") :]
+ self._mode = "toolcall"
+ continue
+
+ if next_think > 0:
+ out_parts.append(buf[:next_think])
+ buf = buf[next_think + len("") :]
+ self._mode = "think"
+
+ self._buffer = ""
+ return "".join(out_parts)
+
+
+class OpenAIProxy:
+ def __init__(self, grpc_address: str, include_debug: bool = False) -> None:
+ self.grpc_address = grpc_address
+ self.include_debug = include_debug
+ self.channel = grpc.insecure_channel(grpc_address)
+ self.stub = InferenceServiceStub(self.channel)
+ self.prompt_builder = AgentPromptBuilder()
+
+ def build_request(self, payload: Dict[str, Any]) -> Tuple[IRequest, Model, bool, List[Tool], bool]:
+ model_name = payload.get("model")
+ model, is_reasoner = _resolve_model(self.stub, model_name)
+
+ messages = list(payload.get("messages") or [])
+ tools_spec = list(payload.get("tools") or [])
+ tool_choice = payload.get("tool_choice")
+ tool_choice_name = None
+ if isinstance(tool_choice, dict):
+ fn = tool_choice.get("function") or {}
+ tool_choice_name = fn.get("name")
+ allow_tools = tool_choice != "none"
+
+ tools = _build_tool_list(tools_spec) if allow_tools else []
+ if tool_choice_name:
+ tools = [t for t in tools if t.name == tool_choice_name]
+
+ if tools:
+ _apply_tool_prompt(messages, _tool_system_prompt(tools))
+
+ convo = _build_conversation(messages)
+ convo.model_uuid = model.uuid
+
+ req = IRequest()
+ req.id = _gen_id("openai-proxy")
+ req.model_uuid = model.uuid
+ req.convo.CopyFrom(convo)
+
+ if payload.get("max_tokens", 0) > 0:
+ req.cfg.max_tokens = int(payload["max_tokens"])
+ else:
+ req.cfg.max_tokens = 16384
+ if payload.get("temperature") is not None:
+ req.cfg.temp = float(payload["temperature"])
+ if payload.get("top_p") is not None:
+ req.cfg.top_p = float(payload["top_p"])
+
+ response_format = payload.get("response_format") or {"type": "text"}
+ rf_type = response_format.get("type") if isinstance(response_format, dict) else "text"
+
+ if tools:
+ if is_reasoner:
+ _build_tool_call_response_format(req, tools)
+ else:
+ _build_tool_call_response_format_non_reasoning(req, tools)
+ elif rf_type in {"json_schema", "json_object"}:
+ if rf_type == "json_schema":
+ schema = response_format.get("json_schema")
+ else:
+ schema = {"type": "object"}
+ if is_reasoner:
+ structural_tag = {
+ "type": "structural_tag",
+ "format": {
+ "type": "sequence",
+ "elements": [
+ {
+ "type": "tag",
+ "begin": "",
+ "content": {"type": "any_text"},
+ "end": THINK_TAGS[1],
+ },
+ {
+ "type": "tag",
+ "begin": "",
+ "content": {"type": "json_schema", "json_schema": schema},
+ "end": "",
+ },
+ ],
+ },
+ }
+ req.cfg.response_format.format = ResponseFormat.STRUCTURAL_TAG
+ req.cfg.response_format.schema = json.dumps(structural_tag, indent=0)
+ else:
+ req.cfg.response_format.format = ResponseFormat.JSON
+ req.cfg.response_format.schema = json.dumps(schema)
+
+ stream = bool(payload.get("stream"))
+ return req, model, is_reasoner, tools, stream
+
+ def run_sync(self, req: IRequest) -> Any:
+ return self.stub.GenerateSync(req)
+
+ def run_stream(self, req: IRequest):
+ return self.stub.Generate(req)
+
+
+class OpenAIProxyHandler(BaseHTTPRequestHandler):
+ server_version = "TruffleOpenAIProxy/0.1"
+ def _set_cors_headers(self) -> None:
+ self.send_header("Access-Control-Allow-Origin", "*")
+ self.send_header("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
+ self.send_header("Access-Control-Allow-Headers", "Content-Type, Authorization")
+
+ def _send_json(self, status: int, payload: Dict[str, Any]) -> None:
+ data = json.dumps(payload).encode("utf-8")
+ self.send_response(status)
+ self._set_cors_headers()
+ self.send_header("Content-Type", "application/json")
+ self.send_header("Content-Length", str(len(data)))
+ self.end_headers()
+ self.wfile.write(data)
+ def do_OPTIONS(self) -> None:
+ self.send_response(204)
+ self._set_cors_headers()
+ self.end_headers()
+
+ def _read_body(self) -> Dict[str, Any]:
+ length = int(self.headers.get("Content-Length", "0"))
+ if length <= 0:
+ return {}
+ raw = self.rfile.read(length)
+ return json.loads(raw.decode("utf-8"))
+
+ def _send_sse(self, payload: Dict[str, Any]) -> bool:
+ data = json.dumps(payload)
+ try:
+ self.wfile.write(f"data: {data}\n\n".encode("utf-8"))
+ self.wfile.flush()
+ except (BrokenPipeError, ConnectionResetError, OSError):
+ # Client disconnected; stop streaming gracefully.
+ self.close_connection = True
+ return False
+ return True
+
+ def do_GET(self) -> None:
+ if self.path == "/health":
+ self._send_json(200, {"status": "ok"})
+ return
+ if self.path in {"/v1/models", "/models"}:
+ proxy: OpenAIProxy = self.server.proxy # type: ignore[attr-defined]
+ models = _get_models(proxy.stub)
+ data = [
+ {"id": m.uuid, "object": "model", "owned_by": m.provider or "truffle", "name": m.name}
+ for m in models
+ ]
+ self._send_json(200, {"object": "list", "data": data})
+ return
+ if self.path.startswith("/v1/models/"):
+ proxy: OpenAIProxy = self.server.proxy # type: ignore[attr-defined]
+ model_id = self.path.split("/v1/models/", 1)[1]
+ models = _get_models(proxy.stub)
+ model = next((m for m in models if m.uuid == model_id or m.name == model_id), None)
+ if model is None:
+ self._send_json(404, {"error": {"message": "model not found", "type": "not_found_error"}})
+ return
+ self._send_json(
+ 200,
+ {
+ "id": model.uuid,
+ "object": "model",
+ "owned_by": model.provider or "truffle",
+ "name": model.name,
+ },
+ )
+ return
+ self.send_error(404, "Not Found")
+
+ def do_POST(self) -> None:
+ if self.path != "/v1/chat/completions":
+ self.send_error(404, "Not Found")
+ return
+ try:
+ payload = self._read_body()
+ except Exception as e:
+ self._send_json(400, {"error": {"message": str(e), "type": "invalid_request_error"}})
+ return
+
+ proxy: OpenAIProxy = self.server.proxy # type: ignore[attr-defined]
+
+ try:
+ req, model, is_reasoner, _tools, stream = proxy.build_request(payload)
+ except Exception as e:
+ self._send_json(400, {"error": {"message": str(e), "type": "invalid_request_error"}})
+ return
+
+ if stream:
+ self.send_response(200)
+ self._set_cors_headers()
+ self.send_header("Content-Type", "text/event-stream; charset=utf-8")
+ self.send_header("Cache-Control", "no-cache, no-transform")
+ self.send_header("Connection", "keep-alive")
+ self.send_header("X-Accel-Buffering", "no")
+ self.end_headers()
+
+ stream_id = _gen_id("chatcmpl")
+ created = _now_ts()
+ if not self._send_sse(
+ {
+ "id": stream_id,
+ "object": "chat.completion.chunk",
+ "created": created,
+ "model": model.name,
+ "choices": [
+ {"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}
+ ],
+ }
+ ):
+ return
+
+ raw_content = ""
+ last_finish = None
+ filter_state = _StreamFilter(hide_cot=is_reasoner)
+
+ for ir in proxy.run_stream(req):
+ raw_content += ir.content
+ if ir.HasField("finish_reason") and ir.finish_reason != FinishReason.FINISH_UNSPECIFIED:
+ last_finish = ir.finish_reason
+ visible = filter_state.feed(ir.content)
+ if visible:
+ if not self._send_sse(
+ {
+ "id": stream_id,
+ "object": "chat.completion.chunk",
+ "created": created,
+ "model": model.name,
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"content": visible},
+ "finish_reason": None,
+ }
+ ],
+ }
+ ):
+ return
+ tail = filter_state.finalize()
+ if tail:
+ if not self._send_sse(
+ {
+ "id": stream_id,
+ "object": "chat.completion.chunk",
+ "created": created,
+ "model": model.name,
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"content": tail},
+ "finish_reason": None,
+ }
+ ],
+ }
+ ):
+ return
+ _cot, after_cot = _safe_parse_cot(raw_content)
+ tool_calls, _clean = proxy.prompt_builder.extract_tool_calls(after_cot)
+ if tool_calls:
+ tc_list = []
+ for i, tc in enumerate(tool_calls):
+ name = tc.get("tool") or ""
+ args = json.dumps(tc.get("args") or {}, separators=(",", ":"))
+ tc_list.append(
+ {
+ "id": f"call_{i+1}",
+ "type": "function",
+ "function": {"name": name, "arguments": args},
+ }
+ )
+ if not self._send_sse(
+ {
+ "id": stream_id,
+ "object": "chat.completion.chunk",
+ "created": created,
+ "model": model.name,
+ "choices": [
+ {
+ "index": 0,
+ "delta": {"tool_calls": tc_list},
+ "finish_reason": None,
+ }
+ ],
+ }
+ ):
+ return
+ finish_reason = _map_finish_reason(last_finish) or "stop"
+ if not self._send_sse(
+ {
+ "id": stream_id,
+ "object": "chat.completion.chunk",
+ "created": created,
+ "model": model.name,
+ "choices": [
+ {"index": 0, "delta": {}, "finish_reason": finish_reason}
+ ],
+ }
+ ):
+ return
+ try:
+ self.wfile.write(b"data: [DONE]\n\n")
+ self.wfile.flush()
+ except (BrokenPipeError, ConnectionResetError, OSError):
+ self.close_connection = True
+ else:
+ self.close_connection = True
+ return
+
+ resp = proxy.run_sync(req)
+ raw = resp.content
+ cot, after_cot = _safe_parse_cot(raw)
+ tool_calls, clean = proxy.prompt_builder.extract_tool_calls(after_cot)
+ message = clean_response(clean)
+
+ finish_reason = _map_finish_reason(resp.finish_reason if resp.HasField("finish_reason") else None)
+ openai_tool_calls = []
+ for i, tc in enumerate(tool_calls):
+ name = tc.get("tool") or ""
+ args = json.dumps(tc.get("args") or {}, separators=(",", ":"))
+ openai_tool_calls.append(
+ {
+ "id": f"call_{i+1}",
+ "type": "function",
+ "function": {"name": name, "arguments": args},
+ }
+ )
+
+ msg: Dict[str, Any] = {"role": "assistant", "content": message}
+ if openai_tool_calls:
+ msg["tool_calls"] = openai_tool_calls
+ if not message:
+ msg["content"] = None
+
+ response = {
+ "id": _gen_id("chatcmpl"),
+ "object": "chat.completion",
+ "created": _now_ts(),
+ "model": model.name,
+ "choices": [
+ {"index": 0, "message": msg, "finish_reason": finish_reason}
+ ],
+ "usage": _usage_to_openai(resp.usage if resp.HasField("usage") else None),
+ }
+
+ debug_req = bool(payload.get("debug") or payload.get("debug_reasoning"))
+ if proxy.include_debug or debug_req:
+ response["debug"] = {"reasoning": cot}
+
+ self._send_json(200, response)
+
+def normalize_grpc_address(address: str, default_port : int = 80) -> str:
+ import socket
+ if '.local' in address:
+ try:
+ ip = socket.gethostbyname(address)
+ address = ip
+ except socket.gaierror as e:
+ raise RuntimeError(f"Failed to resolve mDNS address {address}: {e}")
+ if ':' not in address:
+ address += f":{default_port}"
+ return address
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="OpenAI-compatible proxy for Truffle gRPC inference")
+ parser.add_argument("--truffle", default="truffle-1234", help="truffle id: e.g. truffle-1234")
+ parser.add_argument("--host", default="127.0.0.1", help="HTTP host")
+ parser.add_argument("--port", type=int, default=8080, help="HTTP port")
+ parser.add_argument("--debug", action="store_true", help="Include debug.reasoning in responses")
+ args = parser.parse_args()
+ print(f"Connecting to {args.truffle}")
+ grpc_address = normalize_grpc_address(f"{args.truffle}.local", default_port=80)
+ print(f"Found {args.truffle} at {grpc_address}")
+ proxy = OpenAIProxy(grpc_address, include_debug=args.debug)
+
+ class _Server(ThreadingHTTPServer):
+ def __init__(self, server_address, handler_cls):
+ super().__init__(server_address, handler_cls)
+ self.proxy = proxy
+
+ server = _Server((args.host, args.port), OpenAIProxyHandler)
+ print(f"OpenAI proxy listening on http://{args.host}:{args.port} -> Truffle @ {grpc_address}")
+ server.serve_forever()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/truffile/infer/tooling.py b/truffile/infer/tooling.py
new file mode 100644
index 0000000..8d13889
--- /dev/null
+++ b/truffile/infer/tooling.py
@@ -0,0 +1,19 @@
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Dict
+import json
+
+
+@dataclass
+class Tool:
+ name: str
+ description: str
+ input_schema: Dict
+ display_name: str
+
+ def schema_str(self, indent: int = 2) -> str:
+ return json.dumps(self.input_schema, indent=indent)
+
+ def get_for_system_prompt(self) -> str:
+ return f"{self.name}: {self.description}\nArg Schema: {self.schema_str(indent=2)}"