diff --git a/.gitignore b/.gitignore index bba784959..acb015449 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,5 @@ examples/pydantic_ai_examples/.chat_app_messages.sqlite /docs-site/.wrangler/ /CLAUDE.md node_modules/ +**.idea/ +.coverage* diff --git a/docs/agents.md b/docs/agents.md index 13b0641f4..4b6b19635 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -155,6 +155,7 @@ async def main(): model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ) ), End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)), @@ -226,6 +227,7 @@ async def main(): model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ) ), End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)), @@ -829,6 +831,7 @@ with capture_run_messages() as messages: # (2)! model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ModelRequest( parts=[ @@ -862,6 +865,7 @@ with capture_run_messages() as messages: # (2)! model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ] """ diff --git a/docs/direct.md b/docs/direct.md index 6c9dc2688..e4c7f80fc 100644 --- a/docs/direct.md +++ b/docs/direct.md @@ -95,6 +95,7 @@ async def main(): model_name='gpt-4.1-nano', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ) """ ``` diff --git a/docs/message-history.md b/docs/message-history.md index f87ddb595..5b00e5cba 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -74,6 +74,7 @@ print(result.all_messages()) model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ] """ @@ -159,6 +160,7 @@ async def main(): model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ] """ @@ -225,6 +227,7 @@ print(result2.all_messages()) model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ModelRequest( parts=[ @@ -254,6 +257,7 @@ print(result2.all_messages()) model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ] """ @@ -367,6 +371,7 @@ print(result2.all_messages()) model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ModelRequest( parts=[ @@ -396,6 +401,7 @@ print(result2.all_messages()) model_name='gemini-1.5-pro', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ] """ diff --git a/docs/models/index.md b/docs/models/index.md index 424e20cac..3e4b985e5 100644 --- a/docs/models/index.md +++ b/docs/models/index.md @@ -105,6 +105,7 @@ print(response.all_messages()) model_name='claude-3-5-sonnet-latest', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ] """ diff --git a/docs/tools.md b/docs/tools.md index 9b45e86de..5fe60b119 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -106,6 +106,7 @@ print(dice_result.all_messages()) model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ModelRequest( parts=[ @@ -139,6 +140,7 @@ print(dice_result.all_messages()) model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ModelRequest( parts=[ @@ -170,6 +172,7 @@ print(dice_result.all_messages()) model_name='gemini-1.5-flash', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ), ] """ diff --git a/mkdocs.yml b/mkdocs.yml index 410b13ce4..067c6e208 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -105,25 +105,24 @@ theme: custom_dir: docs/.overrides palette: - media: "(prefers-color-scheme)" - scheme: default primary: pink accent: pink toggle: - icon: material/lightbulb + icon: material/brightness-auto name: "Switch to light mode" - media: "(prefers-color-scheme: light)" scheme: default primary: pink accent: pink toggle: - icon: material/lightbulb-outline + icon: material/brightness-7 name: "Switch to dark mode" - media: "(prefers-color-scheme: dark)" scheme: slate primary: pink accent: pink toggle: - icon: material/lightbulb-auto-outline + icon: material/brightness-4 name: "Switch to system preference" features: - search.suggest diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index f93f951c8..d202a53f0 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -585,6 +585,7 @@ async def main(): model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ) ), End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)), @@ -1854,6 +1855,7 @@ async def main(): model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ) ), End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)), @@ -1999,6 +2001,7 @@ async def main(): model_name='gpt-4o', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ) ), End(data=FinalResult(output='Paris', tool_name=None, tool_call_id=None)), diff --git a/pydantic_ai_slim/pydantic_ai/direct.py b/pydantic_ai_slim/pydantic_ai/direct.py index 71f02e720..004bdba33 100644 --- a/pydantic_ai_slim/pydantic_ai/direct.py +++ b/pydantic_ai_slim/pydantic_ai/direct.py @@ -52,6 +52,7 @@ async def main(): model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ) ''' ``` @@ -108,6 +109,7 @@ def model_request_sync( model_name='claude-3-5-haiku-latest', timestamp=datetime.datetime(...), kind='response', + vendor_id=None, ) ''' ``` diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index dc88eef0c..de5ff292f 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -553,6 +553,9 @@ class ModelResponse: model_name: str | None = None """The name of the model that generated the response.""" + finish_reasons: list[str] = field(default_factory=list) + """The reasons why the model finished generating the response, one for each part of the response.""" + timestamp: datetime = field(default_factory=_now_utc) """The timestamp of the response. @@ -562,6 +565,16 @@ class ModelResponse: kind: Literal['response'] = 'response' """Message type identifier, this is available on all parts as a discriminator.""" + vendor_details: dict[str, Any] | None = field(default=None, repr=False) + """Additional vendor-specific details in a serializable format. + + This allows storing selected vendor-specific data that isn't mapped to standard ModelResponse fields. + For OpenAI models, this may include 'logprobs', 'finish_reason', etc. + """ + + vendor_id: str | None = None + """Vendor ID as specified by the model provider. This can be used to track the specific request to the model.""" + def otel_events(self) -> list[Event]: """Return OpenTelemetry events for the response.""" result: list[Event] = [] diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 4152e635e..68d3f1ab6 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -262,7 +262,7 @@ def _process_response(self, response: AnthropicMessage) -> ModelResponse: ) ) - return ModelResponse(items, usage=_map_usage(response), model_name=response.model) + return ModelResponse(items, usage=_map_usage(response), model_name=response.model, vendor_id=response.id) async def _process_streamed_response(self, response: AsyncStream[RawMessageStreamEvent]) -> StreamedResponse: peekable_response = _utils.PeekableAsyncStream(response) diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 43dea39f2..1b86f0291 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -271,7 +271,8 @@ async def _process_response(self, response: ConverseResponseTypeDef) -> ModelRes response_tokens=response['usage']['outputTokens'], total_tokens=response['usage']['totalTokens'], ) - return ModelResponse(items, usage=u, model_name=self.model_name) + vendor_id = response.get('ResponseMetadata', {}).get('RequestId', None) + return ModelResponse(items, usage=u, model_name=self.model_name, vendor_id=vendor_id) @overload async def _messages_create( diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 4390bc7d6..2b00e21ca 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -273,9 +273,21 @@ def _process_response(self, response: _GeminiResponse) -> ModelResponse: 'Content field missing from Gemini response', str(response) ) parts = response['candidates'][0]['content']['parts'] + finish_reasons = [ + finish_reason + for finish_reason in [response['candidates'][0].get('finish_reason')] + if finish_reason is not None + ] + vendor_id = response.get('vendor_id', None) usage = _metadata_as_usage(response) usage.requests = 1 - return _process_response_from_parts(parts, response.get('model_version', self._model_name), usage) + return _process_response_from_parts( + parts, + response.get('model_version', self._model_name), + usage, + vendor_id=vendor_id, + finish_reasons=finish_reasons, + ) async def _process_streamed_response(self, http_response: HTTPResponse) -> StreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" @@ -597,7 +609,11 @@ def _function_call_part_from_call(tool: ToolCallPart) -> _GeminiFunctionCallPart def _process_response_from_parts( - parts: Sequence[_GeminiPartUnion], model_name: GeminiModelName, usage: usage.Usage + parts: Sequence[_GeminiPartUnion], + model_name: GeminiModelName, + usage: usage.Usage, + vendor_id: str | None, + finish_reasons: list[str], ) -> ModelResponse: items: list[ModelResponsePart] = [] for part in parts: @@ -609,7 +625,9 @@ def _process_response_from_parts( raise UnexpectedModelBehavior( f'Unsupported response from Gemini, expected all parts to be function calls or text, got: {part!r}' ) - return ModelResponse(parts=items, usage=usage, model_name=model_name) + return ModelResponse( + parts=items, usage=usage, model_name=model_name, vendor_id=vendor_id, finish_reasons=finish_reasons + ) class _GeminiFunctionCall(TypedDict): @@ -721,6 +739,7 @@ class _GeminiResponse(TypedDict): usage_metadata: NotRequired[Annotated[_GeminiUsageMetaData, pydantic.Field(alias='usageMetadata')]] prompt_feedback: NotRequired[Annotated[_GeminiPromptFeedback, pydantic.Field(alias='promptFeedback')]] model_version: NotRequired[Annotated[str, pydantic.Field(alias='modelVersion')]] + vendor_id: NotRequired[Annotated[str, pydantic.Field(alias='responseId')]] class _GeminiCandidates(TypedDict): diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index a7effb7d3..47b0f693d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -239,7 +239,9 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: if choice.message.tool_calls is not None: for c in choice.message.tool_calls: items.append(ToolCallPart(tool_name=c.function.name, args=c.function.arguments, tool_call_id=c.id)) - return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp) + return ModelResponse( + items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id + ) async def _process_streamed_response(self, response: AsyncStream[chat.ChatCompletionChunk]) -> GroqStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 11e8fff48..ef76996af 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -325,7 +325,9 @@ def _process_response(self, response: MistralChatCompletionResponse) -> ModelRes tool = self._map_mistral_to_pydantic_tool_call(tool_call=tool_call) parts.append(tool) - return ModelResponse(parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp) + return ModelResponse( + parts, usage=_map_usage(response), model_name=response.model, timestamp=timestamp, vendor_id=response.id + ) async def _process_streamed_response( self, diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 6e999fdef..14156ab33 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -104,6 +104,12 @@ class OpenAIModelSettings(ModelSettings, total=False): result in faster responses and fewer tokens used on reasoning in a response. """ + openai_logprobs: bool + """Include log probabilities in the response.""" + + openai_top_logprobs: int + """Include log probabilities of the top n tokens in the response.""" + openai_user: str """A unique identifier representing the end-user, which can help OpenAI monitor and detect abuse. @@ -287,6 +293,8 @@ async def _completions_create( frequency_penalty=model_settings.get('frequency_penalty', NOT_GIVEN), logit_bias=model_settings.get('logit_bias', NOT_GIVEN), reasoning_effort=model_settings.get('openai_reasoning_effort', NOT_GIVEN), + logprobs=model_settings.get('openai_logprobs', NOT_GIVEN), + top_logprobs=model_settings.get('openai_top_logprobs', NOT_GIVEN), user=model_settings.get('openai_user', NOT_GIVEN), extra_headers=extra_headers, extra_body=model_settings.get('extra_body'), @@ -301,12 +309,38 @@ def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) choice = response.choices[0] items: list[ModelResponsePart] = [] + vendor_details: dict[str, Any] | None = None + + # Add logprobs to vendor_details if available + if choice.logprobs is not None and choice.logprobs.content: + # Convert logprobs to a serializable format + vendor_details = { + 'logprobs': [ + { + 'token': lp.token, + 'bytes': lp.bytes, + 'logprob': lp.logprob, + 'top_logprobs': [ + {'token': tlp.token, 'bytes': tlp.bytes, 'logprob': tlp.logprob} for tlp in lp.top_logprobs + ], + } + for lp in choice.logprobs.content + ], + } + if choice.message.content is not None: items.append(TextPart(choice.message.content)) if choice.message.tool_calls is not None: for c in choice.message.tool_calls: items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)) - return ModelResponse(items, usage=_map_usage(response), model_name=response.model, timestamp=timestamp) + return ModelResponse( + items, + usage=_map_usage(response), + model_name=response.model, + timestamp=timestamp, + vendor_details=vendor_details, + vendor_id=response.id, + ) async def _process_streamed_response(self, response: AsyncStream[ChatCompletionChunk]) -> OpenAIStreamedResponse: """Process a streamed response, and prepare a streaming response to return.""" diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index bec8cfc77..2eaf52f58 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -10,7 +10,6 @@ import logfire_api import typing_extensions -from opentelemetry.trace import Span from typing_extensions import deprecated from typing_inspection import typing_objects @@ -212,7 +211,7 @@ async def iter( state: StateT = None, deps: DepsT = None, persistence: BaseStatePersistence[StateT, RunEndT] | None = None, - span: AbstractContextManager[Span] | None = None, + span: AbstractContextManager[AbstractSpan] | None = None, infer_name: bool = True, ) -> AsyncIterator[GraphRun[StateT, DepsT, RunEndT]]: """A contextmanager which can be used to iterate over the graph's nodes as they are executed. diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 16965ae85..9318cfbe1 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -184,6 +184,7 @@ async def test_sync_request_text_response(allow_model_requests: None): ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), + vendor_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( @@ -197,6 +198,7 @@ async def test_sync_request_text_response(allow_model_requests: None): ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), + vendor_id='123', ), ] ) @@ -287,6 +289,7 @@ async def test_request_structured_response(allow_model_requests: None): ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -356,6 +359,7 @@ async def get_location(loc_name: str) -> str: ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -384,6 +388,7 @@ async def get_location(loc_name: str) -> str: ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -406,6 +411,7 @@ async def get_location(loc_name: str) -> str: ), model_name='claude-3-5-haiku-123', timestamp=IsNow(tz=timezone.utc), + vendor_id='123', ), ] ) @@ -754,6 +760,7 @@ async def get_image() -> BinaryContent: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), + vendor_id='msg_01BPu4UTHXhqtR1TvsRhBLYY', ), ModelRequest( parts=[ @@ -792,6 +799,7 @@ async def get_image() -> BinaryContent: ), model_name='claude-3-5-sonnet-20241022', timestamp=IsDatetime(), + vendor_id='msg_01Ua6uyZUF15YV3G1PusaqSq', ), ] ) @@ -924,6 +932,7 @@ def simple_instructions(): ), model_name='claude-3-opus-20240229', timestamp=IsDatetime(), + vendor_id='msg_01U58nruzfn9BrXrrF2hhb4m', ), ] ) diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 8d551abe4..143bda135 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -147,6 +147,7 @@ async def test_request_simple_success(allow_model_requests: None): usage=Usage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( @@ -154,6 +155,7 @@ async def test_request_simple_success(allow_model_requests: None): usage=Usage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ] ) @@ -206,6 +208,7 @@ async def test_request_structured_response(allow_model_requests: None): usage=Usage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -293,6 +296,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -315,6 +319,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=3, response_tokens=2, total_tokens=6), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -331,6 +336,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1), model_name='llama-3.3-70b-versatile-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ] ) @@ -567,6 +573,7 @@ async def get_image() -> BinaryContent: usage=Usage(requests=1, request_tokens=192, response_tokens=8, total_tokens=200), model_name='meta-llama/llama-4-scout-17b-16e-instruct', timestamp=IsDatetime(), + vendor_id='chatcmpl-3c327c89-e9f5-4aac-a5d5-190e6f6f25c9', ), ModelRequest( parts=[ @@ -590,6 +597,7 @@ async def get_image() -> BinaryContent: usage=Usage(requests=1, request_tokens=2552, response_tokens=11, total_tokens=2563), model_name='meta-llama/llama-4-scout-17b-16e-instruct', timestamp=IsDatetime(), + vendor_id='chatcmpl-82dfad42-6a28-4089-82c3-c8633f626c0d', ), ] ) @@ -669,6 +677,7 @@ async def test_groq_model_instructions(allow_model_requests: None, groq_api_key: usage=Usage(requests=1, request_tokens=48, response_tokens=8, total_tokens=56), model_name='llama-3.3-70b-versatile', timestamp=IsDatetime(), + vendor_id='chatcmpl-7586b6a9-fb4b-4ec7-86a0-59f0a77844cf', ), ] ) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 0778773be..b1a33d32e 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -216,6 +216,7 @@ async def test_multiple_completions(allow_model_requests: None): usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=IsNow(tz=timezone.utc), + vendor_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( @@ -223,6 +224,7 @@ async def test_multiple_completions(allow_model_requests: None): usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ] ) @@ -267,6 +269,7 @@ async def test_three_completions(allow_model_requests: None): usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello again', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( @@ -274,6 +277,7 @@ async def test_three_completions(allow_model_requests: None): usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest(parts=[UserPromptPart(content='final message', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( @@ -281,6 +285,7 @@ async def test_three_completions(allow_model_requests: None): usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ] ) @@ -399,6 +404,7 @@ class CityLocation(BaseModel): usage=Usage(requests=1, request_tokens=1, response_tokens=2, total_tokens=3), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -459,6 +465,7 @@ class CityLocation(BaseModel): usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -518,6 +525,7 @@ async def test_request_output_type_with_arguments_str_response(allow_model_reque usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -1085,6 +1093,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -1107,6 +1116,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=3, response_tokens=2, total_tokens=6), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -1123,6 +1133,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ] ) @@ -1225,6 +1236,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -1247,6 +1259,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=3, response_tokens=2, total_tokens=6), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -1269,6 +1282,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), model_name='mistral-large-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -1787,6 +1801,7 @@ async def get_image() -> BinaryContent: usage=Usage(requests=1, request_tokens=65, response_tokens=16, total_tokens=81), model_name='pixtral-12b-latest', timestamp=IsDatetime(), + vendor_id='fce6d16a4e5940edb24ae16dd0369947', ), ModelRequest( parts=[ @@ -1814,6 +1829,7 @@ async def get_image() -> BinaryContent: usage=Usage(requests=1, request_tokens=2931, response_tokens=70, total_tokens=3001), model_name='pixtral-12b-latest', timestamp=IsDatetime(), + vendor_id='26e7de193646460e8904f8e604a60dc1', ), ] ) @@ -1851,6 +1867,7 @@ async def test_image_url_input(allow_model_requests: None): usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), + vendor_id='123', ), ] ) @@ -1883,6 +1900,7 @@ async def test_image_as_binary_content_input(allow_model_requests: None): usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), + vendor_id='123', ), ] ) @@ -1943,6 +1961,7 @@ async def test_mistral_model_instructions(allow_model_requests: None, mistral_ap usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=1), model_name='mistral-large-123', timestamp=IsDatetime(), + vendor_id='123', ), ] ) diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 75a0160b3..bc146fcd4 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -40,7 +40,7 @@ with try_import() as imports_successful: from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI from openai.types import chat - from openai.types.chat.chat_completion import Choice + from openai.types.chat.chat_completion import Choice, ChoiceLogprobs from openai.types.chat.chat_completion_chunk import ( Choice as ChunkChoice, ChoiceDelta, @@ -49,6 +49,7 @@ ) from openai.types.chat.chat_completion_message import ChatCompletionMessage from openai.types.chat.chat_completion_message_tool_call import Function + from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob from openai.types.completion_usage import CompletionUsage, PromptTokensDetails from pydantic_ai.models.openai import ( @@ -129,10 +130,15 @@ def get_mock_chat_completion_kwargs(async_open_ai: AsyncOpenAI) -> list[dict[str raise RuntimeError('Not a MockOpenAI instance') -def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage | None = None) -> chat.ChatCompletion: +def completion_message( + message: ChatCompletionMessage, *, usage: CompletionUsage | None = None, logprobs: ChoiceLogprobs | None = None +) -> chat.ChatCompletion: + choices = [Choice(finish_reason='stop', index=0, message=message)] + if logprobs: + choices = [Choice(finish_reason='stop', index=0, message=message, logprobs=logprobs)] return chat.ChatCompletion( id='123', - choices=[Choice(finish_reason='stop', index=0, message=message)], + choices=choices, created=1704067200, # 2024-01-01 model='gpt-4o-123', object='chat.completion', @@ -141,7 +147,9 @@ def completion_message(message: ChatCompletionMessage, *, usage: CompletionUsage async def test_request_simple_success(allow_model_requests: None): - c = completion_message(ChatCompletionMessage(content='world', role='assistant')) + c = completion_message( + ChatCompletionMessage(content='world', role='assistant'), + ) mock_client = MockOpenAI.create_mock(c) m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) agent = Agent(m) @@ -164,6 +172,7 @@ async def test_request_simple_success(allow_model_requests: None): usage=Usage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest(parts=[UserPromptPart(content='hello', timestamp=IsNow(tz=timezone.utc))]), ModelResponse( @@ -171,6 +180,7 @@ async def test_request_simple_success(allow_model_requests: None): usage=Usage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', ), ] ) @@ -244,6 +254,7 @@ async def test_request_structured_response(allow_model_requests: None): usage=Usage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -335,6 +346,7 @@ async def get_location(loc_name: str) -> str: ), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -359,6 +371,7 @@ async def get_location(loc_name: str) -> str: ), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), + vendor_id='123', ), ModelRequest( parts=[ @@ -375,6 +388,7 @@ async def get_location(loc_name: str) -> str: usage=Usage(requests=1), model_name='gpt-4o-123', timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), + vendor_id='123', ), ] ) @@ -715,6 +729,7 @@ async def get_image() -> ImageUrl: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRmTHlrARTzAHK1na9s80xDlQGYPX', ), ModelRequest( parts=[ @@ -752,6 +767,7 @@ async def get_image() -> ImageUrl: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRmTI0Y2zmkGw27kLarhsmiFQTGxR', ), ] ) @@ -796,6 +812,7 @@ async def get_image() -> BinaryContent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRlkLhPc87BdohVobEJJCGq3rUAG2', ), ModelRequest( parts=[ @@ -831,6 +848,7 @@ async def get_image() -> BinaryContent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRlkORPA5rXMV3uzcOcgK4eQFKCVW', ), ] ) @@ -1466,6 +1484,7 @@ async def test_openai_instructions(allow_model_requests: None, openai_api_key: s ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BJjf61mLb9z5H45ClJzbx0UWKwjo1', ), ] ) @@ -1514,6 +1533,7 @@ async def get_temperature(city: str) -> float: ), model_name='gpt-4.1-mini-2025-04-14', timestamp=IsDatetime(), + vendor_id='chatcmpl-BMxEwRA0p0gJ52oKS7806KAlfMhqq', ), ModelRequest( parts=[ @@ -1540,6 +1560,47 @@ async def get_temperature(city: str) -> float: ), model_name='gpt-4.1-mini-2025-04-14', timestamp=IsDatetime(), + vendor_id='chatcmpl-BMxEx6B8JEj6oDC45MOWKp0phg8UP', ), ] ) + + +@pytest.mark.vcr() +async def test_openai_instructions_with_logprobs(allow_model_requests: None): + # Create a mock response with logprobs + c = completion_message( + ChatCompletionMessage(content='world', role='assistant'), + logprobs=ChoiceLogprobs( + content=[ + ChatCompletionTokenLogprob( + token='world', logprob=-0.6931, top_logprobs=[], bytes=[119, 111, 114, 108, 100] + ) + ], + ), + ) + + mock_client = MockOpenAI.create_mock(c) + m = OpenAIModel( + 'gpt-4o', + provider=OpenAIProvider(openai_client=mock_client), + ) + agent = Agent( + m, + instructions='You are a helpful assistant.', + ) + result = await agent.run( + 'What is the capital of Minas Gerais?', + model_settings=OpenAIModelSettings(openai_logprobs=True), + ) + messages = result.all_messages() + response = cast(Any, messages[1]) + assert response.vendor_details is not None + assert response.vendor_details['logprobs'] == [ + { + 'token': 'world', + 'logprob': -0.6931, + 'bytes': [119, 111, 114, 108, 100], + 'top_logprobs': [], + } + ] diff --git a/tests/test_agent.py b/tests/test_agent.py index 9c0261b7b..5e93df4e5 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1755,8 +1755,10 @@ def test_binary_content_all_messages_json(): 'details': None, }, 'model_name': 'test', + 'vendor_id': None, 'timestamp': IsStr(), 'kind': 'response', + 'vendor_details': None, }, ] ) diff --git a/tests/test_cli.py b/tests/test_cli.py index 679bd3ebb..4c3c37f25 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -63,6 +63,7 @@ def test_agent_flag( create_test_module: Callable[..., None], ): env.remove('OPENAI_API_KEY') + env.set('COLUMNS', '150') test_agent = Agent(TestModel(custom_output_text='Hello from custom agent')) create_test_module(custom_agent=test_agent) @@ -98,6 +99,7 @@ def test_agent_flag_set_model( create_test_module: Callable[..., None], ): env.set('OPENAI_API_KEY', 'xxx') + env.set('COLUMNS', '150') custom_agent = Agent(TestModel(custom_output_text='Hello from custom agent')) create_test_module(custom_agent=custom_agent) diff --git a/tests/test_mcp.py b/tests/test_mcp.py index e23d42227..101a63723 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -121,6 +121,7 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRlnvvqIPFofAtKqtQKMWZkgXhzlT', ), ModelRequest( parts=[ @@ -149,6 +150,7 @@ async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent) ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRlnyjUo5wlyqvdNdM5I8vIWjo1qF', ), ] ) @@ -222,6 +224,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRlo3e1Ud2lnvkddMilmwC7LAemiy', ), ModelRequest( parts=[ @@ -254,6 +257,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRlo41LxqBYgGKWgGrQn67fQacOLp', ), ] ) @@ -295,6 +299,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRmhyweJVYonarb7s9ckIMSHf2vHo', ), ModelRequest( parts=[ @@ -323,6 +328,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRmhzqXFObpYwSzREMpJvX9kbDikR', ), ] ) @@ -366,6 +372,7 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRlo7KYJVXuNZ5lLLdYcKZDsX2CHb', ), ModelRequest( parts=[ @@ -405,6 +412,7 @@ async def test_tool_returning_image_resource(allow_model_requests: None, agent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRloBGHh27w3fQKwxq4fX2cPuZJa9', ), ] ) @@ -444,6 +452,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRloGQJWIX0Qk7gtNzF4s2Fez0O29', ), ModelRequest( parts=[ @@ -479,6 +488,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRloJHR654fSD0fcvLWZxtKtn0pag', ), ] ) @@ -516,6 +526,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRloOs7Bb2tq8wJyy9Rv7SQ7L65a7', ), ModelRequest( parts=[ @@ -544,6 +555,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRloPczU1HSCWnreyo21DdNtdOM7L', ), ] ) @@ -587,6 +599,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRloSNg7aGSp1rXDkhInjMIUHKd7A', ), ModelRequest( parts=[ @@ -619,6 +632,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRloTvSkFeX4DZKQLqfH9KbQkWlpt', ), ModelRequest( parts=[ @@ -651,6 +665,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRloU3MhnqNEqujs28a3ofRbs7VPF', ), ] ) @@ -688,6 +703,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRloX2RokWc9j9PAXAuNXGR73WNqY', ), ModelRequest( parts=[ @@ -716,6 +732,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRloYWGujk8yE94gfVSsM1T1Ol2Ej', ), ] ) @@ -759,6 +776,7 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRlobKLgm6vf79c9O8sloZaYx3coC', ), ModelRequest( parts=[ @@ -803,6 +821,7 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ), model_name='gpt-4o-2024-08-06', timestamp=IsDatetime(), + vendor_id='chatcmpl-BRloepWR5NJpTgSqFBGTSPeM1SWm8', ), ] )