Skip to content

Commit 43d066c

Browse files
speedstorm1copybara-github
authored andcommitted
chore: Update user ID extraction logic in ADK templates.
PiperOrigin-RevId: 908285375
1 parent 68f053e commit 43d066c

2 files changed

Lines changed: 54 additions & 5 deletions

File tree

tests/unit/vertex_adk/test_agent_engine_templates_adk.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,43 @@ async def test_streaming_agent_run_with_events(
506506
events.append(event)
507507
assert len(events) == 1
508508

509+
@pytest.mark.asyncio
510+
async def test_streaming_agent_run_with_events_extracts_user_id_from_headers(
511+
self,
512+
default_instrumentor_builder_mock: mock.Mock,
513+
get_project_id_mock: mock.Mock,
514+
):
515+
app = agent_engines.AdkApp(agent=_TEST_AGENT)
516+
app.set_up()
517+
app._tmpl_attrs["in_memory_runner"] = _MockRunner()
518+
519+
request_json = json.dumps(
520+
{
521+
"message": {
522+
"parts": [{"text": "Hello"}],
523+
"role": "user",
524+
},
525+
}
526+
)
527+
headers = {
528+
"X-Goog-Authenticated-User-Email": "test_user_from_header@google.com"
529+
}
530+
531+
with mock.patch.object(app, "_init_session") as mock_init_session:
532+
mock_session = mock.Mock()
533+
mock_session.id = "mock_session_id"
534+
mock_init_session.return_value = mock_session
535+
536+
async for _ in app.streaming_agent_run_with_events(
537+
request_json=request_json, headers=headers
538+
):
539+
pass
540+
541+
mock_init_session.assert_called_once()
542+
# Assert that the extracted request object correctly pulled the user_id from headers
543+
request_obj = mock_init_session.call_args.kwargs["request"]
544+
assert request_obj.user_id == "test_user_from_header@google.com"
545+
509546
@pytest.mark.asyncio
510547
@mock.patch.dict(
511548
os.environ,

vertexai/agent_engines/templates/adk.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,14 @@ def __init__(self, **kwargs):
195195
)
196196
# The authorizations of the user, keyed by authorization ID.
197197

198-
self.user_id: Optional[str] = kwargs.get("user_id") or kwargs.get(
199-
"userId", _DEFAULT_USER_ID
200-
)
198+
extracted_user_id = kwargs.get("user_id") or kwargs.get("userId")
199+
if not extracted_user_id:
200+
headers = kwargs.get("headers", {})
201+
extracted_user_id = headers.get(
202+
"X-Goog-Authenticated-User-Email"
203+
) or headers.get("X-Endpoint-API-UserInfo")
204+
205+
self.user_id: Optional[str] = extracted_user_id or _DEFAULT_USER_ID
201206
# The user ID.
202207

203208
self.session_id: Optional[str] = kwargs.get("session_id") or kwargs.get(
@@ -1195,7 +1200,9 @@ def stream_query(
11951200
):
11961201
yield _utils.dump_event_for_json(event)
11971202

1198-
async def streaming_agent_run_with_events(self, request_json: str):
1203+
async def streaming_agent_run_with_events(
1204+
self, request_json: str, headers: Optional[Dict[str, str]] = None
1205+
):
11991206
"""Streams responses asynchronously from the ADK application.
12001207
12011208
In general, you should use `async_stream_query` instead, as it has a
@@ -1206,13 +1213,18 @@ async def streaming_agent_run_with_events(self, request_json: str):
12061213
Args:
12071214
request_json (str):
12081215
Required. The request to stream responses for.
1216+
headers (Dict[str, str]):
1217+
Optional. The HTTP request headers containing IAM metadata.
12091218
"""
12101219

12111220
import json
12121221
from google.genai import types
12131222
from google.genai.errors import ClientError
12141223

1215-
request = _StreamRunRequest(**json.loads(request_json))
1224+
payload = json.loads(request_json)
1225+
if headers:
1226+
payload["headers"] = headers
1227+
request = _StreamRunRequest(**payload)
12161228
if not any(
12171229
self._tmpl_attrs.get(service)
12181230
for service in (

0 commit comments

Comments
 (0)