Skip to content

Commit a1c16c4

Browse files
committed
feat: Connection.state + exit_stack; ctx.session_id/headers; TransportContext.headers
Per-connection state without a connection_lifespan CM or a second Server generic. Stateless is the default deployment, where a per-connection lifespan would wrap a single request; the enter-late mechanics it would need (race init vs dispatcher-done, ready-gate) were more machinery than the use case warrants. - Connection.session_id: str | None — set by the mount via ServerRunner(session_id=...); per-connection, not per-message - Connection.state: dict[str, Any] — scratch that persists across requests; handlers/middleware read and write freely - Connection.exit_stack: AsyncExitStack — handlers/middleware push CMs or callbacks for per-connection teardown; ServerRunner.run() unwinds it (shielded) in a finally after dispatcher.run() returns - TransportContext.headers: Mapping[str, str] | None on the base — populated by HTTP transports, None on stdio - Context.session_id / Context.headers convenience properties - create_direct_dispatcher_pair(headers=...) and connected_runner(session_id=..., headers=...) for tests
1 parent cd77adb commit a1c16c4

6 files changed

Lines changed: 163 additions & 15 deletions

File tree

src/mcp/server/connection.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""`Connection` — per-client connection state and the standalone outbound channel.
22
33
Always present on `Context` (never ``None``), even in stateless deployments.
4-
Holds peer info populated at ``initialize`` time, the per-connection lifespan
5-
output, and an `Outbound` for the standalone stream (the SSE GET stream in
6-
streamable HTTP, or the single duplex stream in stdio).
4+
Holds peer info populated at ``initialize`` time, per-connection scratch
5+
``state`` and an ``exit_stack`` for teardown, and an `Outbound` for the
6+
standalone stream (the SSE GET stream in streamable HTTP, or the single duplex
7+
stream in stdio).
78
89
`notify` is best-effort: it never raises. If there's no standalone channel
910
(stateless HTTP) or the stream has been dropped, the notification is
@@ -14,6 +15,7 @@
1415

1516
import logging
1617
from collections.abc import Mapping
18+
from contextlib import AsyncExitStack
1719
from typing import Any
1820

1921
import anyio
@@ -44,17 +46,27 @@ class Connection(TypedServerRequestMixin):
4446
``None`` until ``initialize`` completes; ``initialized`` is set then.
4547
"""
4648

47-
def __init__(self, outbound: Outbound, *, has_standalone_channel: bool) -> None:
49+
def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_id: str | None = None) -> None:
4850
self._outbound = outbound
4951
self.has_standalone_channel = has_standalone_channel
52+
self.session_id: str | None = session_id
5053

5154
self.client_info: Implementation | None = None
5255
self.client_capabilities: ClientCapabilities | None = None
5356
self.protocol_version: str | None = None
5457
self.initialized: anyio.Event = anyio.Event()
55-
# TODO: make this generic (Connection[StateT]) once connection_lifespan
56-
# wiring lands in ServerRunner.
57-
self.state: Any = None
58+
59+
self.state: dict[str, Any] = {}
60+
"""Per-connection scratch state. Handlers and middleware may read and
61+
write freely; persists across requests on this connection."""
62+
63+
self.exit_stack: AsyncExitStack = AsyncExitStack()
64+
"""Cleanup stack unwound by `ServerRunner` when the connection closes.
65+
66+
Push context managers (``await exit_stack.enter_async_context(...)``)
67+
or callbacks (``exit_stack.push_async_callback(...)``) from handlers or
68+
middleware to register per-connection teardown. Unwound LIFO after
69+
`dispatcher.run()` returns, shielded from cancellation."""
5870

5971
async def send_raw_request(
6072
self,

src/mcp/server/context.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from collections.abc import Awaitable, Callable
3+
from collections.abc import Awaitable, Callable, Mapping
44
from dataclasses import dataclass
55
from typing import Any, Generic, Protocol
66

@@ -69,6 +69,23 @@ def connection(self) -> Connection:
6969
"""The per-client `Connection` for this request's connection."""
7070
return self._connection
7171

72+
@property
73+
def session_id(self) -> str | None:
74+
"""The transport's session id for this connection, when one exists.
75+
76+
Convenience for ``ctx.connection.session_id``. ``None`` on stdio and
77+
stateless HTTP.
78+
"""
79+
return self._connection.session_id
80+
81+
@property
82+
def headers(self) -> Mapping[str, str] | None:
83+
"""Request headers carried by this message, when the transport has them.
84+
85+
Convenience for ``ctx.transport.headers``. ``None`` on stdio.
86+
"""
87+
return self.transport.headers
88+
7289
async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, *, meta: Meta | None = None) -> None:
7390
"""Send a request-scoped ``notifications/message`` log entry.
7491

src/mcp/server/runner.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class ServerRunner(Generic[LifespanT]):
116116
dispatcher: Dispatcher[TransportContext]
117117
lifespan_state: LifespanT
118118
has_standalone_channel: bool
119+
session_id: str | None = None
119120
stateless: bool = False
120121
dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware])
121122

@@ -124,17 +125,25 @@ class ServerRunner(Generic[LifespanT]):
124125

125126
def __post_init__(self) -> None:
126127
self._initialized = self.stateless
127-
self.connection = Connection(self.dispatcher, has_standalone_channel=self.has_standalone_channel)
128+
self.connection = Connection(
129+
self.dispatcher, has_standalone_channel=self.has_standalone_channel, session_id=self.session_id
130+
)
128131

129132
async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None:
130133
"""Drive the dispatcher until the underlying channel closes.
131134
132135
Composes `dispatch_middleware` over `_on_request` and hands the result
133136
to `dispatcher.run()`. ``task_status.started()`` is forwarded so callers
134137
can ``await tg.start(runner.run)`` and resume once the dispatcher is
135-
ready to accept requests.
138+
ready to accept requests. Once the dispatcher exits,
139+
`connection.exit_stack` is unwound (shielded) so any per-connection
140+
cleanup registered by handlers or middleware runs to completion.
136141
"""
137-
await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status)
142+
try:
143+
await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status)
144+
finally:
145+
with anyio.CancelScope(shield=True):
146+
await self.connection.exit_stack.aclose()
138147

139148
def _compose_on_request(self) -> OnRequest:
140149
"""Wrap `_on_request` in `dispatch_middleware`, outermost-first.

src/mcp/shared/direct_dispatcher.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,20 @@ async def _dispatch_notify(self, method: str, params: Mapping[str, Any] | None)
162162
def create_direct_dispatcher_pair(
163163
*,
164164
can_send_request: bool = True,
165+
headers: Mapping[str, str] | None = None,
165166
) -> tuple[DirectDispatcher, DirectDispatcher]:
166167
"""Create two `DirectDispatcher` instances wired to each other.
167168
168169
Args:
169170
can_send_request: Sets `TransportContext.can_send_request` on both
170171
sides. Pass ``False`` to simulate a transport with no back-channel.
172+
headers: Sets `TransportContext.headers` on both sides.
171173
172174
Returns:
173175
A ``(left, right)`` pair. Conventionally ``left`` is the client side
174176
and ``right`` is the server side, but the wiring is symmetric.
175177
"""
176-
ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request)
178+
ctx = TransportContext(kind=DIRECT_TRANSPORT_KIND, can_send_request=can_send_request, headers=headers)
177179
left = DirectDispatcher(ctx)
178180
right = DirectDispatcher(ctx)
179181
left.connect_to(right)

src/mcp/shared/transport_context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
dispatcher (`ServerRunner`, `Context`, user handlers) read its concrete fields.
77
"""
88

9+
from collections.abc import Mapping
910
from dataclasses import dataclass
1011

1112
__all__ = ["TransportContext"]
@@ -28,3 +29,10 @@ class TransportContext:
2829
stdio, SSE, and stateful streamable HTTP. When ``False``,
2930
`DispatchContext.send_raw_request` raises `NoBackChannelError`.
3031
"""
32+
33+
headers: Mapping[str, str] | None = None
34+
"""Request headers carried by this message, when the transport has them.
35+
36+
Populated by HTTP-based transports; ``None`` on stdio. Handlers should
37+
None-check before use.
38+
"""

tests/server/test_runner.py

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
under test.
77
"""
88

9-
from collections.abc import AsyncIterator
10-
from contextlib import asynccontextmanager
9+
from collections.abc import AsyncIterator, Mapping
10+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
1111
from typing import Any, cast
1212

1313
import anyio
@@ -76,6 +76,8 @@ async def connected_runner(
7676
initialized: bool = True,
7777
stateless: bool = False,
7878
has_standalone_channel: bool = True,
79+
session_id: str | None = None,
80+
headers: Mapping[str, str] | None = None,
7981
dispatch_middleware: list[DispatchMiddleware] | None = None,
8082
) -> AsyncIterator[tuple[DirectDispatcher, ServerRunner[dict[str, Any]]]]:
8183
"""Yield ``(client, runner)`` running over an in-memory dispatcher pair.
@@ -85,12 +87,13 @@ async def connected_runner(
8587
``initialized`` is true the helper performs the real ``initialize`` request
8688
before yielding, so tests start past the init-gate via the public path.
8789
"""
88-
client, server_d = create_direct_dispatcher_pair()
90+
client, server_d = create_direct_dispatcher_pair(headers=headers)
8991
runner = ServerRunner(
9092
server=server,
9193
dispatcher=server_d,
9294
lifespan_state={},
9395
has_standalone_channel=has_standalone_channel,
96+
session_id=session_id,
9497
stateless=stateless,
9598
dispatch_middleware=dispatch_middleware or [],
9699
)
@@ -380,3 +383,100 @@ async def failing(ctx: Any, params: Any) -> Any:
380383
[event] = [e for e in span.events if e.name == "exception"]
381384
assert event.attributes is not None
382385
assert event.attributes["exception.type"] == "ValueError"
386+
387+
388+
@pytest.mark.anyio
389+
async def test_connection_state_persists_across_requests_on_same_connection(server: SrvT) -> None:
390+
async def count(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult:
391+
ctx.connection.state["n"] = ctx.connection.state.get("n", 0) + 1
392+
return ListToolsResult(tools=[])
393+
394+
server.add_request_handler("tools/list", PaginatedRequestParams, count)
395+
async with connected_runner(server) as (client, runner):
396+
await client.send_raw_request("tools/list", None)
397+
await client.send_raw_request("tools/list", None)
398+
assert runner.connection.state == {"n": 2}
399+
400+
401+
@pytest.mark.anyio
402+
async def test_connection_exit_stack_runs_pushed_callback_after_close(server: SrvT) -> None:
403+
cleaned: list[str] = []
404+
405+
async def push(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult:
406+
async def _cleanup() -> None:
407+
cleaned.append("done")
408+
409+
ctx.connection.exit_stack.push_async_callback(_cleanup)
410+
return ListToolsResult(tools=[])
411+
412+
server.add_request_handler("tools/list", PaginatedRequestParams, push)
413+
async with connected_runner(server) as (client, _runner):
414+
await client.send_raw_request("tools/list", None)
415+
assert cleaned == []
416+
assert cleaned == ["done"]
417+
418+
419+
@pytest.mark.anyio
420+
async def test_connection_exit_stack_unwinds_entered_context_manager_after_close(server: SrvT) -> None:
421+
events: list[str] = []
422+
423+
class _Tracker(AbstractAsyncContextManager[str]):
424+
async def __aenter__(self) -> str:
425+
events.append("enter")
426+
return "resource"
427+
428+
async def __aexit__(self, *exc: object) -> None:
429+
events.append("exit")
430+
431+
async def acquire(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult:
432+
res = await ctx.connection.exit_stack.enter_async_context(_Tracker())
433+
ctx.connection.state["res"] = res
434+
return ListToolsResult(tools=[])
435+
436+
server.add_request_handler("tools/list", PaginatedRequestParams, acquire)
437+
async with connected_runner(server) as (client, runner):
438+
await client.send_raw_request("tools/list", None)
439+
assert events == ["enter"]
440+
assert runner.connection.state["res"] == "resource"
441+
assert events == ["enter", "exit"]
442+
443+
444+
@pytest.mark.anyio
445+
async def test_connection_exit_stack_runs_callbacks_lifo_after_handler_error(server: SrvT) -> None:
446+
cleaned: list[int] = []
447+
448+
async def push_then_fail(ctx: Any, params: PaginatedRequestParams | None) -> ListToolsResult:
449+
for i in (1, 2, 3):
450+
ctx.connection.exit_stack.push_async_callback(_append, i)
451+
raise RuntimeError("boom")
452+
453+
async def _append(i: int) -> None:
454+
cleaned.append(i)
455+
456+
server.add_request_handler("tools/list", PaginatedRequestParams, push_then_fail)
457+
async with connected_runner(server) as (client, _runner):
458+
with pytest.raises(MCPError) as ei:
459+
await client.send_raw_request("tools/list", None)
460+
assert ei.value.error.code == INTERNAL_ERROR
461+
assert cleaned == []
462+
assert cleaned == [3, 2, 1]
463+
464+
465+
@pytest.mark.anyio
466+
async def test_context_session_id_and_headers_expose_connection_and_transport(server: SrvT) -> None:
467+
async with connected_runner(server, session_id="sess-abc", headers={"authorization": "Bearer t"}) as (client, _r):
468+
await client.send_raw_request("tools/list", None)
469+
[ctx] = _seen_ctx
470+
assert ctx.session_id == "sess-abc"
471+
assert ctx.session_id == ctx.connection.session_id
472+
assert ctx.headers == {"authorization": "Bearer t"}
473+
assert ctx.headers is ctx.transport.headers
474+
475+
476+
@pytest.mark.anyio
477+
async def test_context_session_id_and_headers_default_none(server: SrvT) -> None:
478+
async with connected_runner(server) as (client, _r):
479+
await client.send_raw_request("tools/list", None)
480+
[ctx] = _seen_ctx
481+
assert ctx.session_id is None
482+
assert ctx.headers is None

0 commit comments

Comments
 (0)