diff --git a/bluebox/utils/llm_serialization.py b/bluebox/utils/llm_serialization.py index b550ce50..7b3b59ff 100644 --- a/bluebox/utils/llm_serialization.py +++ b/bluebox/utils/llm_serialization.py @@ -119,7 +119,7 @@ def _excluded_fields(model_cls: type[BaseModel]) -> frozenset[str]: return frozenset( name for name, info in model_cls.model_fields.items() - if any(isinstance(m, LLMExclude) for m in info.metadata) + if any(m is LLMExclude or isinstance(m, LLMExclude) for m in info.metadata) ) diff --git a/tests/unit/utils/test_llm_serialization.py b/tests/unit/utils/test_llm_serialization.py index 1ce4b08d..ec8d818c 100644 --- a/tests/unit/utils/test_llm_serialization.py +++ b/tests/unit/utils/test_llm_serialization.py @@ -103,6 +103,33 @@ class NestedWithComputed(BaseModel): item: WithComputedField +# --- Bare-class (LLMExclude without parens) variants ------------------------- + + +class BareSimple(BaseModel): + """Uses bare LLMExclude class instead of LLMExclude().""" + visible: str + hidden: Annotated[str, LLMExclude] + + +class BareWithInstance(BaseModel): + """Mix of bare class and instance in the same model.""" + keep: str + drop_bare: Annotated[str, LLMExclude] + drop_instance: Annotated[str, LLMExclude()] + + +class BareNested(BaseModel): + name: str + inner: BareSimple + secret: Annotated[str, LLMExclude] + + +class BareWithField(BaseModel): + name: str + internal_id: Annotated[int, Field(description="DB key"), LLMExclude] + + # ============================================================================= # LLMExclude marker # ============================================================================= @@ -123,6 +150,11 @@ def test_not_on_regular_field(self) -> None: info = SimpleModel.model_fields["visible"] assert not any(isinstance(m, LLMExclude) for m in info.metadata) + def test_bare_class_stored_in_metadata(self) -> None: + """Annotated[str, LLMExclude] (no parens) stores the class itself.""" + info = BareSimple.model_fields["hidden"] + assert any(m is LLMExclude for m in info.metadata) + # ============================================================================= # _excluded_fields @@ -159,6 +191,15 @@ def test_result_is_cached(self) -> None: result2 = _excluded_fields(SimpleModel) assert result1 is result2 # same object from cache + def test_bare_class_detected(self) -> None: + assert _excluded_fields(BareSimple) == frozenset({"hidden"}) + + def test_bare_and_instance_both_detected(self) -> None: + assert _excluded_fields(BareWithInstance) == frozenset({"drop_bare", "drop_instance"}) + + def test_bare_with_pydantic_field(self) -> None: + assert _excluded_fields(BareWithField) == frozenset({"internal_id"}) + # ============================================================================= # strip_llm_excluded — BaseModel inputs @@ -267,6 +308,33 @@ def test_nested_model_with_computed_field(self) -> None: "item": {"first": "Jane", "last": "Doe", "full_name": "Jane Doe"}, } + def test_bare_class_simple_model(self) -> None: + """Annotated[str, LLMExclude] (no parens) strips correctly.""" + obj = BareSimple(visible="yes", hidden="no") + result = strip_llm_excluded(obj) + assert result == {"visible": "yes"} + assert "hidden" not in result + + def test_bare_and_instance_mixed(self) -> None: + """Both bare-class and instance annotations strip in the same model.""" + obj = BareWithInstance(keep="ok", drop_bare="gone1", drop_instance="gone2") + result = strip_llm_excluded(obj) + assert result == {"keep": "ok"} + + def test_bare_class_nested(self) -> None: + inner = BareSimple(visible="kept", hidden="dropped") + obj = BareNested(name="test", inner=inner, secret="shh") + result = strip_llm_excluded(obj) + assert result == { + "name": "test", + "inner": {"visible": "kept"}, + } + + def test_bare_class_with_pydantic_field(self) -> None: + obj = BareWithField(name="widget", internal_id=999) + result = strip_llm_excluded(obj) + assert result == {"name": "widget"} + # ============================================================================= # strip_llm_excluded — dict / list / tuple / primitive inputs