Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 45 additions & 26 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,26 @@
MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up


class SessionInfo:
"""Base class for MCP session ID tracking.

The transport delegates all session ID storage to this object.
Override ``get_session_id`` / ``set_session_id`` in subclasses to
add side-effects such as HTTP persistence.
"""

def __init__(self, session_id: str | None = None) -> None:
self.session_id = session_id

async def get_session_id(self) -> str | None:
"""Return the current session ID (or None)."""
return self.session_id

async def set_session_id(self, session_id: str | None) -> None:
"""Store a new session ID assigned by the server, or None to clear."""
self.session_id = session_id


class StreamableHTTPError(Exception):
"""Base exception for StreamableHTTP transport errors."""

Expand All @@ -62,7 +82,6 @@ class RequestContext:
"""Context for a request operation."""

client: httpx.AsyncClient
session_id: str | None
session_message: SessionMessage
metadata: ClientMessageMetadata | None
read_stream_writer: StreamWriter
Expand All @@ -71,17 +90,18 @@ class RequestContext:
class StreamableHTTPTransport:
"""StreamableHTTP client transport implementation."""

def __init__(self, url: str) -> None:
def __init__(self, url: str, *, session_info: SessionInfo | None = None) -> None:
"""Initialize the StreamableHTTP transport.

Args:
url: The endpoint URL.
session_info: Optional SessionInfo for external session ID tracking.
"""
self.url = url
self.session_id: str | None = None
self._session_info = session_info or SessionInfo()
self.protocol_version: str | None = None

def _prepare_headers(self) -> dict[str, str]:
async def _prepare_headers(self) -> dict[str, str]:
"""Build MCP-specific request headers.

These headers will be merged with the httpx.AsyncClient's default headers,
Expand All @@ -92,8 +112,9 @@ def _prepare_headers(self) -> dict[str, str]:
"content-type": "application/json",
}
# Add session headers if available
if self.session_id:
headers[MCP_SESSION_ID] = self.session_id
session_id = await self._session_info.get_session_id()
if session_id:
headers[MCP_SESSION_ID] = session_id
if self.protocol_version:
headers[MCP_PROTOCOL_VERSION] = self.protocol_version
return headers
Expand All @@ -106,12 +127,12 @@ def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialized notification."""
return isinstance(message, JSONRPCNotification) and message.method == "notifications/initialized"

def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None:
async def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None:
"""Extract and store session ID from response headers."""
new_session_id = response.headers.get(MCP_SESSION_ID)
if new_session_id:
self.session_id = new_session_id
logger.info(f"Received session ID: {self.session_id}")
await self._session_info.set_session_id(new_session_id)
logger.info(f"Received session ID: {new_session_id}")

def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) -> None:
"""Extract protocol version from initialization response message."""
Expand Down Expand Up @@ -185,10 +206,10 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:

while attempt < MAX_RECONNECTION_ATTEMPTS: # pragma: no branch
try:
if not self.session_id:
if not await self._session_info.get_session_id():
return

headers = self._prepare_headers()
headers = await self._prepare_headers()
if last_event_id:
headers[LAST_EVENT_ID] = last_event_id

Expand Down Expand Up @@ -224,7 +245,7 @@ async def handle_get_stream(self, client: httpx.AsyncClient, read_stream_writer:

async def _handle_resumption_request(self, ctx: RequestContext) -> None:
"""Handle a resumption request using GET with SSE."""
headers = self._prepare_headers()
headers = await self._prepare_headers()
if ctx.metadata and ctx.metadata.resumption_token:
headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
else:
Expand Down Expand Up @@ -252,7 +273,7 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:

async def _handle_post_request(self, ctx: RequestContext) -> None:
"""Handle a POST request with response processing."""
headers = self._prepare_headers()
headers = await self._prepare_headers()
message = ctx.session_message.message
is_initialization = self._is_initialization_request(message)

Expand All @@ -275,7 +296,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:

response.raise_for_status()
if is_initialization:
self._maybe_extract_session_id_from_response(response)
await self._maybe_extract_session_id_from_response(response)

# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
Expand Down Expand Up @@ -381,7 +402,7 @@ async def _handle_reconnection(
delay_ms = retry_interval_ms if retry_interval_ms is not None else DEFAULT_RECONNECTION_DELAY_MS
await anyio.sleep(delay_ms / 1000.0)

headers = self._prepare_headers()
headers = await self._prepare_headers()
headers[LAST_EVENT_ID] = last_event_id

# Extract original request ID to map responses
Expand Down Expand Up @@ -453,7 +474,6 @@ async def post_writer(

ctx = RequestContext(
client=client,
session_id=self.session_id,
session_message=session_message,
metadata=metadata,
read_stream_writer=read_stream_writer,
Expand All @@ -479,11 +499,11 @@ async def handle_request_async():

async def terminate_session(self, client: httpx.AsyncClient) -> None:
"""Terminate the session by sending a DELETE request."""
if not self.session_id: # pragma: lax no cover
if not await self._session_info.get_session_id(): # pragma: lax no cover
return

try:
headers = self._prepare_headers()
headers = await self._prepare_headers()
response = await client.delete(self.url, headers=headers)

if response.status_code == 405: # pragma: lax no cover
Expand All @@ -493,21 +513,18 @@ async def terminate_session(self, client: httpx.AsyncClient) -> None:
except Exception as exc: # pragma: no cover
logger.warning(f"Session termination failed: {exc}")

# TODO(Marcelo): Check the TODO below, and cover this with tests if necessary.
def get_session_id(self) -> str | None:
async def get_session_id(self) -> str | None:
"""Get the current session ID."""
return self.session_id # pragma: no cover
return await self._session_info.get_session_id()


# TODO(Marcelo): I've dropped the `get_session_id` callback because it breaks the Transport protocol. Is that needed?
# It's a completely wrong abstraction, so removal is a good idea. But if we need the client to find the session ID,
# we should think about a better way to do it. I believe we can achieve it with other means.
@asynccontextmanager
async def streamable_http_client(
url: str,
*,
http_client: httpx.AsyncClient | None = None,
terminate_on_close: bool = True,
session_info: SessionInfo | None = None,
) -> AsyncGenerator[TransportStreams, None]:
"""Client transport for StreamableHTTP.

Expand All @@ -517,6 +534,8 @@ async def streamable_http_client(
client with recommended MCP timeouts will be created. To configure headers,
authentication, or other HTTP settings, create an httpx.AsyncClient and pass it here.
terminate_on_close: If True, send a DELETE request to terminate the session when the context exits.
session_info: Optional SessionInfo for external session ID tracking. If None, a default
SessionInfo is created internally. Pass a custom subclass to observe or persist session IDs.

Yields:
Tuple containing:
Expand All @@ -537,7 +556,7 @@ async def streamable_http_client(
# Create default client with recommended MCP timeouts
client = create_mcp_http_client()

transport = StreamableHTTPTransport(url)
transport = StreamableHTTPTransport(url, session_info=session_info)

async with anyio.create_task_group() as tg:
try:
Expand All @@ -564,7 +583,7 @@ def start_get_stream() -> None:
try:
yield read_stream, write_stream
finally:
if transport.session_id and terminate_on_close:
if await transport.get_session_id() and terminate_on_close:
await transport.terminate_session(client)
tg.cancel_scope.cancel()
finally:
Expand Down