Skip to content

Commit 9fba13a

Browse files
committed
fix oauth auth flow user agent forwarding
1 parent 3d7b311 commit 9fba13a

2 files changed

Lines changed: 108 additions & 2 deletions

File tree

src/mcp/client/auth/oauth2.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,12 @@ def __init__(
273273
self._validate_resource_url_callback = validate_resource_url
274274
self._initialized = False
275275

276+
def _copy_user_agent_header(self, request: httpx.Request, source_request: httpx.Request) -> httpx.Request:
277+
user_agent = source_request.headers.get("User-Agent")
278+
if user_agent and "User-Agent" not in request.headers:
279+
request.headers["User-Agent"] = user_agent
280+
return request
281+
276282
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
277283
"""Handle protected resource metadata discovery response.
278284
@@ -515,6 +521,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
515521
if not self.context.is_token_valid() and self.context.can_refresh_token():
516522
# Try to refresh token
517523
refresh_request = await self._refresh_token() # pragma: no cover
524+
self._copy_user_agent_header(refresh_request, request) # pragma: no cover
518525
refresh_response = yield refresh_request # pragma: no cover
519526

520527
if not await self._handle_refresh_response(refresh_response): # pragma: no cover
@@ -539,6 +546,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
539546

540547
for url in prm_discovery_urls: # pragma: no branch
541548
discovery_request = create_oauth_metadata_request(url)
549+
self._copy_user_agent_header(discovery_request, request)
542550

543551
discovery_response = yield discovery_request # sending request
544552

@@ -565,6 +573,7 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
565573
# Step 2: Discover OAuth Authorization Server Metadata (OASM) (with fallback for legacy servers)
566574
for url in asm_discovery_urls: # pragma: no branch
567575
oauth_metadata_request = create_oauth_metadata_request(url)
576+
self._copy_user_agent_header(oauth_metadata_request, request)
568577
oauth_metadata_response = yield oauth_metadata_request
569578

570579
ok, asm = await handle_auth_metadata_response(oauth_metadata_response)
@@ -604,13 +613,16 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
604613
self.context.client_metadata,
605614
self.context.get_authorization_base_url(self.context.server_url),
606615
)
616+
self._copy_user_agent_header(registration_request, request)
607617
registration_response = yield registration_request
608618
client_information = await handle_registration_response(registration_response)
609619
self.context.client_info = client_information
610620
await self.context.storage.set_client_info(client_information)
611621

612622
# Step 5: Perform authorization and complete token exchange
613-
token_response = yield await self._perform_authorization()
623+
token_request = await self._perform_authorization()
624+
self._copy_user_agent_header(token_request, request)
625+
token_response = yield token_request
614626
await self._handle_token_response(token_response)
615627
except Exception: # pragma: no cover
616628
logger.exception("OAuth flow error")
@@ -635,7 +647,9 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
635647
)
636648

637649
# Step 2b: Perform (re-)authorization and token exchange
638-
token_response = yield await self._perform_authorization()
650+
token_request = await self._perform_authorization()
651+
self._copy_user_agent_header(token_request, request)
652+
token_response = yield token_request
639653
await self._handle_token_response(token_response)
640654
except Exception: # pragma: no cover
641655
logger.exception("OAuth flow error")

tests/client/test_auth.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,98 @@ async def callback_handler() -> tuple[str, str | None]:
15891589
except StopAsyncIteration:
15901590
pass
15911591

1592+
@pytest.mark.anyio
1593+
async def test_oauth_flow_forwards_user_agent_to_generated_auth_requests(
1594+
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage
1595+
):
1596+
"""OAuth discovery, registration, and token requests should preserve the transport User-Agent."""
1597+
1598+
async def redirect_handler(url: str) -> None:
1599+
pass # pragma: no cover
1600+
1601+
async def callback_handler() -> tuple[str, str | None]:
1602+
return "test_auth_code", "test_state" # pragma: no cover
1603+
1604+
provider = OAuthClientProvider(
1605+
server_url="https://api.example.com/v1/mcp",
1606+
client_metadata=client_metadata,
1607+
storage=mock_storage,
1608+
redirect_handler=redirect_handler,
1609+
callback_handler=callback_handler,
1610+
)
1611+
provider._initialized = True
1612+
provider._perform_authorization_code_grant = mock.AsyncMock(
1613+
return_value=("test_auth_code", "test_code_verifier")
1614+
)
1615+
1616+
test_request = httpx.Request(
1617+
"POST",
1618+
"https://api.example.com/v1/mcp",
1619+
headers={"User-Agent": "custom-mcp-client/1.0"},
1620+
)
1621+
auth_flow = provider.async_auth_flow(test_request)
1622+
1623+
try:
1624+
first_request = await auth_flow.__anext__()
1625+
assert first_request.headers["User-Agent"] == "custom-mcp-client/1.0"
1626+
1627+
response = httpx.Response(
1628+
401,
1629+
headers={
1630+
"WWW-Authenticate": (
1631+
'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
1632+
)
1633+
},
1634+
request=test_request,
1635+
)
1636+
1637+
discovery_request = await auth_flow.asend(response)
1638+
assert discovery_request.headers["User-Agent"] == "custom-mcp-client/1.0"
1639+
1640+
discovery_response = httpx.Response(
1641+
200,
1642+
content=(
1643+
b'{"resource": "https://api.example.com/v1/mcp", '
1644+
b'"authorization_servers": ["https://auth.example.com"]}'
1645+
),
1646+
request=discovery_request,
1647+
)
1648+
1649+
oauth_metadata_request = await auth_flow.asend(discovery_response)
1650+
assert oauth_metadata_request.headers["User-Agent"] == "custom-mcp-client/1.0"
1651+
1652+
oauth_metadata_response = httpx.Response(
1653+
200,
1654+
content=(
1655+
b'{"issuer": "https://auth.example.com", '
1656+
b'"authorization_endpoint": "https://auth.example.com/authorize", '
1657+
b'"token_endpoint": "https://auth.example.com/token", '
1658+
b'"registration_endpoint": "https://auth.example.com/register"}'
1659+
),
1660+
request=oauth_metadata_request,
1661+
)
1662+
1663+
registration_request = await auth_flow.asend(oauth_metadata_response)
1664+
assert registration_request.headers["User-Agent"] == "custom-mcp-client/1.0"
1665+
1666+
registration_response = httpx.Response(
1667+
201,
1668+
content=(
1669+
b'{"client_id": "test_client", '
1670+
b'"client_secret": "test_secret", '
1671+
b'"redirect_uris": ["http://localhost:3030/callback"], '
1672+
b'"token_endpoint_auth_method": "client_secret_post", '
1673+
b'"grant_types": ["authorization_code"], '
1674+
b'"response_types": ["code"]}'
1675+
),
1676+
request=registration_request,
1677+
)
1678+
1679+
token_request = await auth_flow.asend(registration_response)
1680+
assert token_request.headers["User-Agent"] == "custom-mcp-client/1.0"
1681+
finally:
1682+
await auth_flow.aclose()
1683+
15921684
@pytest.mark.anyio
15931685
async def test_legacy_server_with_different_prm_and_root_urls(
15941686
self, client_metadata: OAuthClientMetadata, mock_storage: MockTokenStorage

0 commit comments

Comments
 (0)