diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index 58e9fe448..da4907e6e 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -9,6 +9,7 @@ from pydantic import BaseModel from mcp.server.session import ServerSession +from mcp.shared.session import ProgressFnT from mcp.types import RequestId ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) @@ -107,6 +108,7 @@ async def elicit_with_validation( message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> ElicitationResult[ElicitSchemaModelT]: """Elicit information from the client/user with schema validation (form mode). @@ -127,6 +129,7 @@ async def elicit_with_validation( message=message, requested_schema=json_schema, related_request_id=related_request_id, + progress_callback=progress_callback, ) if result.action == "accept" and result.content is not None: @@ -148,6 +151,7 @@ async def elicit_url( url: str, elicitation_id: str, related_request_id: RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> UrlElicitationResult: """Elicit information from the user via out-of-band URL navigation (URL mode). @@ -177,6 +181,7 @@ async def elicit_url( url=url, elicitation_id=elicitation_id, related_request_id=related_request_id, + progress_callback=progress_callback, ) if result.action == "accept": diff --git a/src/mcp/server/mcpserver/server.py b/src/mcp/server/mcpserver/server.py index f26944a2d..14cb0975a 100644 --- a/src/mcp/server/mcpserver/server.py +++ b/src/mcp/server/mcpserver/server.py @@ -45,6 +45,7 @@ from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared.exceptions import MCPError +from mcp.shared.session import ProgressFnT from mcp.types import ( Annotations, BlobResourceContents, @@ -1181,6 +1182,7 @@ async def elicit( self, message: str, schema: type[ElicitSchemaModelT], + progress_callback: ProgressFnT | None = None, ) -> ElicitationResult[ElicitSchemaModelT]: """Elicit information from the client/user. @@ -1195,6 +1197,7 @@ async def elicit( only primitive types are allowed. message: Optional message to present to the user. If not provided, will use a default message based on the schema + progress_callback: Optional callback for receiving progress notifications. Returns: An ElicitationResult containing the action taken and the data if accepted @@ -1209,6 +1212,7 @@ async def elicit( message=message, schema=schema, related_request_id=self.request_id, + progress_callback=progress_callback, ) async def elicit_url( @@ -1216,6 +1220,7 @@ async def elicit_url( message: str, url: str, elicitation_id: str, + progress_callback: ProgressFnT | None = None, ) -> UrlElicitationResult: """Request URL mode elicitation from the client. @@ -1234,6 +1239,7 @@ async def elicit_url( message: Human-readable explanation of why the interaction is needed url: The URL the user should navigate to elicitation_id: Unique identifier for tracking this elicitation + progress_callback: Optional callback for receiving progress notifications. Returns: UrlElicitationResult indicating accept, decline, or cancel @@ -1244,6 +1250,7 @@ async def elicit_url( url=url, elicitation_id=elicitation_id, related_request_id=self.request_id, + progress_callback=progress_callback, ) async def log( diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 6925aa556..cf392841f 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -46,6 +46,7 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult: from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import ( BaseSession, + ProgressFnT, RequestResponder, ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS @@ -245,6 +246,7 @@ async def create_message( tools: None = None, tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResult: """Overload: Without tools, returns single content.""" ... @@ -264,6 +266,7 @@ async def create_message( tools: list[types.Tool], tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResultWithTools: """Overload: With tools, returns array-capable content.""" ... @@ -282,6 +285,7 @@ async def create_message( tools: list[types.Tool] | None = None, tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools: """Send a sampling/create_message request. @@ -301,6 +305,7 @@ async def create_message( tool_choice: Optional control over tool usage behavior. Requires client to have sampling.tools capability. related_request_id: Optional ID of a related request. + progress_callback: Optional callback for receiving progress notifications. Returns: The sampling result from the client. @@ -338,11 +343,13 @@ async def create_message( request=request, result_type=types.CreateMessageResultWithTools, metadata=metadata_obj, + progress_callback=progress_callback, ) return await self.send_request( request=request, result_type=types.CreateMessageResult, metadata=metadata_obj, + progress_callback=progress_callback, ) async def list_roots(self) -> types.ListRootsResult: @@ -359,6 +366,7 @@ async def elicit( message: str, requested_schema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a form mode elicitation/create request. @@ -366,6 +374,7 @@ async def elicit( message: The message to present to the user requested_schema: Schema defining the expected response structure related_request_id: Optional ID of the request that triggered this elicitation + progress_callback: Optional callback for receiving progress notifications. Returns: The client's response @@ -374,13 +383,14 @@ async def elicit( This method is deprecated in favor of elicit_form(). It remains for backward compatibility but new code should use elicit_form(). """ - return await self.elicit_form(message, requested_schema, related_request_id) + return await self.elicit_form(message, requested_schema, related_request_id, progress_callback) async def elicit_form( self, message: str, requested_schema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a form mode elicitation/create request. @@ -388,6 +398,7 @@ async def elicit_form( message: The message to present to the user requested_schema: Schema defining the expected response structure related_request_id: Optional ID of the request that triggered this elicitation + progress_callback: Optional callback for receiving progress notifications. Returns: The client's response with form data @@ -406,6 +417,7 @@ async def elicit_form( ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), + progress_callback=progress_callback, ) async def elicit_url( @@ -414,6 +426,7 @@ async def elicit_url( url: str, elicitation_id: str, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a URL mode elicitation/create request. @@ -425,6 +438,7 @@ async def elicit_url( url: The URL the user should navigate to elicitation_id: Unique identifier for tracking this elicitation related_request_id: Optional ID of the request that triggered this elicitation + progress_callback: Optional callback for receiving progress notifications. Returns: The client's response indicating acceptance, decline, or cancellation @@ -444,6 +458,7 @@ async def elicit_url( ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), + progress_callback=progress_callback, ) async def send_ping(self) -> types.EmptyResult: # pragma: no cover diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py index 6b87774c0..1667af379 100644 --- a/tests/shared/test_progress_notifications.py +++ b/tests/shared/test_progress_notifications.py @@ -8,12 +8,21 @@ from mcp.client.session import ClientSession from mcp.server import Server, ServerRequestContext from mcp.server.lowlevel import NotificationOptions +from mcp.server.mcpserver import MCPServer from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared._context import RequestContext from mcp.shared.message import SessionMessage from mcp.shared.progress import progress from mcp.shared.session import RequestResponder +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + ElicitRequestParams, + ElicitResult, + SamplingMessage, + TextContent, +) @pytest.mark.anyio @@ -375,3 +384,109 @@ async def handle_list_tools( # Check that a warning was logged for the progress callback exception assert len(logged_errors) > 0 assert any("Progress callback raised an exception" in warning for warning in logged_errors) + + +@pytest.mark.anyio +async def test_server_create_message_progress_callback(): + """Test that ServerSession.create_message() accepts and passes through progress_callback.""" + server = MCPServer("test") + + # Track progress updates received by the server's progress callback + progress_updates: list[dict[str, Any]] = [] + + async def my_progress_callback(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append({"progress": progress, "total": total, "message": message}) + + @server.tool("trigger_sampling") + async def trigger_sampling_tool(text: str) -> str: + result = await server.get_context().session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=text))], + max_tokens=100, + progress_callback=my_progress_callback, + ) + assert isinstance(result.content, TextContent) + return result.content.text + + async def sampling_callback( + context: RequestContext[ClientSession], + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + # Send progress notifications back to the server using the progress token + if context.meta and "progress_token" in context.meta: # pragma: no branch + token = context.meta["progress_token"] + await context.session.send_progress_notification( + progress_token=token, + progress=0.5, + total=1.0, + message="Halfway done", + ) + await context.session.send_progress_notification( + progress_token=token, + progress=1.0, + total=1.0, + message="Complete", + ) + + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text="LLM response"), + model="test-model", + stop_reason="endTurn", + ) + + async with Client(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("trigger_sampling", {"text": "Hello"}) + assert result.is_error is False + + # Verify the progress callback was invoked with correct values + assert len(progress_updates) == 2 + assert progress_updates[0] == {"progress": 0.5, "total": 1.0, "message": "Halfway done"} + assert progress_updates[1] == {"progress": 1.0, "total": 1.0, "message": "Complete"} + + +@pytest.mark.anyio +async def test_server_elicit_form_progress_callback(): + """Test that ServerSession.elicit_form() accepts and passes through progress_callback.""" + server = MCPServer("test") + + # Track progress updates received by the server's progress callback + progress_updates: list[dict[str, Any]] = [] + + async def my_progress_callback(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append({"progress": progress, "total": total, "message": message}) + + @server.tool("trigger_elicitation") + async def trigger_elicitation_tool(text: str) -> str: + result = await server.get_context().session.elicit_form( + message=text, + requested_schema={"type": "object", "properties": {"name": {"type": "string"}}}, + progress_callback=my_progress_callback, + ) + return result.action + + async def elicitation_callback( + context: RequestContext[ClientSession], + params: ElicitRequestParams, + ) -> ElicitResult: + # Send progress notifications back to the server using the progress token + if context.meta and "progress_token" in context.meta: # pragma: no branch + token = context.meta["progress_token"] + await context.session.send_progress_notification( + progress_token=token, + progress=1.0, + total=1.0, + message="User responded", + ) + + return ElicitResult( + action="accept", + content={"name": "test"}, + ) + + async with Client(server, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("trigger_elicitation", {"text": "Enter name"}) + assert result.is_error is False + + # Verify the progress callback was invoked + assert len(progress_updates) == 1 + assert progress_updates[0] == {"progress": 1.0, "total": 1.0, "message": "User responded"}