Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 61 additions & 2 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Callable
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import parse_qs, urljoin, urlparse
from urllib.parse import parse_qs, urljoin, urlparse, urlunparse

import anyio
import httpx
Expand All @@ -27,6 +27,65 @@ def _extract_session_id_from_endpoint(endpoint_url: str) -> str | None:
return query_params.get("sessionId", [None])[0] or query_params.get("session_id", [None])[0]


def _resolve_endpoint_url(base_url: str, endpoint: str) -> str:
"""Resolve an endpoint URL, preserving any reverse proxy/API gateway path prefix.

When an MCP server sits behind a reverse proxy or API gateway that adds a
path prefix (e.g., ``/gateway``), the server's endpoint events contain paths
without that prefix. Standard ``urljoin`` drops the base URL's path prefix
for absolute paths (starting with ``/``). This function detects and
preserves such prefixes.

Example::

>>> _resolve_endpoint_url(
... "https://host/gateway/v1/sse",
... "/v1/messages/?session_id=abc",
... )
'https://host/gateway/v1/messages/?session_id=abc'
"""
parsed_ep = urlparse(endpoint)

# Full URL — use as-is
if parsed_ep.scheme:
return endpoint

# Relative path (no leading /) — urljoin handles correctly
if not endpoint.startswith("/"):
return urljoin(base_url, endpoint)

# For absolute paths, detect and preserve any gateway prefix.
# Strategy: find the first path segment of the endpoint inside the base URL
# path. If it appears at a position > 0, everything before it is the
# gateway prefix that must be preserved.
parsed_base = urlparse(base_url)
base_path = parsed_base.path
ep_path = parsed_ep.path

ep_segments = [s for s in ep_path.split("/") if s]
if ep_segments:
first_seg = "/" + ep_segments[0]
idx = base_path.find(first_seg + "/")
if idx < 0 and base_path.endswith(first_seg):
idx = len(base_path) - len(first_seg)

if idx > 0:
prefix = base_path[:idx]
return urlunparse(
(
parsed_base.scheme,
parsed_base.netloc,
prefix + ep_path,
"",
parsed_ep.query,
"",
)
)

# No prefix detected — fall back to standard resolution
return urljoin(base_url, endpoint)


@asynccontextmanager
async def sse_client(
url: str,
Expand Down Expand Up @@ -80,7 +139,7 @@ async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
logger.debug(f"Received SSE event: {sse.event}")
match sse.event:
case "endpoint":
endpoint_url = urljoin(url, sse.data)
endpoint_url = _resolve_endpoint_url(url, sse.data)
logger.debug(f"Received endpoint URL: {endpoint_url}")

url_parsed = urlparse(url)
Expand Down
74 changes: 73 additions & 1 deletion tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import mcp.client.sse
from mcp import types
from mcp.client.session import ClientSession
from mcp.client.sse import _extract_session_id_from_endpoint, sse_client
from mcp.client.sse import _extract_session_id_from_endpoint, _resolve_endpoint_url, sse_client
from mcp.server import Server, ServerRequestContext
from mcp.server.sse import SseServerTransport
from mcp.server.transport_security import TransportSecuritySettings
Expand Down Expand Up @@ -229,6 +229,78 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non
assert _extract_session_id_from_endpoint(endpoint_url) == expected


@pytest.mark.parametrize(
"base_url,endpoint,expected",
[
# --- Gateway / reverse proxy prefix (the bug from issue #795) ---
(
"https://example.com/gateway/v1/sse",
"/v1/messages/?session_id=abc",
"https://example.com/gateway/v1/messages/?session_id=abc",
),
(
"https://example.com/gateway_prefix/v1/sse",
"/v1/messages/?session_id=abc",
"https://example.com/gateway_prefix/v1/messages/?session_id=abc",
),
# Deep gateway prefix
(
"https://example.com/org/team/v1/sse",
"/v1/messages/?session_id=abc",
"https://example.com/org/team/v1/messages/?session_id=abc",
),
# --- No gateway prefix (should behave like urljoin) ---
(
"https://example.com/v1/sse",
"/v1/messages/?session_id=abc",
"https://example.com/v1/messages/?session_id=abc",
),
(
"https://example.com/sse",
"/messages/?session_id=abc",
"https://example.com/messages/?session_id=abc",
),
# --- Relative path (urljoin handles correctly) ---
(
"https://example.com/gateway/v1/sse",
"messages/?session_id=abc",
"https://example.com/gateway/v1/messages/?session_id=abc",
),
# --- Absolute URL endpoint (use as-is) ---
(
"https://example.com/gateway/v1/sse",
"https://example.com/v1/messages/?session_id=abc",
"https://example.com/v1/messages/?session_id=abc",
),
# --- Endpoint at end of base path (no trailing slash match) ---
(
"https://example.com/gw/api",
"/api/messages/?session_id=abc",
"https://example.com/gw/api/messages/?session_id=abc",
),
# --- Empty path (just /) — no segments to match ---
(
"https://example.com/gw/v1/sse",
"/?session_id=abc",
"https://example.com/?session_id=abc",
),
],
ids=[
"gateway_prefix",
"gateway_prefix_underscore",
"deep_gateway_prefix",
"no_prefix_same_root",
"no_prefix_root_level",
"relative_path",
"absolute_url",
"endpoint_at_path_end",
"empty_path_root_slash",
],
)
def test_resolve_endpoint_url(base_url: str, endpoint: str, expected: str) -> None:
assert _resolve_endpoint_url(base_url, endpoint) == expected


@pytest.mark.anyio
async def test_sse_client_on_session_created_not_called_when_no_session_id(
server: None, server_url: str, monkeypatch: pytest.MonkeyPatch
Expand Down