From 14d7d105d47436e9fabffe9d79703af98074239b Mon Sep 17 00:00:00 2001 From: majiayu000 <1835304752@qq.com> Date: Tue, 30 Dec 2025 14:23:46 +0800 Subject: [PATCH] fix(pydantic): support union type hints for stream_type parameter Updated type hints to accept Union types for stream_type parameter in streaming_action.pydantic decorator. This allows users to specify union types like Union[Model1, Model2] for stream_type. Closes #607 Signed-off-by: majiayu000 <1835304752@qq.com> --- burr/core/action.py | 2 +- burr/integrations/pydantic.py | 8 ++-- tests/integrations/test_burr_pydantic.py | 51 +++++++++++++++++++++++- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/burr/core/action.py b/burr/core/action.py index 69a7c75b6..8b42c8817 100644 --- a/burr/core/action.py +++ b/burr/core/action.py @@ -1309,7 +1309,7 @@ def pydantic( writes: List[str], state_input_type: Type["BaseModel"], state_output_type: Type["BaseModel"], - stream_type: Union[Type["BaseModel"], Type[dict]], + stream_type: Union[Type["BaseModel"], Type[dict], type, "types.UnionType"], tags: Optional[List[str]] = None, ) -> Callable: """Creates a streaming action that uses pydantic models. diff --git a/burr/integrations/pydantic.py b/burr/integrations/pydantic.py index 5354537a2..bca543c40 100644 --- a/burr/integrations/pydantic.py +++ b/burr/integrations/pydantic.py @@ -264,7 +264,7 @@ async def async_action_function(state: State, **kwargs) -> State: return decorator -PartialType = Union[Type[pydantic.BaseModel], Type[dict]] +PartialType = Union[Type[pydantic.BaseModel], Type[dict], type, "types.UnionType"] PydanticStreamingActionFunctionSync = Callable[ ..., Generator[Tuple[Union[pydantic.BaseModel, dict], Optional[pydantic.BaseModel]], None, None] @@ -285,12 +285,10 @@ async def async_action_function(state: State, **kwargs) -> State: def _validate_and_extract_signature_types_streaming( fn: PydanticStreamingActionFunction, - stream_type: Optional[Union[Type[pydantic.BaseModel], Type[dict]]], + stream_type: Optional[PartialType], state_input_type: Optional[Type[pydantic.BaseModel]] = None, state_output_type: Optional[Type[pydantic.BaseModel]] = None, -) -> Tuple[ - Type[pydantic.BaseModel], Type[pydantic.BaseModel], Union[Type[dict], Type[pydantic.BaseModel]] -]: +) -> Tuple[Type[pydantic.BaseModel], Type[pydantic.BaseModel], PartialType]: if stream_type is None: # TODO -- derive from the signature raise ValueError(f"stream_type is required for function: {fn.__qualname__}") diff --git a/tests/integrations/test_burr_pydantic.py b/tests/integrations/test_burr_pydantic.py index 6c25520f6..16dc1a3f6 100644 --- a/tests/integrations/test_burr_pydantic.py +++ b/tests/integrations/test_burr_pydantic.py @@ -16,7 +16,7 @@ # under the License. import asyncio -from typing import AsyncGenerator, Generator, List, Optional, Tuple +from typing import AsyncGenerator, Generator, List, Optional, Tuple, Union import pydantic import pytest @@ -758,3 +758,52 @@ async def final_result_streamed( assert state.data.count == 20 assert isinstance(result, IntermediateModel) assert result.result == 20 + + +class StreamTypeModel1(BaseModel): + value: int + + +class StreamTypeModel2(BaseModel): + message: str + + +def test_streaming_pydantic_action_union_stream_type(): + """Test that stream_type accepts union types like Union[Model1, Model2].""" + union_type = Union[StreamTypeModel1, StreamTypeModel2] + + @pydantic_streaming_action( + reads=["count", "times_called"], + writes=["count", "times_called"], + stream_type=union_type, + state_input_type=AppStateModel, + state_output_type=AppStateModel, + ) + def act( + state: AppStateModel, total_count: int + ) -> Generator[ + Tuple[Union[StreamTypeModel1, StreamTypeModel2], Optional[AppStateModel]], None, None + ]: + for i in range(total_count): + if i % 2 == 0: + yield StreamTypeModel1(value=i), None + else: + yield StreamTypeModel2(message=f"step_{i}"), None + state.count = i + state.times_called += 1 + yield StreamTypeModel1(value=state.count), state + + assert hasattr(act, "bind") + action_function = getattr(act, FunctionBasedAction.ACTION_FUNCTION, None) + assert action_function is not None + gen = action_function.fn( + State(dict(count=0, times_called=0), typing_system=PydanticTypingSystem(AppStateModel)), + total_count=4, + ) + result = list(gen) + assert len(result) == 5 + assert isinstance(result[0][0], StreamTypeModel1) + assert isinstance(result[1][0], StreamTypeModel2) + assert isinstance(result[2][0], StreamTypeModel1) + assert isinstance(result[3][0], StreamTypeModel2) + assert isinstance(result[-1][1], State)