Skip to content

Commit 3c39eac

Browse files
committed
Notify connected agents of server restarts and hot-reloads so they can reload
1 parent 6affbe0 commit 3c39eac

4 files changed

Lines changed: 372 additions & 8 deletions

File tree

dash/mcp/_server.py

Lines changed: 79 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import json
1010
import logging
11+
import uuid
1112
from typing import TYPE_CHECKING, Any
1213

1314
from flask import Response, request
@@ -51,6 +52,24 @@ def enable_mcp_server(app: Dash, mcp_path: str) -> None:
5152
app.mcp_decorated_functions = dict(MCP_DECORATED_FUNCTIONS)
5253
MCP_DECORATED_FUNCTIONS.clear()
5354

55+
_session_id: str | None = None
56+
57+
def _get_or_create_session_id() -> str:
58+
"""Read the hot-reload hash or generate a stable fallback."""
59+
# pylint: disable=protected-access
60+
reload_hash = app._hot_reload.hash
61+
return reload_hash if reload_hash is not None else uuid.uuid4().hex
62+
63+
def _is_session_stale(client_session_id: str | None) -> bool:
64+
"""True when the client's session doesn't match or the hash changed."""
65+
if client_session_id != _session_id:
66+
return True
67+
# pylint: disable=protected-access
68+
reload_hash = app._hot_reload.hash
69+
if reload_hash is None:
70+
return False
71+
return reload_hash != _session_id
72+
5473
# -- Streamable HTTP endpoint --------------------------------------------
5574

5675
def mcp_handler() -> Response:
@@ -75,6 +94,42 @@ def _handle_get() -> Response:
7594
status=405,
7695
)
7796

97+
def _check_session(method: str) -> bool:
98+
"""Validate the session header.
99+
100+
Raises ``ValueError`` when the header is missing.
101+
Returns ``True`` when the session was stale and transparently
102+
recovered, or ``False`` when the session is valid.
103+
"""
104+
nonlocal _session_id
105+
if method == "initialize":
106+
_session_id = _get_or_create_session_id()
107+
return False
108+
client_session_id = request.headers.get("Mcp-Session-Id")
109+
if _session_id is not None and not client_session_id:
110+
raise ValueError("Missing Mcp-Session-Id header")
111+
if _is_session_stale(client_session_id):
112+
_session_id = _get_or_create_session_id()
113+
logger.debug("MCP session recovered: %s", _session_id)
114+
return True
115+
return False
116+
117+
def _json_response(*messages: dict) -> Response:
118+
"""Wrap one or more JSON-RPC messages in a Flask Response.
119+
120+
A single message is serialised as a JSON object; multiple
121+
messages are serialised as a JSON array.
122+
"""
123+
body = messages[0] if len(messages) == 1 else list(messages)
124+
resp = Response(
125+
json.dumps(body),
126+
content_type="application/json",
127+
status=200,
128+
)
129+
if _session_id is not None:
130+
resp.headers["Mcp-Session-Id"] = _session_id
131+
return resp
132+
78133
def _handle_post() -> Response:
79134
content_type = request.content_type or ""
80135
if "application/json" not in content_type:
@@ -92,16 +147,33 @@ def _handle_post() -> Response:
92147
status=400,
93148
)
94149

150+
method = data.get("method", "")
151+
152+
try:
153+
is_stale_session = _check_session(method)
154+
except ValueError as err:
155+
return Response(
156+
json.dumps({"error": str(err)}),
157+
content_type="application/json",
158+
status=400,
159+
)
160+
95161
response_data = _process_mcp_message(data)
96162

97163
if response_data is None:
98164
return Response("", status=202)
99165

100-
return Response(
101-
json.dumps(response_data),
102-
content_type="application/json",
103-
status=200,
104-
)
166+
if is_stale_session:
167+
return _json_response(
168+
{"jsonrpc": "2.0", "method": "notifications/tools/list_changed"},
169+
{
170+
"jsonrpc": "2.0",
171+
"method": "notifications/resources/list_changed",
172+
},
173+
response_data,
174+
)
175+
176+
return _json_response(response_data)
105177

106178
def _handle_delete() -> Response:
107179
# No sessions to terminate — server is stateless.
@@ -129,8 +201,8 @@ def _handle_initialize() -> InitializeResult:
129201
return InitializeResult(
130202
protocolVersion=LATEST_PROTOCOL_VERSION,
131203
capabilities=ServerCapabilities(
132-
tools=ToolsCapability(listChanged=False),
133-
resources=ResourcesCapability(),
204+
tools=ToolsCapability(listChanged=True),
205+
resources=ResourcesCapability(listChanged=True),
134206
),
135207
serverInfo=Implementation(name="Plotly Dash", version=__version__),
136208
instructions=(

dash/mcp/primitives/tools/results/result_dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class DataFrameResult(ResultFormatter):
4747
"""Produce a markdown table for tabular component output values."""
4848

4949
@classmethod
50-
def format(cls, output: MCPOutput, returned_output_value: Any) -> list[TextContent]:
50+
def format(cls, output: MCPOutput, returned_output_value: Any) -> list[TextContent]: # type: ignore[override]
5151
if not TABULAR.matches(output.get("component_type"), output["property"]):
5252
return []
5353
if (
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
"""MCP session lifecycle — end-to-end over a real Dash server.
2+
3+
Exercises the full MCP session flow (initialize → operate → hot-reload
4+
recovery) against a live ``dash_duo`` server using real HTTP requests.
5+
Unit-level checks (status codes, header mechanics) live in
6+
``tests/unit/mcp/test_mcp_session.py``; these tests verify the broader
7+
behavioral contract.
8+
"""
9+
10+
import requests
11+
12+
from dash import Dash, Input, Output, html
13+
14+
from tests.integration.mcp.conftest import _mcp_post
15+
16+
17+
def _mcp_post_with_session(
18+
server_url, method, params=None, request_id=1, session_id=None
19+
):
20+
"""Like ``_mcp_post`` but forwards an ``Mcp-Session-Id`` header."""
21+
headers = {"Content-Type": "application/json"}
22+
if session_id is not None:
23+
headers["Mcp-Session-Id"] = session_id
24+
return requests.post(
25+
f"{server_url}/_mcp",
26+
json={
27+
"jsonrpc": "2.0",
28+
"method": method,
29+
"id": request_id,
30+
"params": params or {},
31+
},
32+
headers=headers,
33+
timeout=5,
34+
)
35+
36+
37+
def test_mcpse_e2e001_full_session_lifecycle(dash_duo):
38+
"""Initialize → tools/list → tools/call with session headers throughout."""
39+
app = Dash(__name__)
40+
app.layout = html.Div([html.Div(id="inp"), html.Div(id="out")])
41+
42+
@app.callback(Output("out", "children"), Input("inp", "children"))
43+
def echo(v):
44+
return f"echo: {v}"
45+
46+
dash_duo.start_server(app)
47+
url = dash_duo.server.url
48+
49+
init = _mcp_post_with_session(url, "initialize")
50+
assert init.status_code == 200
51+
sid = init.headers.get("Mcp-Session-Id")
52+
assert sid
53+
54+
notif = _mcp_post_with_session(
55+
url, "notifications/initialized", session_id=sid, request_id=None
56+
)
57+
assert notif.status_code == 202
58+
59+
tools_resp = _mcp_post_with_session(url, "tools/list", session_id=sid, request_id=2)
60+
assert tools_resp.status_code == 200
61+
tools = tools_resp.json()["result"]["tools"]
62+
assert any("echo" in t["name"] for t in tools)
63+
64+
tool_name = next(t["name"] for t in tools if "echo" in t["name"])
65+
call_resp = _mcp_post_with_session(
66+
url,
67+
"tools/call",
68+
params={"name": tool_name, "arguments": {"v": "hello"}},
69+
session_id=sid,
70+
request_id=3,
71+
)
72+
assert call_resp.status_code == 200
73+
assert call_resp.headers.get("Mcp-Session-Id") == sid
74+
75+
76+
def test_mcpse_e2e002_stale_session_recovers_with_notifications(dash_duo):
77+
"""Simulate a hot-reload hash change and verify transparent recovery."""
78+
app = Dash(__name__)
79+
app.layout = html.Div([html.Div(id="inp"), html.Div(id="out")])
80+
81+
@app.callback(Output("out", "children"), Input("inp", "children"))
82+
def echo(v):
83+
return f"echo: {v}"
84+
85+
dash_duo.start_server(app)
86+
url = dash_duo.server.url
87+
88+
app._hot_reload.hash = "original_hash"
89+
90+
init = _mcp_post_with_session(url, "initialize")
91+
sid = init.headers["Mcp-Session-Id"]
92+
assert sid == "original_hash"
93+
94+
resp = _mcp_post_with_session(url, "tools/list", session_id=sid, request_id=2)
95+
assert resp.status_code == 200
96+
97+
app._hot_reload.hash = "new_hash"
98+
99+
resp = _mcp_post_with_session(url, "tools/list", session_id=sid, request_id=3)
100+
assert resp.status_code == 200
101+
new_sid = resp.headers["Mcp-Session-Id"]
102+
assert new_sid == "new_hash"
103+
104+
data = resp.json()
105+
assert isinstance(data, list)
106+
assert len(data) == 3
107+
assert data[0]["method"] == "notifications/tools/list_changed"
108+
assert data[1]["method"] == "notifications/resources/list_changed"
109+
assert "result" in data[2]
110+
assert "tools" in data[2]["result"]
111+
112+
resp = _mcp_post_with_session(url, "tools/list", session_id=new_sid, request_id=4)
113+
assert resp.status_code == 200
114+
data = resp.json()
115+
assert isinstance(data, dict)
116+
assert "result" in data
117+
118+
119+
def test_mcpse_e2e003_capabilities_advertise_list_changed(dash_duo):
120+
"""Server capabilities include listChanged for tools and resources."""
121+
app = Dash(__name__)
122+
app.layout = html.Div(id="root")
123+
dash_duo.start_server(app)
124+
125+
resp = _mcp_post(dash_duo.server.url, "initialize")
126+
caps = resp.json()["result"]["capabilities"]
127+
assert caps["tools"]["listChanged"] is True
128+
assert caps["resources"]["listChanged"] is True

0 commit comments

Comments
 (0)