1111import java .util .ArrayList ;
1212import java .util .List ;
1313import java .util .Map ;
14- import java .util .concurrent .ConcurrentHashMap ;
1514import java .util .concurrent .locks .ReentrantLock ;
1615
1716import org .slf4j .Logger ;
2221import io .modelcontextprotocol .common .McpTransportContext ;
2322import io .modelcontextprotocol .server .McpTransportContextExtractor ;
2423import io .modelcontextprotocol .spec .HttpHeaders ;
24+ import io .modelcontextprotocol .spec .InMemoryMcpSessionStore ;
2525import io .modelcontextprotocol .spec .McpError ;
2626import io .modelcontextprotocol .spec .McpSchema ;
27+ import io .modelcontextprotocol .spec .McpSessionStore ;
2728import io .modelcontextprotocol .spec .McpStreamableServerSession ;
2829import io .modelcontextprotocol .spec .McpStreamableServerTransport ;
2930import io .modelcontextprotocol .spec .McpStreamableServerTransportProvider ;
@@ -104,9 +105,9 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
104105 private McpStreamableServerSession .Factory sessionFactory ;
105106
106107 /**
107- * Map of active client sessions, keyed by mcp-session-id.
108+ * Store for active client sessions, keyed by mcp-session-id.
108109 */
109- private final ConcurrentHashMap < String , McpStreamableServerSession > sessions = new ConcurrentHashMap <>() ;
110+ private final McpSessionStore sessionStore ;
110111
111112 private McpTransportContextExtractor <HttpServletRequest > contextExtractor ;
112113
@@ -141,22 +142,25 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
141142 */
142143 private HttpServletStreamableServerTransportProvider (McpJsonMapper jsonMapper , String mcpEndpoint ,
143144 boolean disallowDelete , McpTransportContextExtractor <HttpServletRequest > contextExtractor ,
144- Duration keepAliveInterval , ServerTransportSecurityValidator securityValidator ) {
145+ Duration keepAliveInterval , ServerTransportSecurityValidator securityValidator ,
146+ McpSessionStore sessionStore ) {
145147 Assert .notNull (jsonMapper , "JsonMapper must not be null" );
146148 Assert .notNull (mcpEndpoint , "MCP endpoint must not be null" );
147149 Assert .notNull (contextExtractor , "Context extractor must not be null" );
148150 Assert .notNull (securityValidator , "Security validator must not be null" );
151+ Assert .notNull (sessionStore , "Session store must not be null" );
149152
150153 this .jsonMapper = jsonMapper ;
151154 this .mcpEndpoint = mcpEndpoint ;
152155 this .disallowDelete = disallowDelete ;
153156 this .contextExtractor = contextExtractor ;
154157 this .securityValidator = securityValidator ;
158+ this .sessionStore = sessionStore ;
155159
156160 if (keepAliveInterval != null ) {
157161
158162 this .keepAliveScheduler = KeepAliveScheduler
159- .builder (() -> (isClosing ) ? Flux .empty () : Flux .fromIterable (sessions .values ()))
163+ .builder (() -> (isClosing ) ? Flux .empty () : Flux .fromIterable (sessionStore .values ()))
160164 .initialDelay (keepAliveInterval )
161165 .interval (keepAliveInterval )
162166 .build ();
@@ -187,15 +191,15 @@ public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory)
187191 */
188192 @ Override
189193 public Mono <Void > notifyClients (String method , Object params ) {
190- if (this .sessions .isEmpty ()) {
194+ if (this .sessionStore .isEmpty ()) {
191195 logger .debug ("No active sessions to broadcast message to" );
192196 return Mono .empty ();
193197 }
194198
195- logger .debug ("Attempting to broadcast message to {} active sessions" , this .sessions .size ());
199+ logger .debug ("Attempting to broadcast message to {} active sessions" , this .sessionStore .size ());
196200
197201 return Mono .fromRunnable (() -> {
198- this .sessions .values ().parallelStream ().forEach (session -> {
202+ this .sessionStore .values ().parallelStream ().forEach (session -> {
199203 try {
200204 session .sendNotification (method , params ).block ();
201205 }
@@ -209,7 +213,7 @@ public Mono<Void> notifyClients(String method, Object params) {
209213 @ Override
210214 public Mono <Void > notifyClient (String sessionId , String method , Object params ) {
211215 return Mono .defer (() -> {
212- McpStreamableServerSession session = this .sessions .get (sessionId );
216+ McpStreamableServerSession session = this .sessionStore .get (sessionId );
213217 if (session == null ) {
214218 logger .debug ("Session {} not found" , sessionId );
215219 return Mono .empty ();
@@ -226,9 +230,9 @@ public Mono<Void> notifyClient(String sessionId, String method, Object params) {
226230 public Mono <Void > closeGracefully () {
227231 return Mono .fromRunnable (() -> {
228232 this .isClosing = true ;
229- logger .debug ("Initiating graceful shutdown with {} active sessions" , this .sessions .size ());
233+ logger .debug ("Initiating graceful shutdown with {} active sessions" , this .sessionStore .size ());
230234
231- this .sessions .values ().parallelStream ().forEach (session -> {
235+ this .sessionStore .values ().parallelStream ().forEach (session -> {
232236 try {
233237 session .closeGracefully ().block ();
234238 }
@@ -237,10 +241,10 @@ public Mono<Void> closeGracefully() {
237241 }
238242 });
239243
240- this .sessions .clear ();
244+ this .sessionStore .clear ();
241245 logger .debug ("Graceful shutdown completed" );
242246 }).then ().doOnSuccess (v -> {
243- sessions .clear ();
247+ sessionStore .clear ();
244248 logger .debug ("Graceful shutdown completed" );
245249 if (this .keepAliveScheduler != null ) {
246250 this .keepAliveScheduler .shutdown ();
@@ -299,7 +303,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
299303 return ;
300304 }
301305
302- McpStreamableServerSession session = this .sessions .get (sessionId );
306+ McpStreamableServerSession session = this .sessionStore .get (sessionId );
303307
304308 if (session == null ) {
305309 response .sendError (HttpServletResponse .SC_NOT_FOUND );
@@ -452,7 +456,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
452456 });
453457 McpStreamableServerSession .McpStreamableServerSessionInit init = this .sessionFactory
454458 .startSession (initializeRequest );
455- this .sessions . put (init .session ().getId (), init .session ());
459+ this .sessionStore . save (init .session ().getId (), init .session ());
456460
457461 try {
458462 McpSchema .InitializeResult initResult = init .initResult ().block ();
@@ -493,7 +497,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
493497 return ;
494498 }
495499
496- McpStreamableServerSession session = this .sessions .get (sessionId );
500+ McpStreamableServerSession session = this .sessionStore .get (sessionId );
497501
498502 if (session == null ) {
499503 this .responseError (response , HttpServletResponse .SC_NOT_FOUND ,
@@ -612,7 +616,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
612616 }
613617
614618 String sessionId = request .getHeader (HttpHeaders .MCP_SESSION_ID );
615- McpStreamableServerSession session = this .sessions .get (sessionId );
619+ McpStreamableServerSession session = this .sessionStore .get (sessionId );
616620
617621 if (session == null ) {
618622 response .sendError (HttpServletResponse .SC_NOT_FOUND );
@@ -621,7 +625,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
621625
622626 try {
623627 session .delete ().contextWrite (ctx -> ctx .put (McpTransportContext .KEY , transportContext )).block ();
624- this .sessions .remove (sessionId );
628+ this .sessionStore .remove (sessionId );
625629 response .setStatus (HttpServletResponse .SC_OK );
626630 }
627631 catch (Exception e ) {
@@ -755,7 +759,7 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message, String messageId
755759 }
756760 catch (Exception e ) {
757761 logger .error ("Failed to send message to session {}: {}" , this .sessionId , e .getMessage ());
758- HttpServletStreamableServerTransportProvider .this .sessions .remove (this .sessionId );
762+ HttpServletStreamableServerTransportProvider .this .sessionStore .remove (this .sessionId );
759763 this .asyncContext .complete ();
760764 }
761765 finally {
@@ -801,7 +805,7 @@ public void close() {
801805
802806 this .closed = true ;
803807
804- // HttpServletStreamableServerTransportProvider.this.sessions .remove(this.sessionId);
808+ // HttpServletStreamableServerTransportProvider.this.sessionStore .remove(this.sessionId);
805809 this .asyncContext .complete ();
806810 logger .debug ("Successfully completed async context for session {}" , sessionId );
807811 }
@@ -838,6 +842,8 @@ public static class Builder {
838842
839843 private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator .NOOP ;
840844
845+ private McpSessionStore sessionStore ;
846+
841847 /**
842848 * Sets the JsonMapper to use for JSON serialization/deserialization of MCP
843849 * messages.
@@ -909,6 +915,19 @@ public Builder securityValidator(ServerTransportSecurityValidator securityValida
909915 return this ;
910916 }
911917
918+ /**
919+ * Sets the session store for managing active client sessions. If not set, an
920+ * {@link InMemoryMcpSessionStore} will be used by default.
921+ * @param sessionStore The session store to use. Must not be null.
922+ * @return this builder instance
923+ * @throws IllegalArgumentException if sessionStore is null
924+ */
925+ public Builder sessionStore (McpSessionStore sessionStore ) {
926+ Assert .notNull (sessionStore , "Session store must not be null" );
927+ this .sessionStore = sessionStore ;
928+ return this ;
929+ }
930+
912931 /**
913932 * Builds a new instance of {@link HttpServletStreamableServerTransportProvider}
914933 * with the configured settings.
@@ -919,7 +938,8 @@ public HttpServletStreamableServerTransportProvider build() {
919938 Assert .notNull (this .mcpEndpoint , "MCP endpoint must be set" );
920939 return new HttpServletStreamableServerTransportProvider (
921940 jsonMapper == null ? McpJsonDefaults .getMapper () : jsonMapper , mcpEndpoint , disallowDelete ,
922- contextExtractor , keepAliveInterval , securityValidator );
941+ contextExtractor , keepAliveInterval , securityValidator ,
942+ sessionStore == null ? new InMemoryMcpSessionStore () : sessionStore );
923943 }
924944
925945 }
0 commit comments