Skip to content

Commit e4bf506

Browse files
committed
feat: add McpSessionStore SPI for pluggable session storage
Introduce a `McpSessionStore` interface that abstracts session storage, enabling custom implementations (e.g., Redis, JDBC) for distributed and multi-instance MCP server deployments. Changes: - Add `McpSessionStore` interface with save/get/remove/values/clear ops - Add `InMemoryMcpSessionStore` as the default ConcurrentHashMap-backed implementation (preserving existing behavior) - Refactor `HttpServletStreamableServerTransportProvider` to use `McpSessionStore` instead of a hardcoded ConcurrentHashMap - Add `sessionStore()` method to the Builder for custom store injection - Default to `InMemoryMcpSessionStore` when no custom store is provided This is a non-breaking change: existing code continues to work without modification as the default in-memory store is used automatically. Closes #274 Relates to #107, #738, #376
1 parent fcdc0d4 commit e4bf506

3 files changed

Lines changed: 182 additions & 21 deletions

File tree

mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import java.util.ArrayList;
1212
import java.util.List;
1313
import java.util.Map;
14-
import java.util.concurrent.ConcurrentHashMap;
1514
import java.util.concurrent.locks.ReentrantLock;
1615

1716
import org.slf4j.Logger;
@@ -22,8 +21,10 @@
2221
import io.modelcontextprotocol.common.McpTransportContext;
2322
import io.modelcontextprotocol.server.McpTransportContextExtractor;
2423
import io.modelcontextprotocol.spec.HttpHeaders;
24+
import io.modelcontextprotocol.spec.InMemoryMcpSessionStore;
2525
import io.modelcontextprotocol.spec.McpError;
2626
import io.modelcontextprotocol.spec.McpSchema;
27+
import io.modelcontextprotocol.spec.McpSessionStore;
2728
import io.modelcontextprotocol.spec.McpStreamableServerSession;
2829
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
2930
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
@@ -104,9 +105,9 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
104105
private McpStreamableServerSession.Factory sessionFactory;
105106

106107
/**
107-
* Map of active client sessions, keyed by mcp-session-id.
108+
* Store for active client sessions, keyed by mcp-session-id.
108109
*/
109-
private final ConcurrentHashMap<String, McpStreamableServerSession> sessions = new ConcurrentHashMap<>();
110+
private final McpSessionStore sessionStore;
110111

111112
private McpTransportContextExtractor<HttpServletRequest> contextExtractor;
112113

@@ -141,22 +142,25 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
141142
*/
142143
private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint,
143144
boolean disallowDelete, McpTransportContextExtractor<HttpServletRequest> contextExtractor,
144-
Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator) {
145+
Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator,
146+
McpSessionStore sessionStore) {
145147
Assert.notNull(jsonMapper, "JsonMapper must not be null");
146148
Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
147149
Assert.notNull(contextExtractor, "Context extractor must not be null");
148150
Assert.notNull(securityValidator, "Security validator must not be null");
151+
Assert.notNull(sessionStore, "Session store must not be null");
149152

150153
this.jsonMapper = jsonMapper;
151154
this.mcpEndpoint = mcpEndpoint;
152155
this.disallowDelete = disallowDelete;
153156
this.contextExtractor = contextExtractor;
154157
this.securityValidator = securityValidator;
158+
this.sessionStore = sessionStore;
155159

156160
if (keepAliveInterval != null) {
157161

158162
this.keepAliveScheduler = KeepAliveScheduler
159-
.builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values()))
163+
.builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessionStore.values()))
160164
.initialDelay(keepAliveInterval)
161165
.interval(keepAliveInterval)
162166
.build();
@@ -187,15 +191,15 @@ public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory)
187191
*/
188192
@Override
189193
public Mono<Void> notifyClients(String method, Object params) {
190-
if (this.sessions.isEmpty()) {
194+
if (this.sessionStore.isEmpty()) {
191195
logger.debug("No active sessions to broadcast message to");
192196
return Mono.empty();
193197
}
194198

195-
logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size());
199+
logger.debug("Attempting to broadcast message to {} active sessions", this.sessionStore.size());
196200

197201
return Mono.fromRunnable(() -> {
198-
this.sessions.values().parallelStream().forEach(session -> {
202+
this.sessionStore.values().parallelStream().forEach(session -> {
199203
try {
200204
session.sendNotification(method, params).block();
201205
}
@@ -209,7 +213,7 @@ public Mono<Void> notifyClients(String method, Object params) {
209213
@Override
210214
public Mono<Void> notifyClient(String sessionId, String method, Object params) {
211215
return Mono.defer(() -> {
212-
McpStreamableServerSession session = this.sessions.get(sessionId);
216+
McpStreamableServerSession session = this.sessionStore.get(sessionId);
213217
if (session == null) {
214218
logger.debug("Session {} not found", sessionId);
215219
return Mono.empty();
@@ -226,9 +230,9 @@ public Mono<Void> notifyClient(String sessionId, String method, Object params) {
226230
public Mono<Void> closeGracefully() {
227231
return Mono.fromRunnable(() -> {
228232
this.isClosing = true;
229-
logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size());
233+
logger.debug("Initiating graceful shutdown with {} active sessions", this.sessionStore.size());
230234

231-
this.sessions.values().parallelStream().forEach(session -> {
235+
this.sessionStore.values().parallelStream().forEach(session -> {
232236
try {
233237
session.closeGracefully().block();
234238
}
@@ -237,10 +241,10 @@ public Mono<Void> closeGracefully() {
237241
}
238242
});
239243

240-
this.sessions.clear();
244+
this.sessionStore.clear();
241245
logger.debug("Graceful shutdown completed");
242246
}).then().doOnSuccess(v -> {
243-
sessions.clear();
247+
sessionStore.clear();
244248
logger.debug("Graceful shutdown completed");
245249
if (this.keepAliveScheduler != null) {
246250
this.keepAliveScheduler.shutdown();
@@ -299,7 +303,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
299303
return;
300304
}
301305

302-
McpStreamableServerSession session = this.sessions.get(sessionId);
306+
McpStreamableServerSession session = this.sessionStore.get(sessionId);
303307

304308
if (session == null) {
305309
response.sendError(HttpServletResponse.SC_NOT_FOUND);
@@ -452,7 +456,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
452456
});
453457
McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory
454458
.startSession(initializeRequest);
455-
this.sessions.put(init.session().getId(), init.session());
459+
this.sessionStore.save(init.session().getId(), init.session());
456460

457461
try {
458462
McpSchema.InitializeResult initResult = init.initResult().block();
@@ -493,7 +497,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
493497
return;
494498
}
495499

496-
McpStreamableServerSession session = this.sessions.get(sessionId);
500+
McpStreamableServerSession session = this.sessionStore.get(sessionId);
497501

498502
if (session == null) {
499503
this.responseError(response, HttpServletResponse.SC_NOT_FOUND,
@@ -612,7 +616,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
612616
}
613617

614618
String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID);
615-
McpStreamableServerSession session = this.sessions.get(sessionId);
619+
McpStreamableServerSession session = this.sessionStore.get(sessionId);
616620

617621
if (session == null) {
618622
response.sendError(HttpServletResponse.SC_NOT_FOUND);
@@ -621,7 +625,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
621625

622626
try {
623627
session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block();
624-
this.sessions.remove(sessionId);
628+
this.sessionStore.remove(sessionId);
625629
response.setStatus(HttpServletResponse.SC_OK);
626630
}
627631
catch (Exception e) {
@@ -755,7 +759,7 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message, String messageId
755759
}
756760
catch (Exception e) {
757761
logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage());
758-
HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
762+
HttpServletStreamableServerTransportProvider.this.sessionStore.remove(this.sessionId);
759763
this.asyncContext.complete();
760764
}
761765
finally {
@@ -801,7 +805,7 @@ public void close() {
801805

802806
this.closed = true;
803807

804-
// HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
808+
// HttpServletStreamableServerTransportProvider.this.sessionStore.remove(this.sessionId);
805809
this.asyncContext.complete();
806810
logger.debug("Successfully completed async context for session {}", sessionId);
807811
}
@@ -838,6 +842,8 @@ public static class Builder {
838842

839843
private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;
840844

845+
private McpSessionStore sessionStore;
846+
841847
/**
842848
* Sets the JsonMapper to use for JSON serialization/deserialization of MCP
843849
* messages.
@@ -909,6 +915,19 @@ public Builder securityValidator(ServerTransportSecurityValidator securityValida
909915
return this;
910916
}
911917

918+
/**
919+
* Sets the session store for managing active client sessions. If not set, an
920+
* {@link InMemoryMcpSessionStore} will be used by default.
921+
* @param sessionStore The session store to use. Must not be null.
922+
* @return this builder instance
923+
* @throws IllegalArgumentException if sessionStore is null
924+
*/
925+
public Builder sessionStore(McpSessionStore sessionStore) {
926+
Assert.notNull(sessionStore, "Session store must not be null");
927+
this.sessionStore = sessionStore;
928+
return this;
929+
}
930+
912931
/**
913932
* Builds a new instance of {@link HttpServletStreamableServerTransportProvider}
914933
* with the configured settings.
@@ -919,7 +938,8 @@ public HttpServletStreamableServerTransportProvider build() {
919938
Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");
920939
return new HttpServletStreamableServerTransportProvider(
921940
jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, mcpEndpoint, disallowDelete,
922-
contextExtractor, keepAliveInterval, securityValidator);
941+
contextExtractor, keepAliveInterval, securityValidator,
942+
sessionStore == null ? new InMemoryMcpSessionStore() : sessionStore);
923943
}
924944

925945
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright 2024-2026 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.spec;
6+
7+
import java.util.Collection;
8+
import java.util.concurrent.ConcurrentHashMap;
9+
10+
/**
11+
* Default in-memory implementation of {@link McpSessionStore} backed by a
12+
* {@link ConcurrentHashMap}. This implementation is suitable for single-instance
13+
* deployments where session state does not need to be shared across multiple server
14+
* instances.
15+
*
16+
* <p>
17+
* This is the default session store used by
18+
* {@link io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider}
19+
* when no custom {@link McpSessionStore} is provided.
20+
*
21+
* @author WeiLin Wang
22+
* @see McpSessionStore
23+
*/
24+
public class InMemoryMcpSessionStore implements McpSessionStore {
25+
26+
private final ConcurrentHashMap<String, McpStreamableServerSession> sessions = new ConcurrentHashMap<>();
27+
28+
@Override
29+
public void save(String sessionId, McpStreamableServerSession session) {
30+
this.sessions.put(sessionId, session);
31+
}
32+
33+
@Override
34+
public McpStreamableServerSession get(String sessionId) {
35+
return this.sessions.get(sessionId);
36+
}
37+
38+
@Override
39+
public McpStreamableServerSession remove(String sessionId) {
40+
return this.sessions.remove(sessionId);
41+
}
42+
43+
@Override
44+
public Collection<McpStreamableServerSession> values() {
45+
return this.sessions.values();
46+
}
47+
48+
@Override
49+
public boolean isEmpty() {
50+
return this.sessions.isEmpty();
51+
}
52+
53+
@Override
54+
public int size() {
55+
return this.sessions.size();
56+
}
57+
58+
@Override
59+
public void clear() {
60+
this.sessions.clear();
61+
}
62+
63+
}
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright 2024-2026 the original author or authors.
3+
*/
4+
5+
package io.modelcontextprotocol.spec;
6+
7+
import java.util.Collection;
8+
9+
/**
10+
* Strategy interface for storing and retrieving MCP server sessions. This abstraction
11+
* allows the session storage mechanism to be customized, enabling implementations such as
12+
* in-memory (default), Redis-backed, JDBC-backed, or any distributed store.
13+
*
14+
* <p>
15+
* The default implementation {@link InMemoryMcpSessionStore} uses a
16+
* {@link java.util.concurrent.ConcurrentHashMap} which is suitable for single-instance
17+
* deployments. For distributed or multi-instance deployments, a custom implementation
18+
* backed by a distributed data store should be used.
19+
*
20+
* <p>
21+
* Note: {@link McpStreamableServerSession} objects contain active transport connections
22+
* (SSE streams) that are inherently tied to the JVM instance. A distributed session store
23+
* therefore stores the session reference per-node and coordinates session lifecycle
24+
* across nodes (e.g., detecting when a session was created on a different node).
25+
*
26+
* @author WeiLin Wang
27+
* @see InMemoryMcpSessionStore
28+
* @see McpStreamableServerSession
29+
*/
30+
public interface McpSessionStore {
31+
32+
/**
33+
* Stores a session with the given ID. If a session with the same ID already exists,
34+
* it will be replaced.
35+
* @param sessionId the unique session identifier
36+
* @param session the session to store
37+
*/
38+
void save(String sessionId, McpStreamableServerSession session);
39+
40+
/**
41+
* Retrieves a session by its ID.
42+
* @param sessionId the unique session identifier
43+
* @return the session associated with the given ID, or {@code null} if not found
44+
*/
45+
McpStreamableServerSession get(String sessionId);
46+
47+
/**
48+
* Removes a session by its ID.
49+
* @param sessionId the unique session identifier
50+
* @return the previously stored session, or {@code null} if no session was stored
51+
* with the given ID
52+
*/
53+
McpStreamableServerSession remove(String sessionId);
54+
55+
/**
56+
* Returns all currently stored sessions.
57+
* @return a collection of all stored sessions; never {@code null}
58+
*/
59+
Collection<McpStreamableServerSession> values();
60+
61+
/**
62+
* Returns whether there are any sessions stored.
63+
* @return {@code true} if no sessions are stored, {@code false} otherwise
64+
*/
65+
boolean isEmpty();
66+
67+
/**
68+
* Returns the number of stored sessions.
69+
* @return the session count
70+
*/
71+
int size();
72+
73+
/**
74+
* Removes all stored sessions.
75+
*/
76+
void clear();
77+
78+
}

0 commit comments

Comments
 (0)