|
54 | 54 | ResponseReasoningSummaryTextDeltaEvent, |
55 | 55 | ResponseFunctionCallArgumentsDeltaEvent, |
56 | 56 | ) |
| 57 | +from openai.types.responses.response_prompt_param import ResponsePromptParam |
57 | 58 |
|
58 | 59 | # AgentEx SDK imports |
59 | 60 | from agentex.lib import adk |
@@ -481,12 +482,23 @@ async def get_response( |
481 | 482 | output_schema: Optional[AgentOutputSchemaBase], |
482 | 483 | handoffs: list[Handoff], |
483 | 484 | tracing: ModelTracing, # noqa: ARG002 |
484 | | - **kwargs, # noqa: ARG002 |
| 485 | + *, |
| 486 | + previous_response_id: Optional[str] = None, |
| 487 | + conversation_id: Optional[str] = None, |
| 488 | + prompt: Optional[ResponsePromptParam] = None, |
485 | 489 | ) -> ModelResponse: |
486 | 490 | """Get a non-streaming response from the model with streaming to Redis. |
487 | 491 |
|
488 | 492 | This method is used by Temporal activities and needs to return a complete |
489 | 493 | response, but we stream the response to Redis while generating it. |
| 494 | +
|
| 495 | + ``previous_response_id``, ``conversation_id``, and ``prompt`` are all |
| 496 | + Responses API server-state parameters threaded through by the OpenAI |
| 497 | + Agents SDK. Each is forwarded to ``responses.create`` only when |
| 498 | + explicitly set — defaults resolve to ``NOT_GIVEN`` and are omitted from |
| 499 | + the request body. Not all OpenAI-compatible backends recognize these |
| 500 | + fields, so callers on alternative providers see no wire-level change |
| 501 | + unless they opt in. |
490 | 502 | """ |
491 | 503 |
|
492 | 504 | task_id = streaming_task_id.get() |
@@ -575,6 +587,11 @@ async def get_response( |
575 | 587 | if model_settings.top_logprobs is not None: |
576 | 588 | extra_args["top_logprobs"] = model_settings.top_logprobs |
577 | 589 |
|
| 590 | + # Opt-in prompt_cache_key: forwarded only when the caller supplies it via |
| 591 | + # model_settings.extra_args["prompt_cache_key"]. Not all OpenAI-compatible |
| 592 | + # endpoints recognize this parameter, so we don't auto-inject a default. |
| 593 | + prompt_cache_key = extra_args.pop("prompt_cache_key", NOT_GIVEN) |
| 594 | + |
578 | 595 | # Create the response stream using Responses API |
579 | 596 | logger.debug(f"[TemporalStreamingModel] Creating response stream with Responses API") |
580 | 597 | stream = await self.client.responses.create( # type: ignore[call-overload] |
@@ -605,12 +622,20 @@ async def get_response( |
605 | 622 | extra_headers=model_settings.extra_headers, |
606 | 623 | extra_query=model_settings.extra_query, |
607 | 624 | extra_body=model_settings.extra_body, |
| 625 | + prompt_cache_key=prompt_cache_key, |
| 626 | + previous_response_id=self._non_null_or_not_given(previous_response_id), |
| 627 | + # SDK abstract names this conversation_id; the Responses API |
| 628 | + # endpoint kwarg is `conversation` (accepts a str id directly). |
| 629 | + conversation=self._non_null_or_not_given(conversation_id), |
| 630 | + prompt=self._non_null_or_not_given(prompt), |
608 | 631 | # Any additional parameters from extra_args |
609 | 632 | **extra_args, |
610 | 633 | ) |
611 | 634 |
|
612 | 635 | # Process the stream of events from Responses API |
613 | 636 | output_items = [] |
| 637 | + captured_usage = None |
| 638 | + captured_response_id = None |
614 | 639 | current_text = "" |
615 | 640 | streaming_context = None |
616 | 641 | reasoning_context = None |
@@ -821,10 +846,13 @@ async def get_response( |
821 | 846 | # Response completed |
822 | 847 | logger.debug(f"[TemporalStreamingModel] Response completed") |
823 | 848 | response = getattr(event, 'response', None) |
824 | | - if response and hasattr(response, 'output'): |
825 | | - # Use the final output from the response |
826 | | - output_items = response.output |
827 | | - logger.debug(f"[TemporalStreamingModel] Found {len(output_items)} output items in final response") |
| 849 | + if response is not None: |
| 850 | + if hasattr(response, 'output'): |
| 851 | + # Use the final output from the response |
| 852 | + output_items = response.output |
| 853 | + logger.debug(f"[TemporalStreamingModel] Found {len(output_items)} output items in final response") |
| 854 | + captured_usage = getattr(response, 'usage', None) |
| 855 | + captured_response_id = getattr(response, 'id', None) |
828 | 856 |
|
829 | 857 | # End of event processing loop - close any open contexts |
830 | 858 | if reasoning_context: |
@@ -863,14 +891,33 @@ async def get_response( |
863 | 891 | ) |
864 | 892 | response_output.append(message) |
865 | 893 |
|
866 | | - # Create usage object |
867 | | - usage = Usage( |
868 | | - input_tokens=0, |
869 | | - output_tokens=0, |
870 | | - total_tokens=0, |
871 | | - input_tokens_details=InputTokensDetails(cached_tokens=0), |
872 | | - output_tokens_details=OutputTokensDetails(reasoning_tokens=len(''.join(reasoning_contents)) // 4), # Approximate |
873 | | - ) |
| 894 | + # Use the real usage from the streaming Response if available; |
| 895 | + # fall back to zeros only when the stream ended without a |
| 896 | + # ResponseCompletedEvent (error paths). |
| 897 | + if captured_usage is not None: |
| 898 | + usage = Usage( |
| 899 | + input_tokens=captured_usage.input_tokens, |
| 900 | + output_tokens=captured_usage.output_tokens, |
| 901 | + total_tokens=captured_usage.total_tokens, |
| 902 | + input_tokens_details=InputTokensDetails( |
| 903 | + cached_tokens=getattr( |
| 904 | + captured_usage.input_tokens_details, "cached_tokens", 0 |
| 905 | + ), |
| 906 | + ), |
| 907 | + output_tokens_details=OutputTokensDetails( |
| 908 | + reasoning_tokens=getattr( |
| 909 | + captured_usage.output_tokens_details, "reasoning_tokens", 0 |
| 910 | + ), |
| 911 | + ), |
| 912 | + ) |
| 913 | + else: |
| 914 | + usage = Usage( |
| 915 | + input_tokens=0, |
| 916 | + output_tokens=0, |
| 917 | + total_tokens=0, |
| 918 | + input_tokens_details=InputTokensDetails(cached_tokens=0), |
| 919 | + output_tokens_details=OutputTokensDetails(reasoning_tokens=0), |
| 920 | + ) |
874 | 921 |
|
875 | 922 | # Serialize response output items for span tracing |
876 | 923 | new_items = [] |
@@ -919,21 +966,34 @@ async def get_response( |
919 | 966 | output_data = { |
920 | 967 | "new_items": new_items, |
921 | 968 | "final_output": final_output, |
| 969 | + "usage": { |
| 970 | + "input_tokens": usage.input_tokens, |
| 971 | + "output_tokens": usage.output_tokens, |
| 972 | + "total_tokens": usage.total_tokens, |
| 973 | + "cached_input_tokens": usage.input_tokens_details.cached_tokens, |
| 974 | + "reasoning_tokens": usage.output_tokens_details.reasoning_tokens, |
| 975 | + }, |
922 | 976 | } |
923 | 977 | # Include tool calls if any were in the input |
924 | 978 | if tool_calls: |
925 | 979 | output_data["tool_calls"] = tool_calls |
926 | 980 | # Include tool outputs if any were processed |
927 | 981 | if tool_outputs: |
928 | 982 | output_data["tool_outputs"] = tool_outputs |
929 | | - |
| 983 | + |
930 | 984 | span.output = output_data |
931 | 985 |
|
932 | | - # Return the response |
| 986 | + # Return the response. response_id is the server-issued id from |
| 987 | + # ResponseCompletedEvent.response.id, or None when the stream ended |
| 988 | + # without a completed event (error path) — matching the documented |
| 989 | + # `str | None` contract on `ModelResponse.response_id`. Returning |
| 990 | + # None lets callers use it safely as `previous_response_id` for |
| 991 | + # multi-turn chaining; a fabricated UUID would 400 against any real |
| 992 | + # server. |
933 | 993 | return ModelResponse( |
934 | 994 | output=response_output, |
935 | 995 | usage=usage, |
936 | | - response_id=f"resp_{uuid.uuid4().hex[:8]}", |
| 996 | + response_id=captured_response_id, |
937 | 997 | ) |
938 | 998 |
|
939 | 999 | except Exception as e: |
|
0 commit comments