diff --git a/bluebox/agents/abstract_agent.py b/bluebox/agents/abstract_agent.py index 8fc0d731..b08814fe 100644 --- a/bluebox/agents/abstract_agent.py +++ b/bluebox/agents/abstract_agent.py @@ -61,24 +61,47 @@ get_workaround_for_error, ) from bluebox.utils.data_utils import format_bytes +from bluebox.utils.llm_serialization import serialize_tool_result, strip_llm_excluded from bluebox.utils.llm_utils import token_optimized as token_optimized_decorator from bluebox.utils.logger import get_logger logger = get_logger(name=__name__) +# Keep persisted tool previews small so iterative runs don't bloat context +PERSISTED_TOOL_PREVIEW_MAX_CHARS = 800 + + class ToolResultPersistMode(StrEnum): + """ + Policy controlling when a tool result is persisted to the workspace. + + Persistence saves the full result as a raw artifact and returns a + compact preview to the LLM, keeping context usage in check for + large payloads. + + Attributes: + NEVER: Never persist; the full result is returned inline. + ALWAYS: Always persist, regardless of size. + OVERFLOW: Persist only when the serialized result exceeds the + tool's ``max_characters`` threshold. + """ NEVER = "never" ALWAYS = "always" OVERFLOW = "overflow" -# Keep persisted tool previews small so iterative runs don't blow context. -PERSISTED_TOOL_PREVIEW_MAX_CHARS = 800 - - class AgentExecutionMode(StrEnum): - """Execution mode for agent loops.""" + """ + Execution mode for agent loops. + + Attributes: + CONVERSATIONAL: Interactive mode where the agent responds to user + messages one at a time via :meth:`process_new_message`. + AUTONOMOUS: Self-directed mode where the agent runs a tool-driven + loop until it calls a finalize tool or hits the iteration cap. + See :meth:`AbstractAgent.run_autonomous`. + """ CONVERSATIONAL = "conversational" AUTONOMOUS = "autonomous" @@ -108,25 +131,50 @@ class variable. Orchestrator agents use these cards to discover subagent @dataclass(frozen=True) class _ToolMeta: - """Metadata attached to a handler method by @agent_tool.""" - name: str # tool name registered with the LLM client - description: str # tool description shown to the LLM - parameters: dict[str, Any] # JSON Schema for tool parameters - availability: bool | Callable[..., bool] # whether the tool should be registered right now + """ + Metadata attached to a handler method by :func:`agent_tool`. + + Instances are stored on the decorated method as ``method._tool_meta`` + and collected at class-definition time by + :meth:`AbstractAgent._collect_tools`. + + Attributes: + name: Tool name registered with the LLM client (derived from the + method name by stripping leading underscores). + description: Human-readable description shown to the LLM. + parameters: JSON Schema ``object`` describing accepted parameters. + availability: Static boolean or a callable ``(self) -> bool`` + evaluated before each LLM call to gate tool registration. + persist: Result-persistence policy. See :class:`ToolResultPersistMode`. + max_characters: Character threshold used by + :attr:`ToolResultPersistMode.OVERFLOW` to decide when to + persist a result to the workspace. + token_optimized: If ``True``, the tool result is encoded with + the ``token_optimized`` decorator for reduced token usage. + """ + name: str + description: str + parameters: dict[str, Any] + availability: bool | Callable[..., bool] persist: ToolResultPersistMode = ToolResultPersistMode.NEVER max_characters: int = 10_000 token_optimized: bool = False -def _serialize_tool_result(tool_result: Any) -> tuple[str, str]: - try: - return json.dumps(tool_result, ensure_ascii=False, default=str, indent=2), "json" - except (TypeError, ValueError): - return str(tool_result), "text" +def _normalize_file_scope(scope: str) -> str: + """ + Normalize and validate a file-tool scope string. Strips whitespace, lowercases, + and ensures the value is one of the accepted scope literals. + Args: + scope: Raw scope value from a tool call (e.g. ``"Workspace"``). -def _normalize_file_scope(scope: str) -> str: - """Normalize and validate file tool scope.""" + Returns: + The normalized scope (``"workspace"`` or ``"docs"``). + + Raises: + ValueError: If *scope* is not a recognized value. + """ normalized_scope = scope.strip().lower() if normalized_scope not in {"workspace", "docs"}: raise ValueError("scope must be 'workspace' or 'docs'") @@ -134,10 +182,24 @@ def _normalize_file_scope(scope: str) -> str: def _parse_search_terms(query: str) -> list[str]: - """Split query text into distinct terms for terms-mode search.""" + """ + Split a query string into unique, order-preserving search terms. + + Tokens are split on commas and whitespace. Empty tokens and + duplicates are discarded while preserving first-occurrence order. + + Args: + query: Free-text search query (e.g. ``"foo, bar baz"``). + + Returns: + Deduplicated list of non-empty terms in original order. + """ seen: set[str] = set() terms: list[str] = [] - for token in re.split(r"[,\s]+", query): + for token in re.split( + pattern=r"[,\s]+", + string=query + ): term = token.strip() if term and term not in seen: seen.add(term) @@ -551,7 +613,7 @@ def _maybe_persist_tool_result( if persist_mode == ToolResultPersistMode.NEVER: return tool_result - serialized, content_type = _serialize_tool_result(tool_result) + serialized, content_type = serialize_tool_result(tool_result) char_count = len(serialized) if persist_mode == ToolResultPersistMode.OVERFLOW and char_count <= tool_meta.max_characters: @@ -1080,6 +1142,7 @@ def _execute_tool(self, tool_name: str, tool_arguments: dict[str, Any]) -> dict[ logger.debug("Executing tool %s with arguments: %s", tool_name, tool_arguments) # handler is unbound (from cls, not self) so pass self explicitly raw_result = handler(self, **validated_arguments) + raw_result = strip_llm_excluded(raw_result) # strip LLMExclude-annotated fields from any Pydantic models result_for_llm = self._maybe_persist_tool_result( tool_name=tool_name, tool_meta=tool_meta, diff --git a/bluebox/agents/specialists/interaction_specialist.py b/bluebox/agents/specialists/interaction_specialist.py index bf798581..92d46439 100644 --- a/bluebox/agents/specialists/interaction_specialist.py +++ b/bluebox/agents/specialists/interaction_specialist.py @@ -56,6 +56,7 @@ class InteractionSpecialist(AbstractAgent): "structural context (forms, inputs, buttons, links)." ), ) + SYSTEM_PROMPT: str = dedent("""\ You are a UI interaction analyst specializing in understanding what users did on web pages from recorded browser interaction events. diff --git a/bluebox/agents/specialists/network_specialist.py b/bluebox/agents/specialists/network_specialist.py index 7c5d55dd..fb7ec0d1 100644 --- a/bluebox/agents/specialists/network_specialist.py +++ b/bluebox/agents/specialists/network_specialist.py @@ -53,6 +53,7 @@ class NetworkSpecialist(AbstractAgent): "inspecting request/response data, and semantic search across captured traffic." ), ) + SYSTEM_PROMPT: str = dedent(f""" You are a network traffic analyst specializing in captured browser network data. diff --git a/bluebox/utils/llm_serialization.py b/bluebox/utils/llm_serialization.py new file mode 100644 index 00000000..b550ce50 --- /dev/null +++ b/bluebox/utils/llm_serialization.py @@ -0,0 +1,167 @@ +""" +bluebox/utils/llm_serialization.py + +Utilities for controlling what data gets sent to LLMs from tool results. + +The LLMExclude marker lets you annotate Pydantic model fields that should be +stripped before a tool result is serialized for the LLM — e.g. large blobs, +internal IDs, or raw data the model doesn't need. + +Usage on models:: + + from typing import Annotated + from pydantic import BaseModel + from bluebox.utils.llm_serialization import LLMExclude + + class NetworkTransaction(BaseModel): + url: str + method: str + response_body: Annotated[str, LLMExclude()] # stripped before LLM sees it + +Tool handlers can return these models (or dicts containing them) directly — +the agent infrastructure calls strip_llm_excluded() automatically. +""" + +from __future__ import annotations + +import functools +import json +from enum import StrEnum +from typing import Any, NamedTuple + +from pydantic import BaseModel + + +class SerializedContentType(StrEnum): + """ + Content type of a serialized tool result. + + Attributes: + JSON: Successfully serialized as JSON. + TEXT: Fell back to ``str()`` representation. + """ + JSON = "json" + TEXT = "text" + + +class SerializedToolResult(NamedTuple): + """ + Result of serializing a tool return value for the LLM. + + Attributes: + serialized: The serialized string (JSON or plain text). + content_type: How the value was serialized. + """ + serialized: str + content_type: SerializedContentType + + +class LLMExclude: + """ + Marker: exclude this field from LLM tool results. + + Attach via ``Annotated``:: + + name: str # included + raw_blob: Annotated[bytes, LLMExclude()] # excluded + """ + pass + + +def serialize_tool_result(tool_result: Any) -> SerializedToolResult: + """ + Serialize a tool result to a JSON or plain-text string for the LLM. + + Attempts JSON serialization first (using ``default=str`` for non-serializable + types). Falls back to ``str()`` if JSON encoding fails. + + Args: + tool_result: The value returned by a tool handler (typically a dict). + + Returns: + A :class:`SerializedToolResult` (also unpacks as a two-tuple). + """ + try: + return SerializedToolResult( + serialized=json.dumps( + tool_result, + ensure_ascii=False, + default=str, + indent=2 + ), + content_type=SerializedContentType.JSON, + ) + except (TypeError, ValueError): + return SerializedToolResult( + serialized=str(tool_result), + content_type=SerializedContentType.TEXT + ) + + +@functools.lru_cache(maxsize=256) +def _excluded_fields(model_cls: type[BaseModel]) -> frozenset[str]: + """ + Return the set of field names annotated with LLMExclude for a model class. + + Scans ``model_cls.model_fields`` and checks each field's ``metadata`` list + for an ``LLMExclude`` instance (attached via ``Annotated[Type, LLMExclude()]``). + + Results are cached per class via ``lru_cache``. Safe because Pydantic field + definitions are fixed at class creation time. + + Args: + model_cls: A Pydantic BaseModel subclass to inspect. + + Returns: + Frozen set of field names that should be excluded from LLM serialization. + Empty frozenset if the model has no LLMExclude annotations. + """ + return frozenset( + name + for name, info in model_cls.model_fields.items() + if any(isinstance(m, LLMExclude) for m in info.metadata) + ) + + +def strip_llm_excluded(obj: Any) -> Any: + """ + Recursively strip LLMExclude-annotated fields from Pydantic models. + + Walks the object tree and converts any ``BaseModel`` instance into a dict + with LLMExclude-annotated fields removed. Non-BaseModel values pass through + unchanged (just an ``isinstance`` check). + + Supported containers (recursed into): + - ``BaseModel``: fields filtered, remaining values recursed + - ``dict``: values recursed, keys preserved + - ``list`` / ``tuple``: elements recursed, container type preserved + + Args: + obj: Any object — typically a tool handler's return value. Can be a + BaseModel, dict, list, tuple, or primitive. + + Returns: + A plain-dict / list / tuple / primitive copy with all LLMExclude fields + removed from any BaseModel instances found at any nesting depth. + """ + if isinstance(obj, BaseModel): + cls = type(obj) + excluded = _excluded_fields(cls) + result = { + name: strip_llm_excluded(value) + for name in cls.model_fields + if name not in excluded + for value in (getattr(obj, name),) # bind to local for clarity + } + # include @computed_field properties (not in model_fields) + for name in cls.model_computed_fields: + if name not in excluded: + result[name] = strip_llm_excluded(getattr(obj, name)) + return result + if isinstance(obj, dict): + return {k: strip_llm_excluded(v) for k, v in obj.items()} + if isinstance(obj, list): + return [strip_llm_excluded(item) for item in obj] + if isinstance(obj, tuple): + return tuple(strip_llm_excluded(item) for item in obj) + return obj diff --git a/tests/unit/agents/test_abstract_agent.py b/tests/unit/agents/test_abstract_agent.py index a08ddc4b..51f7908f 100644 --- a/tests/unit/agents/test_abstract_agent.py +++ b/tests/unit/agents/test_abstract_agent.py @@ -18,11 +18,11 @@ import json import tempfile from pathlib import Path -from typing import Any +from typing import Annotated, Any from unittest.mock import MagicMock, patch import pytest -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, computed_field from bluebox.agents.abstract_agent import ( AbstractAgent, @@ -31,6 +31,7 @@ agent_tool, _ToolMeta, ) +from bluebox.utils.llm_serialization import LLMExclude from bluebox.workspace import LocalAgentWorkspace from bluebox.data_models.llms.interaction import ( Chat, @@ -58,6 +59,33 @@ class SearchParams(BaseModel): tags: list[str] = Field(default_factory=list, description="Tags to filter by.") +class TransactionModel(BaseModel): + """Model with LLMExclude fields for integration testing.""" + url: str + method: str + status_code: int + response_body: Annotated[str, LLMExclude()] + raw_headers: Annotated[dict, LLMExclude()] + + +class ComputedModel(BaseModel): + """Model with a computed field to verify it's preserved.""" + first: str + last: str + internal: Annotated[str, LLMExclude()] + + @computed_field + @property + def full_name(self) -> str: + return f"{self.first} {self.last}" + + +class CleanModel(BaseModel): + """Model with no LLMExclude annotations.""" + x: str + y: int + + class ConcreteAgent(AbstractAgent): """Minimal concrete AbstractAgent for testing.""" @@ -186,6 +214,64 @@ def _token_optimized_no_persist(self) -> dict[str, Any]: """Token-optimized tool with no persistence.""" return {"status": "ok"} + @agent_tool + def _get_transaction(self) -> dict[str, Any]: + """ + Return a transaction model with LLMExclude fields. + """ + tx = TransactionModel( + url="https://api.example.com/data", + method="GET", + status_code=200, + response_body='{"huge": "blob"}', + raw_headers={"x-internal": "secret"}, + ) + return {"transaction": tx, "count": 1} + + @agent_tool + def _get_transaction_direct(self) -> TransactionModel: + """ + Return a TransactionModel directly (not wrapped in a dict). + """ + return TransactionModel( + url="https://example.com", + method="POST", + status_code=201, + response_body="big", + raw_headers={"h": "v"}, + ) + + @agent_tool + def _get_computed(self) -> ComputedModel: + """ + Return a model with a computed field. + """ + return ComputedModel(first="Jane", last="Doe", internal="hidden") + + @agent_tool + def _get_transaction_list(self) -> dict[str, Any]: + """ + Return a list of models with LLMExclude fields. + """ + txs = [ + TransactionModel( + url=f"https://api.example.com/{i}", + method="GET", + status_code=200, + response_body=f"body_{i}", + raw_headers={"h": str(i)}, + ) + for i in range(3) + ] + return {"transactions": txs} + + @agent_tool + def _get_clean_model(self) -> CleanModel: + """ + Return a model with no LLMExclude annotations. + """ + return CleanModel(x="visible", y=42) + @pytest.fixture def mock_emit() -> MagicMock: @@ -516,6 +602,9 @@ def test_finds_all_decorated_methods(self) -> None: "no_params", "optional_params", "raises_error", "search", "persist_always", "persist_overflow", "persist_never", "persist_always_token_optimized", "token_optimized_no_persist", + # LLMExclude integration test tools + "get_transaction", "get_transaction_direct", "get_computed", + "get_transaction_list", "get_clean_model", "add_note", "finalize_with_output", "finalize_with_failure", "finalize_result", "finalize_failure", "execute_python", # Unified file tools from AbstractAgent @@ -877,6 +966,70 @@ def test_token_optimized_without_persist_has_no_note(self, agent: ConcreteAgent) assert "_token_optimized_note" not in result +# ============================================================================= +# _execute_tool — LLMExclude stripping +# ============================================================================= + + +class TestExecuteToolLLMExclude: + """Tests that _execute_tool strips LLMExclude-annotated fields from tool results.""" + + def test_model_in_dict_strips_excluded_fields(self, agent: ConcreteAgent) -> None: + """Model returned inside a dict has LLMExclude fields removed.""" + result = agent._execute_tool("get_transaction", {}) + tx = result["transaction"] + assert tx["url"] == "https://api.example.com/data" + assert tx["method"] == "GET" + assert tx["status_code"] == 200 + assert "response_body" not in tx + assert "raw_headers" not in tx + # non-model values in the dict are preserved + assert result["count"] == 1 + + def test_model_returned_directly_strips_excluded(self, agent: ConcreteAgent) -> None: + """Model returned as the top-level result is converted to dict with exclusions.""" + result = agent._execute_tool("get_transaction_direct", {}) + assert isinstance(result, dict) + assert result["url"] == "https://example.com" + assert result["method"] == "POST" + assert result["status_code"] == 201 + assert "response_body" not in result + assert "raw_headers" not in result + + def test_computed_field_preserved(self, agent: ConcreteAgent) -> None: + """@computed_field values are included in stripped output.""" + result = agent._execute_tool("get_computed", {}) + assert result["first"] == "Jane" + assert result["last"] == "Doe" + assert result["full_name"] == "Jane Doe" + assert "internal" not in result + + def test_list_of_models_strips_all(self, agent: ConcreteAgent) -> None: + """List of models inside a dict — each model has exclusions stripped.""" + result = agent._execute_tool("get_transaction_list", {}) + txs = result["transactions"] + assert len(txs) == 3 + for i, tx in enumerate(txs): + assert tx["url"] == f"https://api.example.com/{i}" + assert "response_body" not in tx + assert "raw_headers" not in tx + + def test_clean_model_unchanged(self, agent: ConcreteAgent) -> None: + """Model with no LLMExclude annotations returns all fields.""" + result = agent._execute_tool("get_clean_model", {}) + assert result == {"x": "visible", "y": 42} + + def test_auto_execute_tool_strips_excluded( + self, agent: ConcreteAgent, mock_emit: MagicMock, + ) -> None: + """LLMExclude stripping also works through _auto_execute_tool path.""" + result_json = agent._auto_execute_tool("get_transaction_direct", {}) + result = json.loads(result_json) + assert result["url"] == "https://example.com" + assert "response_body" not in result + assert "raw_headers" not in result + + # ============================================================================= # _auto_execute_tool # ============================================================================= diff --git a/tests/unit/agents/test_abstract_agent_helpers.py b/tests/unit/agents/test_abstract_agent_helpers.py index e9a994f3..1bfc52a3 100644 --- a/tests/unit/agents/test_abstract_agent_helpers.py +++ b/tests/unit/agents/test_abstract_agent_helpers.py @@ -2,109 +2,18 @@ tests/unit/agents/test_abstract_agent_helpers.py Unit tests for module-level helper functions in abstract_agent.py: - - _serialize_tool_result - _normalize_file_scope - _parse_search_terms """ -import json -from datetime import datetime -from typing import Any - import pytest from bluebox.agents.abstract_agent import ( _normalize_file_scope, _parse_search_terms, - _serialize_tool_result, ) -# ============================================================================= -# _serialize_tool_result -# ============================================================================= - - -class TestSerializeToolResult: - """Tests for _serialize_tool_result.""" - - def test_dict_returns_json(self) -> None: - result = {"key": "value", "count": 42} - serialized, content_type = _serialize_tool_result(result) - assert content_type == "json" - assert json.loads(serialized) == result - - def test_list_returns_json(self) -> None: - result = [1, 2, 3] - serialized, content_type = _serialize_tool_result(result) - assert content_type == "json" - assert json.loads(serialized) == result - - def test_string_returns_json(self) -> None: - result = "hello" - serialized, content_type = _serialize_tool_result(result) - assert content_type == "json" - assert json.loads(serialized) == result - - def test_nested_dict_returns_json(self) -> None: - result = {"a": {"b": [1, 2]}, "c": None} - serialized, content_type = _serialize_tool_result(result) - assert content_type == "json" - assert json.loads(serialized) == result - - def test_none_returns_json(self) -> None: - serialized, content_type = _serialize_tool_result(None) - assert content_type == "json" - assert json.loads(serialized) is None - - def test_non_ascii_preserved(self) -> None: - result = {"emoji": "🔥", "text": "café"} - serialized, content_type = _serialize_tool_result(result) - assert content_type == "json" - assert "🔥" in serialized - assert "café" in serialized - - def test_datetime_uses_default_str(self) -> None: - dt = datetime(2025, 1, 15, 12, 30, 0) - result = {"timestamp": dt} - serialized, content_type = _serialize_tool_result(result) - assert content_type == "json" - parsed = json.loads(serialized) - assert "2025-01-15" in parsed["timestamp"] - - def test_non_serializable_falls_back_to_text(self) -> None: - # An object whose __str__ works but json.dumps with default=str - # should still handle it — need something truly unserializable. - # Actually, default=str handles most things. Let's verify str fallback - # by using an object that raises in __repr__/__str__ during json encoding. - - class BadObj: - def __repr__(self) -> str: - return "BadObj()" - - # default=str calls str() on non-serializable, so this should still - # produce json via the default handler - result = {"obj": BadObj()} - serialized, content_type = _serialize_tool_result(result) - assert content_type == "json" - assert "BadObj()" in serialized - - def test_integer_returns_json(self) -> None: - serialized, content_type = _serialize_tool_result(42) - assert content_type == "json" - assert json.loads(serialized) == 42 - - def test_boolean_returns_json(self) -> None: - serialized, content_type = _serialize_tool_result(True) - assert content_type == "json" - assert json.loads(serialized) is True - - def test_empty_dict_returns_json(self) -> None: - serialized, content_type = _serialize_tool_result({}) - assert content_type == "json" - assert json.loads(serialized) == {} - - # ============================================================================= # _normalize_file_scope # ============================================================================= diff --git a/tests/unit/utils/test_llm_serialization.py b/tests/unit/utils/test_llm_serialization.py new file mode 100644 index 00000000..1ce4b08d --- /dev/null +++ b/tests/unit/utils/test_llm_serialization.py @@ -0,0 +1,492 @@ +""" +tests/unit/utils/test_llm_serialization.py + +Tests for bluebox.utils.llm_serialization: + - LLMExclude marker + - _excluded_fields (cached introspection) + - strip_llm_excluded (recursive stripping) + - SerializedContentType / SerializedToolResult + - serialize_tool_result +""" + +import json +from datetime import datetime +from typing import Annotated, Optional + +from pydantic import BaseModel, Field, computed_field + +from bluebox.utils.llm_serialization import ( + LLMExclude, + SerializedContentType, + SerializedToolResult, + _excluded_fields, + serialize_tool_result, + strip_llm_excluded, +) + + +# --------------------------------------------------------------------------- +# Test models — defined at module level so _excluded_fields cache works +# --------------------------------------------------------------------------- + + +class SimpleModel(BaseModel): + visible: str + hidden: Annotated[str, LLMExclude()] + + +class AllExcluded(BaseModel): + a: Annotated[str, LLMExclude()] + b: Annotated[int, LLMExclude()] + + +class NoExclusions(BaseModel): + x: str + y: int + + +class NestedOuter(BaseModel): + name: str + inner: SimpleModel + secret: Annotated[str, LLMExclude()] + + +class WithOptional(BaseModel): + required: str + maybe: Annotated[Optional[str], LLMExclude()] = None + + +class WithDefault(BaseModel): + keep: str + drop: Annotated[str, LLMExclude()] = "default_val" + + +class WithFieldAndExclude(BaseModel): + """LLMExclude combined with Pydantic Field metadata.""" + name: str + internal_id: Annotated[int, Field(description="DB primary key"), LLMExclude()] + + +class Parent(BaseModel): + children: list[SimpleModel] + tag: Annotated[str, LLMExclude()] + + +class TupleContainer(BaseModel): + items: tuple[SimpleModel, ...] + removed: Annotated[str, LLMExclude()] + + +class WithComputedField(BaseModel): + first: str + last: str + secret: Annotated[str, LLMExclude()] + + @computed_field + @property + def full_name(self) -> str: + return f"{self.first} {self.last}" + + +class ComputedOnly(BaseModel): + """Model where only computed fields exist alongside excluded regular fields.""" + raw: Annotated[str, LLMExclude()] + + @computed_field + @property + def derived(self) -> str: + return self.raw.upper() + + +class NestedWithComputed(BaseModel): + label: str + item: WithComputedField + + +# ============================================================================= +# LLMExclude marker +# ============================================================================= + + +class TestLLMExclude: + """Basic tests for the LLMExclude marker class.""" + + def test_instantiable(self) -> None: + marker = LLMExclude() + assert isinstance(marker, LLMExclude) + + def test_stored_in_pydantic_metadata(self) -> None: + info = SimpleModel.model_fields["hidden"] + assert any(isinstance(m, LLMExclude) for m in info.metadata) + + def test_not_on_regular_field(self) -> None: + info = SimpleModel.model_fields["visible"] + assert not any(isinstance(m, LLMExclude) for m in info.metadata) + + +# ============================================================================= +# _excluded_fields +# ============================================================================= + + +class TestExcludedFields: + """Tests for the cached _excluded_fields helper.""" + + def test_returns_marked_fields(self) -> None: + assert _excluded_fields(SimpleModel) == frozenset({"hidden"}) + + def test_all_excluded(self) -> None: + assert _excluded_fields(AllExcluded) == frozenset({"a", "b"}) + + def test_none_excluded(self) -> None: + assert _excluded_fields(NoExclusions) == frozenset() + + def test_nested_model_own_exclusions(self) -> None: + # NestedOuter's own exclusions, not inner model's + assert _excluded_fields(NestedOuter) == frozenset({"secret"}) + + def test_with_optional_field(self) -> None: + assert _excluded_fields(WithOptional) == frozenset({"maybe"}) + + def test_with_default_field(self) -> None: + assert _excluded_fields(WithDefault) == frozenset({"drop"}) + + def test_combined_with_pydantic_field(self) -> None: + assert _excluded_fields(WithFieldAndExclude) == frozenset({"internal_id"}) + + def test_result_is_cached(self) -> None: + result1 = _excluded_fields(SimpleModel) + result2 = _excluded_fields(SimpleModel) + assert result1 is result2 # same object from cache + + +# ============================================================================= +# strip_llm_excluded — BaseModel inputs +# ============================================================================= + + +class TestStripLLMExcludedModels: + """Tests for strip_llm_excluded with BaseModel instances.""" + + def test_simple_model(self) -> None: + obj = SimpleModel(visible="yes", hidden="no") + result = strip_llm_excluded(obj) + assert result == {"visible": "yes"} + assert "hidden" not in result + + def test_all_fields_excluded_returns_empty_dict(self) -> None: + obj = AllExcluded(a="x", b=42) + assert strip_llm_excluded(obj) == {} + + def test_no_exclusions_returns_all_fields(self) -> None: + obj = NoExclusions(x="hello", y=99) + result = strip_llm_excluded(obj) + assert result == {"x": "hello", "y": 99} + + def test_nested_model_both_levels_stripped(self) -> None: + inner = SimpleModel(visible="kept", hidden="dropped") + outer = NestedOuter(name="test", inner=inner, secret="shh") + result = strip_llm_excluded(outer) + assert result == { + "name": "test", + "inner": {"visible": "kept"}, + } + assert "secret" not in result + assert "hidden" not in result["inner"] + + def test_optional_none_excluded(self) -> None: + obj = WithOptional(required="yes") + result = strip_llm_excluded(obj) + assert result == {"required": "yes"} + + def test_optional_with_value_excluded(self) -> None: + obj = WithOptional(required="yes", maybe="should vanish") + result = strip_llm_excluded(obj) + assert result == {"required": "yes"} + + def test_default_value_excluded(self) -> None: + obj = WithDefault(keep="visible") + result = strip_llm_excluded(obj) + assert result == {"keep": "visible"} + + def test_field_with_pydantic_field_and_exclude(self) -> None: + obj = WithFieldAndExclude(name="widget", internal_id=12345) + result = strip_llm_excluded(obj) + assert result == {"name": "widget"} + + def test_list_of_models_in_parent(self) -> None: + children = [ + SimpleModel(visible="a", hidden="x"), + SimpleModel(visible="b", hidden="y"), + ] + obj = Parent(children=children, tag="remove_me") + result = strip_llm_excluded(obj) + assert result == { + "children": [{"visible": "a"}, {"visible": "b"}], + } + + def test_tuple_of_models_in_parent(self) -> None: + items = ( + SimpleModel(visible="a", hidden="x"), + SimpleModel(visible="b", hidden="y"), + ) + obj = TupleContainer(items=items, removed="gone") + result = strip_llm_excluded(obj) + assert result == { + "items": ({"visible": "a"}, {"visible": "b"}), + } + + def test_returns_dict_not_model(self) -> None: + """strip_llm_excluded always converts BaseModel to dict.""" + obj = NoExclusions(x="a", y=1) + result = strip_llm_excluded(obj) + assert isinstance(result, dict) + assert not isinstance(result, BaseModel) + + def test_computed_field_preserved(self) -> None: + """@computed_field values are included in stripped output.""" + obj = WithComputedField(first="Jane", last="Doe", secret="shh") + result = strip_llm_excluded(obj) + assert result == {"first": "Jane", "last": "Doe", "full_name": "Jane Doe"} + assert "secret" not in result + + def test_computed_field_with_all_regular_excluded(self) -> None: + """Computed fields survive even when all regular fields are excluded.""" + obj = ComputedOnly(raw="hello") + result = strip_llm_excluded(obj) + assert result == {"derived": "HELLO"} + assert "raw" not in result + + def test_nested_model_with_computed_field(self) -> None: + """Nested model's computed fields are preserved recursively.""" + inner = WithComputedField(first="Jane", last="Doe", secret="x") + obj = NestedWithComputed(label="test", item=inner) + result = strip_llm_excluded(obj) + assert result == { + "label": "test", + "item": {"first": "Jane", "last": "Doe", "full_name": "Jane Doe"}, + } + + +# ============================================================================= +# strip_llm_excluded — dict / list / tuple / primitive inputs +# ============================================================================= + + +class TestStripLLMExcludedContainers: + """Tests for strip_llm_excluded with non-model containers.""" + + def test_plain_dict_passthrough(self) -> None: + d = {"a": 1, "b": "two"} + assert strip_llm_excluded(d) == d + + def test_dict_with_model_value(self) -> None: + obj = SimpleModel(visible="yes", hidden="no") + result = strip_llm_excluded({"data": obj, "count": 1}) + assert result == {"data": {"visible": "yes"}, "count": 1} + + def test_dict_nested_dicts_with_models(self) -> None: + obj = SimpleModel(visible="v", hidden="h") + result = strip_llm_excluded({"level1": {"level2": obj}}) + assert result == {"level1": {"level2": {"visible": "v"}}} + + def test_list_of_models(self) -> None: + models = [ + SimpleModel(visible="a", hidden="1"), + SimpleModel(visible="b", hidden="2"), + ] + result = strip_llm_excluded(models) + assert result == [{"visible": "a"}, {"visible": "b"}] + + def test_list_of_primitives_passthrough(self) -> None: + lst = [1, "two", 3.0, None, True] + assert strip_llm_excluded(lst) == lst + + def test_tuple_of_models(self) -> None: + models = ( + SimpleModel(visible="x", hidden="y"), + ) + result = strip_llm_excluded(models) + assert result == ({"visible": "x"},) + assert isinstance(result, tuple) + + def test_tuple_of_primitives_passthrough(self) -> None: + t = (1, 2, 3) + result = strip_llm_excluded(t) + assert result == (1, 2, 3) + assert isinstance(result, tuple) + + def test_empty_dict(self) -> None: + assert strip_llm_excluded({}) == {} + + def test_empty_list(self) -> None: + assert strip_llm_excluded([]) == [] + + def test_empty_tuple(self) -> None: + assert strip_llm_excluded(()) == () + + def test_mixed_list(self) -> None: + """List with models, dicts, and primitives.""" + obj = SimpleModel(visible="v", hidden="h") + result = strip_llm_excluded([obj, {"plain": "dict"}, 42, None]) + assert result == [{"visible": "v"}, {"plain": "dict"}, 42, None] + + +# ============================================================================= +# strip_llm_excluded — primitives +# ============================================================================= + + +class TestStripLLMExcludedPrimitives: + """Primitives pass through unchanged.""" + + def test_string(self) -> None: + assert strip_llm_excluded("hello") == "hello" + + def test_int(self) -> None: + assert strip_llm_excluded(42) == 42 + + def test_float(self) -> None: + assert strip_llm_excluded(3.14) == 3.14 + + def test_bool(self) -> None: + assert strip_llm_excluded(True) is True + + def test_none(self) -> None: + assert strip_llm_excluded(None) is None + + +# ============================================================================= +# SerializedContentType +# ============================================================================= + + +class TestSerializedContentType: + """Tests for the SerializedContentType enum.""" + + def test_json_value(self) -> None: + assert SerializedContentType.JSON == "json" + + def test_text_value(self) -> None: + assert SerializedContentType.TEXT == "text" + + def test_is_str_enum(self) -> None: + assert isinstance(SerializedContentType.JSON, str) + + +# ============================================================================= +# SerializedToolResult +# ============================================================================= + + +class TestSerializedToolResult: + """Tests for the SerializedToolResult named tuple.""" + + def test_tuple_unpacking(self) -> None: + r = SerializedToolResult(serialized="{}", content_type=SerializedContentType.JSON) + s, ct = r + assert s == "{}" + assert ct == SerializedContentType.JSON + + def test_named_access(self) -> None: + r = SerializedToolResult(serialized="hi", content_type=SerializedContentType.TEXT) + assert r.serialized == "hi" + assert r.content_type == SerializedContentType.TEXT + + def test_equality_with_plain_tuple(self) -> None: + r = SerializedToolResult(serialized="x", content_type=SerializedContentType.JSON) + assert r == ("x", "json") # NamedTuple is tuple-compatible + + +# ============================================================================= +# serialize_tool_result +# ============================================================================= + + +class TestSerializeToolResult: + """Tests for serialize_tool_result.""" + + def test_dict_returns_json(self) -> None: + result = serialize_tool_result({"key": "value", "count": 42}) + assert result.content_type == SerializedContentType.JSON + assert json.loads(result.serialized) == {"key": "value", "count": 42} + + def test_list_returns_json(self) -> None: + result = serialize_tool_result([1, 2, 3]) + assert result.content_type == SerializedContentType.JSON + assert json.loads(result.serialized) == [1, 2, 3] + + def test_string_returns_json(self) -> None: + result = serialize_tool_result("hello") + assert result.content_type == SerializedContentType.JSON + assert json.loads(result.serialized) == "hello" + + def test_nested_dict_returns_json(self) -> None: + data = {"a": {"b": [1, 2]}, "c": None} + result = serialize_tool_result(data) + assert result.content_type == SerializedContentType.JSON + assert json.loads(result.serialized) == data + + def test_none_returns_json(self) -> None: + result = serialize_tool_result(None) + assert result.content_type == SerializedContentType.JSON + assert json.loads(result.serialized) is None + + def test_non_ascii_preserved(self) -> None: + data = {"emoji": "\U0001f525", "text": "caf\u00e9"} + result = serialize_tool_result(data) + assert result.content_type == SerializedContentType.JSON + assert "\U0001f525" in result.serialized + assert "caf\u00e9" in result.serialized + + def test_datetime_uses_default_str(self) -> None: + dt = datetime(2025, 1, 15, 12, 30, 0) + result = serialize_tool_result({"timestamp": dt}) + assert result.content_type == SerializedContentType.JSON + parsed = json.loads(result.serialized) + assert "2025-01-15" in parsed["timestamp"] + + def test_custom_object_handled_by_default_str(self) -> None: + class Widget: + def __str__(self) -> str: + return "Widget()" + + result = serialize_tool_result({"obj": Widget()}) + assert result.content_type == SerializedContentType.JSON + assert "Widget()" in result.serialized + + def test_integer_returns_json(self) -> None: + result = serialize_tool_result(42) + assert result.content_type == SerializedContentType.JSON + assert json.loads(result.serialized) == 42 + + def test_boolean_returns_json(self) -> None: + result = serialize_tool_result(True) + assert result.content_type == SerializedContentType.JSON + assert json.loads(result.serialized) is True + + def test_empty_dict_returns_json(self) -> None: + result = serialize_tool_result({}) + assert result.content_type == SerializedContentType.JSON + assert json.loads(result.serialized) == {} + + def test_result_is_named_tuple(self) -> None: + result = serialize_tool_result({"a": 1}) + assert isinstance(result, SerializedToolResult) + assert isinstance(result, tuple) + + def test_tuple_unpacking_still_works(self) -> None: + """Backwards compat: callers that do `s, ct = serialize_tool_result(...)` still work.""" + serialized, content_type = serialize_tool_result({"x": 1}) + assert isinstance(serialized, str) + assert content_type == "json" + + def test_output_is_indented(self) -> None: + result = serialize_tool_result({"a": 1}) + assert "\n" in result.serialized # indent=2 produces newlines + + def test_large_payload(self) -> None: + data = {f"key_{i}": f"value_{i}" for i in range(500)} + result = serialize_tool_result(data) + assert result.content_type == SerializedContentType.JSON + assert json.loads(result.serialized) == data