Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion burr/core/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,7 +1309,7 @@ def pydantic(
writes: List[str],
state_input_type: Type["BaseModel"],
state_output_type: Type["BaseModel"],
stream_type: Union[Type["BaseModel"], Type[dict]],
stream_type: Union[Type["BaseModel"], Type[dict], type, "types.UnionType"],
tags: Optional[List[str]] = None,
) -> Callable:
"""Creates a streaming action that uses pydantic models.
Expand Down
8 changes: 3 additions & 5 deletions burr/integrations/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ async def async_action_function(state: State, **kwargs) -> State:
return decorator


PartialType = Union[Type[pydantic.BaseModel], Type[dict]]
PartialType = Union[Type[pydantic.BaseModel], Type[dict], type, "types.UnionType"]

PydanticStreamingActionFunctionSync = Callable[
..., Generator[Tuple[Union[pydantic.BaseModel, dict], Optional[pydantic.BaseModel]], None, None]
Expand All @@ -285,12 +285,10 @@ async def async_action_function(state: State, **kwargs) -> State:

def _validate_and_extract_signature_types_streaming(
fn: PydanticStreamingActionFunction,
stream_type: Optional[Union[Type[pydantic.BaseModel], Type[dict]]],
stream_type: Optional[PartialType],
state_input_type: Optional[Type[pydantic.BaseModel]] = None,
state_output_type: Optional[Type[pydantic.BaseModel]] = None,
) -> Tuple[
Type[pydantic.BaseModel], Type[pydantic.BaseModel], Union[Type[dict], Type[pydantic.BaseModel]]
]:
) -> Tuple[Type[pydantic.BaseModel], Type[pydantic.BaseModel], PartialType]:
if stream_type is None:
# TODO -- derive from the signature
raise ValueError(f"stream_type is required for function: {fn.__qualname__}")
Expand Down
51 changes: 50 additions & 1 deletion tests/integrations/test_burr_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import asyncio
from typing import AsyncGenerator, Generator, List, Optional, Tuple
from typing import AsyncGenerator, Generator, List, Optional, Tuple, Union

import pydantic
import pytest
Expand Down Expand Up @@ -758,3 +758,52 @@ async def final_result_streamed(
assert state.data.count == 20
assert isinstance(result, IntermediateModel)
assert result.result == 20


class StreamTypeModel1(BaseModel):
value: int


class StreamTypeModel2(BaseModel):
message: str


def test_streaming_pydantic_action_union_stream_type():
"""Test that stream_type accepts union types like Union[Model1, Model2]."""
union_type = Union[StreamTypeModel1, StreamTypeModel2]

@pydantic_streaming_action(
reads=["count", "times_called"],
writes=["count", "times_called"],
stream_type=union_type,
state_input_type=AppStateModel,
state_output_type=AppStateModel,
)
def act(
state: AppStateModel, total_count: int
) -> Generator[
Tuple[Union[StreamTypeModel1, StreamTypeModel2], Optional[AppStateModel]], None, None
]:
for i in range(total_count):
if i % 2 == 0:
yield StreamTypeModel1(value=i), None
else:
yield StreamTypeModel2(message=f"step_{i}"), None
state.count = i
state.times_called += 1
yield StreamTypeModel1(value=state.count), state

assert hasattr(act, "bind")
action_function = getattr(act, FunctionBasedAction.ACTION_FUNCTION, None)
assert action_function is not None
gen = action_function.fn(
State(dict(count=0, times_called=0), typing_system=PydanticTypingSystem(AppStateModel)),
total_count=4,
)
result = list(gen)
assert len(result) == 5
assert isinstance(result[0][0], StreamTypeModel1)
assert isinstance(result[1][0], StreamTypeModel2)
assert isinstance(result[2][0], StreamTypeModel1)
assert isinstance(result[3][0], StreamTypeModel2)
assert isinstance(result[-1][1], State)