From 8d1257fb5a1543fe2506b098927ae2eda2abb1c5 Mon Sep 17 00:00:00 2001 From: Bryce Watson Date: Wed, 11 Feb 2026 21:19:44 -0800 Subject: [PATCH 1/2] feat: expose progress_callback in ServerSession.create_message() and elicit_form() Add progress_callback parameter to ServerSession methods that send requests to clients, matching the pattern already used by ClientSession.call_tool(). This brings the Python SDK into parity with the TypeScript SDK's RequestOptions.onprogress support. The MCP spec supports bidirectional progress notifications, and BaseSession.send_request() already accepts progress_callback, but the ServerSession convenience methods were not exposing it. Fixes #1671 Github-Issue: #1671 --- src/mcp/server/elicitation.py | 5 + src/mcp/server/mcpserver/server.py | 7 ++ src/mcp/server/session.py | 17 ++- tests/shared/test_progress_notifications.py | 115 ++++++++++++++++++++ 4 files changed, 143 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 58e9fe448..da4907e6e 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from mcp.server.session import ServerSession +from mcp.shared.session import ProgressFnT from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) @@ -107,6 +108,7 @@ async def elicit_with_validation( message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> ElicitationResult[ElicitSchemaModelT]: """Elicit information from the client/user with schema validation (form mode). @@ -127,6 +129,7 @@ async def elicit_with_validation( message=message, requested_schema=json_schema, related_request_id=related_request_id, + progress_callback=progress_callback, ) if result.action == "accept" and result.content is not None: @@ -148,6 +151,7 @@ async def elicit_url( url: str, elicitation_id: str, related_request_id: RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> UrlElicitationResult: """Elicit information from the user via out-of-band URL navigation (URL mode). @@ -177,6 +181,7 @@ async def elicit_url( url=url, elicitation_id=elicitation_id, related_request_id=related_request_id, + progress_callback=progress_callback, ) if result.action == "accept": diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index 8c1fc342b..7a46031e4 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -42,6 +42,7 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared.session import ProgressFnT from mcp.types import Annotations, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt from mcp.types import PromptArgument as MCPPromptArgument @@ -1070,6 +1071,7 @@ async def elicit( self, message: str, schema: type[ElicitSchemaModelT], + progress_callback: ProgressFnT | None = None, ) -> ElicitationResult[ElicitSchemaModelT]: """Elicit information from the client/user. @@ -1084,6 +1086,7 @@ async def elicit( only primitive types are allowed. message: Optional message to present to the user. If not provided, will use a default message based on the schema + progress_callback: Optional callback for receiving progress notifications. Returns: An ElicitationResult containing the action taken and the data if accepted @@ -1098,6 +1101,7 @@ async def elicit( message=message, schema=schema, related_request_id=self.request_id, + progress_callback=progress_callback, ) async def elicit_url( @@ -1105,6 +1109,7 @@ async def elicit_url( message: str, url: str, elicitation_id: str, + progress_callback: ProgressFnT | None = None, ) -> UrlElicitationResult: """Request URL mode elicitation from the client. @@ -1123,6 +1128,7 @@ async def elicit_url( message: Human-readable explanation of why the interaction is needed url: The URL the user should navigate to elicitation_id: Unique identifier for tracking this elicitation + progress_callback: Optional callback for receiving progress notifications. Returns: UrlElicitationResult indicating accept, decline, or cancel @@ -1133,6 +1139,7 @@ async def elicit_url( url=url, elicitation_id=elicitation_id, related_request_id=self.request_id, + progress_callback=progress_callback, ) async def log( diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index f496121a3..a842c9bf0 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -54,6 +54,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, + ProgressFnT, RequestResponder, ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -253,6 +254,7 @@ async def create_message( tools: None = None, tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResult: """Overload: Without tools, returns single content.""" ... @@ -272,6 +274,7 @@ async def create_message( tools: list[types.Tool], tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResultWithTools: """Overload: With tools, returns array-capable content.""" ... @@ -290,6 +293,7 @@ async def create_message( tools: list[types.Tool] | None = None, tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools: """Send a sampling/create_message request. @@ -309,6 +313,7 @@ async def create_message( tool_choice: Optional control over tool usage behavior. Requires client to have sampling.tools capability. related_request_id: Optional ID of a related request. + progress_callback: Optional callback for receiving progress notifications. Returns: The sampling result from the client. @@ -346,11 +351,13 @@ async def create_message( request=request, result_type=types.CreateMessageResultWithTools, metadata=metadata_obj, + progress_callback=progress_callback, ) return await self.send_request( request=request, result_type=types.CreateMessageResult, metadata=metadata_obj, + progress_callback=progress_callback, ) async def list_roots(self) -> types.ListRootsResult: @@ -367,6 +374,7 @@ async def elicit( message: str, requested_schema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a form mode elicitation/create request. @@ -374,6 +382,7 @@ async def elicit( message: The message to present to the user requested_schema: Schema defining the expected response structure related_request_id: Optional ID of the request that triggered this elicitation + progress_callback: Optional callback for receiving progress notifications. Returns: The client's response @@ -382,13 +391,14 @@ async def elicit( This method is deprecated in favor of elicit_form(). It remains for backward compatibility but new code should use elicit_form(). """ - return await self.elicit_form(message, requested_schema, related_request_id) + return await self.elicit_form(message, requested_schema, related_request_id, progress_callback) async def elicit_form( self, message: str, requested_schema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a form mode elicitation/create request. @@ -396,6 +406,7 @@ async def elicit_form( message: The message to present to the user requested_schema: Schema defining the expected response structure related_request_id: Optional ID of the request that triggered this elicitation + progress_callback: Optional callback for receiving progress notifications. Returns: The client's response with form data @@ -414,6 +425,7 @@ async def elicit_form( ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), + progress_callback=progress_callback, ) async def elicit_url( @@ -422,6 +434,7 @@ async def elicit_url( url: str, elicitation_id: str, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a URL mode elicitation/create request. @@ -433,6 +446,7 @@ async def elicit_url( url: The URL the user should navigate to elicitation_id: Unique identifier for tracking this elicitation related_request_id: Optional ID of the request that triggered this elicitation + progress_callback: Optional callback for receiving progress notifications. Returns: The client's response indicating acceptance, decline, or cancellation @@ -452,6 +466,7 @@ async def elicit_url( ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), + progress_callback=progress_callback, ) async def send_ping(self) -> types.EmptyResult: # pragma: no cover diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index ab117f1f0..257020d7d 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -8,12 +8,21 @@ from mcp.client.session import ClientSession from mcp.server import Server from mcp.server.lowlevel import NotificationOptions +from mcp.server.mcpserver import MCPServer from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.progress import progress from mcp.shared.session import RequestResponder +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + ElicitRequestParams, + ElicitResult, + SamplingMessage, + TextContent, +) @pytest.mark.anyio @@ -378,3 +387,109 @@ async def handle_list_tools() -> list[types.Tool]: # Check that a warning was logged for the progress callback exception assert len(logged_errors) > 0 assert any("Progress callback raised an exception" in warning for warning in logged_errors) + + +@pytest.mark.anyio +async def test_server_create_message_progress_callback(): + """Test that ServerSession.create_message() accepts and passes through progress_callback.""" + server = MCPServer("test") + + # Track progress updates received by the server's progress callback + progress_updates: list[dict[str, Any]] = [] + + async def my_progress_callback(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append({"progress": progress, "total": total, "message": message}) + + @server.tool("trigger_sampling") + async def trigger_sampling_tool(text: str) -> str: + result = await server.get_context().session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=text))], + max_tokens=100, + progress_callback=my_progress_callback, + ) + assert isinstance(result.content, TextContent) + return result.content.text + + async def sampling_callback( + context: RequestContext[ClientSession], + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + # Send progress notifications back to the server using the progress token + if context.meta and "progress_token" in context.meta: + token = context.meta["progress_token"] + await context.session.send_progress_notification( + progress_token=token, + progress=0.5, + total=1.0, + message="Halfway done", + ) + await context.session.send_progress_notification( + progress_token=token, + progress=1.0, + total=1.0, + message="Complete", + ) + + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text="LLM response"), + model="test-model", + stop_reason="endTurn", + ) + + async with Client(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("trigger_sampling", {"text": "Hello"}) + assert result.is_error is False + + # Verify the progress callback was invoked with correct values + assert len(progress_updates) == 2 + assert progress_updates[0] == {"progress": 0.5, "total": 1.0, "message": "Halfway done"} + assert progress_updates[1] == {"progress": 1.0, "total": 1.0, "message": "Complete"} + + +@pytest.mark.anyio +async def test_server_elicit_form_progress_callback(): + """Test that ServerSession.elicit_form() accepts and passes through progress_callback.""" + server = MCPServer("test") + + # Track progress updates received by the server's progress callback + progress_updates: list[dict[str, Any]] = [] + + async def my_progress_callback(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append({"progress": progress, "total": total, "message": message}) + + @server.tool("trigger_elicitation") + async def trigger_elicitation_tool(text: str) -> str: + result = await server.get_context().session.elicit_form( + message=text, + requested_schema={"type": "object", "properties": {"name": {"type": "string"}}}, + progress_callback=my_progress_callback, + ) + return result.action + + async def elicitation_callback( + context: RequestContext[ClientSession], + params: ElicitRequestParams, + ) -> ElicitResult: + # Send progress notifications back to the server using the progress token + if context.meta and "progress_token" in context.meta: + token = context.meta["progress_token"] + await context.session.send_progress_notification( + progress_token=token, + progress=1.0, + total=1.0, + message="User responded", + ) + + return ElicitResult( + action="accept", + content={"name": "test"}, + ) + + async with Client(server, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("trigger_elicitation", {"text": "Enter name"}) + assert result.is_error is False + + # Verify the progress callback was invoked + assert len(progress_updates) == 1 + assert progress_updates[0] == {"progress": 1.0, "total": 1.0, "message": "User responded"} From f4e5fe55addcfaaa0c9e32c707af212de881c8fb Mon Sep 17 00:00:00 2001 From: Bryce Watson Date: Wed, 11 Feb 2026 22:00:38 -0800 Subject: [PATCH 2/2] fix: add pragma no branch to test coverage guards The if-guards checking for progress_token in context.meta always evaluate to True when progress_callback is provided, so the false branches are never taken. Mark with pragma: no branch to satisfy 100% coverage requirement. --- tests/shared/test_progress_notifications.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 257020d7d..43e5663c7 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -415,7 +415,7 @@ async def sampling_callback( params: CreateMessageRequestParams, ) -> CreateMessageResult: # Send progress notifications back to the server using the progress token - if context.meta and "progress_token" in context.meta: + if context.meta and "progress_token" in context.meta: # pragma: no branch token = context.meta["progress_token"] await context.session.send_progress_notification( progress_token=token, @@ -472,7 +472,7 @@ async def elicitation_callback( params: ElicitRequestParams, ) -> ElicitResult: # Send progress notifications back to the server using the progress token - if context.meta and "progress_token" in context.meta: + if context.meta and "progress_token" in context.meta: # pragma: no branch token = context.meta["progress_token"] await context.session.send_progress_notification( progress_token=token,