From b3c972b2182f845a6cbc179434d78cc1251d7dc0 Mon Sep 17 00:00:00 2001 From: Seth Gilchrist Date: Sat, 10 May 2025 14:40:38 -0700 Subject: [PATCH 1/4] global per-step input filter for non-streaming run --- src/agents/run.py | 102 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/src/agents/run.py b/src/agents/run.py index 849da7bf..df87b70f 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -2,7 +2,9 @@ import asyncio import copy +from collections.abc import Iterable from dataclasses import dataclass, field +from inspect import iscoroutinefunction from typing import Any, cast from openai.types.responses import ResponseCompletedEvent @@ -70,6 +72,20 @@ class RunConfig: agent. See the documentation in `Handoff.input_filter` for more details. """ + run_step_input_filter: callable[ + str | list[TResponseInputItem], + str | list[TResponseInputItem] + ] | None = None + """A global input filter to apply between agent steps. If set, the input to the agent will be + passed through this function before being sent to the model. This is useful for modifying the + input to the model, for example, to manage the context window size.""" + + run_step_input_filter_raise_error: bool = False + """What to do if the input filter raises an exception. If False (the default), we'll continue + with the original input. If True, we'll raise the exception. This is useful for debugging, but + generally you want to set this to False so that the agent can continue running even if + the input filter fails.""" + input_guardrails: list[InputGuardrail[Any]] | None = None """A list of input guardrails to run on the initial run input.""" @@ -214,6 +230,12 @@ async def run( f"Running agent {current_agent.name} (turn {current_turn})", ) + original_input = await cls._run_step_input_filter( + original_input=original_input, + run_config=run_config, + span=current_span, + ) + if current_turn == 1: input_guardrail_results, turn_result = await asyncio.gather( cls._run_input_guardrails( @@ -966,3 +988,83 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return agent.model return run_config.model_provider.get_model(agent.model) + + @classmethod + async def _run_step_input_filter( + cls, + original_input: str | list[TResponseInputItem], + run_config: RunConfig, + span: Span[AgentSpanData], + ) -> str | list[TResponseInputItem]: + filter = run_config.run_step_input_filter + _raise = run_config.run_step_input_filter_raise_error + + def is_acceptable_response( + response: object + ) -> bool: + return ( + isinstance(response, str) + or ( + isinstance(response, Iterable) + and all( + "type" in item for item in response # minimal check for ResponseInputItem + ) + ) + ) + + if not filter: + return original_input + + if not callable(filter): + _error_tracing.attach_error_to_span( + span, + SpanError( + message="Input step filter is not callable", + data={"input_step_filter": filter}, + ), + ) + if _raise: + raise ModelBehaviorError( + "Input step filter is not callable" + ) + return original_input + try: + if iscoroutinefunction(filter): + input_filter_response = await filter(original_input) + else: + input_filter_response = filter(original_input) + except Exception as e: + _error_tracing.attach_error_to_span( + span, + SpanError( + message="Input step filter raised an exception", + data={ + "input_step_filter": filter, + "exception": str(e), + }, + ), + ) + if _raise: + raise ModelBehaviorError( + "Input step filter raised an exception" + ) from e + return original_input + + if not is_acceptable_response(input_filter_response): + _error_tracing.attach_error_to_span( + span, + SpanError( + message=( + "Input step filter did not return a string " + "or list of ResponseInputItems") + , + data={"input_step_filter": filter, "response": input_filter_response}, + ), + ) + if _raise: + raise ModelBehaviorError( + "Input step filter did not return a string or list" + ) + return original_input + + return input_filter_response From 9542bb508d4ce6d817f0d80f0bc69ebd488a9d9e Mon Sep 17 00:00:00 2001 From: Seth Gilchrist Date: Sat, 10 May 2025 14:54:18 -0700 Subject: [PATCH 2/4] global per-step input filter for streaming run --- .gitignore | 3 +++ src/agents/run.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/.gitignore b/.gitignore index 2e9b9237..b3d5789c 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,6 @@ cython_debug/ # PyPI configuration file .pypirc .aider* + +# VSCode Local history +.history \ No newline at end of file diff --git a/src/agents/run.py b/src/agents/run.py index df87b70f..a9231a05 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -568,6 +568,13 @@ async def _run_streamed_impl( streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) break + + streamed_result.input = await cls._run_step_input_filter( + original_input=streamed_result.input, + run_config=run_config, + span=current_span + ) + if current_turn == 1: # Run the input guardrails in the background and put the results on the queue streamed_result._input_guardrails_task = asyncio.create_task( From f49f3e034f9cb5bd667f1be5baf477fdddd13765 Mon Sep 17 00:00:00 2001 From: Seth Gilchrist Date: Mon, 12 May 2025 15:02:20 -0700 Subject: [PATCH 3/4] response item filter unit tests --- tests/test_run_step_processing.py | 196 ++++++++++++++++++++++++++++++ 1 file changed, 196 insertions(+) diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 2ea98f06..0b4835da 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio + import pytest from openai.types.responses import ( ResponseComputerToolCall, @@ -24,6 +26,8 @@ Usage, ) from agents._run_impl import RunImpl +from agents.run import RunConfig +from agents.tracing.create import agent_span from .test_responses import ( get_final_output_message, @@ -476,3 +480,195 @@ async def test_tool_and_handoff_parsed_correctly(): assert handoff.handoff.tool_name == Handoff.default_tool_name(agent_1) assert handoff.handoff.tool_description == Handoff.default_tool_description(agent_1) assert handoff.handoff.agent_name == agent_1.name + + +@pytest.fixture +def response_input_items(): + return [ + { # Message + "role": "system", "status": "completed", "type": "message", "content": "Message 1" + }, + { # Message + "role": "user", "status": "completed", "type": "message", "content": "Message 2" + }, + { # ComputerCallOutput + "call_id": "call_1", + "output": { # ResponseComputerToolCallOutputScreenshotParam + "id": "screenshot_1", "type": "screenshot", "url": "http://example.com/screenshot.png" + }, + "type": "computer_call_output", + "id": "output_1", + "acknowledged_safety_checks": [ + { # ComputerCallOutputAcknowledgedSafetyCheck + "id": "check_1", "code": "code_1", "message": "message_1" + } + ], + }, + { # FunctionCallOutput + "call_id": "call_2", + "output": { # ResponseFunctionWebSearch + "id": "web_search_1", "type": "web_search_call", "status": "completed" + }, + "type": "function_call_output", + "id": "output_2", + }, + { # Message + "role": "user", "status": "completed", "type": "message", "content": "Message 3" + }, + { # Message + "role": "system", "status": "completed", "type": "message", "content": "Message 4" + } + ] + +@pytest.fixture +def run_config(): + return RunConfig() + +@pytest.fixture +def span(): + return agent_span(name="test_span") + +@pytest.mark.asyncio +async def test_run_input_step_filter_not_callable(response_input_items, run_config, span): + input_filter = "This is not callable" + run_config.run_step_input_filter = input_filter + + # returns input by default + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert result == response_input_items + + # raises error if run_step_input_filter_raise_error is True + run_config.run_step_input_filter_raise_error = True + with pytest.raises(ModelBehaviorError): + await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + + +@pytest.mark.asyncio +async def test_run_input_step_filter_not_set(response_input_items, run_config, span): + # returns input by default + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert result == response_input_items + + +@pytest.mark.asyncio +async def test_run_input_step_filter_output(response_input_items, run_config, span): + # invalid output type + def input_filter(*args, **kwargs): + return 5 + run_config.run_step_input_filter = input_filter + + # returns input by default + response = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert response == response_input_items + + # raises error if run_step_input_filter_raise_error is True + run_config.run_step_input_filter_raise_error = True + with pytest.raises(ModelBehaviorError): + await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + + # string output is okay + def input_filter_str_output(*args, **kwargs): + return "This is a string output" + run_config.run_step_input_filter = input_filter_str_output + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert result == "This is a string output" + + # list of dicts with "type" + def input_filter_dict_output(*args, **kwargs): + return [ + { + "type": "message", + "role": "user", + "content": "This is a user message" + }, + { + "type": "message", + "role": "system", + "content": "This is a system message" + } + ] + run_config.run_step_input_filter = input_filter_dict_output + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert len(result) == 2 + assert result == input_filter_dict_output() + +@pytest.mark.asyncio +async def test_run_input_step_filter_error(response_input_items, run_config, span): + def input_filter(*args, **kwargs): + raise Exception("This is an error") + run_config.run_step_input_filter = input_filter + + # returns input by default + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert result == response_input_items + + # raises error if run_step_input_filter_raise_error is True + run_config.run_step_input_filter_raise_error = True + with pytest.raises(ModelBehaviorError): + await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + + +@pytest.mark.asyncio +async def test_run_input_step_filter(response_input_items, run_config, span): + # test sync function + def input_filter(input_items): + return [item for item in input_items if item.get("role", "") == "user"] + run_config.run_step_input_filter = input_filter + + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert all(item["role"] == "user" for item in result) + assert len(result) == 2 + + # test async function + async def input_filter_async(input_items): + return await asyncio.to_thread( + lambda : [item for item in input_items if item.get("role", "") == "user"] + ) + run_config.run_step_input_filter = input_filter_async + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert all(item["role"] == "user" for item in result) + assert len(result) == 2 From 3859f2522a37c86332f1bf6f1f38800304bf9c87 Mon Sep 17 00:00:00 2001 From: Seth Gilchrist Date: Mon, 12 May 2025 15:07:46 -0700 Subject: [PATCH 4/4] make format changes --- src/agents/models/openai_chatcompletions.py | 2 +- src/agents/run.py | 45 ++++-------- src/agents/voice/model.py | 2 + tests/test_extra_headers.py | 9 ++- tests/test_run_step_processing.py | 76 +++++++++++++-------- 5 files changed, 70 insertions(+), 64 deletions(-) diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 89619f83..90ad6daf 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -252,7 +252,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) }, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/src/agents/run.py b/src/agents/run.py index a9231a05..23720170 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -72,10 +72,9 @@ class RunConfig: agent. See the documentation in `Handoff.input_filter` for more details. """ - run_step_input_filter: callable[ - str | list[TResponseInputItem], - str | list[TResponseInputItem] - ] | None = None + run_step_input_filter: ( + callable[str | list[TResponseInputItem], str | list[TResponseInputItem]] | None + ) = None """A global input filter to apply between agent steps. If set, the input to the agent will be passed through this function before being sent to the model. This is useful for modifying the input to the model, for example, to manage the context window size.""" @@ -568,11 +567,8 @@ async def _run_streamed_impl( streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) break - streamed_result.input = await cls._run_step_input_filter( - original_input=streamed_result.input, - run_config=run_config, - span=current_span + original_input=streamed_result.input, run_config=run_config, span=current_span ) if current_turn == 1: @@ -1006,16 +1002,12 @@ async def _run_step_input_filter( filter = run_config.run_step_input_filter _raise = run_config.run_step_input_filter_raise_error - def is_acceptable_response( - response: object - ) -> bool: - return ( - isinstance(response, str) - or ( - isinstance(response, Iterable) - and all( - "type" in item for item in response # minimal check for ResponseInputItem - ) + def is_acceptable_response(response: object) -> bool: + return isinstance(response, str) or ( + isinstance(response, Iterable) + and all( + "type" in item + for item in response # minimal check for ResponseInputItem ) ) @@ -1031,9 +1023,7 @@ def is_acceptable_response( ), ) if _raise: - raise ModelBehaviorError( - "Input step filter is not callable" - ) + raise ModelBehaviorError("Input step filter is not callable") return original_input try: if iscoroutinefunction(filter): @@ -1052,9 +1042,7 @@ def is_acceptable_response( ), ) if _raise: - raise ModelBehaviorError( - "Input step filter raised an exception" - ) from e + raise ModelBehaviorError("Input step filter raised an exception") from e return original_input if not is_acceptable_response(input_filter_response): @@ -1062,16 +1050,13 @@ def is_acceptable_response( span, SpanError( message=( - "Input step filter did not return a string " - "or list of ResponseInputItems") - , + "Input step filter did not return a string or list of ResponseInputItems" + ), data={"input_step_filter": filter, "response": input_filter_response}, ), ) if _raise: - raise ModelBehaviorError( - "Input step filter did not return a string or list" - ) + raise ModelBehaviorError("Input step filter did not return a string or list") return original_input return input_filter_response diff --git a/src/agents/voice/model.py b/src/agents/voice/model.py index c36a4de7..b048a452 100644 --- a/src/agents/voice/model.py +++ b/src/agents/voice/model.py @@ -17,9 +17,11 @@ TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] """Exportable type for the TTSModelSettings voice enum""" + @dataclass class TTSModelSettings: """Settings for a TTS model.""" + voice: TTSVoice | None = None """ The voice to use for the TTS model. If not provided, the default voice for the respective model diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py index f29c2540..8efa95a7 100644 --- a/tests/test_extra_headers.py +++ b/tests/test_extra_headers.py @@ -17,21 +17,21 @@ class DummyResponses: async def create(self, **kwargs): nonlocal called_kwargs called_kwargs = kwargs + class DummyResponse: id = "dummy" output = [] usage = type( "Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} )() + return DummyResponse() class DummyClient: def __init__(self): self.responses = DummyResponses() - - - model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, @@ -47,7 +47,6 @@ def __init__(self): assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" - @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_extra_headers_passed_to_openai_client(): @@ -76,7 +75,7 @@ def __init__(self): self.chat = type("_Chat", (), {"completions": DummyCompletions()})() self.base_url = "https://api.openai.com" - model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 0b4835da..3b1877db 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -485,49 +485,70 @@ async def test_tool_and_handoff_parsed_correctly(): @pytest.fixture def response_input_items(): return [ - { # Message - "role": "system", "status": "completed", "type": "message", "content": "Message 1" + { # Message + "role": "system", + "status": "completed", + "type": "message", + "content": "Message 1", }, - { # Message - "role": "user", "status": "completed", "type": "message", "content": "Message 2" + { # Message + "role": "user", + "status": "completed", + "type": "message", + "content": "Message 2", }, - { # ComputerCallOutput + { # ComputerCallOutput "call_id": "call_1", - "output": { # ResponseComputerToolCallOutputScreenshotParam - "id": "screenshot_1", "type": "screenshot", "url": "http://example.com/screenshot.png" + "output": { # ResponseComputerToolCallOutputScreenshotParam + "id": "screenshot_1", + "type": "screenshot", + "url": "http://example.com/screenshot.png", }, "type": "computer_call_output", "id": "output_1", "acknowledged_safety_checks": [ - { # ComputerCallOutputAcknowledgedSafetyCheck - "id": "check_1", "code": "code_1", "message": "message_1" + { # ComputerCallOutputAcknowledgedSafetyCheck + "id": "check_1", + "code": "code_1", + "message": "message_1", } ], }, - { # FunctionCallOutput + { # FunctionCallOutput "call_id": "call_2", - "output": { # ResponseFunctionWebSearch - "id": "web_search_1", "type": "web_search_call", "status": "completed" + "output": { # ResponseFunctionWebSearch + "id": "web_search_1", + "type": "web_search_call", + "status": "completed", }, "type": "function_call_output", "id": "output_2", }, - { # Message - "role": "user", "status": "completed", "type": "message", "content": "Message 3" + { # Message + "role": "user", + "status": "completed", + "type": "message", + "content": "Message 3", + }, + { # Message + "role": "system", + "status": "completed", + "type": "message", + "content": "Message 4", }, - { # Message - "role": "system", "status": "completed", "type": "message", "content": "Message 4" - } ] + @pytest.fixture def run_config(): return RunConfig() + @pytest.fixture def span(): return agent_span(name="test_span") + @pytest.mark.asyncio async def test_run_input_step_filter_not_callable(response_input_items, run_config, span): input_filter = "This is not callable" @@ -567,6 +588,7 @@ async def test_run_input_step_filter_output(response_input_items, run_config, sp # invalid output type def input_filter(*args, **kwargs): return 5 + run_config.run_step_input_filter = input_filter # returns input by default @@ -589,6 +611,7 @@ def input_filter(*args, **kwargs): # string output is okay def input_filter_str_output(*args, **kwargs): return "This is a string output" + run_config.run_step_input_filter = input_filter_str_output result = await Runner._run_step_input_filter( original_input=response_input_items, @@ -600,17 +623,10 @@ def input_filter_str_output(*args, **kwargs): # list of dicts with "type" def input_filter_dict_output(*args, **kwargs): return [ - { - "type": "message", - "role": "user", - "content": "This is a user message" - }, - { - "type": "message", - "role": "system", - "content": "This is a system message" - } + {"type": "message", "role": "user", "content": "This is a user message"}, + {"type": "message", "role": "system", "content": "This is a system message"}, ] + run_config.run_step_input_filter = input_filter_dict_output result = await Runner._run_step_input_filter( original_input=response_input_items, @@ -620,10 +636,12 @@ def input_filter_dict_output(*args, **kwargs): assert len(result) == 2 assert result == input_filter_dict_output() + @pytest.mark.asyncio async def test_run_input_step_filter_error(response_input_items, run_config, span): def input_filter(*args, **kwargs): raise Exception("This is an error") + run_config.run_step_input_filter = input_filter # returns input by default @@ -649,6 +667,7 @@ async def test_run_input_step_filter(response_input_items, run_config, span): # test sync function def input_filter(input_items): return [item for item in input_items if item.get("role", "") == "user"] + run_config.run_step_input_filter = input_filter result = await Runner._run_step_input_filter( @@ -662,8 +681,9 @@ def input_filter(input_items): # test async function async def input_filter_async(input_items): return await asyncio.to_thread( - lambda : [item for item in input_items if item.get("role", "") == "user"] + lambda: [item for item in input_items if item.get("role", "") == "user"] ) + run_config.run_step_input_filter = input_filter_async result = await Runner._run_step_input_filter( original_input=response_input_items,