Skip to content

Feat: Add id and finish_reason to ModelResponse #1761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite
/docs-site/.wrangler/
/CLAUDE.md
node_modules/
**.idea/
.coverage*
4 changes: 4 additions & 0 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ async def main():
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
)
),
End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
Expand Down Expand Up @@ -226,6 +227,7 @@ async def main():
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
)
),
End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
Expand Down Expand Up @@ -829,6 +831,7 @@ with capture_run_messages() as messages: # (2)!
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
ModelRequest(
parts=[
Expand Down Expand Up @@ -862,6 +865,7 @@ with capture_run_messages() as messages: # (2)!
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
]
"""
Expand Down
1 change: 1 addition & 0 deletions docs/direct.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ async def main():
model_name='gpt-4.1-nano',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
)
"""
```
Expand Down
6 changes: 6 additions & 0 deletions docs/message-history.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ print(result.all_messages())
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
]
"""
Expand Down Expand Up @@ -159,6 +160,7 @@ async def main():
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
]
"""
Expand Down Expand Up @@ -225,6 +227,7 @@ print(result2.all_messages())
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
ModelRequest(
parts=[
Expand Down Expand Up @@ -254,6 +257,7 @@ print(result2.all_messages())
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
]
"""
Expand Down Expand Up @@ -367,6 +371,7 @@ print(result2.all_messages())
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
ModelRequest(
parts=[
Expand Down Expand Up @@ -396,6 +401,7 @@ print(result2.all_messages())
model_name='gemini-1.5-pro',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
]
"""
Expand Down
1 change: 1 addition & 0 deletions docs/models/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ print(response.all_messages())
model_name='claude-3-5-sonnet-latest',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
]
"""
Expand Down
3 changes: 3 additions & 0 deletions docs/tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ print(dice_result.all_messages())
model_name='gemini-1.5-flash',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
ModelRequest(
parts=[
Expand Down Expand Up @@ -139,6 +140,7 @@ print(dice_result.all_messages())
model_name='gemini-1.5-flash',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
ModelRequest(
parts=[
Expand Down Expand Up @@ -170,6 +172,7 @@ print(dice_result.all_messages())
model_name='gemini-1.5-flash',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
),
]
"""
Expand Down
7 changes: 3 additions & 4 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,24 @@ theme:
custom_dir: docs/.overrides
palette:
- media: "(prefers-color-scheme)"
scheme: default
primary: pink
accent: pink
toggle:
icon: material/lightbulb
icon: material/brightness-auto
name: "Switch to light mode"
- media: "(prefers-color-scheme: light)"
scheme: default
primary: pink
accent: pink
toggle:
icon: material/lightbulb-outline
icon: material/brightness-7
name: "Switch to dark mode"
- media: "(prefers-color-scheme: dark)"
scheme: slate
primary: pink
accent: pink
toggle:
icon: material/lightbulb-auto-outline
icon: material/brightness-4
name: "Switch to system preference"
features:
- search.suggest
Expand Down
3 changes: 3 additions & 0 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ async def main():
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
)
),
End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
Expand Down Expand Up @@ -1854,6 +1855,7 @@ async def main():
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
)
),
End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
Expand Down Expand Up @@ -1999,6 +2001,7 @@ async def main():
model_name='gpt-4o',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
)
),
End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)),
Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async def main():
model_name='claude-3-5-haiku-latest',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
)
'''
```
Expand Down Expand Up @@ -108,6 +109,7 @@ def model_request_sync(
model_name='claude-3-5-haiku-latest',
timestamp=datetime.datetime(...),
kind='response',
vendor_id=None,
)
'''
```
Expand Down
13 changes: 13 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ class ModelResponse:
model_name: str | None = None
"""The name of the model that generated the response."""

finish_reasons: list[str] = field(default_factory=list)
"""The reasons why the model finished generating the response, one for each part of the response."""

timestamp: datetime = field(default_factory=_now_utc)
"""The timestamp of the response.

Expand All @@ -562,6 +565,16 @@ class ModelResponse:
kind: Literal['response'] = 'response'
"""Message type identifier, this is available on all parts as a discriminator."""

vendor_details: dict[str, Any] | None = field(default=None, repr=False)
"""Additional vendor-specific details in a serializable format.

This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields.
For OpenAI models, this may include 'logprobs', 'finish_reason', etc.
"""

vendor_id: str | None = None
"""Vendor ID as specified by the model provider. This can be used to track the specific request to the model."""

def otel_events(self) -> list[Event]:
"""Return OpenTelemetry events for the response."""
result: list[Event] = []
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pydantic_ai/models/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _process_response(self, response: AnthropicMessage) -> ModelResponse:
)
)

return ModelResponse(items, usage=_map_usage(response), model_name=response.model)
return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id)

async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse:
peekable_response = _utils.PeekableAsyncStream(response)
Expand Down
3 changes: 2 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes
response_tokens=response['usage']['outputTokens'],
total_tokens=response['usage']['totalTokens'],
)
return ModelResponse(items, usage=u, model_name=self.model_name)
vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None)
return ModelResponse(items, usage=u, model_name=self.model_name, vendor_id=vendor_id)

@overload
async def _messages_create(
Expand Down
25 changes: 22 additions & 3 deletions pydantic_ai_slim/pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,9 +273,21 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse:
'Content field missing from Gemini response', str(response)
)
parts = response['candidates'][0]['content']['parts']
finish_reasons = [
finish_reason
for finish_reason in [response['candidates'][0].get('finish_reason')]
if finish_reason is not None
]
vendor_id = response.get('vendor_id', None)
usage = _metadata_as_usage(response)
usage.requests = 1
return _process_response_from_parts(parts, response.get('model_version', self._model_name), usage)
return _process_response_from_parts(
parts,
response.get('model_version', self._model_name),
usage,
vendor_id=vendor_id,
finish_reasons=finish_reasons,
)

async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse:
"""Process a streamed response, and prepare a streaming response to return."""
Expand Down Expand Up @@ -597,7 +609,11 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart


def _process_response_from_parts(
parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage
parts: Sequence[_GeminiPartUnion],
model_name: GeminiModelName,
usage: usage.Usage,
vendor_id: str | None,
finish_reasons: list[str],
) -> ModelResponse:
items: list[ModelResponsePart] = []
for part in parts:
Expand All @@ -609,7 +625,9 @@ def _process_response_from_parts(
raise UnexpectedModelBehavior(
f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}'
)
return ModelResponse(parts=items, usage=usage, model_name=model_name)
return ModelResponse(
parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, finish_reasons=finish_reasons
)


class _GeminiFunctionCall(TypedDict):
Expand Down Expand Up @@ -721,6 +739,7 @@ class _GeminiResponse(TypedDict):
usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]]
prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]]
model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]]
vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]]


class _GeminiCandidates(TypedDict):
Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,9 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
if choice.message.tool_calls is not None:
for c in choice.message.tool_calls:
items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id))
return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
return ModelResponse(
items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
)

async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse:
"""Process a streamed response, and prepare a streaming response to return."""
Expand Down
4 changes: 3 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ def _process_response(self, response: MistralChatCompletionResponse) -> ModelRes
tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call)
parts.append(tool)

return ModelResponse(parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
return ModelResponse(
parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id
)

async def _process_streamed_response(
self,
Expand Down
36 changes: 35 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ class OpenAIModelSettings(ModelSettings, total=False):
result in faster responses and fewer tokens used on reasoning in a response.
"""

openai_logprobs: bool
"""Include log probabilities in the response."""

openai_top_logprobs: int
"""Include log probabilities of the top n tokens in the response."""

openai_user: str
"""A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse.

Expand Down Expand Up @@ -287,6 +293,8 @@ async def _completions_create(
frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN),
logit_bias=model_settings.get('logit_bias', NOT_GIVEN),
reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN),
logprobs=model_settings.get('openai_logprobs', NOT_GIVEN),
top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN),
user=model_settings.get('openai_user', NOT_GIVEN),
extra_headers=extra_headers,
extra_body=model_settings.get('extra_body'),
Expand All @@ -301,12 +309,38 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse:
timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc)
choice = response.choices[0]
items: list[ModelResponsePart] = []
vendor_details: dict[str, Any] | None = None

# Add logprobs to vendor_details if available
if choice.logprobs is not None and choice.logprobs.content:
# Convert logprobs to a serializable format
vendor_details = {
'logprobs': [
{
'token': lp.token,
'bytes': lp.bytes,
'logprob': lp.logprob,
'top_logprobs': [
{'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs
],
}
for lp in choice.logprobs.content
],
}

if choice.message.content is not None:
items.append(TextPart(choice.message.content))
if choice.message.tool_calls is not None:
for c in choice.message.tool_calls:
items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id))
return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp)
return ModelResponse(
items,
usage=_map_usage(response),
model_name=response.model,
timestamp=timestamp,
vendor_details=vendor_details,
vendor_id=response.id,
)

async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse:
"""Process a streamed response, and prepare a streaming response to return."""
Expand Down
3 changes: 1 addition & 2 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import logfire_api
import typing_extensions
from opentelemetry.trace import Span
from typing_extensions import deprecated
from typing_inspection import typing_objects

Expand Down Expand Up @@ -212,7 +211,7 @@ async def iter(
state: StateT = None,
deps: DepsT = None,
persistence: BaseStatePersistence[StateT, RunEndT] | None = None,
span: AbstractContextManager[Span] | None = None,
span: AbstractContextManager[AbstractSpan] | None = None,
infer_name: bool = True,
) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]:
"""A contextmanager which can be used to iterate over the graph's nodes as they are executed.
Expand Down
Loading
Loading