Skip to content

Commit 09606c0

Browse files
kiqapsKludex
andauthored
Enhance Gemini usage tracking to collect comprehensive token data (#1752)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
1 parent 45d0ff2 commit 09606c0

File tree

4 files changed

+143
-32
lines changed

4 files changed

+143
-32
lines changed

pydantic_ai_slim/pydantic_ai/models/gemini.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,13 +463,12 @@ async def _get_gemini_responses(self) -> AsyncIterator[_GeminiResponse]:
463463
responses_to_yield = gemini_responses[:-1]
464464
for r in responses_to_yield[current_gemini_response_index:]:
465465
current_gemini_response_index += 1
466-
self._usage += _metadata_as_usage(r)
467466
yield r
468467

469468
# Now yield the final response, which should be complete
470469
if gemini_responses: # pragma: no branch
471470
r = gemini_responses[-1]
472-
self._usage += _metadata_as_usage(r)
471+
self._usage = _metadata_as_usage(r)
473472
yield r
474473

475474
@property
@@ -770,8 +769,17 @@ class _GeminiCandidates(TypedDict):
770769
safety_ratings: NotRequired[Annotated[list[_GeminiSafetyRating], pydantic.Field(alias='safetyRatings')]]
771770

772771

772+
class _GeminiModalityTokenCount(TypedDict):
773+
"""See <https://ai.google.dev/api/generate-content#modalitytokencount>."""
774+
775+
modality: Annotated[
776+
Literal['MODALITY_UNSPECIFIED', 'TEXT', 'IMAGE', 'VIDEO', 'AUDIO', 'DOCUMENT'], pydantic.Field(alias='modality')
777+
]
778+
token_count: Annotated[int, pydantic.Field(alias='tokenCount', default=0)]
779+
780+
773781
class _GeminiUsageMetaData(TypedDict, total=False):
774-
"""See <https://ai.google.dev/api/generate-content#FinishReason>.
782+
"""See <https://ai.google.dev/api/generate-content#UsageMetadata>.
775783
776784
The docs suggest all fields are required, but some are actually not required, so we assume they are all optional.
777785
"""
@@ -780,6 +788,20 @@ class _GeminiUsageMetaData(TypedDict, total=False):
780788
candidates_token_count: NotRequired[Annotated[int, pydantic.Field(alias='candidatesTokenCount')]]
781789
total_token_count: Annotated[int, pydantic.Field(alias='totalTokenCount')]
782790
cached_content_token_count: NotRequired[Annotated[int, pydantic.Field(alias='cachedContentTokenCount')]]
791+
thoughts_token_count: NotRequired[Annotated[int, pydantic.Field(alias='thoughtsTokenCount')]]
792+
tool_use_prompt_token_count: NotRequired[Annotated[int, pydantic.Field(alias='toolUsePromptTokenCount')]]
793+
prompt_tokens_details: NotRequired[
794+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='promptTokensDetails')]
795+
]
796+
cache_tokens_details: NotRequired[
797+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='cacheTokensDetails')]
798+
]
799+
candidates_tokens_details: NotRequired[
800+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='candidatesTokensDetails')]
801+
]
802+
tool_use_prompt_tokens_details: NotRequired[
803+
Annotated[list[_GeminiModalityTokenCount], pydantic.Field(alias='toolUsePromptTokensDetails')]
804+
]
783805

784806

785807
def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
@@ -788,7 +810,21 @@ def _metadata_as_usage(response: _GeminiResponse) -> usage.Usage:
788810
return usage.Usage() # pragma: no cover
789811
details: dict[str, int] = {}
790812
if cached_content_token_count := metadata.get('cached_content_token_count'):
791-
details['cached_content_token_count'] = cached_content_token_count # pragma: no cover
813+
details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
814+
815+
if thoughts_token_count := metadata.get('thoughts_token_count'):
816+
details['thoughts_tokens'] = thoughts_token_count
817+
818+
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
819+
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
820+
821+
for key, metadata_details in metadata.items():
822+
if key.endswith('_details') and metadata_details:
823+
metadata_details = cast(list[_GeminiModalityTokenCount], metadata_details)
824+
suffix = key.removesuffix('_details')
825+
for detail in metadata_details:
826+
details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
827+
792828
return usage.Usage(
793829
request_tokens=metadata.get('prompt_token_count', 0),
794830
response_tokens=metadata.get('candidates_token_count', 0),

pydantic_ai_slim/pydantic_ai/models/google.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ class GeminiStreamedResponse(StreamedResponse):
399399

400400
async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
401401
async for chunk in self._response:
402-
self._usage += _metadata_as_usage(chunk)
402+
self._usage = _metadata_as_usage(chunk)
403403

404404
assert chunk.candidates is not None
405405
candidate = chunk.candidates[0]
@@ -490,17 +490,28 @@ def _metadata_as_usage(response: GenerateContentResponse) -> usage.Usage:
490490
metadata = response.usage_metadata
491491
if metadata is None:
492492
return usage.Usage() # pragma: no cover
493-
# TODO(Marcelo): We exclude the `prompt_tokens_details` and `candidate_token_details` fields because on
494-
# `usage.Usage.incr``, it will try to sum non-integer values with integers, which will fail. We should probably
495-
# handle this in the `Usage` class.
496-
details = metadata.model_dump(
497-
exclude={'prompt_tokens_details', 'candidates_tokens_details', 'traffic_type'},
498-
exclude_defaults=True,
499-
)
493+
metadata = metadata.model_dump(exclude_defaults=True)
494+
495+
details: dict[str, int] = {}
496+
if cached_content_token_count := metadata.get('cached_content_token_count'):
497+
details['cached_content_tokens'] = cached_content_token_count # pragma: no cover
498+
499+
if thoughts_token_count := metadata.get('thoughts_token_count'):
500+
details['thoughts_tokens'] = thoughts_token_count
501+
502+
if tool_use_prompt_token_count := metadata.get('tool_use_prompt_token_count'):
503+
details['tool_use_prompt_tokens'] = tool_use_prompt_token_count # pragma: no cover
504+
505+
for key, metadata_details in metadata.items():
506+
if key.endswith('_details') and metadata_details:
507+
suffix = key.removesuffix('_details')
508+
for detail in metadata_details:
509+
details[f'{detail["modality"].lower()}_{suffix}'] = detail['token_count']
510+
500511
return usage.Usage(
501-
request_tokens=details.pop('prompt_token_count', 0),
502-
response_tokens=details.pop('candidates_token_count', 0),
503-
total_tokens=details.pop('total_token_count', 0),
512+
request_tokens=metadata.get('prompt_token_count', 0),
513+
response_tokens=metadata.get('candidates_token_count', 0),
514+
total_tokens=metadata.get('total_token_count', 0),
504515
details=details,
505516
)
506517

tests/models/test_gemini.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -739,12 +739,12 @@ async def test_stream_text(get_gemini_client: GetGeminiClient):
739739
'Hello world',
740740
]
741741
)
742-
assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
742+
assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))
743743

744744
async with agent.run_stream('Hello') as result:
745745
chunks = [chunk async for chunk in result.stream_text(delta=True, debounce_by=None)]
746746
assert chunks == snapshot(['Hello ', 'world'])
747-
assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
747+
assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))
748748

749749

750750
async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient):
@@ -776,7 +776,7 @@ async def test_stream_invalid_unicode_text(get_gemini_client: GetGeminiClient):
776776
async with agent.run_stream('Hello') as result:
777777
chunks = [chunk async for chunk in result.stream(debounce_by=None)]
778778
assert chunks == snapshot(['abc', 'abc€def', 'abc€def'])
779-
assert result.usage() == snapshot(Usage(requests=1, request_tokens=2, response_tokens=4, total_tokens=6))
779+
assert result.usage() == snapshot(Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3))
780780

781781

782782
async def test_stream_text_no_data(get_gemini_client: GetGeminiClient):
@@ -847,7 +847,7 @@ async def bar(y: str) -> str:
847847
async with agent.run_stream('Hello') as result:
848848
response = await result.get_output()
849849
assert response == snapshot((1, 2))
850-
assert result.usage() == snapshot(Usage(requests=2, request_tokens=3, response_tokens=6, total_tokens=9))
850+
assert result.usage() == snapshot(Usage(requests=2, request_tokens=2, response_tokens=4, total_tokens=6))
851851
assert result.all_messages() == snapshot(
852852
[
853853
ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]),
@@ -856,7 +856,7 @@ async def bar(y: str) -> str:
856856
ToolCallPart(tool_name='foo', args={'x': 'a'}, tool_call_id=IsStr()),
857857
ToolCallPart(tool_name='bar', args={'y': 'b'}, tool_call_id=IsStr()),
858858
],
859-
usage=Usage(request_tokens=2, response_tokens=4, total_tokens=6),
859+
usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}),
860860
model_name='gemini-1.5-flash',
861861
timestamp=IsNow(tz=timezone.utc),
862862
),
@@ -872,7 +872,7 @@ async def bar(y: str) -> str:
872872
),
873873
ModelResponse(
874874
parts=[ToolCallPart(tool_name='final_result', args={'response': [1, 2]}, tool_call_id=IsStr())],
875-
usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3),
875+
usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}),
876876
model_name='gemini-1.5-flash',
877877
timestamp=IsNow(tz=timezone.utc),
878878
),
@@ -1103,7 +1103,13 @@ async def get_image() -> BinaryContent:
11031103
),
11041104
ToolCallPart(tool_name='get_image', args={}, tool_call_id=IsStr()),
11051105
],
1106-
usage=Usage(requests=1, request_tokens=38, response_tokens=28, total_tokens=427, details={}),
1106+
usage=Usage(
1107+
requests=1,
1108+
request_tokens=38,
1109+
response_tokens=28,
1110+
total_tokens=427,
1111+
details={'thoughts_tokens': 361, 'text_prompt_tokens': 38},
1112+
),
11071113
model_name='gemini-2.5-pro-preview-03-25',
11081114
timestamp=IsDatetime(),
11091115
vendor_details={'finish_reason': 'STOP'},
@@ -1127,7 +1133,13 @@ async def get_image() -> BinaryContent:
11271133
),
11281134
ModelResponse(
11291135
parts=[TextPart(content='The image shows a kiwi fruit, sliced in half.')],
1130-
usage=Usage(requests=1, request_tokens=360, response_tokens=11, total_tokens=572, details={}),
1136+
usage=Usage(
1137+
requests=1,
1138+
request_tokens=360,
1139+
response_tokens=11,
1140+
total_tokens=572,
1141+
details={'thoughts_tokens': 201, 'text_prompt_tokens': 102, 'image_prompt_tokens': 258},
1142+
),
11311143
model_name='gemini-2.5-pro-preview-03-25',
11321144
timestamp=IsDatetime(),
11331145
vendor_details={'finish_reason': 'STOP'},
@@ -1250,7 +1262,13 @@ async def test_gemini_model_instructions(allow_model_requests: None, gemini_api_
12501262
),
12511263
ModelResponse(
12521264
parts=[TextPart(content='The capital of France is Paris.\n')],
1253-
usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}),
1265+
usage=Usage(
1266+
requests=1,
1267+
request_tokens=13,
1268+
response_tokens=8,
1269+
total_tokens=21,
1270+
details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8},
1271+
),
12541272
model_name='gemini-1.5-flash',
12551273
timestamp=IsDatetime(),
12561274
vendor_details={'finish_reason': 'STOP'},

tests/models/test_google.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,15 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP
6565

6666
result = await agent.run('Hello!')
6767
assert result.output == snapshot('Hello there! How can I help you today?\n')
68-
assert result.usage() == snapshot(Usage(requests=1, request_tokens=7, response_tokens=11, total_tokens=18))
68+
assert result.usage() == snapshot(
69+
Usage(
70+
requests=1,
71+
request_tokens=7,
72+
response_tokens=11,
73+
total_tokens=18,
74+
details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11},
75+
)
76+
)
6977
assert result.all_messages() == snapshot(
7078
[
7179
ModelRequest(
@@ -82,7 +90,13 @@ async def test_google_model(allow_model_requests: None, google_provider: GoogleP
8290
),
8391
ModelResponse(
8492
parts=[TextPart(content='Hello there! How can I help you today?\n')],
85-
usage=Usage(requests=1, request_tokens=7, response_tokens=11, total_tokens=18, details={}),
93+
usage=Usage(
94+
requests=1,
95+
request_tokens=7,
96+
response_tokens=11,
97+
total_tokens=18,
98+
details={'text_prompt_tokens': 7, 'text_candidates_tokens': 11},
99+
),
86100
model_name='gemini-1.5-flash',
87101
timestamp=IsDatetime(),
88102
vendor_details={'finish_reason': 'STOP'},
@@ -115,7 +129,15 @@ async def temperature(city: str, date: datetime.date) -> str:
115129

116130
result = await agent.run('What was the temperature in London 1st January 2022?', output_type=Response)
117131
assert result.output == snapshot({'temperature': '30°C', 'date': datetime.date(2022, 1, 1), 'city': 'London'})
118-
assert result.usage() == snapshot(Usage(requests=2, request_tokens=224, response_tokens=35, total_tokens=259))
132+
assert result.usage() == snapshot(
133+
Usage(
134+
requests=2,
135+
request_tokens=224,
136+
response_tokens=35,
137+
total_tokens=259,
138+
details={'text_prompt_tokens': 224, 'text_candidates_tokens': 35},
139+
)
140+
)
119141
assert result.all_messages() == snapshot(
120142
[
121143
ModelRequest(
@@ -136,7 +158,13 @@ async def temperature(city: str, date: datetime.date) -> str:
136158
tool_name='temperature', args={'date': '2022-01-01', 'city': 'London'}, tool_call_id=IsStr()
137159
)
138160
],
139-
usage=Usage(requests=1, request_tokens=101, response_tokens=14, total_tokens=115, details={}),
161+
usage=Usage(
162+
requests=1,
163+
request_tokens=101,
164+
response_tokens=14,
165+
total_tokens=115,
166+
details={'text_prompt_tokens': 101, 'text_candidates_tokens': 14},
167+
),
140168
model_name='gemini-1.5-flash',
141169
timestamp=IsDatetime(),
142170
vendor_details={'finish_reason': 'STOP'},
@@ -156,7 +184,13 @@ async def temperature(city: str, date: datetime.date) -> str:
156184
tool_call_id=IsStr(),
157185
)
158186
],
159-
usage=Usage(requests=1, request_tokens=123, response_tokens=21, total_tokens=144, details={}),
187+
usage=Usage(
188+
requests=1,
189+
request_tokens=123,
190+
response_tokens=21,
191+
total_tokens=144,
192+
details={'text_prompt_tokens': 123, 'text_candidates_tokens': 21},
193+
),
160194
model_name='gemini-1.5-flash',
161195
timestamp=IsDatetime(),
162196
vendor_details={'finish_reason': 'STOP'},
@@ -214,7 +248,7 @@ async def get_capital(country: str) -> str:
214248
request_tokens=57,
215249
response_tokens=15,
216250
total_tokens=173,
217-
details={'thoughts_token_count': 101},
251+
details={'thoughts_tokens': 101, 'text_prompt_tokens': 57},
218252
),
219253
model_name='models/gemini-2.5-pro-preview-05-06',
220254
timestamp=IsDatetime(),
@@ -236,7 +270,13 @@ async def get_capital(country: str) -> str:
236270
content='I am sorry, I cannot fulfill this request. The country you provided is not supported.'
237271
)
238272
],
239-
usage=Usage(requests=1, request_tokens=104, response_tokens=18, total_tokens=122, details={}),
273+
usage=Usage(
274+
requests=1,
275+
request_tokens=104,
276+
response_tokens=18,
277+
total_tokens=122,
278+
details={'text_prompt_tokens': 104},
279+
),
240280
model_name='models/gemini-2.5-pro-preview-05-06',
241281
timestamp=IsDatetime(),
242282
vendor_details={'finish_reason': 'STOP'},
@@ -493,7 +533,13 @@ def instructions() -> str:
493533
),
494534
ModelResponse(
495535
parts=[TextPart(content='The capital of France is Paris.\n')],
496-
usage=Usage(requests=1, request_tokens=13, response_tokens=8, total_tokens=21, details={}),
536+
usage=Usage(
537+
requests=1,
538+
request_tokens=13,
539+
response_tokens=8,
540+
total_tokens=21,
541+
details={'text_prompt_tokens': 13, 'text_candidates_tokens': 8},
542+
),
497543
model_name='gemini-2.0-flash',
498544
timestamp=IsDatetime(),
499545
vendor_details={'finish_reason': 'STOP'},

0 commit comments

Comments
 (0)