diff --git a/chatkit/agents.py b/chatkit/agents.py index 5ace8d2..f1f79dd 100644 --- a/chatkit/agents.py +++ b/chatkit/agents.py @@ -25,6 +25,7 @@ EasyInputMessageParam, ResponseFunctionToolCallParam, ResponseInputContentParam, + ResponseInputImageParam, ResponseInputMessageContentListParam, ResponseInputTextParam, ResponseOutputText, @@ -55,6 +56,9 @@ DurationSummary, EndOfTurnItem, FileSource, + GeneratedImage, + GeneratedImageItem, + GeneratedImageUpdated, HiddenContextItem, SDKHiddenContextItem, Task, @@ -105,6 +109,7 @@ class AgentContext(BaseModel, Generic[TContext]): previous_response_id: str | None = None client_tool_call: ClientToolCall | None = None workflow_item: WorkflowItem | None = None + generated_image_item: GeneratedImageItem | None = None _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue() def generate_id( @@ -356,10 +361,97 @@ class StreamingThoughtTracker(BaseModel): task: ThoughtTask +class ResponseStreamConverter: + """Used by `stream_agent_response` to convert streamed Agents SDK output + into values used by ChatKit thread items and thread stream events. + + Defines overridable methods for adapting streamed data (such as image + generation results and partial updates) into the forms expected by ChatKit. + """ + + partial_images: int | None = None + """ + The expected number of partial image updates for an image generation result. + + When set, this value is used to normalize partial image indices into a + progress value in the range [0, 1]. If unset, all partial image updates are + assigned a progress value of 0. + """ + + def __init__(self, *, partial_images: int | None = None): + """ + Args: + partial_images: The expected number of partial image updates for image + generation results, or None if no progress normalization should + be performed. + """ + self.partial_images = partial_images + + async def base64_image_to_url( + self, + image_id: str, + base64_image: str, + partial_image_index: int | None = None, + ) -> str: + """ + Convert a base64-encoded image into a URL. + + This method is used to produce the URL stored on thread items for image + generation results. + + Args: + image_id: The ID of the image generation call. This stays stable across partial image updates. + base64_image: The base64-encoded image. + partial_image_index: The index of the partial image update, starting from 0. + + Returns: + A URL string. + """ + return f"data:image/png;base64,{base64_image}" + + def partial_image_index_to_progress(self, partial_image_index: int) -> float: + """ + Convert a partial image index into a normalized progress value. + + Args: + partial_image_index: The index of the partial image update, starting from 0. + + Returns: + A float between 0 and 1 representing progress for the image + generation result. + """ + if self.partial_images is None or self.partial_images <= 0: + return 0.0 + + return min(1.0, partial_image_index / self.partial_images) + + +_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter() + + async def stream_agent_response( - context: AgentContext, result: RunResultStreaming + context: AgentContext, + result: RunResultStreaming, + *, + converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER, ) -> AsyncIterator[ThreadStreamEvent]: - """Convert a streamed Agents SDK run into ChatKit ThreadStreamEvents.""" + """ + Convert a streamed Agents SDK run into ChatKit thread stream events. + + This function consumes a streaming run result and yields `ThreadStreamEvent` + objects as the run progresses. + + Args: + context: The AgentContext to use for the stream. + result: The RunResultStreaming to convert. + image_generation_stream_converter: Controls how streamed image generation output + is converted into URLs and progress updates. The default converter stores the + generated base64 image and assigns a progress value of 0 to all partial image + updates. + + Returns: + An async iterator that yields thread stream events representing the run result. + """ current_item_id = None current_tool_call = None ctx = context @@ -527,6 +619,38 @@ def end_workflow(item: WorkflowItem): created_at=datetime.now(), ), ) + elif item.type == "image_generation_call": + ctx.generated_image_item = GeneratedImageItem( + id=ctx.generate_id("message"), + thread_id=thread.id, + created_at=datetime.now(), + image=None, + ) + produced_items.add(ctx.generated_image_item.id) + yield ThreadItemAddedEvent(item=ctx.generated_image_item) + elif event.type == "response.image_generation_call.partial_image": + if not ctx.generated_image_item: + continue + + url = await converter.base64_image_to_url( + image_id=event.item_id, + base64_image=event.partial_image_b64, + partial_image_index=event.partial_image_index, + ) + progress = converter.partial_image_index_to_progress( + event.partial_image_index + ) + + ctx.generated_image_item.image = GeneratedImage( + id=event.item_id, url=url + ) + + yield ThreadItemUpdatedEvent( + item_id=ctx.generated_image_item.id, + update=GeneratedImageUpdated( + image=ctx.generated_image_item.image, progress=progress + ), + ) elif event.type == "response.reasoning_summary_text.delta": if not ctx.workflow_item: continue @@ -604,6 +728,20 @@ def end_workflow(item: WorkflowItem): created_at=datetime.now(), ), ) + elif item.type == "image_generation_call" and item.result: + if not ctx.generated_image_item: + continue + + url = await converter.base64_image_to_url( + image_id=item.id, + base64_image=item.result, + ) + image = GeneratedImage(id=item.id, url=url) + + ctx.generated_image_item.image = image + yield ThreadItemDoneEvent(item=ctx.generated_image_item) + + ctx.generated_image_item = None except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered): for item_id in produced_items: @@ -694,6 +832,33 @@ async def tag_to_message_content( "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented" ) + async def generated_image_to_input( + self, item: GeneratedImageItem + ) -> TResponseInputItem | list[TResponseInputItem] | None: + """ + Convert a GeneratedImageItem into input item(s) to send to the model. + Override this method to customize the conversion of generated images, such as when your + generated image url is not publicly reachable. + """ + if not item.image: + return None + + return Message( + type="message", + content=[ + ResponseInputTextParam( + type="input_text", + text="The following image was generated by the agent.", + ), + ResponseInputImageParam( + type="input_image", + detail="auto", + image_url=item.image.url, + ), + ], + role="user", + ) + async def hidden_context_to_input( self, item: HiddenContextItem ) -> TResponseInputItem | list[TResponseInputItem] | None: @@ -984,6 +1149,9 @@ async def _thread_item_to_input_item( case SDKHiddenContextItem(): out = await self.sdk_hidden_context_to_input(item) or [] return out if isinstance(out, list) else [out] + case GeneratedImageItem(): + out = await self.generated_image_to_input(item) or [] + return out if isinstance(out, list) else [out] case _: assert_never(item) diff --git a/chatkit/types.py b/chatkit/types.py index c84c3db..5a18ec9 100644 --- a/chatkit/types.py +++ b/chatkit/types.py @@ -471,6 +471,14 @@ class WorkflowTaskUpdated(BaseModel): task: Task +class GeneratedImageUpdated(BaseModel): + """Event emitted when a generated image is updated.""" + + type: Literal["generated_image.updated"] = "generated_image.updated" + image: GeneratedImage + progress: float | None = None + + ThreadItemUpdate = ( AssistantMessageContentPartAdded | AssistantMessageContentPartTextDelta @@ -481,6 +489,7 @@ class WorkflowTaskUpdated(BaseModel): | WidgetRootUpdated | WorkflowTaskAdded | WorkflowTaskUpdated + | GeneratedImageUpdated ) """Union of possible updates applied to thread items.""" @@ -579,6 +588,20 @@ class WidgetItem(ThreadItemBase): copy_text: str | None = None +class GeneratedImage(BaseModel): + """Generated image.""" + + id: str + url: str + + +class GeneratedImageItem(ThreadItemBase): + """Thread item containing a generated image.""" + + type: Literal["generated_image"] = "generated_image" + image: GeneratedImage | None = None + + class TaskItem(ThreadItemBase): """Thread item containing a task.""" @@ -624,6 +647,7 @@ class SDKHiddenContextItem(ThreadItemBase): | AssistantMessageItem | ClientToolCallItem | WidgetItem + | GeneratedImageItem | WorkflowItem | TaskItem | HiddenContextItem diff --git a/tests/test_agents.py b/tests/test_agents.py index 045b90e..7cdcfa9 100644 --- a/tests/test_agents.py +++ b/tests/test_agents.py @@ -59,6 +59,7 @@ from chatkit.agents import ( AgentContext, + ResponseStreamConverter, ThreadItemConverter, accumulate_text, simple_to_agent_input, @@ -78,6 +79,9 @@ CustomTask, DurationSummary, FileSource, + GeneratedImage, + GeneratedImageItem, + GeneratedImageUpdated, HiddenContextItem, InferenceOptions, Page, @@ -544,6 +548,45 @@ async def test_input_item_converter_user_input_with_tags_throws_by_default(): await simple_to_agent_input(items) +async def test_input_item_converter_generated_image_item(): + items = [ + GeneratedImageItem( + id="img_item_1", + thread_id=thread.id, + created_at=datetime.now(), + image=GeneratedImage(id="img_1", url="https://example.com/img.png"), + ) + ] + + input_items = await simple_to_agent_input(items) + assert len(input_items) == 1 + + message = cast(dict, input_items[0]) + assert message.get("type") == "message" + assert message.get("role") == "user" + + content = cast(list, message.get("content")) + assert content[0].get("type") == "input_text" + assert content[0].get("text") == "The following image was generated by the agent." + assert content[1].get("type") == "input_image" + assert content[1].get("file_id") is None + assert content[1].get("image_url") == "https://example.com/img.png" + assert content[1].get("detail") == "auto" + + +async def test_input_item_converter_generated_image_item_without_image(): + items = [ + GeneratedImageItem( + id="img_item_1", + thread_id=thread.id, + created_at=datetime.now(), + ) + ] + + input_items = await simple_to_agent_input(items) + assert input_items == [] + + async def test_input_item_converter_for_hidden_context_with_string_content(): items = [ HiddenContextItem( @@ -1191,6 +1234,192 @@ async def test_stream_agent_response_assistant_message_content_types(): assert message.id == "1" +async def test_stream_agent_response_image_generation_events(): + context = AgentContext( + previous_response_id=None, thread=thread, store=mock_store, request_context=None + ) + result = make_result() + + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=Mock( + type="response.output_item.added", + item=Mock(type="image_generation_call", id="img_call_1"), + output_index=0, + sequence_number=0, + ), + ) + ) + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=Mock( + type="response.output_item.done", + item=Mock( + type="image_generation_call", id="img_call_1", result="dGVzdA==" + ), + output_index=0, + sequence_number=1, + ), + ) + ) + result.done() + + stream = stream_agent_response(context, result) + event1 = await stream.__anext__() + assert isinstance(event1, ThreadItemAddedEvent) + assert isinstance(event1.item, GeneratedImageItem) + assert event1.item.type == "generated_image" + assert event1.item.id == "message_id" + assert event1.item.image is None + + event2 = await stream.__anext__() + assert isinstance(event2, ThreadItemDoneEvent) + assert isinstance(event2.item, GeneratedImageItem) + assert event2.item.id == event1.item.id + assert event2.item.image == GeneratedImage( + id="img_call_1", url="data:image/png;base64,dGVzdA==" + ) + + with pytest.raises(StopAsyncIteration): + await stream.__anext__() + + +async def test_stream_agent_response_image_generation_events_with_custom_converter(): + context = AgentContext( + previous_response_id=None, thread=thread, store=mock_store, request_context=None + ) + result = make_result() + + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=Mock( + type="response.output_item.added", + item=Mock(type="image_generation_call", id="img_call_1"), + output_index=0, + sequence_number=0, + ), + ) + ) + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=Mock( + type="response.output_item.done", + item=Mock( + type="image_generation_call", id="img_call_1", result="dGVzdA==" + ), + output_index=0, + sequence_number=1, + ), + ) + ) + result.done() + + class CustomResponseStreamConverter(ResponseStreamConverter): + def __init__(self): + super().__init__() + self.calls: list[tuple[str, str, int | None]] = [] + + async def base64_image_to_url( + self, + image_id: str, + base64_image: str, + partial_image_index: int | None = None, + ) -> str: + self.calls.append((image_id, base64_image, partial_image_index)) + return f"https://example.com/{image_id}" + + converter = CustomResponseStreamConverter() + stream = stream_agent_response(context, result, converter=converter) + event1 = await stream.__anext__() + assert isinstance(event1, ThreadItemAddedEvent) + assert isinstance(event1.item, GeneratedImageItem) + assert event1.item.image is None + + event2 = await stream.__anext__() + assert isinstance(event2, ThreadItemDoneEvent) + assert isinstance(event2.item, GeneratedImageItem) + assert converter.calls == [("img_call_1", "dGVzdA==", None)] + assert event2.item.image == GeneratedImage( + id="img_call_1", url="https://example.com/img_call_1" + ) + with pytest.raises(StopAsyncIteration): + await stream.__anext__() + + +async def test_stream_agent_response_image_generation_partial_progress(): + context = AgentContext( + previous_response_id=None, thread=thread, store=mock_store, request_context=None + ) + result = make_result() + + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=Mock( + type="response.output_item.added", + item=Mock(type="image_generation_call", id="img_call_1"), + output_index=0, + sequence_number=0, + ), + ) + ) + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=Mock( + type="response.image_generation_call.partial_image", + partial_image_b64="dGVzdA==", + partial_image_index=1, + item_id="img_call_1", + output_index=0, + sequence_number=1, + ), + ) + ) + result.add_event( + RawResponsesStreamEvent( + type="raw_response_event", + data=Mock( + type="response.output_item.done", + item=Mock( + type="image_generation_call", id="img_call_1", result="dGVzdA==" + ), + output_index=0, + sequence_number=2, + ), + ) + ) + result.done() + + converter = ResponseStreamConverter(partial_images=3) + events = await all_events( + stream_agent_response(context, result, converter=converter) + ) + + assert len(events) == 3 + added_event, partial_event, done_event = events + + assert isinstance(added_event, ThreadItemAddedEvent) + assert isinstance(added_event.item, GeneratedImageItem) + + assert isinstance(partial_event, ThreadItemUpdatedEvent) + assert isinstance(partial_event.update, GeneratedImageUpdated) + assert partial_event.update.progress == pytest.approx(1 / 3) + assert partial_event.update.image == GeneratedImage( + id="img_call_1", url="data:image/png;base64,dGVzdA==" + ) + + assert isinstance(done_event, ThreadItemDoneEvent) + assert isinstance(done_event.item, GeneratedImageItem) + assert done_event.item.image == GeneratedImage( + id="img_call_1", url="data:image/png;base64,dGVzdA==" + ) + + async def test_workflow_streams_first_thought(): context = AgentContext( previous_response_id=None, thread=thread, store=mock_store, request_context=None