Skip to content

Commit cd77adb

Browse files
committed
feat: Server registry stores HandlerEntry; ServerRunner consumes Server[L] directly
Server is generic in LifespanResultT only — no TransportContextT. Spike (scratch/spike-tt-on-server) found a third generic breaks bare-Server plumbing helpers via invariance and only buys one None-check; it remains additive later via PEP 696 default if demand materialises. TT stays on the transport layer (Dispatcher/DispatchContext/BaseContext in mcp.shared); the server layer (Server/Context/ServerRunner/ServerMiddleware) consumes base TransportContext. - HandlerEntry[L] frozen dataclass (params_type, handler) replaces bare callables in the registry; params type erased to Any in storage, correlated at add_request_handler[P] - Public add_request_handler/add_notification_handler; capabilities() zero-arg (notification_options/experimental_capabilities now ctor kwargs) - ServerRunner drops the ServerRegistry Protocol scaffold and reads Server[L] directly; _make_context no longer narrows dctx - ServerMiddleware[L] (one contravariant param) - Context[L] (BaseContext[TransportContext] fixed)
1 parent 918e20a commit cd77adb

5 files changed

Lines changed: 197 additions & 177 deletions

File tree

src/mcp/server/context.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ class ServerRequestContext(RequestContext[ServerSession], Generic[LifespanContex
3333

3434

3535
LifespanT = TypeVar("LifespanT", default=Any, covariant=True)
36-
TransportT = TypeVar("TransportT", bound=TransportContext, default=TransportContext, covariant=True)
3736

3837

39-
class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Generic[LifespanT, TransportT]):
38+
class Context(BaseContext[TransportContext], PeerMixin, TypedServerRequestMixin, Generic[LifespanT]):
4039
"""Server-side per-request context.
4140
4241
Composes `BaseContext` (forwards to `DispatchContext`, satisfies `Outbound`),
@@ -50,7 +49,7 @@ class Context(BaseContext[TransportT], PeerMixin, TypedServerRequestMixin, Gener
5049

5150
def __init__(
5251
self,
53-
dctx: DispatchContext[TransportT],
52+
dctx: DispatchContext[TransportContext],
5453
*,
5554
lifespan: LifespanT,
5655
connection: Connection,
@@ -94,23 +93,23 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *
9493
_MwLifespanT = TypeVar("_MwLifespanT", contravariant=True)
9594

9695

97-
class ContextMiddleware(Protocol[_MwLifespanT]):
96+
class ServerMiddleware(Protocol[_MwLifespanT]):
9897
"""Context-tier middleware: ``(ctx, method, typed_params, call_next) -> result``.
9998
10099
Runs *inside* `ServerRunner._on_request` after params validation and
101100
`Context` construction. Wraps registered handlers (including ``ping``) but
102101
not ``initialize``, ``METHOD_NOT_FOUND``, or validation failures. Listed
103102
outermost-first on `Server.middleware`.
104103
105-
`Server[L].middleware` holds `ContextMiddleware[L]`, so an app-specific
106-
middleware sees `ctx.lifespan: L`. A reusable middleware (no app-specific
107-
types) can be typed `ContextMiddleware[object]` — `Context` is covariant in
108-
`LifespanT`, so it registers on any `Server[L]`.
104+
`Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific
105+
middleware sees `ctx.lifespan: L`. A reusable middleware can be typed
106+
`ServerMiddleware[object]` — `Context` is covariant in `LifespanT`, so it
107+
registers on any `Server[L]`.
109108
"""
110109

111110
async def __call__(
112111
self,
113-
ctx: Context[_MwLifespanT, TransportContext],
112+
ctx: Context[_MwLifespanT],
114113
method: str,
115114
params: BaseModel,
116115
call_next: CallNext,

src/mcp/server/lowlevel/server.py

Lines changed: 101 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,13 @@ async def main():
4141
import warnings
4242
from collections.abc import AsyncIterator, Awaitable, Callable
4343
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
44+
from dataclasses import dataclass
4445
from importlib.metadata import version as importlib_version
4546
from typing import Any, Generic, cast
4647

4748
import anyio
4849
from opentelemetry.trace import SpanKind, StatusCode
50+
from pydantic import BaseModel
4951
from starlette.applications import Starlette
5052
from starlette.middleware import Middleware
5153
from starlette.middleware.authentication import AuthenticationMiddleware
@@ -58,7 +60,7 @@ async def main():
5860
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, TokenVerifier
5961
from mcp.server.auth.routes import build_resource_metadata_url, create_auth_routes, create_protected_resource_routes
6062
from mcp.server.auth.settings import AuthSettings
61-
from mcp.server.context import ContextMiddleware, ServerRequestContext
63+
from mcp.server.context import HandlerResult, ServerMiddleware, ServerRequestContext
6264
from mcp.server.experimental.request_context import Experimental
6365
from mcp.server.lowlevel.experimental import ExperimentalHandlers
6466
from mcp.server.models import InitializationOptions
@@ -76,6 +78,30 @@ async def main():
7678

7779
LifespanResultT = TypeVar("LifespanResultT", default=Any)
7880

81+
_ParamsT = TypeVar("_ParamsT", bound=BaseModel, default=BaseModel)
82+
83+
RequestHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[HandlerResult]]
84+
"""A registered request handler: ``(ctx, params) -> result``."""
85+
86+
NotificationHandler = Callable[[ServerRequestContext[LifespanResultT], _ParamsT], Awaitable[None]]
87+
"""A registered notification handler: ``(ctx, params) -> None``."""
88+
89+
90+
@dataclass(frozen=True, slots=True)
91+
class HandlerEntry(Generic[LifespanResultT]):
92+
"""A registered handler and the params model to validate incoming params against.
93+
94+
Stored in `Server._request_handlers` / `_notification_handlers` and consumed
95+
by `ServerRunner` to validate, build `Context`, and invoke. The handler's
96+
second-argument type is erased to ``Any`` in storage (each entry has a
97+
different concrete params type and `Callable` parameters are contravariant);
98+
the precise type is recoverable via `params_type`. The correlation is
99+
enforced at registration time by `Server.add_request_handler`.
100+
"""
101+
102+
params_type: type[BaseModel]
103+
handler: RequestHandler[LifespanResultT, Any]
104+
79105

80106
class NotificationOptions:
81107
def __init__(self, prompts_changed: bool = False, resources_changed: bool = False, tools_changed: bool = False):
@@ -85,7 +111,7 @@ def __init__(self, prompts_changed: bool = False, resources_changed: bool = Fals
85111

86112

87113
@asynccontextmanager
88-
async def lifespan(_: Server[LifespanResultT]) -> AsyncIterator[dict[str, Any]]:
114+
async def lifespan(_: Server[Any]) -> AsyncIterator[dict[str, Any]]:
89115
"""Default lifespan context manager that does nothing.
90116
91117
Returns:
@@ -109,6 +135,8 @@ def __init__(
109135
instructions: str | None = None,
110136
website_url: str | None = None,
111137
icons: list[types.Icon] | None = None,
138+
notification_options: NotificationOptions | None = None,
139+
experimental_capabilities: dict[str, dict[str, Any]] | None = None,
112140
lifespan: Callable[
113141
[Server[LifespanResultT]],
114142
AbstractAsyncContextManager[LifespanResultT],
@@ -193,72 +221,96 @@ def __init__(
193221
self.website_url = website_url
194222
self.icons = icons
195223
self.lifespan = lifespan
196-
self._request_handlers: dict[str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]]] = {}
197-
self._notification_handlers: dict[
198-
str, Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[None]]
199-
] = {}
224+
self._notification_options = notification_options or NotificationOptions()
225+
self._experimental_capabilities = experimental_capabilities or {}
226+
self._request_handlers: dict[str, HandlerEntry[LifespanResultT]] = {}
227+
self._notification_handlers: dict[str, HandlerEntry[LifespanResultT]] = {}
200228
self._experimental_handlers: ExperimentalHandlers[LifespanResultT] | None = None
201229
self._session_manager: StreamableHTTPSessionManager | None = None
202230
# Context-tier middleware consumed by `ServerRunner`. Additive; the
203231
# existing `run()` path ignores it.
204-
self.middleware: list[ContextMiddleware[LifespanResultT]] = []
232+
self.middleware: list[ServerMiddleware[LifespanResultT]] = []
205233
logger.debug("Initializing server %r", name)
206234

207-
# Populate internal handler dicts from on_* kwargs
208-
self._request_handlers.update(
209-
{
210-
method: handler
211-
for method, handler in {
212-
"ping": on_ping,
213-
"prompts/list": on_list_prompts,
214-
"prompts/get": on_get_prompt,
215-
"resources/list": on_list_resources,
216-
"resources/templates/list": on_list_resource_templates,
217-
"resources/read": on_read_resource,
218-
"resources/subscribe": on_subscribe_resource,
219-
"resources/unsubscribe": on_unsubscribe_resource,
220-
"tools/list": on_list_tools,
221-
"tools/call": on_call_tool,
222-
"logging/setLevel": on_set_logging_level,
223-
"completion/complete": on_completion,
224-
}.items()
225-
if handler is not None
226-
}
227-
)
235+
_spec_requests: list[tuple[str, type[BaseModel], RequestHandler[LifespanResultT, Any] | None]] = [
236+
("ping", types.RequestParams, on_ping),
237+
("prompts/list", types.PaginatedRequestParams, on_list_prompts),
238+
("prompts/get", types.GetPromptRequestParams, on_get_prompt),
239+
("resources/list", types.PaginatedRequestParams, on_list_resources),
240+
("resources/templates/list", types.PaginatedRequestParams, on_list_resource_templates),
241+
("resources/read", types.ReadResourceRequestParams, on_read_resource),
242+
("resources/subscribe", types.SubscribeRequestParams, on_subscribe_resource),
243+
("resources/unsubscribe", types.UnsubscribeRequestParams, on_unsubscribe_resource),
244+
("tools/list", types.PaginatedRequestParams, on_list_tools),
245+
("tools/call", types.CallToolRequestParams, on_call_tool),
246+
("logging/setLevel", types.SetLevelRequestParams, on_set_logging_level),
247+
("completion/complete", types.CompleteRequestParams, on_completion),
248+
]
249+
self._request_handlers.update({m: HandlerEntry(pt, h) for m, pt, h in _spec_requests if h is not None})
228250

251+
_spec_notifications: list[tuple[str, type[BaseModel], NotificationHandler[LifespanResultT, Any] | None]] = [
252+
("notifications/roots/list_changed", types.NotificationParams, on_roots_list_changed),
253+
("notifications/progress", types.ProgressNotificationParams, on_progress),
254+
]
229255
self._notification_handlers.update(
230-
{
231-
method: handler
232-
for method, handler in {
233-
"notifications/roots/list_changed": on_roots_list_changed,
234-
"notifications/progress": on_progress,
235-
}.items()
236-
if handler is not None
237-
}
256+
{m: HandlerEntry(pt, h) for m, pt, h in _spec_notifications if h is not None}
238257
)
239258

259+
def add_request_handler(
260+
self,
261+
method: str,
262+
params_type: type[_ParamsT],
263+
handler: RequestHandler[LifespanResultT, _ParamsT],
264+
) -> None:
265+
"""Register a request handler for ``method``.
266+
267+
``params_type`` is the model incoming params are validated against
268+
before the handler is invoked. It should subclass `RequestParams` so
269+
``_meta`` parses uniformly. Replaces any existing handler for the same
270+
method (no collision guard against spec methods).
271+
"""
272+
self._request_handlers[method] = HandlerEntry(params_type, handler)
273+
274+
def add_notification_handler(
275+
self,
276+
method: str,
277+
params_type: type[_ParamsT],
278+
handler: NotificationHandler[LifespanResultT, _ParamsT],
279+
) -> None:
280+
"""Register a notification handler for ``method``.
281+
282+
``params_type`` should subclass `NotificationParams` so ``_meta``
283+
parses uniformly. Replaces any existing handler.
284+
"""
285+
self._notification_handlers[method] = HandlerEntry(params_type, handler)
286+
240287
def _add_request_handler(
241288
self,
242289
method: str,
243-
handler: Callable[[ServerRequestContext[LifespanResultT], Any], Awaitable[Any]],
290+
handler: RequestHandler[LifespanResultT, Any],
244291
) -> None:
245-
"""Add a request handler, silently replacing any existing handler for the same method."""
246-
self._request_handlers[method] = handler
292+
# TODO: remove once experimental tasks plumbing and remaining callers
293+
# migrate to `add_request_handler` with an explicit params_type.
294+
self.add_request_handler(method, types.RequestParams, handler)
247295

248296
def _has_handler(self, method: str) -> bool:
249297
"""Check if a handler is registered for the given method."""
250298
return method in self._request_handlers or method in self._notification_handlers
251299

252300
# --- ServerRegistry protocol (consumed by ServerRunner) ------------------
253301

254-
def get_request_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None:
255-
"""Return the handler for a request method, or ``None``."""
302+
def get_request_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None:
303+
"""Return the registered entry for a request method, or ``None``."""
256304
return self._request_handlers.get(method)
257305

258-
def get_notification_handler(self, method: str) -> Callable[..., Awaitable[Any]] | None:
259-
"""Return the handler for a notification method, or ``None``."""
306+
def get_notification_handler(self, method: str) -> HandlerEntry[LifespanResultT] | None:
307+
"""Return the registered entry for a notification method, or ``None``."""
260308
return self._notification_handlers.get(method)
261309

310+
def capabilities(self) -> types.ServerCapabilities:
311+
"""Derive `ServerCapabilities` from registered handlers and constructor options."""
312+
return self.get_capabilities(self._notification_options, self._experimental_capabilities)
313+
262314
# TODO: Rethink capabilities API. Currently capabilities are derived from registered
263315
# handlers but require NotificationOptions to be passed externally for list_changed
264316
# flags, and experimental_capabilities as a separate dict. Consider deriving capabilities
@@ -474,7 +526,8 @@ async def _handle_request(
474526
attributes={"mcp.method.name": req.method, "jsonrpc.request.id": message.request_id},
475527
context=parent_context,
476528
) as span:
477-
if handler := self._request_handlers.get(req.method):
529+
if entry := self._request_handlers.get(req.method):
530+
handler = entry.handler
478531
logger.debug("Dispatching request of type %s", type(req).__name__)
479532

480533
try:
@@ -533,7 +586,8 @@ async def _handle_request(
533586
span.set_status(StatusCode.ERROR, response.message)
534587

535588
try:
536-
await message.respond(response)
589+
# TODO: cast goes away when `_handle_request` is deleted.
590+
await message.respond(cast(types.ServerResult | types.ErrorData, response))
537591
except (anyio.BrokenResourceError, anyio.ClosedResourceError):
538592
# Transport closed between handler unblocking and respond. Happens
539593
# when _receive_loop's finally wakes a handler blocked on
@@ -552,7 +606,8 @@ async def _handle_notification(
552606
session: ServerSession,
553607
lifespan_context: LifespanResultT,
554608
) -> None:
555-
if handler := self._notification_handlers.get(notify.method):
609+
if entry := self._notification_handlers.get(notify.method):
610+
handler = entry.handler
556611
logger.debug("Dispatching notification of type %s", type(notify).__name__)
557612

558613
try:

0 commit comments

Comments
 (0)