diff --git a/packages/modal-infra/src/sandbox/bridge.py b/packages/modal-infra/src/sandbox/bridge.py index 362a7a1..9fe9f03 100644 --- a/packages/modal-infra/src/sandbox/bridge.py +++ b/packages/modal-infra/src/sandbox/bridge.py @@ -534,6 +534,16 @@ async def _create_opencode_session(self) -> None: await self._save_session_id() + @staticmethod + def _extract_error_message(error: Any) -> str | None: + """Extract message from OpenCode NamedError: { "name": "...", "data": { "message": "..." } }.""" + if isinstance(error, dict): + data = error.get("data") + if isinstance(data, dict) and "message" in data: + return data["message"] + return error.get("message") or error.get("name") + return str(error) if error else None + def _transform_part_to_event( self, part: dict[str, Any], @@ -705,6 +715,8 @@ async def _stream_opencode_response_sse( 2. OpenCode creates assistant messages with parentID = our ascending ID 3. Filter events to only process parts from our assistant messages 4. Use control plane's message_id for events sent back + 5. Track child sessions (sub-tasks) and forward their non-text events + with isSubtask=True The ascending ID ensures our user message ID is lexicographically greater than any previous assistant message IDs, preventing the early exit condition @@ -728,6 +740,9 @@ async def _stream_opencode_response_sse( pending_parts_total = 0 pending_drop_logged = False + # Child session tracking (sub-tasks) + tracked_child_session_ids: set[str] = set() + start_time = time.time() loop = asyncio.get_running_loop() @@ -746,12 +761,19 @@ def buffer_part(oc_msg_id: str, part: dict[str, Any], delta: Any) -> None: pending_parts.setdefault(oc_msg_id, []).append((part, delta)) pending_parts_total += 1 - def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: + def handle_part( + part: dict[str, Any], + delta: Any, + *, + is_subtask: bool = False, + ) -> list[dict[str, Any]]: part_type = part.get("type", "") part_id = part.get("id", "") events: list[dict[str, Any]] = [] if part_type == "text": + if is_subtask: + return events # Don't forward child text tokens text = part.get("text", "") if delta: cumulative_text[part_id] = cumulative_text.get(part_id, "") + delta @@ -773,7 +795,8 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: state = part.get("state", {}) status = state.get("status", "") call_id = part.get("callID", "") - tool_key = f"tool:{call_id}:{status}" + part_sid = part.get("sessionID", "") + tool_key = f"tool:{part_sid}:{call_id}:{status}" if tool_key not in emitted_tool_states: emitted_tool_states.add(tool_key) @@ -798,6 +821,9 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: } ) + if is_subtask: + for ev in events: + ev["isSubtask"] = True return events try: @@ -837,10 +863,31 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: if event_type == "server.connected": pass elif event_type != "server.heartbeat": + # Track direct child sessions before filtering + if event_type == "session.created": + info = props.get("info", {}) + child_id = info.get("id") + child_parent = info.get("parentID") + if child_id and child_parent == self.opencode_session_id: + tracked_child_session_ids.add(child_id) + self.log.info( + "bridge.child_session_detected", + child_session_id=child_id, + source="session.created", + ) + # Always continue: no downstream handler processes session.created, + # and non-matching events would just fall through to no-op. + continue + event_session_id = props.get("sessionID") or props.get("part", {}).get( "sessionID" ) - if not event_session_id or event_session_id == self.opencode_session_id: + is_child = event_session_id in tracked_child_session_ids + if ( + not event_session_id + or event_session_id == self.opencode_session_id + or is_child + ): if event_type == "message.updated": info = props.get("info", {}) msg_session_id = info.get("sessionID") @@ -876,18 +923,59 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: finish=finish, ) + elif msg_session_id in tracked_child_session_ids: + # Child session: authorize all assistant messages + oc_msg_id = info.get("id", "") + role = info.get("role", "") + if role == "assistant" and oc_msg_id: + allowed_assistant_msg_ids.add(oc_msg_id) + pending = pending_parts.pop(oc_msg_id, []) + if pending: + pending_parts_total -= len(pending) + for part, delta in pending: + for ev in handle_part( + part, delta, is_subtask=True + ): + yield ev + elif event_type == "message.part.updated": part = props.get("part", {}) delta = props.get("delta") oc_msg_id = part.get("messageID", "") + part_session_id = part.get("sessionID", "") + + # Discover child sessions from task tool metadata (covers task_id resume) + if ( + part.get("tool") == "task" + and part_session_id == self.opencode_session_id + ): + metadata = part.get("metadata") + child_sid = ( + metadata.get("sessionId") + if isinstance(metadata, dict) + else None + ) + if child_sid and child_sid not in tracked_child_session_ids: + tracked_child_session_ids.add(child_sid) + self.log.info( + "bridge.child_session_detected", + child_session_id=child_sid, + source="task_metadata", + ) + if oc_msg_id in allowed_assistant_msg_ids: - for part_event in handle_part(part, delta): - yield part_event + if part_session_id in tracked_child_session_ids: + for ev in handle_part(part, delta, is_subtask=True): + yield ev + else: + for part_event in handle_part(part, delta): + yield part_event elif oc_msg_id: buffer_part(oc_msg_id, part, delta) elif event_type == "session.idle": idle_session_id = props.get("sessionID") + # Only parent idle terminates the stream if idle_session_id == self.opencode_session_id: elapsed = time.time() - start_time self.log.debug( @@ -907,6 +995,7 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: elif event_type == "session.status": status_session_id = props.get("sessionID") status = props.get("status", {}) + # Only parent status=idle terminates the stream if ( status_session_id == self.opencode_session_id and status.get("type") == "idle" @@ -929,18 +1018,9 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: elif event_type == "session.error": error_session_id = props.get("sessionID") if error_session_id == self.opencode_session_id: - error = props.get("error", {}) - # OpenCode NamedError structure: { "name": "...", "data": { "message": "..." } } - if isinstance(error, dict): - data = error.get("data") - if isinstance(data, dict) and "message" in data: - error_msg = data["message"] - else: - error_msg = error.get("message") or error.get( - "name" - ) - else: - error_msg = str(error) if error else None + error_msg = self._extract_error_message( + props.get("error", {}) + ) self.log.error("bridge.session_error", error_msg=error_msg) yield { "type": "error", @@ -948,6 +1028,22 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: "messageId": message_id, } return + elif error_session_id in tracked_child_session_ids: + error_msg = self._extract_error_message( + props.get("error", {}) + ) + self.log.error( + "bridge.child_session_error", + error_msg=error_msg, + child_session_id=error_session_id, + ) + yield { + "type": "error", + "error": error_msg or "Sub-task error", + "messageId": message_id, + "isSubtask": True, + } + # No return — parent stream continues if loop.time() > prompt_start + self.PROMPT_MAX_DURATION: elapsed = time.time() - start_time diff --git a/packages/modal-infra/tests/test_bridge_sse.py b/packages/modal-infra/tests/test_bridge_sse.py index 7d00250..5046439 100644 --- a/packages/modal-infra/tests/test_bridge_sse.py +++ b/packages/modal-infra/tests/test_bridge_sse.py @@ -720,6 +720,37 @@ async def test_skips_user_messages(self, bridge_with_mock_client: AgentBridge): assert events[0]["content"] == "Assistant response" +class TestExtractErrorMessage: + """Tests for _extract_error_message static method.""" + + def test_named_error_with_data_message(self): + """Should extract message from NamedError data.message.""" + error = {"name": "SomeError", "data": {"message": "Something broke"}} + assert AgentBridge._extract_error_message(error) == "Something broke" + + def test_dict_with_message_key(self): + """Should fall back to error.message when no data.message.""" + error = {"message": "Direct message"} + assert AgentBridge._extract_error_message(error) == "Direct message" + + def test_dict_with_name_key_only(self): + """Should fall back to error.name when no message key.""" + error = {"name": "TimeoutError"} + assert AgentBridge._extract_error_message(error) == "TimeoutError" + + def test_non_dict_error(self): + """Should stringify non-dict errors.""" + assert AgentBridge._extract_error_message("raw error string") == "raw error string" + + def test_none_error(self): + """Should return None for falsy error.""" + assert AgentBridge._extract_error_message(None) is None + + def test_empty_dict(self): + """Should return None for empty dict (no message or name).""" + assert AgentBridge._extract_error_message({}) is None + + class TestSSEFollowUpMessageBug: """Integration tests for the follow-up message bug fix. @@ -1161,5 +1192,676 @@ async def test_prompt_max_duration_timeout(self): assert any(url.endswith("/message") for url in http_client.get_urls) +class TestSubtaskStreaming: + """Tests for child session (sub-task) event streaming through the bridge.""" + + @pytest.mark.asyncio + async def test_child_session_tool_events_streamed( + self, bridge: AgentBridge, opencode_message_id: str + ): + """Child session tool events should be forwarded with isSubtask=True.""" + http_client = bridge.http_client + + http_client.sse_events = [ + create_sse_event("server.connected", {}), + # Parent message.updated to authorize parent messages + create_sse_event( + "message.updated", + { + "info": { + "id": "oc-msg-1", + "role": "assistant", + "sessionID": "oc-session-123", + "parentID": opencode_message_id, + } + }, + ), + # Child session created + create_sse_event( + "session.created", + { + "info": { + "id": "child-1", + "parentID": "oc-session-123", + } + }, + ), + # Child message.updated + create_sse_event( + "message.updated", + { + "info": { + "id": "child-msg-1", + "role": "assistant", + "sessionID": "child-1", + } + }, + ), + # Child tool running + create_sse_event( + "message.part.updated", + { + "part": { + "type": "tool", + "id": "child-part-1", + "sessionID": "child-1", + "messageID": "child-msg-1", + "tool": "Bash", + "callID": "child-call-1", + "state": { + "status": "running", + "input": {"command": "ls"}, + "output": "", + }, + } + }, + ), + # Child tool completed + create_sse_event( + "message.part.updated", + { + "part": { + "type": "tool", + "id": "child-part-1", + "sessionID": "child-1", + "messageID": "child-msg-1", + "tool": "Bash", + "callID": "child-call-1", + "state": { + "status": "completed", + "input": {"command": "ls"}, + "output": "file.txt", + }, + } + }, + ), + create_sse_event("session.idle", {"sessionID": "oc-session-123"}), + ] + + events = [] + async for event in bridge._stream_opencode_response_sse("cp-msg-1", "Test prompt"): + events.append(event) + + tool_events = [e for e in events if e["type"] == "tool_call"] + assert len(tool_events) == 2 + assert tool_events[0]["status"] == "running" + assert tool_events[0]["isSubtask"] is True + assert tool_events[0]["messageId"] == "cp-msg-1" + assert tool_events[1]["status"] == "completed" + assert tool_events[1]["isSubtask"] is True + + @pytest.mark.asyncio + async def test_child_text_events_not_forwarded( + self, bridge: AgentBridge, opencode_message_id: str + ): + """Child session text events should NOT be forwarded.""" + http_client = bridge.http_client + + http_client.sse_events = [ + create_sse_event("server.connected", {}), + create_sse_event( + "message.updated", + { + "info": { + "id": "oc-msg-1", + "role": "assistant", + "sessionID": "oc-session-123", + "parentID": opencode_message_id, + } + }, + ), + create_sse_event( + "session.created", + { + "info": { + "id": "child-1", + "parentID": "oc-session-123", + } + }, + ), + create_sse_event( + "message.updated", + { + "info": { + "id": "child-msg-1", + "role": "assistant", + "sessionID": "child-1", + } + }, + ), + # Child text event — should be suppressed + create_sse_event( + "message.part.updated", + { + "part": { + "type": "text", + "id": "child-text-1", + "sessionID": "child-1", + "messageID": "child-msg-1", + "text": "I am thinking...", + }, + "delta": "I am thinking...", + }, + ), + create_sse_event("session.idle", {"sessionID": "oc-session-123"}), + ] + + events = [] + async for event in bridge._stream_opencode_response_sse("cp-msg-1", "Test prompt"): + events.append(event) + + token_events = [e for e in events if e["type"] == "token"] + assert len(token_events) == 0 + + @pytest.mark.asyncio + async def test_child_idle_does_not_terminate_stream( + self, bridge: AgentBridge, opencode_message_id: str + ): + """Child session.idle should NOT terminate the parent stream.""" + http_client = bridge.http_client + + http_client.sse_events = [ + create_sse_event("server.connected", {}), + create_sse_event( + "message.updated", + { + "info": { + "id": "oc-msg-1", + "role": "assistant", + "sessionID": "oc-session-123", + "parentID": opencode_message_id, + } + }, + ), + create_sse_event( + "session.created", + { + "info": { + "id": "child-1", + "parentID": "oc-session-123", + } + }, + ), + # Child goes idle — should NOT terminate + create_sse_event("session.idle", {"sessionID": "child-1"}), + # Parent text event after child idle + create_sse_event( + "message.part.updated", + { + "part": { + "type": "text", + "id": "part-1", + "sessionID": "oc-session-123", + "messageID": "oc-msg-1", + "text": "Task result", + } + }, + ), + # Parent goes idle — terminates + create_sse_event("session.idle", {"sessionID": "oc-session-123"}), + ] + + events = [] + async for event in bridge._stream_opencode_response_sse("cp-msg-1", "Test prompt"): + events.append(event) + + token_events = [e for e in events if e["type"] == "token"] + assert len(token_events) == 1 + assert token_events[0]["content"] == "Task result" + + @pytest.mark.asyncio + async def test_child_session_status_idle_does_not_terminate_stream( + self, bridge: AgentBridge, opencode_message_id: str + ): + """Child session.status with type=idle should NOT terminate the parent stream.""" + http_client = bridge.http_client + + http_client.sse_events = [ + create_sse_event("server.connected", {}), + create_sse_event( + "message.updated", + { + "info": { + "id": "oc-msg-1", + "role": "assistant", + "sessionID": "oc-session-123", + "parentID": opencode_message_id, + } + }, + ), + create_sse_event( + "session.created", + { + "info": { + "id": "child-1", + "parentID": "oc-session-123", + } + }, + ), + # Child session.status idle — should NOT terminate + create_sse_event( + "session.status", + {"sessionID": "child-1", "status": {"type": "idle"}}, + ), + # Parent text event after child status idle + create_sse_event( + "message.part.updated", + { + "part": { + "type": "text", + "id": "part-1", + "sessionID": "oc-session-123", + "messageID": "oc-msg-1", + "text": "Still going", + } + }, + ), + create_sse_event("session.idle", {"sessionID": "oc-session-123"}), + ] + + events = [] + async for event in bridge._stream_opencode_response_sse("cp-msg-1", "Test prompt"): + events.append(event) + + token_events = [e for e in events if e["type"] == "token"] + assert len(token_events) == 1 + assert token_events[0]["content"] == "Still going" + + @pytest.mark.asyncio + async def test_child_session_error_forwarded_without_termination( + self, bridge: AgentBridge, opencode_message_id: str + ): + """Child session errors should be forwarded with isSubtask=True but not terminate stream.""" + http_client = bridge.http_client + + http_client.sse_events = [ + create_sse_event("server.connected", {}), + create_sse_event( + "message.updated", + { + "info": { + "id": "oc-msg-1", + "role": "assistant", + "sessionID": "oc-session-123", + "parentID": opencode_message_id, + } + }, + ), + create_sse_event( + "session.created", + { + "info": { + "id": "child-1", + "parentID": "oc-session-123", + } + }, + ), + # Child session error + create_sse_event( + "session.error", + { + "sessionID": "child-1", + "error": {"data": {"message": "Sub-task failed"}}, + }, + ), + # Parent text after child error — should still be received + create_sse_event( + "message.part.updated", + { + "part": { + "type": "text", + "id": "part-1", + "sessionID": "oc-session-123", + "messageID": "oc-msg-1", + "text": "Recovered from sub-task error", + } + }, + ), + create_sse_event("session.idle", {"sessionID": "oc-session-123"}), + ] + + events = [] + async for event in bridge._stream_opencode_response_sse("cp-msg-1", "Test prompt"): + events.append(event) + + error_events = [e for e in events if e["type"] == "error"] + assert len(error_events) == 1 + assert error_events[0]["error"] == "Sub-task failed" + assert error_events[0]["isSubtask"] is True + + token_events = [e for e in events if e["type"] == "token"] + assert len(token_events) == 1 + assert token_events[0]["content"] == "Recovered from sub-task error" + + @pytest.mark.asyncio + async def test_child_message_buffering_race_condition( + self, bridge: AgentBridge, opencode_message_id: str + ): + """Child parts arriving before message.updated should be buffered and flushed.""" + http_client = bridge.http_client + + http_client.sse_events = [ + create_sse_event("server.connected", {}), + create_sse_event( + "message.updated", + { + "info": { + "id": "oc-msg-1", + "role": "assistant", + "sessionID": "oc-session-123", + "parentID": opencode_message_id, + } + }, + ), + create_sse_event( + "session.created", + { + "info": { + "id": "child-1", + "parentID": "oc-session-123", + } + }, + ), + # Child tool event BEFORE message.updated — should be buffered + create_sse_event( + "message.part.updated", + { + "part": { + "type": "tool", + "id": "child-part-1", + "sessionID": "child-1", + "messageID": "child-msg-1", + "tool": "Read", + "callID": "child-call-1", + "state": { + "status": "running", + "input": {"path": "/file.txt"}, + "output": "", + }, + } + }, + ), + # Now child message.updated arrives — should flush buffered part + create_sse_event( + "message.updated", + { + "info": { + "id": "child-msg-1", + "role": "assistant", + "sessionID": "child-1", + } + }, + ), + create_sse_event("session.idle", {"sessionID": "oc-session-123"}), + ] + + events = [] + async for event in bridge._stream_opencode_response_sse("cp-msg-1", "Test prompt"): + events.append(event) + + tool_events = [e for e in events if e["type"] == "tool_call"] + assert len(tool_events) == 1 + assert tool_events[0]["isSubtask"] is True + assert tool_events[0]["tool"] == "Read" + + @pytest.mark.asyncio + async def test_resumed_child_session_discovered_via_metadata( + self, bridge: AgentBridge, opencode_message_id: str + ): + """Child sessions resumed via task_id should be discovered from task tool metadata.""" + http_client = bridge.http_client + + http_client.sse_events = [ + create_sse_event("server.connected", {}), + create_sse_event( + "message.updated", + { + "info": { + "id": "oc-msg-1", + "role": "assistant", + "sessionID": "oc-session-123", + "parentID": opencode_message_id, + } + }, + ), + # NO session.created — child was resumed via task_id + # Parent task tool part with metadata.sessionId + create_sse_event( + "message.part.updated", + { + "part": { + "type": "tool", + "id": "parent-part-1", + "sessionID": "oc-session-123", + "messageID": "oc-msg-1", + "tool": "task", + "callID": "task-call-1", + "metadata": {"sessionId": "child-1"}, + "state": { + "status": "running", + "input": {"prompt": "do something"}, + "output": "", + }, + } + }, + ), + # Now child events arrive + create_sse_event( + "message.updated", + { + "info": { + "id": "child-msg-1", + "role": "assistant", + "sessionID": "child-1", + } + }, + ), + create_sse_event( + "message.part.updated", + { + "part": { + "type": "tool", + "id": "child-part-1", + "sessionID": "child-1", + "messageID": "child-msg-1", + "tool": "Bash", + "callID": "child-call-1", + "state": { + "status": "running", + "input": {"command": "echo hello"}, + "output": "", + }, + } + }, + ), + create_sse_event("session.idle", {"sessionID": "oc-session-123"}), + ] + + events = [] + async for event in bridge._stream_opencode_response_sse("cp-msg-1", "Test prompt"): + events.append(event) + + tool_events = [e for e in events if e["type"] == "tool_call"] + # Should have: parent task tool (running) + child Bash tool (running) + parent_tools = [e for e in tool_events if not e.get("isSubtask")] + child_tools = [e for e in tool_events if e.get("isSubtask")] + assert len(parent_tools) == 1 + assert parent_tools[0]["tool"] == "task" + assert len(child_tools) == 1 + assert child_tools[0]["tool"] == "Bash" + assert child_tools[0]["isSubtask"] is True + + @pytest.mark.asyncio + async def test_parent_child_callid_collision( + self, bridge: AgentBridge, opencode_message_id: str + ): + """Parent and child using same callID should both emit events (session-scoped dedupe).""" + http_client = bridge.http_client + + http_client.sse_events = [ + create_sse_event("server.connected", {}), + create_sse_event( + "message.updated", + { + "info": { + "id": "oc-msg-1", + "role": "assistant", + "sessionID": "oc-session-123", + "parentID": opencode_message_id, + } + }, + ), + create_sse_event( + "session.created", + { + "info": { + "id": "child-1", + "parentID": "oc-session-123", + } + }, + ), + create_sse_event( + "message.updated", + { + "info": { + "id": "child-msg-1", + "role": "assistant", + "sessionID": "child-1", + } + }, + ), + # Parent tool with callID "abc" + create_sse_event( + "message.part.updated", + { + "part": { + "type": "tool", + "id": "parent-part-1", + "sessionID": "oc-session-123", + "messageID": "oc-msg-1", + "tool": "Bash", + "callID": "abc", + "state": { + "status": "running", + "input": {"command": "echo parent"}, + "output": "", + }, + } + }, + ), + # Child tool with same callID "abc" + create_sse_event( + "message.part.updated", + { + "part": { + "type": "tool", + "id": "child-part-1", + "sessionID": "child-1", + "messageID": "child-msg-1", + "tool": "Bash", + "callID": "abc", + "state": { + "status": "running", + "input": {"command": "echo child"}, + "output": "", + }, + } + }, + ), + create_sse_event("session.idle", {"sessionID": "oc-session-123"}), + ] + + events = [] + async for event in bridge._stream_opencode_response_sse("cp-msg-1", "Test prompt"): + events.append(event) + + tool_events = [e for e in events if e["type"] == "tool_call"] + # Both should be emitted despite same callID (session-scoped dedupe) + assert len(tool_events) == 2 + parent_tools = [e for e in tool_events if not e.get("isSubtask")] + child_tools = [e for e in tool_events if e.get("isSubtask")] + assert len(parent_tools) == 1 + assert len(child_tools) == 1 + + @pytest.mark.asyncio + async def test_grandchild_session_not_tracked( + self, bridge: AgentBridge, opencode_message_id: str + ): + """Grandchild sessions (parentID != opencode_session_id) should NOT be tracked.""" + http_client = bridge.http_client + + http_client.sse_events = [ + create_sse_event("server.connected", {}), + create_sse_event( + "message.updated", + { + "info": { + "id": "oc-msg-1", + "role": "assistant", + "sessionID": "oc-session-123", + "parentID": opencode_message_id, + } + }, + ), + # Direct child + create_sse_event( + "session.created", + { + "info": { + "id": "child-1", + "parentID": "oc-session-123", + } + }, + ), + # Grandchild — parentID is child-1, NOT oc-session-123 + create_sse_event( + "session.created", + { + "info": { + "id": "grandchild-1", + "parentID": "child-1", + } + }, + ), + # Grandchild message + tool — should be filtered out + create_sse_event( + "message.updated", + { + "info": { + "id": "gc-msg-1", + "role": "assistant", + "sessionID": "grandchild-1", + } + }, + ), + create_sse_event( + "message.part.updated", + { + "part": { + "type": "tool", + "id": "gc-part-1", + "sessionID": "grandchild-1", + "messageID": "gc-msg-1", + "tool": "Bash", + "callID": "gc-call-1", + "state": { + "status": "running", + "input": {"command": "echo grandchild"}, + "output": "", + }, + } + }, + ), + create_sse_event("session.idle", {"sessionID": "oc-session-123"}), + ] + + events = [] + async for event in bridge._stream_opencode_response_sse("cp-msg-1", "Test prompt"): + events.append(event) + + tool_events = [e for e in events if e["type"] == "tool_call"] + assert len(tool_events) == 0 # Grandchild events should be filtered out + + if __name__ == "__main__": pytest.main([__file__, "-v"])