Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [

[project.scripts]
truffile = "truffile.cli:main"
truffleinferproxy = "truffile.infer.proxy:main"

[project.optional-dependencies]
dev = [
Expand Down
113 changes: 113 additions & 0 deletions scripts/test_oai_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python3
"""Smoke test for the local OpenAI-compatible proxy."""

from __future__ import annotations

import argparse
import os
from typing import Any, Dict, List
try:
from openai import OpenAI
except ImportError:
raise ImportError("Please install the 'openai' package to run this test script.")

def _print_header(title: str) -> None:
print("\n" + "=" * 8 + f" {title} " + "=" * 8)


def test_basic(client: OpenAI, model: str) -> None:
_print_header("basic")
resp = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "Say hello in one sentence."}],
max_tokens=2048,
temperature=0.7,
top_p=0.9,
)
msg = resp.choices[0].message
print("content:", msg.content)


def test_json_schema(client: OpenAI, model: str) -> None:
_print_header("json_schema")
schema: Dict[str, Any] = {
"type": "object",
"properties": {
"answer": {"type": "string"},
"confidence": {"type": "number"},
},
"required": ["answer", "confidence"],
}
resp = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "What is 2+2? Respond as JSON."}],
response_format={"type": "json_schema", "json_schema": schema},
max_tokens=2048,
)
msg = resp.choices[0].message
print("content:", msg.content)


def test_tools(client: OpenAI, model: str) -> None:
_print_header("tools")
tools: List[Dict[str, Any]] = [
{
"type": "function",
"function": {
"name": "get_time",
"description": "Return the current time in ISO-8601",
"parameters": {
"type": "object",
"properties": {"tz": {"type": "string"}},
"required": [],
},
},
}
]
resp = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "What time is it? Use the tool."}],
tools=tools,
tool_choice="auto",
max_tokens=2048,
)
msg = resp.choices[0].message
print("tool_calls:", msg.tool_calls)
print("content:", msg.content)


def test_stream(client: OpenAI, model: str) -> None:
_print_header("stream")
stream = client.chat.completions.create(
model=model,
messages=[{"role": "user", "content": "Stream a short haiku."}],
max_tokens=2048,
stream=True,
)
parts: List[str] = []
for chunk in stream:
delta = chunk.choices[0].delta
if delta and delta.content:
parts.append(delta.content)
print("content:", "".join(parts))


def main() -> None:
parser = argparse.ArgumentParser(description="Smoke test for OpenAI proxy")
parser.add_argument("--base-url", default="http://127.0.0.1:8080/v1", help="Proxy base URL")
parser.add_argument("--model", default="auto", help="Model name or UUID")
parser.add_argument("--no-stream", action="store_true", help="Skip streaming test")
args = parser.parse_args()

api_key = os.getenv("OPENAI_API_KEY", "test")
client = OpenAI(base_url=args.base_url, api_key=api_key)

test_basic(client, args.model)
test_json_schema(client, args.model)
test_tools(client, args.model)
if not args.no_stream:
test_stream(client, args.model)


if __name__ == "__main__":
main()
18 changes: 18 additions & 0 deletions truffile/infer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Accessing Inference APIs on the Truffle

The Truffle currently uses its own non-standard set of APIs for inference.

Provided here is a proxy that both demonstrates the usage of these APIs and allows for easier compatibility with existing clients.

This is experimental and may not be fully API compatible, but should serve as a good starting point for exploring the Truffle while core software improves.

### Usage

```bash
truffleinferproxy --truffle truffle-5970 --host 127.0.0.1 --port 8080

truffleinferproxy --help
```



1 change: 1 addition & 0 deletions truffile/infer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Standalone OpenAI-compatible proxy for Truffle gRPC inference."""
7 changes: 7 additions & 0 deletions truffile/infer/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from __future__ import annotations

THINK_TAGS = ["<think>", "</think>"]


def clean_response(response: str) -> str:
return response.strip().replace("�", "")
91 changes: 91 additions & 0 deletions truffile/infer/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from __future__ import annotations

from typing import List, Tuple
import json
import re

from truffle.infer.gencfg_pb2 import ResponseFormat

from .common import THINK_TAGS
from .tooling import Tool


TOOL_TAGS = ["<toolcall>", "</toolcall>"]
tool_tag_pattern = re.compile(f"{TOOL_TAGS[0]}(.*?){TOOL_TAGS[1]}", re.DOTALL)


class AgentPromptBuilder:
def extract_tool_calls(self, response: str) -> Tuple[List[dict], str]:
tool_calls: List[dict] = []
matches = tool_tag_pattern.findall(response)
if not matches:
return tool_calls, response
for match in matches:
try:
tool_call = json.loads(match.strip())
tool_calls.append(tool_call)
except json.JSONDecodeError:
continue
clean_response = tool_tag_pattern.sub("", response).strip()
return tool_calls, clean_response


def _build_tool_call_response_format_non_reasoning(
req, available_tools: List[Tool], allow_parallel: bool = False
) -> None:
def get_tag_for_tool(tool: Tool) -> dict:
begin = f"{TOOL_TAGS[0]}\n" + '{"tool": ' + f'"{tool.name}", "args": '
end = "}" + f"{TOOL_TAGS[1]}\n"
return {
"begin": begin,
"content": {"type": "json_schema", "json_schema": tool.input_schema},
"end": end,
}

structural_tag = {
"type": "structural_tag",
"format": {
"type": "triggered_tags",
"triggers": [TOOL_TAGS[0]],
"tags": [get_tag_for_tool(tool) for tool in available_tools],
"stop_after_first": not allow_parallel,
},
}
req.cfg.response_format.format = ResponseFormat.STRUCTURAL_TAG
req.cfg.response_format.schema = json.dumps(structural_tag, indent=0)


def _build_tool_call_response_format(
req, available_tools: List[Tool], allow_parallel: bool = False
) -> None:
def get_tag_for_tool(tool: Tool) -> dict:
begin = f"{TOOL_TAGS[0]}\n" + '{"tool": ' + f'"{tool.name}", "args": '
end = "}" + f"{TOOL_TAGS[1]}\n"
return {
"begin": begin,
"content": {"type": "json_schema", "json_schema": tool.input_schema},
"end": end,
}

structural_tag = {
"type": "structural_tag",
"format": {
"type": "sequence",
"elements": [
{
"type": "tag",
"begin": "",
"content": {"type": "any_text"},
"end": THINK_TAGS[1],
},
{
"type": "triggered_tags",
"triggers": [TOOL_TAGS[0]],
"tags": [get_tag_for_tool(tool) for tool in available_tools],
"stop_after_first": not allow_parallel,
},
],
},
}
req.cfg.response_format.format = ResponseFormat.STRUCTURAL_TAG
req.cfg.response_format.schema = json.dumps(structural_tag, indent=0)
Loading