66under test.
77"""
88
9- from collections .abc import AsyncIterator
10- from contextlib import asynccontextmanager
9+ from collections .abc import AsyncIterator , Mapping
10+ from contextlib import AbstractAsyncContextManager , asynccontextmanager
1111from typing import Any , cast
1212
1313import anyio
@@ -76,6 +76,8 @@ async def connected_runner(
7676 initialized : bool = True ,
7777 stateless : bool = False ,
7878 has_standalone_channel : bool = True ,
79+ session_id : str | None = None ,
80+ headers : Mapping [str , str ] | None = None ,
7981 dispatch_middleware : list [DispatchMiddleware ] | None = None ,
8082) -> AsyncIterator [tuple [DirectDispatcher , ServerRunner [dict [str , Any ]]]]:
8183 """Yield ``(client, runner)`` running over an in-memory dispatcher pair.
@@ -85,12 +87,13 @@ async def connected_runner(
8587 ``initialized`` is true the helper performs the real ``initialize`` request
8688 before yielding, so tests start past the init-gate via the public path.
8789 """
88- client , server_d = create_direct_dispatcher_pair ()
90+ client , server_d = create_direct_dispatcher_pair (headers = headers )
8991 runner = ServerRunner (
9092 server = server ,
9193 dispatcher = server_d ,
9294 lifespan_state = {},
9395 has_standalone_channel = has_standalone_channel ,
96+ session_id = session_id ,
9497 stateless = stateless ,
9598 dispatch_middleware = dispatch_middleware or [],
9699 )
@@ -380,3 +383,100 @@ async def failing(ctx: Any, params: Any) -> Any:
380383 [event ] = [e for e in span .events if e .name == "exception" ]
381384 assert event .attributes is not None
382385 assert event .attributes ["exception.type" ] == "ValueError"
386+
387+
388+ @pytest .mark .anyio
389+ async def test_connection_state_persists_across_requests_on_same_connection (server : SrvT ) -> None :
390+ async def count (ctx : Any , params : PaginatedRequestParams | None ) -> ListToolsResult :
391+ ctx .connection .state ["n" ] = ctx .connection .state .get ("n" , 0 ) + 1
392+ return ListToolsResult (tools = [])
393+
394+ server .add_request_handler ("tools/list" , PaginatedRequestParams , count )
395+ async with connected_runner (server ) as (client , runner ):
396+ await client .send_raw_request ("tools/list" , None )
397+ await client .send_raw_request ("tools/list" , None )
398+ assert runner .connection .state == {"n" : 2 }
399+
400+
401+ @pytest .mark .anyio
402+ async def test_connection_exit_stack_runs_pushed_callback_after_close (server : SrvT ) -> None :
403+ cleaned : list [str ] = []
404+
405+ async def push (ctx : Any , params : PaginatedRequestParams | None ) -> ListToolsResult :
406+ async def _cleanup () -> None :
407+ cleaned .append ("done" )
408+
409+ ctx .connection .exit_stack .push_async_callback (_cleanup )
410+ return ListToolsResult (tools = [])
411+
412+ server .add_request_handler ("tools/list" , PaginatedRequestParams , push )
413+ async with connected_runner (server ) as (client , _runner ):
414+ await client .send_raw_request ("tools/list" , None )
415+ assert cleaned == []
416+ assert cleaned == ["done" ]
417+
418+
419+ @pytest .mark .anyio
420+ async def test_connection_exit_stack_unwinds_entered_context_manager_after_close (server : SrvT ) -> None :
421+ events : list [str ] = []
422+
423+ class _Tracker (AbstractAsyncContextManager [str ]):
424+ async def __aenter__ (self ) -> str :
425+ events .append ("enter" )
426+ return "resource"
427+
428+ async def __aexit__ (self , * exc : object ) -> None :
429+ events .append ("exit" )
430+
431+ async def acquire (ctx : Any , params : PaginatedRequestParams | None ) -> ListToolsResult :
432+ res = await ctx .connection .exit_stack .enter_async_context (_Tracker ())
433+ ctx .connection .state ["res" ] = res
434+ return ListToolsResult (tools = [])
435+
436+ server .add_request_handler ("tools/list" , PaginatedRequestParams , acquire )
437+ async with connected_runner (server ) as (client , runner ):
438+ await client .send_raw_request ("tools/list" , None )
439+ assert events == ["enter" ]
440+ assert runner .connection .state ["res" ] == "resource"
441+ assert events == ["enter" , "exit" ]
442+
443+
444+ @pytest .mark .anyio
445+ async def test_connection_exit_stack_runs_callbacks_lifo_after_handler_error (server : SrvT ) -> None :
446+ cleaned : list [int ] = []
447+
448+ async def push_then_fail (ctx : Any , params : PaginatedRequestParams | None ) -> ListToolsResult :
449+ for i in (1 , 2 , 3 ):
450+ ctx .connection .exit_stack .push_async_callback (_append , i )
451+ raise RuntimeError ("boom" )
452+
453+ async def _append (i : int ) -> None :
454+ cleaned .append (i )
455+
456+ server .add_request_handler ("tools/list" , PaginatedRequestParams , push_then_fail )
457+ async with connected_runner (server ) as (client , _runner ):
458+ with pytest .raises (MCPError ) as ei :
459+ await client .send_raw_request ("tools/list" , None )
460+ assert ei .value .error .code == INTERNAL_ERROR
461+ assert cleaned == []
462+ assert cleaned == [3 , 2 , 1 ]
463+
464+
465+ @pytest .mark .anyio
466+ async def test_context_session_id_and_headers_expose_connection_and_transport (server : SrvT ) -> None :
467+ async with connected_runner (server , session_id = "sess-abc" , headers = {"authorization" : "Bearer t" }) as (client , _r ):
468+ await client .send_raw_request ("tools/list" , None )
469+ [ctx ] = _seen_ctx
470+ assert ctx .session_id == "sess-abc"
471+ assert ctx .session_id == ctx .connection .session_id
472+ assert ctx .headers == {"authorization" : "Bearer t" }
473+ assert ctx .headers is ctx .transport .headers
474+
475+
476+ @pytest .mark .anyio
477+ async def test_context_session_id_and_headers_default_none (server : SrvT ) -> None :
478+ async with connected_runner (server ) as (client , _r ):
479+ await client .send_raw_request ("tools/list" , None )
480+ [ctx ] = _seen_ctx
481+ assert ctx .session_id is None
482+ assert ctx .headers is None
0 commit comments