Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/mcp/server/elicitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pydantic import BaseModel

from mcp.server.session import ServerSession
from mcp.shared.session import ProgressFnT
from mcp.types import RequestId

ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)
Expand Down Expand Up @@ -107,6 +108,7 @@ async def elicit_with_validation(
message: str,
schema: type[ElicitSchemaModelT],
related_request_id: RequestId | None = None,
progress_callback: ProgressFnT | None = None,
) -> ElicitationResult[ElicitSchemaModelT]:
"""Elicit information from the client/user with schema validation (form mode).

Expand All @@ -127,6 +129,7 @@ async def elicit_with_validation(
message=message,
requested_schema=json_schema,
related_request_id=related_request_id,
progress_callback=progress_callback,
)

if result.action == "accept" and result.content is not None:
Expand All @@ -148,6 +151,7 @@ async def elicit_url(
url: str,
elicitation_id: str,
related_request_id: RequestId | None = None,
progress_callback: ProgressFnT | None = None,
) -> UrlElicitationResult:
"""Elicit information from the user via out-of-band URL navigation (URL mode).

Expand Down Expand Up @@ -177,6 +181,7 @@ async def elicit_url(
url=url,
elicitation_id=elicitation_id,
related_request_id=related_request_id,
progress_callback=progress_callback,
)

if result.action == "accept":
Expand Down
7 changes: 7 additions & 0 deletions src/mcp/server/mcpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from mcp.server.transport_security import TransportSecuritySettings
from mcp.shared.exceptions import MCPError
from mcp.shared.session import ProgressFnT
from mcp.types import (
Annotations,
BlobResourceContents,
Expand Down Expand Up @@ -1181,6 +1182,7 @@ async def elicit(
self,
message: str,
schema: type[ElicitSchemaModelT],
progress_callback: ProgressFnT | None = None,
) -> ElicitationResult[ElicitSchemaModelT]:
"""Elicit information from the client/user.

Expand All @@ -1195,6 +1197,7 @@ async def elicit(
only primitive types are allowed.
message: Optional message to present to the user. If not provided, will use
a default message based on the schema
progress_callback: Optional callback for receiving progress notifications.

Returns:
An ElicitationResult containing the action taken and the data if accepted
Expand All @@ -1209,13 +1212,15 @@ async def elicit(
message=message,
schema=schema,
related_request_id=self.request_id,
progress_callback=progress_callback,
)

async def elicit_url(
self,
message: str,
url: str,
elicitation_id: str,
progress_callback: ProgressFnT | None = None,
) -> UrlElicitationResult:
"""Request URL mode elicitation from the client.

Expand All @@ -1234,6 +1239,7 @@ async def elicit_url(
message: Human-readable explanation of why the interaction is needed
url: The URL the user should navigate to
elicitation_id: Unique identifier for tracking this elicitation
progress_callback: Optional callback for receiving progress notifications.

Returns:
UrlElicitationResult indicating accept, decline, or cancel
Expand All @@ -1244,6 +1250,7 @@ async def elicit_url(
url=url,
elicitation_id=elicitation_id,
related_request_id=self.request_id,
progress_callback=progress_callback,
)

async def log(
Expand Down
17 changes: 16 additions & 1 deletion src/mcp/server/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ async def handle_list_prompts(ctx: RequestContext, params) -> ListPromptsResult:
from mcp.shared.message import ServerMessageMetadata, SessionMessage
from mcp.shared.session import (
BaseSession,
ProgressFnT,
RequestResponder,
)
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
Expand Down Expand Up @@ -245,6 +246,7 @@ async def create_message(
tools: None = None,
tool_choice: types.ToolChoice | None = None,
related_request_id: types.RequestId | None = None,
progress_callback: ProgressFnT | None = None,
) -> types.CreateMessageResult:
"""Overload: Without tools, returns single content."""
...
Expand All @@ -264,6 +266,7 @@ async def create_message(
tools: list[types.Tool],
tool_choice: types.ToolChoice | None = None,
related_request_id: types.RequestId | None = None,
progress_callback: ProgressFnT | None = None,
) -> types.CreateMessageResultWithTools:
"""Overload: With tools, returns array-capable content."""
...
Expand All @@ -282,6 +285,7 @@ async def create_message(
tools: list[types.Tool] | None = None,
tool_choice: types.ToolChoice | None = None,
related_request_id: types.RequestId | None = None,
progress_callback: ProgressFnT | None = None,
) -> types.CreateMessageResult | types.CreateMessageResultWithTools:
"""Send a sampling/create_message request.

Expand All @@ -301,6 +305,7 @@ async def create_message(
tool_choice: Optional control over tool usage behavior.
Requires client to have sampling.tools capability.
related_request_id: Optional ID of a related request.
progress_callback: Optional callback for receiving progress notifications.

Returns:
The sampling result from the client.
Expand Down Expand Up @@ -338,11 +343,13 @@ async def create_message(
request=request,
result_type=types.CreateMessageResultWithTools,
metadata=metadata_obj,
progress_callback=progress_callback,
)
return await self.send_request(
request=request,
result_type=types.CreateMessageResult,
metadata=metadata_obj,
progress_callback=progress_callback,
)

async def list_roots(self) -> types.ListRootsResult:
Expand All @@ -359,13 +366,15 @@ async def elicit(
message: str,
requested_schema: types.ElicitRequestedSchema,
related_request_id: types.RequestId | None = None,
progress_callback: ProgressFnT | None = None,
) -> types.ElicitResult:
"""Send a form mode elicitation/create request.

Args:
message: The message to present to the user
requested_schema: Schema defining the expected response structure
related_request_id: Optional ID of the request that triggered this elicitation
progress_callback: Optional callback for receiving progress notifications.

Returns:
The client's response
Expand All @@ -374,20 +383,22 @@ async def elicit(
This method is deprecated in favor of elicit_form(). It remains for
backward compatibility but new code should use elicit_form().
"""
return await self.elicit_form(message, requested_schema, related_request_id)
return await self.elicit_form(message, requested_schema, related_request_id, progress_callback)

async def elicit_form(
self,
message: str,
requested_schema: types.ElicitRequestedSchema,
related_request_id: types.RequestId | None = None,
progress_callback: ProgressFnT | None = None,
) -> types.ElicitResult:
"""Send a form mode elicitation/create request.

Args:
message: The message to present to the user
requested_schema: Schema defining the expected response structure
related_request_id: Optional ID of the request that triggered this elicitation
progress_callback: Optional callback for receiving progress notifications.

Returns:
The client's response with form data
Expand All @@ -406,6 +417,7 @@ async def elicit_form(
),
types.ElicitResult,
metadata=ServerMessageMetadata(related_request_id=related_request_id),
progress_callback=progress_callback,
)

async def elicit_url(
Expand All @@ -414,6 +426,7 @@ async def elicit_url(
url: str,
elicitation_id: str,
related_request_id: types.RequestId | None = None,
progress_callback: ProgressFnT | None = None,
) -> types.ElicitResult:
"""Send a URL mode elicitation/create request.

Expand All @@ -425,6 +438,7 @@ async def elicit_url(
url: The URL the user should navigate to
elicitation_id: Unique identifier for tracking this elicitation
related_request_id: Optional ID of the request that triggered this elicitation
progress_callback: Optional callback for receiving progress notifications.

Returns:
The client's response indicating acceptance, decline, or cancellation
Expand All @@ -444,6 +458,7 @@ async def elicit_url(
),
types.ElicitResult,
metadata=ServerMessageMetadata(related_request_id=related_request_id),
progress_callback=progress_callback,
)

async def send_ping(self) -> types.EmptyResult: # pragma: no cover
Expand Down
115 changes: 115 additions & 0 deletions tests/shared/test_progress_notifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,21 @@
from mcp.client.session import ClientSession
from mcp.server import Server, ServerRequestContext
from mcp.server.lowlevel import NotificationOptions
from mcp.server.mcpserver import MCPServer
from mcp.server.models import InitializationOptions
from mcp.server.session import ServerSession
from mcp.shared._context import RequestContext
from mcp.shared.message import SessionMessage
from mcp.shared.progress import progress
from mcp.shared.session import RequestResponder
from mcp.types import (
CreateMessageRequestParams,
CreateMessageResult,
ElicitRequestParams,
ElicitResult,
SamplingMessage,
TextContent,
)


@pytest.mark.anyio
Expand Down Expand Up @@ -375,3 +384,109 @@ async def handle_list_tools(
# Check that a warning was logged for the progress callback exception
assert len(logged_errors) > 0
assert any("Progress callback raised an exception" in warning for warning in logged_errors)


@pytest.mark.anyio
async def test_server_create_message_progress_callback():
"""Test that ServerSession.create_message() accepts and passes through progress_callback."""
server = MCPServer("test")

# Track progress updates received by the server's progress callback
progress_updates: list[dict[str, Any]] = []

async def my_progress_callback(progress: float, total: float | None, message: str | None) -> None:
progress_updates.append({"progress": progress, "total": total, "message": message})

@server.tool("trigger_sampling")
async def trigger_sampling_tool(text: str) -> str:
result = await server.get_context().session.create_message(
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=text))],
max_tokens=100,
progress_callback=my_progress_callback,
)
assert isinstance(result.content, TextContent)
return result.content.text

async def sampling_callback(
context: RequestContext[ClientSession],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
# Send progress notifications back to the server using the progress token
if context.meta and "progress_token" in context.meta: # pragma: no branch
token = context.meta["progress_token"]
await context.session.send_progress_notification(
progress_token=token,
progress=0.5,
total=1.0,
message="Halfway done",
)
await context.session.send_progress_notification(
progress_token=token,
progress=1.0,
total=1.0,
message="Complete",
)

return CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="LLM response"),
model="test-model",
stop_reason="endTurn",
)

async with Client(server, sampling_callback=sampling_callback) as client:
result = await client.call_tool("trigger_sampling", {"text": "Hello"})
assert result.is_error is False

# Verify the progress callback was invoked with correct values
assert len(progress_updates) == 2
assert progress_updates[0] == {"progress": 0.5, "total": 1.0, "message": "Halfway done"}
assert progress_updates[1] == {"progress": 1.0, "total": 1.0, "message": "Complete"}


@pytest.mark.anyio
async def test_server_elicit_form_progress_callback():
"""Test that ServerSession.elicit_form() accepts and passes through progress_callback."""
server = MCPServer("test")

# Track progress updates received by the server's progress callback
progress_updates: list[dict[str, Any]] = []

async def my_progress_callback(progress: float, total: float | None, message: str | None) -> None:
progress_updates.append({"progress": progress, "total": total, "message": message})

@server.tool("trigger_elicitation")
async def trigger_elicitation_tool(text: str) -> str:
result = await server.get_context().session.elicit_form(
message=text,
requested_schema={"type": "object", "properties": {"name": {"type": "string"}}},
progress_callback=my_progress_callback,
)
return result.action

async def elicitation_callback(
context: RequestContext[ClientSession],
params: ElicitRequestParams,
) -> ElicitResult:
# Send progress notifications back to the server using the progress token
if context.meta and "progress_token" in context.meta: # pragma: no branch
token = context.meta["progress_token"]
await context.session.send_progress_notification(
progress_token=token,
progress=1.0,
total=1.0,
message="User responded",
)

return ElicitResult(
action="accept",
content={"name": "test"},
)

async with Client(server, elicitation_callback=elicitation_callback) as client:
result = await client.call_tool("trigger_elicitation", {"text": "Enter name"})
assert result.is_error is False

# Verify the progress callback was invoked
assert len(progress_updates) == 1
assert progress_updates[0] == {"progress": 1.0, "total": 1.0, "message": "User responded"}
Loading