Skip to content

Commit ba5d64b

Browse files
xsmoreinis
andauthored
feat(openai_agents): expose real usage, response_id, plumb previous_response_id, opt-in prompt_cache_key for stateful responses and prompt caching (#335)
Co-authored-by: Stas Moreinis <stas.moreinis@scale.com>
1 parent ed6fd5e commit ba5d64b

2 files changed

Lines changed: 488 additions & 71 deletions

File tree

src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py

Lines changed: 76 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
ResponseReasoningSummaryTextDeltaEvent,
5555
ResponseFunctionCallArgumentsDeltaEvent,
5656
)
57+
from openai.types.responses.response_prompt_param import ResponsePromptParam
5758

5859
# AgentEx SDK imports
5960
from agentex.lib import adk
@@ -481,12 +482,23 @@ async def get_response(
481482
output_schema: Optional[AgentOutputSchemaBase],
482483
handoffs: list[Handoff],
483484
tracing: ModelTracing, # noqa: ARG002
484-
**kwargs, # noqa: ARG002
485+
*,
486+
previous_response_id: Optional[str] = None,
487+
conversation_id: Optional[str] = None,
488+
prompt: Optional[ResponsePromptParam] = None,
485489
) -> ModelResponse:
486490
"""Get a non-streaming response from the model with streaming to Redis.
487491
488492
This method is used by Temporal activities and needs to return a complete
489493
response, but we stream the response to Redis while generating it.
494+
495+
``previous_response_id``, ``conversation_id``, and ``prompt`` are all
496+
Responses API server-state parameters threaded through by the OpenAI
497+
Agents SDK. Each is forwarded to ``responses.create`` only when
498+
explicitly set — defaults resolve to ``NOT_GIVEN`` and are omitted from
499+
the request body. Not all OpenAI-compatible backends recognize these
500+
fields, so callers on alternative providers see no wire-level change
501+
unless they opt in.
490502
"""
491503

492504
task_id = streaming_task_id.get()
@@ -575,6 +587,11 @@ async def get_response(
575587
if model_settings.top_logprobs is not None:
576588
extra_args["top_logprobs"] = model_settings.top_logprobs
577589

590+
# Opt-in prompt_cache_key: forwarded only when the caller supplies it via
591+
# model_settings.extra_args["prompt_cache_key"]. Not all OpenAI-compatible
592+
# endpoints recognize this parameter, so we don't auto-inject a default.
593+
prompt_cache_key = extra_args.pop("prompt_cache_key", NOT_GIVEN)
594+
578595
# Create the response stream using Responses API
579596
logger.debug(f"[TemporalStreamingModel] Creating response stream with Responses API")
580597
stream = await self.client.responses.create( # type: ignore[call-overload]
@@ -605,12 +622,20 @@ async def get_response(
605622
extra_headers=model_settings.extra_headers,
606623
extra_query=model_settings.extra_query,
607624
extra_body=model_settings.extra_body,
625+
prompt_cache_key=prompt_cache_key,
626+
previous_response_id=self._non_null_or_not_given(previous_response_id),
627+
# SDK abstract names this conversation_id; the Responses API
628+
# endpoint kwarg is `conversation` (accepts a str id directly).
629+
conversation=self._non_null_or_not_given(conversation_id),
630+
prompt=self._non_null_or_not_given(prompt),
608631
# Any additional parameters from extra_args
609632
**extra_args,
610633
)
611634

612635
# Process the stream of events from Responses API
613636
output_items = []
637+
captured_usage = None
638+
captured_response_id = None
614639
current_text = ""
615640
streaming_context = None
616641
reasoning_context = None
@@ -821,10 +846,13 @@ async def get_response(
821846
# Response completed
822847
logger.debug(f"[TemporalStreamingModel] Response completed")
823848
response = getattr(event, 'response', None)
824-
if response and hasattr(response, 'output'):
825-
# Use the final output from the response
826-
output_items = response.output
827-
logger.debug(f"[TemporalStreamingModel] Found {len(output_items)} output items in final response")
849+
if response is not None:
850+
if hasattr(response, 'output'):
851+
# Use the final output from the response
852+
output_items = response.output
853+
logger.debug(f"[TemporalStreamingModel] Found {len(output_items)} output items in final response")
854+
captured_usage = getattr(response, 'usage', None)
855+
captured_response_id = getattr(response, 'id', None)
828856

829857
# End of event processing loop - close any open contexts
830858
if reasoning_context:
@@ -863,14 +891,33 @@ async def get_response(
863891
)
864892
response_output.append(message)
865893

866-
# Create usage object
867-
usage = Usage(
868-
input_tokens=0,
869-
output_tokens=0,
870-
total_tokens=0,
871-
input_tokens_details=InputTokensDetails(cached_tokens=0),
872-
output_tokens_details=OutputTokensDetails(reasoning_tokens=len(''.join(reasoning_contents)) // 4), # Approximate
873-
)
894+
# Use the real usage from the streaming Response if available;
895+
# fall back to zeros only when the stream ended without a
896+
# ResponseCompletedEvent (error paths).
897+
if captured_usage is not None:
898+
usage = Usage(
899+
input_tokens=captured_usage.input_tokens,
900+
output_tokens=captured_usage.output_tokens,
901+
total_tokens=captured_usage.total_tokens,
902+
input_tokens_details=InputTokensDetails(
903+
cached_tokens=getattr(
904+
captured_usage.input_tokens_details, "cached_tokens", 0
905+
),
906+
),
907+
output_tokens_details=OutputTokensDetails(
908+
reasoning_tokens=getattr(
909+
captured_usage.output_tokens_details, "reasoning_tokens", 0
910+
),
911+
),
912+
)
913+
else:
914+
usage = Usage(
915+
input_tokens=0,
916+
output_tokens=0,
917+
total_tokens=0,
918+
input_tokens_details=InputTokensDetails(cached_tokens=0),
919+
output_tokens_details=OutputTokensDetails(reasoning_tokens=0),
920+
)
874921

875922
# Serialize response output items for span tracing
876923
new_items = []
@@ -919,21 +966,34 @@ async def get_response(
919966
output_data = {
920967
"new_items": new_items,
921968
"final_output": final_output,
969+
"usage": {
970+
"input_tokens": usage.input_tokens,
971+
"output_tokens": usage.output_tokens,
972+
"total_tokens": usage.total_tokens,
973+
"cached_input_tokens": usage.input_tokens_details.cached_tokens,
974+
"reasoning_tokens": usage.output_tokens_details.reasoning_tokens,
975+
},
922976
}
923977
# Include tool calls if any were in the input
924978
if tool_calls:
925979
output_data["tool_calls"] = tool_calls
926980
# Include tool outputs if any were processed
927981
if tool_outputs:
928982
output_data["tool_outputs"] = tool_outputs
929-
983+
930984
span.output = output_data
931985

932-
# Return the response
986+
# Return the response. response_id is the server-issued id from
987+
# ResponseCompletedEvent.response.id, or None when the stream ended
988+
# without a completed event (error path) — matching the documented
989+
# `str | None` contract on `ModelResponse.response_id`. Returning
990+
# None lets callers use it safely as `previous_response_id` for
991+
# multi-turn chaining; a fabricated UUID would 400 against any real
992+
# server.
933993
return ModelResponse(
934994
output=response_output,
935995
usage=usage,
936-
response_id=f"resp_{uuid.uuid4().hex[:8]}",
996+
response_id=captured_response_id,
937997
)
938998

939999
except Exception as e:

0 commit comments

Comments
 (0)