Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 113 additions & 17 deletions packages/modal-infra/src/sandbox/bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,16 @@ async def _create_opencode_session(self) -> None:

await self._save_session_id()

@staticmethod
def _extract_error_message(error: Any) -> str | None:
"""Extract message from OpenCode NamedError: { "name": "...", "data": { "message": "..." } }."""
if isinstance(error, dict):
data = error.get("data")
if isinstance(data, dict) and "message" in data:
return data["message"]
return error.get("message") or error.get("name")
return str(error) if error else None

def _transform_part_to_event(
self,
part: dict[str, Any],
Expand Down Expand Up @@ -705,6 +715,8 @@ async def _stream_opencode_response_sse(
2. OpenCode creates assistant messages with parentID = our ascending ID
3. Filter events to only process parts from our assistant messages
4. Use control plane's message_id for events sent back
5. Track child sessions (sub-tasks) and forward their non-text events
with isSubtask=True

The ascending ID ensures our user message ID is lexicographically greater
than any previous assistant message IDs, preventing the early exit condition
Expand All @@ -728,6 +740,9 @@ async def _stream_opencode_response_sse(
pending_parts_total = 0
pending_drop_logged = False

# Child session tracking (sub-tasks)
tracked_child_session_ids: set[str] = set()

start_time = time.time()
loop = asyncio.get_running_loop()

Expand All @@ -746,12 +761,19 @@ def buffer_part(oc_msg_id: str, part: dict[str, Any], delta: Any) -> None:
pending_parts.setdefault(oc_msg_id, []).append((part, delta))
pending_parts_total += 1

def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]:
def handle_part(
part: dict[str, Any],
delta: Any,
*,
is_subtask: bool = False,
) -> list[dict[str, Any]]:
part_type = part.get("type", "")
part_id = part.get("id", "")
events: list[dict[str, Any]] = []

if part_type == "text":
if is_subtask:
return events # Don't forward child text tokens
text = part.get("text", "")
if delta:
cumulative_text[part_id] = cumulative_text.get(part_id, "") + delta
Expand All @@ -773,7 +795,8 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]:
state = part.get("state", {})
status = state.get("status", "")
call_id = part.get("callID", "")
tool_key = f"tool:{call_id}:{status}"
part_sid = part.get("sessionID", "")
tool_key = f"tool:{part_sid}:{call_id}:{status}"

if tool_key not in emitted_tool_states:
emitted_tool_states.add(tool_key)
Expand All @@ -798,6 +821,9 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]:
}
)

if is_subtask:
for ev in events:
ev["isSubtask"] = True
return events

try:
Expand Down Expand Up @@ -837,10 +863,31 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]:
if event_type == "server.connected":
pass
elif event_type != "server.heartbeat":
# Track direct child sessions before filtering
if event_type == "session.created":
info = props.get("info", {})
child_id = info.get("id")
child_parent = info.get("parentID")
if child_id and child_parent == self.opencode_session_id:
tracked_child_session_ids.add(child_id)
self.log.info(
"bridge.child_session_detected",
child_session_id=child_id,
source="session.created",
)
# Always continue: no downstream handler processes session.created,
# and non-matching events would just fall through to no-op.
continue

event_session_id = props.get("sessionID") or props.get("part", {}).get(
"sessionID"
)
if not event_session_id or event_session_id == self.opencode_session_id:
is_child = event_session_id in tracked_child_session_ids
if (
not event_session_id
or event_session_id == self.opencode_session_id
or is_child
):
if event_type == "message.updated":
info = props.get("info", {})
msg_session_id = info.get("sessionID")
Expand Down Expand Up @@ -876,18 +923,59 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]:
finish=finish,
)

elif msg_session_id in tracked_child_session_ids:
# Child session: authorize all assistant messages
oc_msg_id = info.get("id", "")
role = info.get("role", "")
if role == "assistant" and oc_msg_id:
allowed_assistant_msg_ids.add(oc_msg_id)
pending = pending_parts.pop(oc_msg_id, [])
if pending:
pending_parts_total -= len(pending)
for part, delta in pending:
for ev in handle_part(
part, delta, is_subtask=True
):
yield ev

elif event_type == "message.part.updated":
part = props.get("part", {})
delta = props.get("delta")
oc_msg_id = part.get("messageID", "")
part_session_id = part.get("sessionID", "")

# Discover child sessions from task tool metadata (covers task_id resume)
if (
part.get("tool") == "task"
and part_session_id == self.opencode_session_id
):
metadata = part.get("metadata")
child_sid = (
metadata.get("sessionId")
if isinstance(metadata, dict)
else None
)
if child_sid and child_sid not in tracked_child_session_ids:
tracked_child_session_ids.add(child_sid)
self.log.info(
"bridge.child_session_detected",
child_session_id=child_sid,
source="task_metadata",
)

if oc_msg_id in allowed_assistant_msg_ids:
for part_event in handle_part(part, delta):
yield part_event
if part_session_id in tracked_child_session_ids:
for ev in handle_part(part, delta, is_subtask=True):
yield ev
else:
for part_event in handle_part(part, delta):
yield part_event
elif oc_msg_id:
buffer_part(oc_msg_id, part, delta)

elif event_type == "session.idle":
idle_session_id = props.get("sessionID")
# Only parent idle terminates the stream
if idle_session_id == self.opencode_session_id:
elapsed = time.time() - start_time
self.log.debug(
Expand All @@ -907,6 +995,7 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]:
elif event_type == "session.status":
status_session_id = props.get("sessionID")
status = props.get("status", {})
# Only parent status=idle terminates the stream
if (
status_session_id == self.opencode_session_id
and status.get("type") == "idle"
Expand All @@ -929,25 +1018,32 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]:
elif event_type == "session.error":
error_session_id = props.get("sessionID")
if error_session_id == self.opencode_session_id:
error = props.get("error", {})
# OpenCode NamedError structure: { "name": "...", "data": { "message": "..." } }
if isinstance(error, dict):
data = error.get("data")
if isinstance(data, dict) and "message" in data:
error_msg = data["message"]
else:
error_msg = error.get("message") or error.get(
"name"
)
else:
error_msg = str(error) if error else None
error_msg = self._extract_error_message(
props.get("error", {})
)
self.log.error("bridge.session_error", error_msg=error_msg)
yield {
"type": "error",
"error": error_msg or "Unknown error",
"messageId": message_id,
}
return
elif error_session_id in tracked_child_session_ids:
error_msg = self._extract_error_message(
props.get("error", {})
)
self.log.error(
"bridge.child_session_error",
error_msg=error_msg,
child_session_id=error_session_id,
)
yield {
"type": "error",
"error": error_msg or "Sub-task error",
"messageId": message_id,
"isSubtask": True,
}
Comment on lines 1031 to 1045
Copy link

Copilot AI Feb 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Child session.error events are forwarded as {type: 'error', isSubtask: true}. In _handle_prompt(), any event with type == 'error' currently flips had_error=True and will make the final execution_complete.success false even if the parent stream recovers and completes normally. If the intent is “child errors don’t fail the parent execution”, consider emitting a different event type for subtask errors (or adjusting the parent error accounting to ignore isSubtask errors).

Copilot uses AI. Check for mistakes.
# No return — parent stream continues

if loop.time() > prompt_start + self.PROMPT_MAX_DURATION:
elapsed = time.time() - start_time
Expand Down
Loading
Loading