diff --git a/pyproject.toml b/pyproject.toml index 65bde6966..dc028f621 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ Issues = "https://github.com/modelcontextprotocol/python-sdk/issues" packages = ["src/mcp"] [tool.pyright] +pythonVersion = "3.10" typeCheckingMode = "strict" include = [ "src/mcp", diff --git a/src/mcp/server/lowlevel/func_inspection.py b/src/mcp/server/lowlevel/func_inspection.py index d17697090..0a3c0e46c 100644 --- a/src/mcp/server/lowlevel/func_inspection.py +++ b/src/mcp/server/lowlevel/func_inspection.py @@ -31,7 +31,7 @@ def create_call_wrapper(func: Callable[..., R], request_type: type[T]) -> Callab if param.default is not inspect.Parameter.empty: # pragma: no cover return lambda _: func() # Found positional-only parameter with correct type and no default - return lambda req: func(req) + return func # Check for any positional/keyword parameter typed as request_type for param_name, param in sig.parameters.items(): diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 8c1fc342b..46598cd97 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -805,7 +805,6 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no await self._lowlevel_server.run( streams[0], streams[1], self._lowlevel_server.create_initialization_options() ) - return Response() # Create routes routes: list[Route | Mount] = [] @@ -869,15 +868,19 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send): # pragma: no ) else: # Auth is disabled, no need for RequireAuthMiddleware - # Since handle_sse is an ASGI app, we need to create a compatible endpoint - async def sse_endpoint(request: Request) -> Response: # pragma: no cover - # Convert the Starlette request to ASGI parameters - return await handle_sse(request.scope, request.receive, request._send) # type: ignore[reportPrivateUsage] + # Use an ASGI-compatible wrapper to avoid Starlette's + # BaseHTTPMiddleware wrapping the SSE handler as a regular + # endpoint. BaseHTTPMiddleware expects http.response.body + # messages, but the SSE handler sends raw ASGI events, + # which triggers "AssertionError: Unexpected message". + class HandleSseAsgi: # pragma: no cover + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + await handle_sse(scope, receive, send) routes.append( Route( sse_path, - endpoint=sse_endpoint, + endpoint=HandleSseAsgi(), methods=["GET"], ) ) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f496121a3..1f7e9372e 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -102,7 +102,7 @@ def __init__( self._incoming_message_stream_writer, self._incoming_message_stream_reader = anyio.create_memory_object_stream[ ServerRequestResponder ](0) - self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) + self._exit_stack.push_async_callback(self._incoming_message_stream_reader.aclose) @property def _receive_request_adapter(self) -> TypeAdapter[types.ClientRequest]: diff --git a/tests/issues/test_883_middleware.py b/tests/issues/test_883_middleware.py new file mode 100644 index 000000000..b2a8730a5 --- /dev/null +++ b/tests/issues/test_883_middleware.py @@ -0,0 +1,96 @@ +"""Regression test for issue #883: AssertionError when using Starlette middleware. + +BaseHTTPMiddleware expects http.response.body messages, but the SSE handler +sends raw ASGI events, which triggers "AssertionError: Unexpected message" +when the SSE endpoint is wrapped as a regular Starlette endpoint. + +The fix uses an ASGI-compatible callable class (HandleSseAsgi) instead of a +Starlette endpoint wrapper, so the SSE handler bypasses middleware response +body wrapping. +""" + +import multiprocessing +import socket +from collections.abc import Generator + +import anyio +import httpx +import pytest +import uvicorn +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import Response + +from mcp.server.mcpserver import MCPServer +from mcp.server.transport_security import TransportSecuritySettings +from tests.test_helpers import wait_for_server + + +class PassthroughMiddleware(BaseHTTPMiddleware): # pragma: no cover + """A simple pass-through middleware that triggers BaseHTTPMiddleware wrapping.""" + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: + return await call_next(request) + + +@pytest.fixture +def server_port() -> int: + with socket.socket() as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def run_server_with_middleware(server_port: int) -> None: # pragma: no cover + """Create an MCP server wrapped in Starlette BaseHTTPMiddleware.""" + mcp_server = MCPServer("test-883") + transport_security = TransportSecuritySettings(enable_dns_rebinding_protection=False) + sse_app = mcp_server.sse_app(transport_security=transport_security, host="0.0.0.0") + + # This is the exact scenario that triggers #883: + # BaseHTTPMiddleware wrapping a Starlette app containing SSE endpoints + app = Starlette(middleware=[Middleware(PassthroughMiddleware)]) + app.mount("/", sse_app) + + server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) + server.run() + + +@pytest.fixture() +def middleware_server(server_port: int) -> Generator[None, None, None]: + proc = multiprocessing.Process( + target=run_server_with_middleware, + kwargs={"server_port": server_port}, + daemon=True, + ) + proc.start() + wait_for_server(server_port) + yield + proc.kill() + proc.join(timeout=2) + + +@pytest.mark.anyio +async def test_sse_with_middleware_no_assertion_error(middleware_server: None, server_port: int) -> None: + """Verify SSE endpoint works when Starlette BaseHTTPMiddleware is applied. + + Before the fix, this would raise: + AssertionError: Unexpected message type 'http.response.body' + """ + async with httpx.AsyncClient(base_url=f"http://127.0.0.1:{server_port}") as client: + with anyio.fail_after(5): + async with client.stream("GET", "/sse") as response: # pragma: no branch + assert response.status_code == 200 + assert "text/event-stream" in response.headers.get("content-type", "") + + # Read the first event to confirm SSE is streaming properly + line_number = 0 + async for line in response.aiter_lines(): # pragma: no branch + if line_number == 0: + assert line == "event: endpoint" + elif line_number == 1: + assert line.startswith("data: /messages/?session_id=") + else: + break + line_number += 1