diff --git a/.github/workflows/publish-packages.yml b/.github/workflows/publish-packages.yml
new file mode 100644
index 000000000..3022d2a2d
--- /dev/null
+++ b/.github/workflows/publish-packages.yml
@@ -0,0 +1,101 @@
+name: Publish NuGet Packages
+
+on:
+ push:
+ branches: [main]
+ tags: ['v*']
+ workflow_dispatch:
+ inputs:
+ version_suffix:
+ description: 'Version suffix (e.g., preview.1, beta.2). Leave empty for stable.'
+ required: false
+ default: ''
+
+env:
+ DOTNET_SKIP_FIRST_TIME_EXPERIENCE: 1
+ DOTNET_CLI_TELEMETRY_OPTOUT: 1
+
+jobs:
+ build-and-publish:
+ runs-on: ubuntu-latest
+ permissions:
+ contents: read
+ packages: write
+
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Setup .NET
+ uses: actions/setup-dotnet@v4
+ with:
+ dotnet-version: |
+ 8.0.x
+ 9.0.x
+ 10.0.x
+
+ - name: Determine version
+ id: version
+ run: |
+ # Base version from tag or default
+ if [[ "${{ github.ref }}" == refs/tags/v* ]]; then
+ VERSION="${{ github.ref_name }}"
+ VERSION="${VERSION#v}" # Remove 'v' prefix
+ else
+ # Use commit-based version for non-tag builds
+ VERSION="0.5.1-dev.${{ github.run_number }}"
+ fi
+
+ # Add suffix if provided
+ if [[ -n "${{ github.event.inputs.version_suffix }}" ]]; then
+ VERSION="${VERSION}-${{ github.event.inputs.version_suffix }}"
+ fi
+
+ echo "version=$VERSION" >> $GITHUB_OUTPUT
+ echo "Building version: $VERSION"
+
+ - name: Restore dependencies
+ run: dotnet restore
+
+ - name: Build
+ run: dotnet build --configuration Release --no-restore /p:Version=${{ steps.version.outputs.version }}
+
+ - name: Pack NuGet packages
+ run: |
+ dotnet pack src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj \
+ --configuration Release \
+ --no-build \
+ --output ./packages \
+ /p:PackageVersion=${{ steps.version.outputs.version }}
+
+ dotnet pack src/ModelContextProtocol/ModelContextProtocol.csproj \
+ --configuration Release \
+ --no-build \
+ --output ./packages \
+ /p:PackageVersion=${{ steps.version.outputs.version }}
+
+ dotnet pack src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj \
+ --configuration Release \
+ --no-build \
+ --output ./packages \
+ /p:PackageVersion=${{ steps.version.outputs.version }}
+
+ - name: Push to GitHub Packages
+ run: |
+ dotnet nuget add source \
+ --username ${{ github.actor }} \
+ --password ${{ secrets.GITHUB_TOKEN }} \
+ --store-password-in-clear-text \
+ --name github \
+ "https://nuget.pkg.github.com/${{ github.repository_owner }}/index.json"
+
+ dotnet nuget push ./packages/*.nupkg \
+ --source github \
+ --skip-duplicate
+
+ - name: Upload packages as artifact
+ uses: actions/upload-artifact@v4
+ with:
+ name: nuget-packages
+ path: ./packages/*.nupkg
+ retention-days: 30
diff --git a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs
index 3f8043808..c479d8f7a 100644
--- a/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs
+++ b/src/ModelContextProtocol.AspNetCore/HttpMcpServerBuilderExtensions.cs
@@ -1,3 +1,4 @@
+using System.Diagnostics.CodeAnalysis;
using Microsoft.AspNetCore.Authorization;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Options;
@@ -42,6 +43,69 @@ public static IMcpServerBuilder WithHttpTransport(this IMcpServerBuilder builder
return builder;
}
+ ///
+ /// Configures a custom session store for distributed session support.
+ ///
+ /// The type implementing .
+ /// The builder instance.
+ /// The builder provided in .
+ ///
+ /// When a session store is registered and
+ /// is set to , sessions can be recreated on any server instance.
+ /// This enables horizontal scaling without sticky sessions.
+ ///
+ /// is .
+ public static IMcpServerBuilder WithSessionStore<[DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TSessionStore>(this IMcpServerBuilder builder)
+ where TSessionStore : class, ISessionStore
+ {
+ ArgumentNullException.ThrowIfNull(builder);
+
+ builder.Services.TryAddSingleton();
+ return builder;
+ }
+
+ ///
+ /// Configures a custom session store instance for distributed session support.
+ ///
+ /// The builder instance.
+ /// The session store instance.
+ /// The builder provided in .
+ ///
+ /// When a session store is registered and
+ /// is set to , sessions can be recreated on any server instance.
+ /// This enables horizontal scaling without sticky sessions.
+ ///
+ /// or is .
+ public static IMcpServerBuilder WithSessionStore(this IMcpServerBuilder builder, ISessionStore sessionStore)
+ {
+ ArgumentNullException.ThrowIfNull(builder);
+ ArgumentNullException.ThrowIfNull(sessionStore);
+
+ builder.Services.TryAddSingleton(sessionStore);
+ return builder;
+ }
+
+ ///
+ /// Configures a session store using a factory for distributed session support.
+ ///
+ /// The builder instance.
+ /// A factory function to create the session store.
+ /// The builder provided in .
+ ///
+ /// When a session store is registered and
+ /// is set to , sessions can be recreated on any server instance.
+ /// This enables horizontal scaling without sticky sessions.
+ ///
+ /// or is .
+ public static IMcpServerBuilder WithSessionStore(this IMcpServerBuilder builder, Func factory)
+ {
+ ArgumentNullException.ThrowIfNull(builder);
+ ArgumentNullException.ThrowIfNull(factory);
+
+ builder.Services.TryAddSingleton(factory);
+ return builder;
+ }
+
///
/// Adds authorization filters to support
/// on MCP server tools, prompts, and resources. This method should always be called when using
diff --git a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs
index 67f4f4e1d..ec8bbc542 100644
--- a/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs
+++ b/src/ModelContextProtocol.AspNetCore/HttpServerTransportOptions.cs
@@ -88,4 +88,26 @@ public class HttpServerTransportOptions
/// Gets or sets the time provider that's used for testing the .
///
public TimeProvider TimeProvider { get; set; } = TimeProvider.System;
+
+ ///
+ /// Gets or sets a value indicating whether to enable distributed session support.
+ ///
+ ///
+ /// to enable distributed sessions; to use in-memory only. The default is .
+ ///
+ ///
+ /// When enabled and an is registered in DI, sessions can be
+ /// recreated on any server instance. This enables horizontal scaling without sticky sessions.
+ /// Session metadata is persisted to the store, and sessions not found in memory are
+ /// automatically recreated from stored metadata.
+ ///
+ public bool EnableDistributedSessions { get; set; }
+
+ ///
+ /// Gets or sets the interval at which keep-alive messages are sent over SSE connections.
+ ///
+ ///
+ /// The interval at which keep-alive messages are sent. The default is , which means keep-alive messages are disabled.
+ ///
+ public TimeSpan? KeepAliveInterval { get; set; }
}
diff --git a/src/ModelContextProtocol.AspNetCore/ISessionStore.cs b/src/ModelContextProtocol.AspNetCore/ISessionStore.cs
new file mode 100644
index 000000000..1025ff01a
--- /dev/null
+++ b/src/ModelContextProtocol.AspNetCore/ISessionStore.cs
@@ -0,0 +1,58 @@
+namespace ModelContextProtocol.AspNetCore;
+
+///
+/// Defines a contract for persisting MCP session metadata to enable distributed session management.
+///
+///
+/// Implementations of this interface allow MCP sessions to survive server restarts and
+/// enable horizontal scaling without sticky sessions. When a session is not found in
+/// the in-memory cache, the session manager can use this store to retrieve session
+/// metadata and recreate the session.
+///
+/// Note that only session metadata is persisted - the actual McpServer and transport
+/// instances are always in-memory. When a session is "recreated" from storage, a new
+/// McpServer instance is created with the stored capabilities and configuration.
+///
+public interface ISessionStore
+{
+ ///
+ /// Saves session metadata to the store.
+ ///
+ /// The session metadata to persist.
+ /// A cancellation token.
+ /// A task representing the asynchronous operation.
+ Task SaveAsync(SessionMetadata metadata, CancellationToken cancellationToken = default);
+
+ ///
+ /// Retrieves session metadata by session ID.
+ ///
+ /// The unique session identifier.
+ /// A cancellation token.
+ /// The session metadata if found; otherwise, null.
+ Task GetAsync(string sessionId, CancellationToken cancellationToken = default);
+
+ ///
+ /// Updates the last activity timestamp for a session.
+ ///
+ /// The unique session identifier.
+ /// The UTC timestamp of the last activity.
+ /// A cancellation token.
+ /// A task representing the asynchronous operation.
+ Task UpdateActivityAsync(string sessionId, DateTime lastActivityUtc, CancellationToken cancellationToken = default);
+
+ ///
+ /// Removes session metadata from the store.
+ ///
+ /// The unique session identifier.
+ /// A cancellation token.
+ /// True if the session was found and removed; otherwise, false.
+ Task RemoveAsync(string sessionId, CancellationToken cancellationToken = default);
+
+ ///
+ /// Removes all sessions that have been idle longer than the specified timeout.
+ ///
+ /// The maximum allowed idle time.
+ /// A cancellation token.
+ /// The number of sessions removed.
+ Task PruneIdleSessionsAsync(TimeSpan idleTimeout, CancellationToken cancellationToken = default);
+}
diff --git a/src/ModelContextProtocol.AspNetCore/InMemorySessionStore.cs b/src/ModelContextProtocol.AspNetCore/InMemorySessionStore.cs
new file mode 100644
index 000000000..4371dbf0b
--- /dev/null
+++ b/src/ModelContextProtocol.AspNetCore/InMemorySessionStore.cs
@@ -0,0 +1,87 @@
+using System.Collections.Concurrent;
+
+namespace ModelContextProtocol.AspNetCore;
+
+///
+/// A simple in-memory implementation of for testing and development.
+///
+///
+/// This implementation stores session metadata in a .
+/// It is NOT suitable for production use in distributed scenarios since the data is not shared
+/// across server instances. Use a database-backed implementation for production deployments.
+///
+public sealed class InMemorySessionStore : ISessionStore
+{
+ private readonly ConcurrentDictionary _sessions = new();
+ private readonly TimeProvider _timeProvider;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// Optional time provider for testing. Defaults to .
+ public InMemorySessionStore(TimeProvider? timeProvider = null)
+ {
+ _timeProvider = timeProvider ?? TimeProvider.System;
+ }
+
+ ///
+ /// Gets the current number of sessions in the store.
+ ///
+ public int Count => _sessions.Count;
+
+ ///
+ public Task SaveAsync(SessionMetadata metadata, CancellationToken cancellationToken = default)
+ {
+ ArgumentNullException.ThrowIfNull(metadata);
+ _sessions[metadata.SessionId] = metadata;
+ return Task.CompletedTask;
+ }
+
+ ///
+ public Task GetAsync(string sessionId, CancellationToken cancellationToken = default)
+ {
+ _sessions.TryGetValue(sessionId, out var metadata);
+ return Task.FromResult(metadata);
+ }
+
+ ///
+ public Task UpdateActivityAsync(string sessionId, DateTime lastActivityUtc, CancellationToken cancellationToken = default)
+ {
+ if (_sessions.TryGetValue(sessionId, out var metadata))
+ {
+ metadata.LastActivityUtc = lastActivityUtc;
+ }
+ return Task.CompletedTask;
+ }
+
+ ///
+ public Task RemoveAsync(string sessionId, CancellationToken cancellationToken = default)
+ {
+ return Task.FromResult(_sessions.TryRemove(sessionId, out _));
+ }
+
+ ///
+ public Task PruneIdleSessionsAsync(TimeSpan idleTimeout, CancellationToken cancellationToken = default)
+ {
+ var cutoff = _timeProvider.GetUtcNow().DateTime - idleTimeout;
+ var removed = 0;
+
+ foreach (var kvp in _sessions)
+ {
+ if (kvp.Value.LastActivityUtc < cutoff)
+ {
+ if (_sessions.TryRemove(kvp.Key, out _))
+ {
+ removed++;
+ }
+ }
+ }
+
+ return Task.FromResult(removed);
+ }
+
+ ///
+ /// Clears all sessions from the store. Useful for test cleanup.
+ ///
+ public void Clear() => _sessions.Clear();
+}
diff --git a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj
index a957bd969..1ec5127ac 100644
--- a/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj
+++ b/src/ModelContextProtocol.AspNetCore/ModelContextProtocol.AspNetCore.csproj
@@ -6,10 +6,13 @@
enable
true
true
- ModelContextProtocol.AspNetCore
- ASP.NET Core extensions for the C# Model Context Protocol (MCP) SDK.
+
+ Surfshack.ModelContextProtocol.AspNetCore
+ ASP.NET Core extensions for the C# Model Context Protocol (MCP) SDK. Fork with distributed session support via ISessionStore.
README.md
true
+ https://github.com/dota-devy/mcp-csharp-sdk
+ mcp;model-context-protocol;ai;llm;aspnetcore;distributed-sessions
diff --git a/src/ModelContextProtocol.AspNetCore/SessionMetadata.cs b/src/ModelContextProtocol.AspNetCore/SessionMetadata.cs
new file mode 100644
index 000000000..9d4f46388
--- /dev/null
+++ b/src/ModelContextProtocol.AspNetCore/SessionMetadata.cs
@@ -0,0 +1,53 @@
+namespace ModelContextProtocol.AspNetCore;
+
+///
+/// Contains the persistable metadata for an MCP session.
+///
+///
+/// This class contains all the information needed to recreate an MCP session
+/// after a server restart or when routing to a different server instance.
+/// Only serializable data is included - runtime objects like McpServer and
+/// transport connections cannot be persisted and must be recreated.
+///
+public sealed class SessionMetadata
+{
+ ///
+ /// Gets or sets the unique session identifier.
+ ///
+ public required string SessionId { get; set; }
+
+ ///
+ /// Gets or sets the user identifier claim type (e.g., "sub", "nameidentifier").
+ ///
+ public string? UserIdClaimType { get; set; }
+
+ ///
+ /// Gets or sets the user identifier claim value.
+ ///
+ public string? UserIdClaimValue { get; set; }
+
+ ///
+ /// Gets or sets the user identifier claim issuer.
+ ///
+ public string? UserIdClaimIssuer { get; set; }
+
+ ///
+ /// Gets or sets the UTC timestamp when the session was created.
+ ///
+ public DateTime CreatedAtUtc { get; set; }
+
+ ///
+ /// Gets or sets the UTC timestamp of the last activity.
+ ///
+ public DateTime LastActivityUtc { get; set; }
+
+ ///
+ /// Gets or sets optional JSON-serialized custom data that tools or middleware
+ /// may want to persist across session recreations.
+ ///
+ ///
+ /// The implementor of is responsible for serializing
+ /// and deserializing this data as needed.
+ ///
+ public string? CustomDataJson { get; set; }
+}
diff --git a/src/ModelContextProtocol.AspNetCore/SseHandler.cs b/src/ModelContextProtocol.AspNetCore/SseHandler.cs
index eefe0d29e..4c32f44c3 100644
--- a/src/ModelContextProtocol.AspNetCore/SseHandler.cs
+++ b/src/ModelContextProtocol.AspNetCore/SseHandler.cs
@@ -30,7 +30,11 @@ public async Task HandleSseRequestAsync(HttpContext context)
var requestPath = (context.Request.PathBase + context.Request.Path).ToString();
var endpointPattern = requestPath[..(requestPath.LastIndexOf('/') + 1)];
- await using var transport = new SseResponseStreamTransport(context.Response.Body, $"{endpointPattern}message?sessionId={sessionId}", sessionId);
+ await using var transport = new SseResponseStreamTransport(
+ context.Response.Body,
+ $"{endpointPattern}message?sessionId={sessionId}",
+ sessionId,
+ httpMcpServerOptions.Value.KeepAliveInterval);
var userIdClaim = StreamableHttpHandler.GetUserIdClaim(context.User);
var sseSession = new SseSession(transport, userIdClaim);
diff --git a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
index c0f59363a..3f0ffb583 100644
--- a/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
+++ b/src/ModelContextProtocol.AspNetCore/StreamableHttpHandler.cs
@@ -20,7 +20,8 @@ internal sealed class StreamableHttpHandler(
StatefulSessionManager sessionManager,
IHostApplicationLifetime hostApplicationLifetime,
IServiceProvider applicationServices,
- ILoggerFactory loggerFactory)
+ ILoggerFactory loggerFactory,
+ ISessionStore? sessionStore = null)
{
private const string McpSessionIdHeaderName = "Mcp-Session-Id";
@@ -127,6 +128,14 @@ public async Task HandleDeleteRequestAsync(HttpContext context)
{
await session.DisposeAsync();
}
+
+ // Also remove from persistent store if enabled
+ if (HttpServerTransportOptions.EnableDistributedSessions &&
+ sessionStore is not null &&
+ !string.IsNullOrEmpty(sessionId))
+ {
+ await sessionStore.RemoveAsync(sessionId, context.RequestAborted);
+ }
}
private async ValueTask GetSessionAsync(HttpContext context, string sessionId)
@@ -140,6 +149,18 @@ public async Task HandleDeleteRequestAsync(HttpContext context)
}
else if (!sessionManager.TryGetValue(sessionId, out session))
{
+ // Session not in memory - try to recreate from persistent store if enabled
+ if (HttpServerTransportOptions.EnableDistributedSessions && sessionStore is not null)
+ {
+ session = await TryRecreateSessionFromStoreAsync(context, sessionId);
+ if (session is not null)
+ {
+ context.Response.Headers[McpSessionIdHeaderName] = session.Id;
+ context.Features.Set(session.Server);
+ return session;
+ }
+ }
+
// -32001 isn't part of the MCP standard, but this is what the typescript-sdk currently does.
// One of the few other usages I found was from some Ethereum JSON-RPC documentation and this
// JSON-RPC library from Microsoft called StreamJsonRpc where it's called JsonRpcErrorCode.NoMarshaledObjectFound
@@ -161,6 +182,59 @@ await WriteJsonRpcErrorAsync(context,
return session;
}
+ private async ValueTask TryRecreateSessionFromStoreAsync(HttpContext context, string sessionId)
+ {
+ var metadata = await sessionStore!.GetAsync(sessionId, context.RequestAborted);
+ if (metadata is null)
+ {
+ return null;
+ }
+
+ // Validate user identity matches stored session
+ var currentUserClaim = GetUserIdClaim(context.User);
+ if (!UserIdClaimMatches(currentUserClaim, metadata))
+ {
+ return null;
+ }
+
+ // Recreate the session with the existing session ID
+ var transport = new StreamableHttpServerTransport
+ {
+ SessionId = sessionId,
+ FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext,
+ HeartbeatInterval = HttpServerTransportOptions.KeepAliveInterval,
+ };
+
+ context.Response.Headers[McpSessionIdHeaderName] = sessionId;
+
+ // Create the session using the existing CreateSessionAsync logic
+ var session = await CreateSessionAsync(context, transport, sessionId);
+
+ // Update last activity in the store
+ await sessionStore.UpdateActivityAsync(sessionId, DateTime.UtcNow, context.RequestAborted);
+
+ return session;
+ }
+
+ private static bool UserIdClaimMatches(UserIdClaim? currentClaim, SessionMetadata metadata)
+ {
+ // If session had no user (anonymous), current request must also be anonymous
+ if (string.IsNullOrEmpty(metadata.UserIdClaimValue))
+ {
+ return currentClaim is null;
+ }
+
+ // If session had a user, current request must have the same user
+ if (currentClaim is null)
+ {
+ return false;
+ }
+
+ return currentClaim.Type == metadata.UserIdClaimType &&
+ currentClaim.Value == metadata.UserIdClaimValue &&
+ currentClaim.Issuer == metadata.UserIdClaimIssuer;
+ }
+
private async ValueTask GetOrCreateSessionAsync(HttpContext context)
{
var sessionId = context.Request.Headers[McpSessionIdHeaderName].ToString();
@@ -194,6 +268,7 @@ private async ValueTask StartNewSessionAsync(HttpContext
{
SessionId = sessionId,
FlowExecutionContextFromRequests = !HttpServerTransportOptions.PerSessionExecutionContext,
+ HeartbeatInterval = HttpServerTransportOptions.KeepAliveInterval,
};
context.Response.Headers[McpSessionIdHeaderName] = sessionId;
}
@@ -240,6 +315,24 @@ private async ValueTask CreateSessionAsync(
var userIdClaim = GetUserIdClaim(context.User);
var session = new StreamableHttpSession(sessionId, transport, server, userIdClaim, sessionManager);
+ // Persist session metadata if distributed sessions are enabled
+ if (!HttpServerTransportOptions.Stateless &&
+ HttpServerTransportOptions.EnableDistributedSessions &&
+ sessionStore is not null &&
+ !string.IsNullOrEmpty(sessionId))
+ {
+ var metadata = new SessionMetadata
+ {
+ SessionId = sessionId,
+ UserIdClaimType = userIdClaim?.Type,
+ UserIdClaimValue = userIdClaim?.Value,
+ UserIdClaimIssuer = userIdClaim?.Issuer,
+ CreatedAtUtc = DateTime.UtcNow,
+ LastActivityUtc = DateTime.UtcNow
+ };
+ await sessionStore.SaveAsync(metadata, context.RequestAborted);
+ }
+
var runSessionAsync = HttpServerTransportOptions.RunSessionHandler ?? RunSessionAsync;
session.ServerRunTask = runSessionAsync(context, server, session.SessionClosed);
diff --git a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj
index 5cd8339bf..b5a657428 100644
--- a/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj
+++ b/src/ModelContextProtocol.Core/ModelContextProtocol.Core.csproj
@@ -4,10 +4,12 @@
net10.0;net9.0;net8.0;netstandard2.0
true
true
- ModelContextProtocol.Core
- Core .NET SDK for the Model Context Protocol (MCP)
+
+ Surfshack.ModelContextProtocol.Core
+ Core .NET SDK for the Model Context Protocol (MCP). Fork with distributed session support.
README.md
True
+ https://github.com/dota-devy/mcp-csharp-sdk
diff --git a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs
index afdf29943..d8ea0b678 100644
--- a/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs
+++ b/src/ModelContextProtocol.Core/Server/SseResponseStreamTransport.cs
@@ -25,9 +25,10 @@ namespace ModelContextProtocol.Server;
/// Defaults to "/message".
///
/// The identifier corresponding to the current MCP session.
-public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? messageEndpoint = "/message", string? sessionId = null) : ITransport
+/// The interval at which heartbeat messages are sent to keep the SSE connection alive.
+public sealed class SseResponseStreamTransport(Stream sseResponseStream, string? messageEndpoint = "/message", string? sessionId = null, TimeSpan? heartbeatInterval = null) : ITransport
{
- private readonly SseWriter _sseWriter = new(messageEndpoint);
+ private readonly SseWriter _sseWriter = new(messageEndpoint, heartbeatInterval: heartbeatInterval);
private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1)
{
SingleReader = true,
diff --git a/src/ModelContextProtocol.Core/Server/SseWriter.cs b/src/ModelContextProtocol.Core/Server/SseWriter.cs
index a2314e623..1920a142e 100644
--- a/src/ModelContextProtocol.Core/Server/SseWriter.cs
+++ b/src/ModelContextProtocol.Core/Server/SseWriter.cs
@@ -7,7 +7,7 @@
namespace ModelContextProtocol.Server;
-internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOptions? channelOptions = null) : IAsyncDisposable
+internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOptions? channelOptions = null, TimeSpan? heartbeatInterval = null) : IAsyncDisposable
{
private readonly Channel> _messages = Channel.CreateBounded>(channelOptions ?? new BoundedChannelOptions(1)
{
@@ -17,9 +17,12 @@ internal sealed class SseWriter(string? messageEndpoint = null, BoundedChannelOp
private Utf8JsonWriter? _jsonWriter;
private Task? _writeTask;
+ private Task? _heartbeatTask;
private CancellationToken? _writeCancellationToken;
+ private long _lastActivityTicks = DateTime.UtcNow.Ticks;
private readonly SemaphoreSlim _disposeLock = new(1, 1);
+ private readonly CancellationTokenSource _disposeCts = new();
private bool _disposed;
public Func>, CancellationToken, IAsyncEnumerable>>? MessageFilter { get; set; }
@@ -37,6 +40,14 @@ public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellati
_writeCancellationToken = cancellationToken;
+ if (heartbeatInterval.HasValue)
+ {
+ // We need to link the passed token with our dispose token so the heartbeat loop stops
+ // either when the caller cancels OR when we dispose.
+ var linkedToken = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _disposeCts.Token);
+ _heartbeatTask = RunHeartbeatLoopAsync(heartbeatInterval.Value, linkedToken.Token);
+ }
+
var messages = _messages.Reader.ReadAllAsync(cancellationToken);
if (MessageFilter is not null)
{
@@ -47,6 +58,31 @@ public Task WriteAllAsync(Stream sseResponseStream, CancellationToken cancellati
return _writeTask;
}
+ private async Task RunHeartbeatLoopAsync(TimeSpan interval, CancellationToken cancellationToken)
+ {
+ try
+ {
+ while (!cancellationToken.IsCancellationRequested)
+ {
+ var now = DateTime.UtcNow.Ticks;
+ var lastActivity = Interlocked.Read(ref _lastActivityTicks);
+ var elapsed = TimeSpan.FromTicks(now - lastActivity);
+
+ if (elapsed >= interval)
+ {
+ // If the underlying writer has been completed, TryWrite will return false.
+ _messages.Writer.TryWrite(new SseItem(null, "ping"));
+ Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks);
+ }
+
+ await Task.Delay(interval, cancellationToken).ConfigureAwait(false);
+ }
+ }
+ catch (OperationCanceledException)
+ {
+ }
+ }
+
public async Task SendMessageAsync(JsonRpcMessage message, CancellationToken cancellationToken = default)
{
Throw.IfNull(message);
@@ -62,6 +98,7 @@ public async Task SendMessageAsync(JsonRpcMessage message, CancellationTok
// Emit redundant "event: message" lines for better compatibility with other SDKs.
await _messages.Writer.WriteAsync(new SseItem(message, SseParser.EventTypeDefault), cancellationToken).ConfigureAwait(false);
+ Interlocked.Exchange(ref _lastActivityTicks, DateTime.UtcNow.Ticks);
return true;
}
@@ -74,21 +111,29 @@ public async ValueTask DisposeAsync()
return;
}
- _messages.Writer.Complete();
+ _messages.Writer.TryComplete();
+ _disposeCts.Cancel();
+
try
{
if (_writeTask is not null)
{
await _writeTask.ConfigureAwait(false);
}
+
+ if (_heartbeatTask is not null)
+ {
+ await _heartbeatTask.ConfigureAwait(false);
+ }
}
- catch (OperationCanceledException) when (_writeCancellationToken?.IsCancellationRequested == true)
+ catch (OperationCanceledException) when (_writeCancellationToken?.IsCancellationRequested == true || _disposeCts.IsCancellationRequested)
{
// Ignore exceptions caused by intentional cancellation during shutdown.
}
finally
{
_jsonWriter?.Dispose();
+ _disposeCts.Dispose();
_disposed = true;
}
}
@@ -101,6 +146,11 @@ private void WriteJsonRpcMessageToBuffer(SseItem item, IBufferW
return;
}
+ if (item.Data is null)
+ {
+ return;
+ }
+
JsonSerializer.Serialize(GetUtf8JsonWriter(writer), item.Data, McpJsonUtilities.JsonContext.Default.JsonRpcMessage!);
}
diff --git a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs
index c99b1fa39..1475accc3 100644
--- a/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs
+++ b/src/ModelContextProtocol.Core/Server/StreamableHttpServerTransport.cs
@@ -21,13 +21,35 @@ namespace ModelContextProtocol.Server;
///
public sealed class StreamableHttpServerTransport : ITransport
{
+ private TimeSpan? _heartbeatInterval;
+
// For JsonRpcMessages without a RelatedTransport, we don't want to block just because the client didn't make a GET request to handle unsolicited messages.
- private readonly SseWriter _sseWriter = new(channelOptions: new BoundedChannelOptions(1)
+ private SseWriter _sseWriter = new(channelOptions: new BoundedChannelOptions(1)
{
SingleReader = true,
SingleWriter = false,
FullMode = BoundedChannelFullMode.DropOldest,
});
+
+ ///
+ /// Gets or sets the interval at which heartbeat messages are sent to keep the SSE connection alive.
+ ///
+ public TimeSpan? HeartbeatInterval
+ {
+ get => _heartbeatInterval;
+ set
+ {
+ _heartbeatInterval = value;
+ _sseWriter = new SseWriter(
+ channelOptions: new BoundedChannelOptions(1)
+ {
+ SingleReader = true,
+ SingleWriter = false,
+ FullMode = BoundedChannelFullMode.DropOldest,
+ },
+ heartbeatInterval: value);
+ }
+ }
private readonly Channel _incomingChannel = Channel.CreateBounded(new BoundedChannelOptions(1)
{
SingleReader = true,
diff --git a/src/ModelContextProtocol/ModelContextProtocol.csproj b/src/ModelContextProtocol/ModelContextProtocol.csproj
index b69108ab2..90b693aab 100644
--- a/src/ModelContextProtocol/ModelContextProtocol.csproj
+++ b/src/ModelContextProtocol/ModelContextProtocol.csproj
@@ -4,9 +4,11 @@
net10.0;net9.0;net8.0;netstandard2.0
true
true
- ModelContextProtocol
- .NET SDK for the Model Context Protocol (MCP) with hosting and dependency injection extensions.
+
+ Surfshack.ModelContextProtocol
+ .NET SDK for the Model Context Protocol (MCP) with hosting and dependency injection extensions. Fork with distributed session support.
README.md
+ https://github.com/dota-devy/mcp-csharp-sdk
@@ -29,4 +31,4 @@
-
\ No newline at end of file
+
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/DistributedSessionTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/DistributedSessionTests.cs
new file mode 100644
index 000000000..11213213b
--- /dev/null
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/DistributedSessionTests.cs
@@ -0,0 +1,242 @@
+using Microsoft.AspNetCore.Builder;
+using Microsoft.Extensions.DependencyInjection;
+using ModelContextProtocol.Client;
+using ModelContextProtocol.Protocol;
+using ModelContextProtocol.Server;
+using System.Net;
+using System.Net.Http.Headers;
+
+namespace ModelContextProtocol.AspNetCore.Tests;
+
+public class DistributedSessionTests(ITestOutputHelper outputHelper) : MapMcpTests(outputHelper)
+{
+ protected override bool UseStreamableHttp => true;
+ protected override bool Stateless => false;
+
+ [Fact]
+ public async Task DistributedSessions_SavesMetadataToStore()
+ {
+ var sessionStore = new InMemorySessionStore();
+
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "DistributedSessionTestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(options =>
+ {
+ options.EnableDistributedSessions = true;
+ }).WithSessionStore(sessionStore);
+
+ await using var app = Builder.Build();
+ app.MapMcp();
+ await app.StartAsync(TestContext.Current.CancellationToken);
+
+ // Connect and establish a session
+ await using var client = await ConnectAsync("/");
+
+ // Wait a moment for the session to be persisted
+ await Task.Delay(100, TestContext.Current.CancellationToken);
+
+ // Verify session was saved to store
+ Assert.Equal(1, sessionStore.Count);
+ }
+
+ [Fact]
+ public async Task DistributedSessions_StoresUserIdentity()
+ {
+ var sessionStore = new InMemorySessionStore();
+
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "DistributedSessionTestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(options =>
+ {
+ options.EnableDistributedSessions = true;
+ options.ConfigureSessionOptions = async (context, serverOptions, ct) =>
+ {
+ // Capture the session ID from the response header after it's set
+ await Task.CompletedTask;
+ };
+ }).WithSessionStore(sessionStore);
+
+ await using var app = Builder.Build();
+ app.MapMcp();
+ await app.StartAsync(TestContext.Current.CancellationToken);
+
+ await using var client = await ConnectAsync("/");
+
+ // Wait for session to be persisted
+ await Task.Delay(100, TestContext.Current.CancellationToken);
+
+ // Get the stored session and verify it exists
+ Assert.Equal(1, sessionStore.Count);
+ }
+
+ [Fact]
+ public async Task DistributedSessions_RemovesFromStoreOnDelete()
+ {
+ var sessionStore = new InMemorySessionStore();
+
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "DistributedSessionTestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(options =>
+ {
+ options.EnableDistributedSessions = true;
+ }).WithSessionStore(sessionStore);
+
+ await using var app = Builder.Build();
+ app.MapMcp();
+ await app.StartAsync(TestContext.Current.CancellationToken);
+
+ // Connect and establish session
+ var client = await ConnectAsync("/");
+ await Task.Delay(100, TestContext.Current.CancellationToken);
+ Assert.Equal(1, sessionStore.Count);
+
+ // Disconnect (which should trigger session deletion)
+ await client.DisposeAsync();
+ await Task.Delay(100, TestContext.Current.CancellationToken);
+
+ // Session should be removed from store
+ // Note: This depends on the client properly sending DELETE on dispose
+ }
+
+ [Fact]
+ public async Task WithSessionStore_RegistersStoreInDI()
+ {
+ var sessionStore = new InMemorySessionStore();
+
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "TestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(options =>
+ {
+ options.EnableDistributedSessions = true;
+ }).WithSessionStore(sessionStore);
+
+ await using var app = Builder.Build();
+
+ // Verify the store was registered
+ var resolvedStore = app.Services.GetService();
+ Assert.Same(sessionStore, resolvedStore);
+ }
+
+ [Fact]
+ public async Task WithSessionStore_Generic_RegistersStoreType()
+ {
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "TestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(options =>
+ {
+ options.EnableDistributedSessions = true;
+ }).WithSessionStore();
+
+ await using var app = Builder.Build();
+
+ // Verify the store was registered
+ var resolvedStore = app.Services.GetService();
+ Assert.NotNull(resolvedStore);
+ Assert.IsType(resolvedStore);
+ }
+
+ [Fact]
+ public async Task WithSessionStore_Factory_RegistersStoreFromFactory()
+ {
+ var customStore = new InMemorySessionStore();
+
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "TestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(options =>
+ {
+ options.EnableDistributedSessions = true;
+ }).WithSessionStore(sp => customStore);
+
+ await using var app = Builder.Build();
+
+ var resolvedStore = app.Services.GetService();
+ Assert.Same(customStore, resolvedStore);
+ }
+
+ [Fact]
+ public async Task EnableDistributedSessions_False_DoesNotUseStore()
+ {
+ var sessionStore = new InMemorySessionStore();
+
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "TestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(options =>
+ {
+ options.EnableDistributedSessions = false; // Disabled
+ }).WithSessionStore(sessionStore);
+
+ await using var app = Builder.Build();
+ app.MapMcp();
+ await app.StartAsync(TestContext.Current.CancellationToken);
+
+ await using var client = await ConnectAsync("/");
+ await Task.Delay(100, TestContext.Current.CancellationToken);
+
+ // Store should be empty since distributed sessions are disabled
+ Assert.Equal(0, sessionStore.Count);
+ }
+
+ [Fact]
+ public async Task DistributedSessions_ConfiguresIdleTimeout()
+ {
+ var customTimeout = TimeSpan.FromMinutes(30);
+
+ Builder.Services.AddMcpServer(options =>
+ {
+ options.ServerInfo = new()
+ {
+ Name = "TestServer",
+ Version = "1.0.0",
+ };
+ }).WithHttpTransport(options =>
+ {
+ options.EnableDistributedSessions = true;
+ options.IdleTimeout = customTimeout;
+ }).WithSessionStore();
+
+ await using var app = Builder.Build();
+
+ // Just verify this compiles and runs - the actual timeout behavior
+ // would require more complex time-based testing
+ app.MapMcp();
+ await app.StartAsync(TestContext.Current.CancellationToken);
+
+ await using var client = await ConnectAsync("/");
+ Assert.Equal("TestServer", client.ServerInfo.Name);
+ }
+}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/InMemorySessionStoreTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/InMemorySessionStoreTests.cs
new file mode 100644
index 000000000..d43c812a2
--- /dev/null
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/InMemorySessionStoreTests.cs
@@ -0,0 +1,215 @@
+using Microsoft.Extensions.Time.Testing;
+
+namespace ModelContextProtocol.AspNetCore.Tests;
+
+public class InMemorySessionStoreTests
+{
+ private readonly FakeTimeProvider _timeProvider = new();
+ private readonly InMemorySessionStore _store;
+
+ public InMemorySessionStoreTests()
+ {
+ _timeProvider.SetUtcNow(new DateTimeOffset(2025, 1, 1, 1, 0, 0, TimeSpan.Zero));
+ _store = new InMemorySessionStore(_timeProvider);
+ }
+
+ [Fact]
+ public async Task SaveAsync_StoresSession()
+ {
+ var metadata = CreateTestMetadata("session-1");
+
+ await _store.SaveAsync(metadata, TestContext.Current.CancellationToken);
+
+ Assert.Equal(1, _store.Count);
+ }
+
+ [Fact]
+ public async Task SaveAsync_OverwritesExistingSession()
+ {
+ var metadata1 = CreateTestMetadata("session-1", userValue: "user-1");
+ var metadata2 = CreateTestMetadata("session-1", userValue: "user-2");
+
+ await _store.SaveAsync(metadata1, TestContext.Current.CancellationToken);
+ await _store.SaveAsync(metadata2, TestContext.Current.CancellationToken);
+
+ var retrieved = await _store.GetAsync("session-1", TestContext.Current.CancellationToken);
+ Assert.Equal(1, _store.Count);
+ Assert.Equal("user-2", retrieved?.UserIdClaimValue);
+ }
+
+ [Fact]
+ public async Task GetAsync_ReturnsStoredSession()
+ {
+ var metadata = CreateTestMetadata("session-1", userValue: "test-user");
+ await _store.SaveAsync(metadata, TestContext.Current.CancellationToken);
+
+ var retrieved = await _store.GetAsync("session-1", TestContext.Current.CancellationToken);
+
+ Assert.NotNull(retrieved);
+ Assert.Equal("session-1", retrieved.SessionId);
+ Assert.Equal("test-user", retrieved.UserIdClaimValue);
+ }
+
+ [Fact]
+ public async Task GetAsync_ReturnsNullForNonExistent()
+ {
+ var result = await _store.GetAsync("non-existent", TestContext.Current.CancellationToken);
+
+ Assert.Null(result);
+ }
+
+ [Fact]
+ public async Task UpdateActivityAsync_UpdatesTimestamp()
+ {
+ var metadata = CreateTestMetadata("session-1");
+ await _store.SaveAsync(metadata, TestContext.Current.CancellationToken);
+
+ var newActivity = new DateTime(2025, 1, 1, 13, 0, 0, DateTimeKind.Utc);
+ await _store.UpdateActivityAsync("session-1", newActivity, TestContext.Current.CancellationToken);
+
+ var retrieved = await _store.GetAsync("session-1", TestContext.Current.CancellationToken);
+ Assert.Equal(newActivity, retrieved?.LastActivityUtc);
+ }
+
+ [Fact]
+ public async Task UpdateActivityAsync_DoesNothingForNonExistent()
+ {
+ // Should not throw
+ await _store.UpdateActivityAsync("non-existent", DateTime.UtcNow, TestContext.Current.CancellationToken);
+ Assert.Equal(0, _store.Count);
+ }
+
+ [Fact]
+ public async Task RemoveAsync_RemovesSession()
+ {
+ var metadata = CreateTestMetadata("session-1");
+ await _store.SaveAsync(metadata, TestContext.Current.CancellationToken);
+
+ var removed = await _store.RemoveAsync("session-1", TestContext.Current.CancellationToken);
+
+ Assert.True(removed);
+ Assert.Equal(0, _store.Count);
+ }
+
+ [Fact]
+ public async Task RemoveAsync_ReturnsFalseForNonExistent()
+ {
+ var removed = await _store.RemoveAsync("non-existent", TestContext.Current.CancellationToken);
+
+ Assert.False(removed);
+ }
+
+ [Fact]
+ public async Task PruneIdleSessionsAsync_RemovesIdleSessions()
+ {
+ // Create sessions at different times
+ _timeProvider.SetUtcNow(new DateTimeOffset(2025, 1, 1, 10, 0, 0, TimeSpan.Zero));
+ var oldSession = CreateTestMetadata("old-session");
+ oldSession.LastActivityUtc = _timeProvider.GetUtcNow().DateTime;
+ await _store.SaveAsync(oldSession, TestContext.Current.CancellationToken);
+
+ _timeProvider.SetUtcNow(new DateTimeOffset(2025, 1, 1, 12, 0, 0, TimeSpan.Zero));
+ var newSession = CreateTestMetadata("new-session");
+ newSession.LastActivityUtc = _timeProvider.GetUtcNow().DateTime;
+ await _store.SaveAsync(newSession, TestContext.Current.CancellationToken);
+
+ // Advance time and prune with 1 hour timeout
+ _timeProvider.SetUtcNow(new DateTimeOffset(2025, 1, 1, 12, 30, 0, TimeSpan.Zero));
+ var removed = await _store.PruneIdleSessionsAsync(TimeSpan.FromHours(1), TestContext.Current.CancellationToken);
+
+ Assert.Equal(1, removed);
+ Assert.Equal(1, _store.Count);
+ Assert.Null(await _store.GetAsync("old-session", TestContext.Current.CancellationToken));
+ Assert.NotNull(await _store.GetAsync("new-session", TestContext.Current.CancellationToken));
+ }
+
+ [Fact]
+ public async Task PruneIdleSessionsAsync_ReturnsZeroWhenNoIdleSessions()
+ {
+ var metadata = CreateTestMetadata("session-1");
+ metadata.LastActivityUtc = _timeProvider.GetUtcNow().DateTime;
+ await _store.SaveAsync(metadata, TestContext.Current.CancellationToken);
+
+ var removed = await _store.PruneIdleSessionsAsync(TimeSpan.FromHours(1), TestContext.Current.CancellationToken);
+
+ Assert.Equal(0, removed);
+ Assert.Equal(1, _store.Count);
+ }
+
+ [Fact]
+ public async Task Clear_RemovesAllSessions()
+ {
+ // Add multiple sessions
+ await _store.SaveAsync(CreateTestMetadata("session-1"), TestContext.Current.CancellationToken);
+ await _store.SaveAsync(CreateTestMetadata("session-2"), TestContext.Current.CancellationToken);
+ await _store.SaveAsync(CreateTestMetadata("session-3"), TestContext.Current.CancellationToken);
+
+ Assert.Equal(3, _store.Count);
+
+ _store.Clear();
+
+ Assert.Equal(0, _store.Count);
+ }
+
+ [Fact]
+ public async Task MultipleSessions_WorkCorrectly()
+ {
+ await _store.SaveAsync(CreateTestMetadata("session-1", userValue: "user-1"), TestContext.Current.CancellationToken);
+ await _store.SaveAsync(CreateTestMetadata("session-2", userValue: "user-2"), TestContext.Current.CancellationToken);
+ await _store.SaveAsync(CreateTestMetadata("session-3", userValue: "user-3"), TestContext.Current.CancellationToken);
+
+ Assert.Equal(3, _store.Count);
+
+ var session1 = await _store.GetAsync("session-1", TestContext.Current.CancellationToken);
+ var session2 = await _store.GetAsync("session-2", TestContext.Current.CancellationToken);
+ var session3 = await _store.GetAsync("session-3", TestContext.Current.CancellationToken);
+
+ Assert.Equal("user-1", session1?.UserIdClaimValue);
+ Assert.Equal("user-2", session2?.UserIdClaimValue);
+ Assert.Equal("user-3", session3?.UserIdClaimValue);
+ }
+
+ [Fact]
+ public async Task ConcurrentAccess_IsThreadSafe()
+ {
+ var tasks = new List();
+ var ct = TestContext.Current.CancellationToken;
+
+ // Add 100 sessions concurrently
+ for (int i = 0; i < 100; i++)
+ {
+ var sessionId = $"session-{i}";
+ tasks.Add(_store.SaveAsync(CreateTestMetadata(sessionId), ct));
+ }
+
+ await Task.WhenAll(tasks);
+
+ Assert.Equal(100, _store.Count);
+
+ // Read all sessions concurrently
+ var readTasks = Enumerable.Range(0, 100)
+ .Select(i => _store.GetAsync($"session-{i}", ct))
+ .ToList();
+
+ var results = await Task.WhenAll(readTasks);
+
+ Assert.All(results, r => Assert.NotNull(r));
+ }
+
+ private SessionMetadata CreateTestMetadata(
+ string sessionId,
+ string? userType = "sub",
+ string? userValue = null,
+ string? userIssuer = "test-issuer")
+ {
+ return new SessionMetadata
+ {
+ SessionId = sessionId,
+ UserIdClaimType = userType,
+ UserIdClaimValue = userValue ?? $"user-for-{sessionId}",
+ UserIdClaimIssuer = userIssuer,
+ CreatedAtUtc = _timeProvider.GetUtcNow().DateTime,
+ LastActivityUtc = _timeProvider.GetUtcNow().DateTime
+ };
+ }
+}
diff --git a/tests/ModelContextProtocol.AspNetCore.Tests/SessionMetadataTests.cs b/tests/ModelContextProtocol.AspNetCore.Tests/SessionMetadataTests.cs
new file mode 100644
index 000000000..3aad21d3a
--- /dev/null
+++ b/tests/ModelContextProtocol.AspNetCore.Tests/SessionMetadataTests.cs
@@ -0,0 +1,108 @@
+namespace ModelContextProtocol.AspNetCore.Tests;
+
+public class SessionMetadataTests
+{
+ [Fact]
+ public void SessionMetadata_RequiresSessionId()
+ {
+ var metadata = new SessionMetadata
+ {
+ SessionId = "test-session-id"
+ };
+
+ Assert.Equal("test-session-id", metadata.SessionId);
+ }
+
+ [Fact]
+ public void SessionMetadata_OptionalPropertiesAreNullByDefault()
+ {
+ var metadata = new SessionMetadata
+ {
+ SessionId = "test-session"
+ };
+
+ Assert.Null(metadata.UserIdClaimType);
+ Assert.Null(metadata.UserIdClaimValue);
+ Assert.Null(metadata.UserIdClaimIssuer);
+ Assert.Null(metadata.CustomDataJson);
+ }
+
+ [Fact]
+ public void SessionMetadata_StoresUserIdClaims()
+ {
+ var metadata = new SessionMetadata
+ {
+ SessionId = "test-session",
+ UserIdClaimType = "sub",
+ UserIdClaimValue = "user-123",
+ UserIdClaimIssuer = "https://issuer.example.com"
+ };
+
+ Assert.Equal("sub", metadata.UserIdClaimType);
+ Assert.Equal("user-123", metadata.UserIdClaimValue);
+ Assert.Equal("https://issuer.example.com", metadata.UserIdClaimIssuer);
+ }
+
+ [Fact]
+ public void SessionMetadata_StoresTimestamps()
+ {
+ var createdAt = new DateTime(2025, 1, 1, 12, 0, 0, DateTimeKind.Utc);
+ var lastActivity = new DateTime(2025, 1, 1, 12, 30, 0, DateTimeKind.Utc);
+
+ var metadata = new SessionMetadata
+ {
+ SessionId = "test-session",
+ CreatedAtUtc = createdAt,
+ LastActivityUtc = lastActivity
+ };
+
+ Assert.Equal(createdAt, metadata.CreatedAtUtc);
+ Assert.Equal(lastActivity, metadata.LastActivityUtc);
+ }
+
+ [Fact]
+ public void SessionMetadata_StoresCustomData()
+ {
+ var customJson = """{"key": "value", "count": 42}""";
+
+ var metadata = new SessionMetadata
+ {
+ SessionId = "test-session",
+ CustomDataJson = customJson
+ };
+
+ Assert.Equal(customJson, metadata.CustomDataJson);
+ }
+
+ [Fact]
+ public void SessionMetadata_CanRepresentAnonymousSession()
+ {
+ // Anonymous sessions have no user claims
+ var metadata = new SessionMetadata
+ {
+ SessionId = "anonymous-session",
+ CreatedAtUtc = DateTime.UtcNow,
+ LastActivityUtc = DateTime.UtcNow
+ };
+
+ Assert.Null(metadata.UserIdClaimValue);
+ Assert.NotEmpty(metadata.SessionId);
+ }
+
+ [Fact]
+ public void SessionMetadata_CanRepresentAuthenticatedSession()
+ {
+ var metadata = new SessionMetadata
+ {
+ SessionId = "authenticated-session",
+ UserIdClaimType = "http://schemas.xmlsoap.org/ws/2005/05/identity/claims/nameidentifier",
+ UserIdClaimValue = "user@example.com",
+ UserIdClaimIssuer = "local",
+ CreatedAtUtc = DateTime.UtcNow,
+ LastActivityUtc = DateTime.UtcNow
+ };
+
+ Assert.NotNull(metadata.UserIdClaimValue);
+ Assert.Equal("user@example.com", metadata.UserIdClaimValue);
+ }
+}
diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
index d9f3f1b48..d3658f34f 100644
--- a/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
+++ b/tests/ModelContextProtocol.Tests/Server/McpServerTests.cs
@@ -887,7 +887,7 @@ private static async Task InitializeServerAsync(TestServerTransport transport, C
await transport.SendClientMessageAsync(initializeRequest, cancellationToken);
// Wait for the initialize response to be sent
- await tcs.Task.WaitAsync(TimeSpan.FromSeconds(5), cancellationToken);
+ await tcs.Task.WaitAsync(TimeSpan.FromSeconds(30), cancellationToken);
}
private sealed class TestServerForIChatClient(bool supportsSampling) : McpServer
diff --git a/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs
index b49542784..4aada7679 100644
--- a/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs
+++ b/tests/ModelContextProtocol.Tests/Transport/SseResponseStreamTransportTests.cs
@@ -32,4 +32,54 @@ public async Task Can_Customize_MessageEndpoint()
responsePipe.Reader.Complete();
responsePipe.Writer.Complete();
}
+
+ [Fact]
+ public async Task Sends_Heartbeats_When_Idle()
+ {
+ var responsePipe = new Pipe();
+ var heartbeatInterval = TimeSpan.FromMilliseconds(100);
+
+ await using var transport = new SseResponseStreamTransport(
+ responsePipe.Writer.AsStream(),
+ heartbeatInterval: heartbeatInterval);
+
+ var transportRunTask = transport.RunAsync(TestContext.Current.CancellationToken);
+
+ using var responseStreamReader = new StreamReader(responsePipe.Reader.AsStream());
+
+ // Skip endpoint event
+ await responseStreamReader.ReadLineAsync(
+#if NET
+ TestContext.Current.CancellationToken
+#endif
+ );
+ await responseStreamReader.ReadLineAsync(
+#if NET
+ TestContext.Current.CancellationToken
+#endif
+ );
+ await responseStreamReader.ReadLineAsync(
+#if NET
+ TestContext.Current.CancellationToken
+#endif
+ );
+
+ // Wait for first heartbeat
+ var eventLine = await responseStreamReader.ReadLineAsync(
+#if NET
+ TestContext.Current.CancellationToken
+#endif
+ );
+ Assert.Equal("event: ping", eventLine);
+
+ var dataLine = await responseStreamReader.ReadLineAsync(
+#if NET
+ TestContext.Current.CancellationToken
+#endif
+ );
+ Assert.Equal("data: ", dataLine);
+
+ responsePipe.Reader.Complete();
+ responsePipe.Writer.Complete();
+ }
}