Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
Content,
Message,
ResponseStream,
UsageDetails,
)

else:
Expand Down Expand Up @@ -2152,6 +2153,8 @@ def get_response(
if not stream:

Comment thread
eavanvalkenburg marked this conversation as resolved.
async def _get_response() -> ChatResponse[Any]:
from ._types import add_usage_details

nonlocal mutable_options
nonlocal filtered_kwargs
errors_in_a_row: int = 0
Expand All @@ -2160,6 +2163,7 @@ async def _get_response() -> ChatResponse[Any]:
prepped_messages = list(messages)
fcc_messages: list[Message] = []
response: ChatResponse[Any] | None = None
aggregated_usage: UsageDetails | None = None

loop_enabled = self.function_invocation_configuration.get("enabled", True)
max_iterations = self.function_invocation_configuration.get("max_iterations", DEFAULT_MAX_ITERATIONS)
Expand Down Expand Up @@ -2191,6 +2195,7 @@ async def _get_response() -> ChatResponse[Any]:
client_kwargs=filtered_kwargs,
),
)
aggregated_usage = add_usage_details(aggregated_usage, response.usage_details)

if response.conversation_id is not None:
_update_conversation_id(kwargs, response.conversation_id, mutable_options)
Expand All @@ -2207,6 +2212,7 @@ async def _get_response() -> ChatResponse[Any]:
execute_function_calls=execute_function_calls,
)
if result.get("action") == "return":
response.usage_details = aggregated_usage
return response
total_function_calls += result.get("function_call_count", 0)
if result.get("action") == "stop":
Expand Down Expand Up @@ -2262,6 +2268,8 @@ async def _get_response() -> ChatResponse[Any]:
client_kwargs=filtered_kwargs,
),
)
aggregated_usage = add_usage_details(aggregated_usage, response.usage_details)
response.usage_details = aggregated_usage
if fcc_messages:
for msg in reversed(fcc_messages):
response.messages.insert(0, msg)
Expand Down
92 changes: 92 additions & 0 deletions python/packages/core/tests/core/test_observability.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ChatResponseUpdate,
Content,
Message,
RawAgent,
ResponseStream,
SupportsAgentRun,
UsageDetails,
Expand Down Expand Up @@ -2781,3 +2782,94 @@ def mock_get_meter(*args, **kwargs):
meter = get_meter(name="test", attributes={"key": "val"})
assert meter is not None
assert call_count == 2


# region Agent token usage aggregation


@tool(name="get_weather", description="Get weather for a city", approval_mode="never_require")
def _get_weather(city: str) -> str:
"""Get weather for a city."""
return "Sunny, 72°F"


@pytest.mark.parametrize("enable_sensitive_data", [False], indirect=True)
async def test_agent_invoke_span_aggregates_usage_across_tool_calls(span_exporter: InMemorySpanExporter):
"""The invoke_agent span should sum token usage from all chat completions in the function invocation loop."""
from tests.core.conftest import MockBaseChatClient

class _InstrumentedAgent(AgentTelemetryLayer, RawAgent):
pass

client = MockBaseChatClient()
client.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_1", name="get_weather", arguments='{"city": "Seattle"}')
],
),
usage_details=UsageDetails(input_token_count=2239, output_token_count=192),
),
ChatResponse(
messages=Message(role="assistant", text="The weather in Seattle is sunny."),
usage_details=UsageDetails(input_token_count=2569, output_token_count=99),
),
]

agent = _InstrumentedAgent(client=client, name="test_agent", id="test_agent_id")

span_exporter.clear()
await agent.run(
messages="What is the weather in Seattle?",
options={"tools": [_get_weather], "tool_choice": "auto"},
)

spans = span_exporter.get_finished_spans()

invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION]
assert len(invoke_spans) == 1
agent_span = invoke_spans[0]

chat_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.CHAT_COMPLETION_OPERATION]
assert len(chat_spans) == 2

# Individual chat spans retain their own usage
assert chat_spans[0].attributes.get(OtelAttr.INPUT_TOKENS) == 2239
assert chat_spans[0].attributes.get(OtelAttr.OUTPUT_TOKENS) == 192
assert chat_spans[1].attributes.get(OtelAttr.INPUT_TOKENS) == 2569
assert chat_spans[1].attributes.get(OtelAttr.OUTPUT_TOKENS) == 99

# The invoke_agent span must report the aggregate across all LLM round-trips
assert agent_span.attributes.get(OtelAttr.INPUT_TOKENS) == 2239 + 2569
assert agent_span.attributes.get(OtelAttr.OUTPUT_TOKENS) == 192 + 99


@pytest.mark.parametrize("enable_sensitive_data", [False], indirect=True)
async def test_agent_invoke_span_usage_single_call(span_exporter: InMemorySpanExporter):
"""When only one chat completion occurs, the invoke_agent span usage equals that single call."""
from tests.core.conftest import MockBaseChatClient

class _InstrumentedAgent(AgentTelemetryLayer, RawAgent):
pass

client = MockBaseChatClient()
client.run_responses = [
ChatResponse(
messages=Message(role="assistant", text="Hello!"),
usage_details=UsageDetails(input_token_count=100, output_token_count=50),
),
]

agent = _InstrumentedAgent(client=client, name="test_agent", id="test_agent_id")

span_exporter.clear()
await agent.run(messages="Hi")

spans = span_exporter.get_finished_spans()
invoke_spans = [s for s in spans if s.attributes.get(OtelAttr.OPERATION.value) == OtelAttr.AGENT_INVOKE_OPERATION]
assert len(invoke_spans) == 1

assert invoke_spans[0].attributes.get(OtelAttr.INPUT_TOKENS) == 100
assert invoke_spans[0].attributes.get(OtelAttr.OUTPUT_TOKENS) == 50
Comment thread
eavanvalkenburg marked this conversation as resolved.
Loading