-
Notifications
You must be signed in to change notification settings - Fork 77
feat: stream sub-task (child session) events through the bridge #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a2935a2
e3ad5f3
b146a6e
1c570f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -534,6 +534,16 @@ async def _create_opencode_session(self) -> None: | |
|
|
||
| await self._save_session_id() | ||
|
|
||
| @staticmethod | ||
| def _extract_error_message(error: Any) -> str | None: | ||
| """Extract message from OpenCode NamedError: { "name": "...", "data": { "message": "..." } }.""" | ||
| if isinstance(error, dict): | ||
| data = error.get("data") | ||
| if isinstance(data, dict) and "message" in data: | ||
| return data["message"] | ||
| return error.get("message") or error.get("name") | ||
| return str(error) if error else None | ||
ColeMurray marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _transform_part_to_event( | ||
| self, | ||
| part: dict[str, Any], | ||
|
|
@@ -705,6 +715,8 @@ async def _stream_opencode_response_sse( | |
| 2. OpenCode creates assistant messages with parentID = our ascending ID | ||
| 3. Filter events to only process parts from our assistant messages | ||
| 4. Use control plane's message_id for events sent back | ||
| 5. Track child sessions (sub-tasks) and forward their non-text events | ||
| with isSubtask=True | ||
|
|
||
| The ascending ID ensures our user message ID is lexicographically greater | ||
| than any previous assistant message IDs, preventing the early exit condition | ||
|
|
@@ -728,6 +740,9 @@ async def _stream_opencode_response_sse( | |
| pending_parts_total = 0 | ||
| pending_drop_logged = False | ||
|
|
||
| # Child session tracking (sub-tasks) | ||
| tracked_child_session_ids: set[str] = set() | ||
|
|
||
| start_time = time.time() | ||
| loop = asyncio.get_running_loop() | ||
|
|
||
|
|
@@ -746,12 +761,19 @@ def buffer_part(oc_msg_id: str, part: dict[str, Any], delta: Any) -> None: | |
| pending_parts.setdefault(oc_msg_id, []).append((part, delta)) | ||
| pending_parts_total += 1 | ||
|
|
||
| def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: | ||
| def handle_part( | ||
| part: dict[str, Any], | ||
| delta: Any, | ||
| *, | ||
| is_subtask: bool = False, | ||
| ) -> list[dict[str, Any]]: | ||
| part_type = part.get("type", "") | ||
| part_id = part.get("id", "") | ||
| events: list[dict[str, Any]] = [] | ||
|
|
||
| if part_type == "text": | ||
| if is_subtask: | ||
| return events # Don't forward child text tokens | ||
| text = part.get("text", "") | ||
| if delta: | ||
| cumulative_text[part_id] = cumulative_text.get(part_id, "") + delta | ||
|
|
@@ -773,7 +795,8 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: | |
| state = part.get("state", {}) | ||
| status = state.get("status", "") | ||
| call_id = part.get("callID", "") | ||
| tool_key = f"tool:{call_id}:{status}" | ||
| part_sid = part.get("sessionID", "") | ||
| tool_key = f"tool:{part_sid}:{call_id}:{status}" | ||
|
|
||
| if tool_key not in emitted_tool_states: | ||
| emitted_tool_states.add(tool_key) | ||
|
|
@@ -798,6 +821,9 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: | |
| } | ||
| ) | ||
|
|
||
| if is_subtask: | ||
| for ev in events: | ||
| ev["isSubtask"] = True | ||
| return events | ||
|
|
||
| try: | ||
|
|
@@ -837,10 +863,31 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: | |
| if event_type == "server.connected": | ||
| pass | ||
| elif event_type != "server.heartbeat": | ||
| # Track direct child sessions before filtering | ||
| if event_type == "session.created": | ||
| info = props.get("info", {}) | ||
| child_id = info.get("id") | ||
| child_parent = info.get("parentID") | ||
| if child_id and child_parent == self.opencode_session_id: | ||
| tracked_child_session_ids.add(child_id) | ||
| self.log.info( | ||
| "bridge.child_session_detected", | ||
| child_session_id=child_id, | ||
| source="session.created", | ||
| ) | ||
| # Always continue: no downstream handler processes session.created, | ||
| # and non-matching events would just fall through to no-op. | ||
| continue | ||
|
|
||
| event_session_id = props.get("sessionID") or props.get("part", {}).get( | ||
| "sessionID" | ||
| ) | ||
| if not event_session_id or event_session_id == self.opencode_session_id: | ||
| is_child = event_session_id in tracked_child_session_ids | ||
| if ( | ||
| not event_session_id | ||
| or event_session_id == self.opencode_session_id | ||
| or is_child | ||
| ): | ||
| if event_type == "message.updated": | ||
| info = props.get("info", {}) | ||
| msg_session_id = info.get("sessionID") | ||
|
|
@@ -876,18 +923,59 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: | |
| finish=finish, | ||
| ) | ||
|
|
||
| elif msg_session_id in tracked_child_session_ids: | ||
| # Child session: authorize all assistant messages | ||
| oc_msg_id = info.get("id", "") | ||
| role = info.get("role", "") | ||
| if role == "assistant" and oc_msg_id: | ||
| allowed_assistant_msg_ids.add(oc_msg_id) | ||
| pending = pending_parts.pop(oc_msg_id, []) | ||
| if pending: | ||
| pending_parts_total -= len(pending) | ||
| for part, delta in pending: | ||
| for ev in handle_part( | ||
| part, delta, is_subtask=True | ||
| ): | ||
| yield ev | ||
|
|
||
| elif event_type == "message.part.updated": | ||
| part = props.get("part", {}) | ||
| delta = props.get("delta") | ||
| oc_msg_id = part.get("messageID", "") | ||
| part_session_id = part.get("sessionID", "") | ||
|
|
||
| # Discover child sessions from task tool metadata (covers task_id resume) | ||
| if ( | ||
| part.get("tool") == "task" | ||
| and part_session_id == self.opencode_session_id | ||
| ): | ||
| metadata = part.get("metadata") | ||
| child_sid = ( | ||
| metadata.get("sessionId") | ||
| if isinstance(metadata, dict) | ||
| else None | ||
| ) | ||
| if child_sid and child_sid not in tracked_child_session_ids: | ||
| tracked_child_session_ids.add(child_sid) | ||
| self.log.info( | ||
| "bridge.child_session_detected", | ||
| child_session_id=child_sid, | ||
| source="task_metadata", | ||
| ) | ||
|
|
||
| if oc_msg_id in allowed_assistant_msg_ids: | ||
| for part_event in handle_part(part, delta): | ||
| yield part_event | ||
| if part_session_id in tracked_child_session_ids: | ||
| for ev in handle_part(part, delta, is_subtask=True): | ||
| yield ev | ||
| else: | ||
| for part_event in handle_part(part, delta): | ||
| yield part_event | ||
| elif oc_msg_id: | ||
| buffer_part(oc_msg_id, part, delta) | ||
|
|
||
| elif event_type == "session.idle": | ||
| idle_session_id = props.get("sessionID") | ||
| # Only parent idle terminates the stream | ||
| if idle_session_id == self.opencode_session_id: | ||
| elapsed = time.time() - start_time | ||
| self.log.debug( | ||
|
|
@@ -907,6 +995,7 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: | |
| elif event_type == "session.status": | ||
| status_session_id = props.get("sessionID") | ||
| status = props.get("status", {}) | ||
| # Only parent status=idle terminates the stream | ||
| if ( | ||
| status_session_id == self.opencode_session_id | ||
| and status.get("type") == "idle" | ||
|
|
@@ -929,25 +1018,32 @@ def handle_part(part: dict[str, Any], delta: Any) -> list[dict[str, Any]]: | |
| elif event_type == "session.error": | ||
| error_session_id = props.get("sessionID") | ||
| if error_session_id == self.opencode_session_id: | ||
| error = props.get("error", {}) | ||
| # OpenCode NamedError structure: { "name": "...", "data": { "message": "..." } } | ||
| if isinstance(error, dict): | ||
| data = error.get("data") | ||
| if isinstance(data, dict) and "message" in data: | ||
| error_msg = data["message"] | ||
| else: | ||
| error_msg = error.get("message") or error.get( | ||
| "name" | ||
| ) | ||
| else: | ||
| error_msg = str(error) if error else None | ||
| error_msg = self._extract_error_message( | ||
| props.get("error", {}) | ||
| ) | ||
| self.log.error("bridge.session_error", error_msg=error_msg) | ||
| yield { | ||
| "type": "error", | ||
| "error": error_msg or "Unknown error", | ||
| "messageId": message_id, | ||
| } | ||
| return | ||
| elif error_session_id in tracked_child_session_ids: | ||
| error_msg = self._extract_error_message( | ||
| props.get("error", {}) | ||
| ) | ||
| self.log.error( | ||
| "bridge.child_session_error", | ||
| error_msg=error_msg, | ||
| child_session_id=error_session_id, | ||
| ) | ||
| yield { | ||
| "type": "error", | ||
| "error": error_msg or "Sub-task error", | ||
| "messageId": message_id, | ||
| "isSubtask": True, | ||
| } | ||
|
Comment on lines
1031
to
1045
|
||
| # No return — parent stream continues | ||
|
|
||
| if loop.time() > prompt_start + self.PROMPT_MAX_DURATION: | ||
| elapsed = time.time() - start_time | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.