diff --git a/src/agents/extensions/models/litellm_model.py b/src/agents/extensions/models/litellm_model.py index d3b25a19..ffb2c3c1 100644 --- a/src/agents/extensions/models/litellm_model.py +++ b/src/agents/extensions/models/litellm_model.py @@ -6,6 +6,7 @@ from typing import Any, Literal, cast, overload import litellm.types +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents.exceptions import ModelBehaviorError @@ -107,6 +108,16 @@ async def get_response( input_tokens=response_usage.prompt_tokens, output_tokens=response_usage.completion_tokens, total_tokens=response_usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response_usage.prompt_tokens_details, "cached_tokens", 0 + ) + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response_usage.completion_tokens_details, "reasoning_tokens", 0 + ) + ), ) if response.usage else Usage() diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 89619f83..4465ff2f 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -9,6 +9,7 @@ from openai.types import ChatModel from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.responses import Response +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from .. import _debug from ..agent_output import AgentOutputSchemaBase @@ -83,6 +84,18 @@ async def get_response( input_tokens=response.usage.prompt_tokens, output_tokens=response.usage.completion_tokens, total_tokens=response.usage.total_tokens, + input_tokens_details=InputTokensDetails( + cached_tokens=getattr( + response.usage.prompt_tokens_details, "cached_tokens", 0 + ) + or 0, + ), + output_tokens_details=OutputTokensDetails( + reasoning_tokens=getattr( + response.usage.completion_tokens_details, "reasoning_tokens", 0 + ) + or 0, + ), ) if response.usage else Usage() @@ -252,7 +265,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) }, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/src/agents/models/openai_responses.py b/src/agents/models/openai_responses.py index c1ff85b9..6ec8f8f7 100644 --- a/src/agents/models/openai_responses.py +++ b/src/agents/models/openai_responses.py @@ -98,6 +98,8 @@ async def get_response( input_tokens=response.usage.input_tokens, output_tokens=response.usage.output_tokens, total_tokens=response.usage.total_tokens, + input_tokens_details=response.usage.input_tokens_details, + output_tokens_details=response.usage.output_tokens_details, ) if response.usage else Usage() diff --git a/src/agents/run.py b/src/agents/run.py index 849da7bf..b196c3bf 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -689,6 +689,8 @@ async def _run_single_turn_streamed( input_tokens=event.response.usage.input_tokens, output_tokens=event.response.usage.output_tokens, total_tokens=event.response.usage.total_tokens, + input_tokens_details=event.response.usage.input_tokens_details, + output_tokens_details=event.response.usage.output_tokens_details, ) if event.response.usage else Usage() diff --git a/src/agents/usage.py b/src/agents/usage.py index 23d989b4..843f6293 100644 --- a/src/agents/usage.py +++ b/src/agents/usage.py @@ -1,4 +1,6 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field + +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails @dataclass @@ -9,9 +11,18 @@ class Usage: input_tokens: int = 0 """Total input tokens sent, across all requests.""" + input_tokens_details: InputTokensDetails = field( + default_factory=lambda: InputTokensDetails(cached_tokens=0) + ) + """Details about the input tokens, matching responses API usage details.""" output_tokens: int = 0 """Total output tokens received, across all requests.""" + output_tokens_details: OutputTokensDetails = field( + default_factory=lambda: OutputTokensDetails(reasoning_tokens=0) + ) + """Details about the output tokens, matching responses API usage details.""" + total_tokens: int = 0 """Total tokens sent and received, across all requests.""" @@ -20,3 +31,12 @@ def add(self, other: "Usage") -> None: self.input_tokens += other.input_tokens if other.input_tokens else 0 self.output_tokens += other.output_tokens if other.output_tokens else 0 self.total_tokens += other.total_tokens if other.total_tokens else 0 + self.input_tokens_details = InputTokensDetails( + cached_tokens=self.input_tokens_details.cached_tokens + + other.input_tokens_details.cached_tokens + ) + + self.output_tokens_details = OutputTokensDetails( + reasoning_tokens=self.output_tokens_details.reasoning_tokens + + other.output_tokens_details.reasoning_tokens + ) diff --git a/tests/models/test_litellm_chatcompletions_stream.py b/tests/models/test_litellm_chatcompletions_stream.py index 80bd8ea2..06e46b39 100644 --- a/tests/models/test_litellm_chatcompletions_stream.py +++ b/tests/models/test_litellm_chatcompletions_stream.py @@ -8,7 +8,11 @@ ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No model="fake", object="chat.completion.chunk", choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=2), + prompt_tokens_details=PromptTokensDetails(cached_tokens=6), + ), ) async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: @@ -112,6 +122,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert completed_resp.usage.input_tokens == 7 assert completed_resp.usage.output_tokens == 5 assert completed_resp.usage.total_tokens == 12 + assert completed_resp.usage.input_tokens_details.cached_tokens == 6 + assert completed_resp.usage.output_tokens_details.reasoning_tokens == 2 @pytest.mark.allow_call_model_methods diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py index f29c2540..a6af3007 100644 --- a/tests/test_extra_headers.py +++ b/tests/test_extra_headers.py @@ -1,6 +1,7 @@ import pytest from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents import ModelSettings, ModelTracing, OpenAIChatCompletionsModel, OpenAIResponsesModel @@ -17,21 +18,29 @@ class DummyResponses: async def create(self, **kwargs): nonlocal called_kwargs called_kwargs = kwargs + class DummyResponse: id = "dummy" output = [] usage = type( - "Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} + "Usage", + (), + { + "input_tokens": 0, + "output_tokens": 0, + "total_tokens": 0, + "input_tokens_details": InputTokensDetails(cached_tokens=0), + "output_tokens_details": OutputTokensDetails(reasoning_tokens=0), + }, )() + return DummyResponse() class DummyClient: def __init__(self): self.responses = DummyResponses() - - - model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, @@ -47,7 +56,6 @@ def __init__(self): assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" - @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_extra_headers_passed_to_openai_client(): @@ -76,7 +84,7 @@ def __init__(self): self.chat = type("_Chat", (), {"completions": DummyCompletions()})() self.base_url = "https://api.openai.com" - model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, diff --git a/tests/test_openai_chatcompletions.py b/tests/test_openai_chatcompletions.py index ba3ec68d..ba4605d0 100644 --- a/tests/test_openai_chatcompletions.py +++ b/tests/test_openai_chatcompletions.py @@ -13,7 +13,10 @@ ChatCompletionMessageToolCall, Function, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -51,7 +54,13 @@ async def test_get_response_with_text_message(monkeypatch) -> None: model="fake", object="chat.completion", choices=[choice], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + # completion_tokens_details left blank to test default + prompt_tokens_details=PromptTokensDetails(cached_tokens=3), + ), ) async def patched_fetch_response(self, *args, **kwargs): @@ -81,6 +90,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert resp.usage.input_tokens == 7 assert resp.usage.output_tokens == 5 assert resp.usage.total_tokens == 12 + assert resp.usage.input_tokens_details.cached_tokens == 3 + assert resp.usage.output_tokens_details.reasoning_tokens == 0 assert resp.response_id is None @@ -127,6 +138,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert resp.usage.requests == 0 assert resp.usage.input_tokens == 0 assert resp.usage.output_tokens == 0 + assert resp.usage.input_tokens_details.cached_tokens == 0 + assert resp.usage.output_tokens_details.reasoning_tokens == 0 @pytest.mark.allow_call_model_methods diff --git a/tests/test_openai_chatcompletions_stream.py b/tests/test_openai_chatcompletions_stream.py index b82f2430..5c8bb9e3 100644 --- a/tests/test_openai_chatcompletions_stream.py +++ b/tests/test_openai_chatcompletions_stream.py @@ -8,7 +8,11 @@ ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) -from openai.types.completion_usage import CompletionUsage +from openai.types.completion_usage import ( + CompletionTokensDetails, + CompletionUsage, + PromptTokensDetails, +) from openai.types.responses import ( Response, ResponseFunctionToolCall, @@ -46,7 +50,13 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No model="fake", object="chat.completion.chunk", choices=[Choice(index=0, delta=ChoiceDelta(content="llo"))], - usage=CompletionUsage(completion_tokens=5, prompt_tokens=7, total_tokens=12), + usage=CompletionUsage( + completion_tokens=5, + prompt_tokens=7, + total_tokens=12, + prompt_tokens_details=PromptTokensDetails(cached_tokens=2), + completion_tokens_details=CompletionTokensDetails(reasoning_tokens=3), + ), ) async def fake_stream() -> AsyncIterator[ChatCompletionChunk]: @@ -112,6 +122,8 @@ async def patched_fetch_response(self, *args, **kwargs): assert completed_resp.usage.input_tokens == 7 assert completed_resp.usage.output_tokens == 5 assert completed_resp.usage.total_tokens == 12 + assert completed_resp.usage.input_tokens_details.cached_tokens == 2 + assert completed_resp.usage.output_tokens_details.reasoning_tokens == 3 @pytest.mark.allow_call_model_methods diff --git a/tests/test_responses_tracing.py b/tests/test_responses_tracing.py index 0bc97a95..dfac74bb 100644 --- a/tests/test_responses_tracing.py +++ b/tests/test_responses_tracing.py @@ -1,7 +1,10 @@ +from typing import Optional + import pytest from inline_snapshot import snapshot from openai import AsyncOpenAI from openai.types.responses import ResponseCompletedEvent +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from agents import ModelSettings, ModelTracing, OpenAIResponsesModel, trace from agents.tracing.span_data import ResponseSpanData @@ -16,10 +19,25 @@ def is_disabled(self): class DummyUsage: - def __init__(self, input_tokens=1, output_tokens=1, total_tokens=2): + def __init__( + self, + input_tokens: int = 1, + input_tokens_details: Optional[InputTokensDetails] = None, + output_tokens: int = 1, + output_tokens_details: Optional[OutputTokensDetails] = None, + total_tokens: int = 2, + ): self.input_tokens = input_tokens self.output_tokens = output_tokens self.total_tokens = total_tokens + self.input_tokens_details = ( + input_tokens_details if input_tokens_details else InputTokensDetails(cached_tokens=0) + ) + self.output_tokens_details = ( + output_tokens_details + if output_tokens_details + else OutputTokensDetails(reasoning_tokens=0) + ) class DummyResponse: diff --git a/tests/test_usage.py b/tests/test_usage.py new file mode 100644 index 00000000..405f99dd --- /dev/null +++ b/tests/test_usage.py @@ -0,0 +1,52 @@ +from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails + +from agents.usage import Usage + + +def test_usage_add_aggregates_all_fields(): + u1 = Usage( + requests=1, + input_tokens=10, + input_tokens_details=InputTokensDetails(cached_tokens=3), + output_tokens=20, + output_tokens_details=OutputTokensDetails(reasoning_tokens=5), + total_tokens=30, + ) + u2 = Usage( + requests=2, + input_tokens=7, + input_tokens_details=InputTokensDetails(cached_tokens=4), + output_tokens=8, + output_tokens_details=OutputTokensDetails(reasoning_tokens=6), + total_tokens=15, + ) + + u1.add(u2) + + assert u1.requests == 3 + assert u1.input_tokens == 17 + assert u1.output_tokens == 28 + assert u1.total_tokens == 45 + assert u1.input_tokens_details.cached_tokens == 7 + assert u1.output_tokens_details.reasoning_tokens == 11 + + +def test_usage_add_aggregates_with_none_values(): + u1 = Usage() + u2 = Usage( + requests=2, + input_tokens=7, + input_tokens_details=InputTokensDetails(cached_tokens=4), + output_tokens=8, + output_tokens_details=OutputTokensDetails(reasoning_tokens=6), + total_tokens=15, + ) + + u1.add(u2) + + assert u1.requests == 2 + assert u1.input_tokens == 7 + assert u1.output_tokens == 8 + assert u1.total_tokens == 15 + assert u1.input_tokens_details.cached_tokens == 4 + assert u1.output_tokens_details.reasoning_tokens == 6