diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 61026aa0c..5f47540ec 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -2,7 +2,7 @@ from collections.abc import Callable from contextlib import asynccontextmanager from typing import Any -from urllib.parse import parse_qs, urljoin, urlparse +from urllib.parse import parse_qs, urljoin, urlparse, urlunparse import anyio import httpx @@ -27,6 +27,65 @@ def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None: return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0] +def _resolve_endpoint_url(base_url: str, endpoint: str) -> str: + """Resolve an endpoint URL, preserving any reverse proxy/API gateway path prefix. + + When an MCP server sits behind a reverse proxy or API gateway that adds a + path prefix (e.g., ``/gateway``), the server's endpoint events contain paths + without that prefix. Standard ``urljoin`` drops the base URL's path prefix + for absolute paths (starting with ``/``). This function detects and + preserves such prefixes. + + Example:: + + >>> _resolve_endpoint_url( + ... "https://host/gateway/v1/sse", + ... "/v1/messages/?session_id=abc", + ... ) + 'https://host/gateway/v1/messages/?session_id=abc' + """ + parsed_ep = urlparse(endpoint) + + # Full URL — use as-is + if parsed_ep.scheme: + return endpoint + + # Relative path (no leading /) — urljoin handles correctly + if not endpoint.startswith("/"): + return urljoin(base_url, endpoint) + + # For absolute paths, detect and preserve any gateway prefix. + # Strategy: find the first path segment of the endpoint inside the base URL + # path. If it appears at a position > 0, everything before it is the + # gateway prefix that must be preserved. + parsed_base = urlparse(base_url) + base_path = parsed_base.path + ep_path = parsed_ep.path + + ep_segments = [s for s in ep_path.split("/") if s] + if ep_segments: + first_seg = "/" + ep_segments[0] + idx = base_path.find(first_seg + "/") + if idx < 0 and base_path.endswith(first_seg): + idx = len(base_path) - len(first_seg) + + if idx > 0: + prefix = base_path[:idx] + return urlunparse( + ( + parsed_base.scheme, + parsed_base.netloc, + prefix + ep_path, + "", + parsed_ep.query, + "", + ) + ) + + # No prefix detected — fall back to standard resolution + return urljoin(base_url, endpoint) + + @asynccontextmanager async def sse_client( url: str, @@ -80,7 +139,7 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): logger.debug(f"Received SSE event: {sse.event}") match sse.event: case "endpoint": - endpoint_url = urljoin(url, sse.data) + endpoint_url = _resolve_endpoint_url(url, sse.data) logger.debug(f"Received endpoint URL: {endpoint_url}") url_parsed = urlparse(url) diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 890e99733..8875495f8 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -20,7 +20,7 @@ import mcp.client.sse from mcp import types from mcp.client.session import ClientSession -from mcp.client.sse import _extract_session_id_from_endpoint, sse_client +from mcp.client.sse import _extract_session_id_from_endpoint, _resolve_endpoint_url, sse_client from mcp.server import Server, ServerRequestContext from mcp.server.sse import SseServerTransport from mcp.server.transport_security import TransportSecuritySettings @@ -229,6 +229,78 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non assert _extract_session_id_from_endpoint(endpoint_url) == expected +@pytest.mark.parametrize( + "base_url,endpoint,expected", + [ + # --- Gateway / reverse proxy prefix (the bug from issue #795) --- + ( + "https://example.com/gateway/v1/sse", + "/v1/messages/?session_id=abc", + "https://example.com/gateway/v1/messages/?session_id=abc", + ), + ( + "https://example.com/gateway_prefix/v1/sse", + "/v1/messages/?session_id=abc", + "https://example.com/gateway_prefix/v1/messages/?session_id=abc", + ), + # Deep gateway prefix + ( + "https://example.com/org/team/v1/sse", + "/v1/messages/?session_id=abc", + "https://example.com/org/team/v1/messages/?session_id=abc", + ), + # --- No gateway prefix (should behave like urljoin) --- + ( + "https://example.com/v1/sse", + "/v1/messages/?session_id=abc", + "https://example.com/v1/messages/?session_id=abc", + ), + ( + "https://example.com/sse", + "/messages/?session_id=abc", + "https://example.com/messages/?session_id=abc", + ), + # --- Relative path (urljoin handles correctly) --- + ( + "https://example.com/gateway/v1/sse", + "messages/?session_id=abc", + "https://example.com/gateway/v1/messages/?session_id=abc", + ), + # --- Absolute URL endpoint (use as-is) --- + ( + "https://example.com/gateway/v1/sse", + "https://example.com/v1/messages/?session_id=abc", + "https://example.com/v1/messages/?session_id=abc", + ), + # --- Endpoint at end of base path (no trailing slash match) --- + ( + "https://example.com/gw/api", + "/api/messages/?session_id=abc", + "https://example.com/gw/api/messages/?session_id=abc", + ), + # --- Empty path (just /) — no segments to match --- + ( + "https://example.com/gw/v1/sse", + "/?session_id=abc", + "https://example.com/?session_id=abc", + ), + ], + ids=[ + "gateway_prefix", + "gateway_prefix_underscore", + "deep_gateway_prefix", + "no_prefix_same_root", + "no_prefix_root_level", + "relative_path", + "absolute_url", + "endpoint_at_path_end", + "empty_path_root_slash", + ], +) +def test_resolve_endpoint_url(base_url: str, endpoint: str, expected: str) -> None: + assert _resolve_endpoint_url(base_url, endpoint) == expected + + @pytest.mark.anyio async def test_sse_client_on_session_created_not_called_when_no_session_id( server: None, server_url: str, monkeypatch: pytest.MonkeyPatch