diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9d45bec6e..eeb3acf2d 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -49,6 +49,26 @@ MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up +class SessionInfo: + """Base class for MCP session ID tracking. + + The transport delegates all session ID storage to this object. + Override ``get_session_id`` / ``set_session_id`` in subclasses to + add side-effects such as HTTP persistence. + """ + + def __init__(self, session_id: str | None = None) -> None: + self.session_id = session_id + + async def get_session_id(self) -> str | None: + """Return the current session ID (or None).""" + return self.session_id + + async def set_session_id(self, session_id: str | None) -> None: + """Store a new session ID assigned by the server, or None to clear.""" + self.session_id = session_id + + class StreamableHTTPError(Exception): """Base exception for StreamableHTTP transport errors.""" @@ -62,7 +82,6 @@ class RequestContext: """Context for a request operation.""" client: httpx.AsyncClient - session_id: str | None session_message: SessionMessage metadata: ClientMessageMetadata | None read_stream_writer: StreamWriter @@ -71,17 +90,18 @@ class RequestContext: class StreamableHTTPTransport: """StreamableHTTP client transport implementation.""" - def __init__(self, url: str) -> None: + def __init__(self, url: str, *, session_info: SessionInfo | None = None) -> None: """Initialize the StreamableHTTP transport. Args: url: The endpoint URL. + session_info: Optional SessionInfo for external session ID tracking. """ self.url = url - self.session_id: str | None = None + self._session_info = session_info or SessionInfo() self.protocol_version: str | None = None - def _prepare_headers(self) -> dict[str, str]: + async def _prepare_headers(self) -> dict[str, str]: """Build MCP-specific request headers. These headers will be merged with the httpx.AsyncClient's default headers, @@ -92,8 +112,9 @@ def _prepare_headers(self) -> dict[str, str]: "content-type": "application/json", } # Add session headers if available - if self.session_id: - headers[MCP_SESSION_ID] = self.session_id + session_id = await self._session_info.get_session_id() + if session_id: + headers[MCP_SESSION_ID] = session_id if self.protocol_version: headers[MCP_PROTOCOL_VERSION] = self.protocol_version return headers @@ -106,12 +127,12 @@ def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialized notification.""" return isinstance(message, JSONRPCNotification) and message.method == "notifications/initialized" - def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None: + async def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None: """Extract and store session ID from response headers.""" new_session_id = response.headers.get(MCP_SESSION_ID) if new_session_id: - self.session_id = new_session_id - logger.info(f"Received session ID: {self.session_id}") + await self._session_info.set_session_id(new_session_id) + logger.info(f"Received session ID: {new_session_id}") def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) -> None: """Extract protocol version from initialization response message.""" @@ -185,10 +206,10 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: while attempt < MAX_RECONNECTION_ATTEMPTS: # pragma: no branch try: - if not self.session_id: + if not await self._session_info.get_session_id(): return - headers = self._prepare_headers() + headers = await self._prepare_headers() if last_event_id: headers[LAST_EVENT_ID] = last_event_id @@ -224,7 +245,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer: async def _handle_resumption_request(self, ctx: RequestContext) -> None: """Handle a resumption request using GET with SSE.""" - headers = self._prepare_headers() + headers = await self._prepare_headers() if ctx.metadata and ctx.metadata.resumption_token: headers[LAST_EVENT_ID] = ctx.metadata.resumption_token else: @@ -252,7 +273,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: async def _handle_post_request(self, ctx: RequestContext) -> None: """Handle a POST request with response processing.""" - headers = self._prepare_headers() + headers = await self._prepare_headers() message = ctx.session_message.message is_initialization = self._is_initialization_request(message) @@ -275,7 +296,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: response.raise_for_status() if is_initialization: - self._maybe_extract_session_id_from_response(response) + await self._maybe_extract_session_id_from_response(response) # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications: # The server MUST NOT send a response to notifications. @@ -381,7 +402,7 @@ async def _handle_reconnection( delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS await anyio.sleep(delay_ms / 1000.0) - headers = self._prepare_headers() + headers = await self._prepare_headers() headers[LAST_EVENT_ID] = last_event_id # Extract original request ID to map responses @@ -453,7 +474,6 @@ async def post_writer( ctx = RequestContext( client=client, - session_id=self.session_id, session_message=session_message, metadata=metadata, read_stream_writer=read_stream_writer, @@ -479,11 +499,11 @@ async def handle_request_async(): async def terminate_session(self, client: httpx.AsyncClient) -> None: """Terminate the session by sending a DELETE request.""" - if not self.session_id: # pragma: lax no cover + if not await self._session_info.get_session_id(): # pragma: lax no cover return try: - headers = self._prepare_headers() + headers = await self._prepare_headers() response = await client.delete(self.url, headers=headers) if response.status_code == 405: # pragma: lax no cover @@ -493,21 +513,18 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None: except Exception as exc: # pragma: no cover logger.warning(f"Session termination failed: {exc}") - # TODO(Marcelo): Check the TODO below, and cover this with tests if necessary. - def get_session_id(self) -> str | None: + async def get_session_id(self) -> str | None: """Get the current session ID.""" - return self.session_id # pragma: no cover + return await self._session_info.get_session_id() -# TODO(Marcelo): I've dropped the `get_session_id` callback because it breaks the Transport protocol. Is that needed? -# It's a completely wrong abstraction, so removal is a good idea. But if we need the client to find the session ID, -# we should think about a better way to do it. I believe we can achieve it with other means. @asynccontextmanager async def streamable_http_client( url: str, *, http_client: httpx.AsyncClient | None = None, terminate_on_close: bool = True, + session_info: SessionInfo | None = None, ) -> AsyncGenerator[TransportStreams, None]: """Client transport for StreamableHTTP. @@ -517,6 +534,8 @@ async def streamable_http_client( client with recommended MCP timeouts will be created. To configure headers, authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here. terminate_on_close: If True, send a DELETE request to terminate the session when the context exits. + session_info: Optional SessionInfo for external session ID tracking. If None, a default + SessionInfo is created internally. Pass a custom subclass to observe or persist session IDs. Yields: Tuple containing: @@ -537,7 +556,7 @@ async def streamable_http_client( # Create default client with recommended MCP timeouts client = create_mcp_http_client() - transport = StreamableHTTPTransport(url) + transport = StreamableHTTPTransport(url, session_info=session_info) async with anyio.create_task_group() as tg: try: @@ -564,7 +583,7 @@ def start_get_stream() -> None: try: yield read_stream, write_stream finally: - if transport.session_id and terminate_on_close: + if await transport.get_session_id() and terminate_on_close: await transport.terminate_session(client) tg.cancel_scope.cancel() finally: