diff --git a/.gitignore b/.gitignore index 2e9b9237..b3d5789c 100644 --- a/.gitignore +++ b/.gitignore @@ -143,3 +143,6 @@ cython_debug/ # PyPI configuration file .pypirc .aider* + +# VSCode Local history +.history \ No newline at end of file diff --git a/src/agents/models/openai_chatcompletions.py b/src/agents/models/openai_chatcompletions.py index 89619f83..90ad6daf 100644 --- a/src/agents/models/openai_chatcompletions.py +++ b/src/agents/models/openai_chatcompletions.py @@ -252,7 +252,7 @@ async def _fetch_response( stream_options=self._non_null_or_not_given(stream_options), store=self._non_null_or_not_given(store), reasoning_effort=self._non_null_or_not_given(reasoning_effort), - extra_headers={ **HEADERS, **(model_settings.extra_headers or {}) }, + extra_headers={**HEADERS, **(model_settings.extra_headers or {})}, extra_query=model_settings.extra_query, extra_body=model_settings.extra_body, metadata=self._non_null_or_not_given(model_settings.metadata), diff --git a/src/agents/run.py b/src/agents/run.py index 849da7bf..23720170 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -2,7 +2,9 @@ import asyncio import copy +from collections.abc import Iterable from dataclasses import dataclass, field +from inspect import iscoroutinefunction from typing import Any, cast from openai.types.responses import ResponseCompletedEvent @@ -70,6 +72,19 @@ class RunConfig: agent. See the documentation in `Handoff.input_filter` for more details. """ + run_step_input_filter: ( + callable[str | list[TResponseInputItem], str | list[TResponseInputItem]] | None + ) = None + """A global input filter to apply between agent steps. If set, the input to the agent will be + passed through this function before being sent to the model. This is useful for modifying the + input to the model, for example, to manage the context window size.""" + + run_step_input_filter_raise_error: bool = False + """What to do if the input filter raises an exception. If False (the default), we'll continue + with the original input. If True, we'll raise the exception. This is useful for debugging, but + generally you want to set this to False so that the agent can continue running even if + the input filter fails.""" + input_guardrails: list[InputGuardrail[Any]] | None = None """A list of input guardrails to run on the initial run input.""" @@ -214,6 +229,12 @@ async def run( f"Running agent {current_agent.name} (turn {current_turn})", ) + original_input = await cls._run_step_input_filter( + original_input=original_input, + run_config=run_config, + span=current_span, + ) + if current_turn == 1: input_guardrail_results, turn_result = await asyncio.gather( cls._run_input_guardrails( @@ -546,6 +567,10 @@ async def _run_streamed_impl( streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) break + streamed_result.input = await cls._run_step_input_filter( + original_input=streamed_result.input, run_config=run_config, span=current_span + ) + if current_turn == 1: # Run the input guardrails in the background and put the results on the queue streamed_result._input_guardrails_task = asyncio.create_task( @@ -966,3 +991,72 @@ def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: return agent.model return run_config.model_provider.get_model(agent.model) + + @classmethod + async def _run_step_input_filter( + cls, + original_input: str | list[TResponseInputItem], + run_config: RunConfig, + span: Span[AgentSpanData], + ) -> str | list[TResponseInputItem]: + filter = run_config.run_step_input_filter + _raise = run_config.run_step_input_filter_raise_error + + def is_acceptable_response(response: object) -> bool: + return isinstance(response, str) or ( + isinstance(response, Iterable) + and all( + "type" in item + for item in response # minimal check for ResponseInputItem + ) + ) + + if not filter: + return original_input + + if not callable(filter): + _error_tracing.attach_error_to_span( + span, + SpanError( + message="Input step filter is not callable", + data={"input_step_filter": filter}, + ), + ) + if _raise: + raise ModelBehaviorError("Input step filter is not callable") + return original_input + try: + if iscoroutinefunction(filter): + input_filter_response = await filter(original_input) + else: + input_filter_response = filter(original_input) + except Exception as e: + _error_tracing.attach_error_to_span( + span, + SpanError( + message="Input step filter raised an exception", + data={ + "input_step_filter": filter, + "exception": str(e), + }, + ), + ) + if _raise: + raise ModelBehaviorError("Input step filter raised an exception") from e + return original_input + + if not is_acceptable_response(input_filter_response): + _error_tracing.attach_error_to_span( + span, + SpanError( + message=( + "Input step filter did not return a string or list of ResponseInputItems" + ), + data={"input_step_filter": filter, "response": input_filter_response}, + ), + ) + if _raise: + raise ModelBehaviorError("Input step filter did not return a string or list") + return original_input + + return input_filter_response diff --git a/src/agents/voice/model.py b/src/agents/voice/model.py index c36a4de7..b048a452 100644 --- a/src/agents/voice/model.py +++ b/src/agents/voice/model.py @@ -17,9 +17,11 @@ TTSVoice = Literal["alloy", "ash", "coral", "echo", "fable", "onyx", "nova", "sage", "shimmer"] """Exportable type for the TTSModelSettings voice enum""" + @dataclass class TTSModelSettings: """Settings for a TTS model.""" + voice: TTSVoice | None = None """ The voice to use for the TTS model. If not provided, the default voice for the respective model diff --git a/tests/test_extra_headers.py b/tests/test_extra_headers.py index f29c2540..8efa95a7 100644 --- a/tests/test_extra_headers.py +++ b/tests/test_extra_headers.py @@ -17,21 +17,21 @@ class DummyResponses: async def create(self, **kwargs): nonlocal called_kwargs called_kwargs = kwargs + class DummyResponse: id = "dummy" output = [] usage = type( "Usage", (), {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0} )() + return DummyResponse() class DummyClient: def __init__(self): self.responses = DummyResponses() - - - model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIResponsesModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, @@ -47,7 +47,6 @@ def __init__(self): assert called_kwargs["extra_headers"]["X-Test-Header"] == "test-value" - @pytest.mark.allow_call_model_methods @pytest.mark.asyncio async def test_extra_headers_passed_to_openai_client(): @@ -76,7 +75,7 @@ def __init__(self): self.chat = type("_Chat", (), {"completions": DummyCompletions()})() self.base_url = "https://api.openai.com" - model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore + model = OpenAIChatCompletionsModel(model="gpt-4", openai_client=DummyClient()) # type: ignore extra_headers = {"X-Test-Header": "test-value"} await model.get_response( system_instructions=None, diff --git a/tests/test_run_step_processing.py b/tests/test_run_step_processing.py index 2ea98f06..3b1877db 100644 --- a/tests/test_run_step_processing.py +++ b/tests/test_run_step_processing.py @@ -1,5 +1,7 @@ from __future__ import annotations +import asyncio + import pytest from openai.types.responses import ( ResponseComputerToolCall, @@ -24,6 +26,8 @@ Usage, ) from agents._run_impl import RunImpl +from agents.run import RunConfig +from agents.tracing.create import agent_span from .test_responses import ( get_final_output_message, @@ -476,3 +480,215 @@ async def test_tool_and_handoff_parsed_correctly(): assert handoff.handoff.tool_name == Handoff.default_tool_name(agent_1) assert handoff.handoff.tool_description == Handoff.default_tool_description(agent_1) assert handoff.handoff.agent_name == agent_1.name + + +@pytest.fixture +def response_input_items(): + return [ + { # Message + "role": "system", + "status": "completed", + "type": "message", + "content": "Message 1", + }, + { # Message + "role": "user", + "status": "completed", + "type": "message", + "content": "Message 2", + }, + { # ComputerCallOutput + "call_id": "call_1", + "output": { # ResponseComputerToolCallOutputScreenshotParam + "id": "screenshot_1", + "type": "screenshot", + "url": "http://example.com/screenshot.png", + }, + "type": "computer_call_output", + "id": "output_1", + "acknowledged_safety_checks": [ + { # ComputerCallOutputAcknowledgedSafetyCheck + "id": "check_1", + "code": "code_1", + "message": "message_1", + } + ], + }, + { # FunctionCallOutput + "call_id": "call_2", + "output": { # ResponseFunctionWebSearch + "id": "web_search_1", + "type": "web_search_call", + "status": "completed", + }, + "type": "function_call_output", + "id": "output_2", + }, + { # Message + "role": "user", + "status": "completed", + "type": "message", + "content": "Message 3", + }, + { # Message + "role": "system", + "status": "completed", + "type": "message", + "content": "Message 4", + }, + ] + + +@pytest.fixture +def run_config(): + return RunConfig() + + +@pytest.fixture +def span(): + return agent_span(name="test_span") + + +@pytest.mark.asyncio +async def test_run_input_step_filter_not_callable(response_input_items, run_config, span): + input_filter = "This is not callable" + run_config.run_step_input_filter = input_filter + + # returns input by default + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert result == response_input_items + + # raises error if run_step_input_filter_raise_error is True + run_config.run_step_input_filter_raise_error = True + with pytest.raises(ModelBehaviorError): + await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + + +@pytest.mark.asyncio +async def test_run_input_step_filter_not_set(response_input_items, run_config, span): + # returns input by default + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert result == response_input_items + + +@pytest.mark.asyncio +async def test_run_input_step_filter_output(response_input_items, run_config, span): + # invalid output type + def input_filter(*args, **kwargs): + return 5 + + run_config.run_step_input_filter = input_filter + + # returns input by default + response = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert response == response_input_items + + # raises error if run_step_input_filter_raise_error is True + run_config.run_step_input_filter_raise_error = True + with pytest.raises(ModelBehaviorError): + await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + + # string output is okay + def input_filter_str_output(*args, **kwargs): + return "This is a string output" + + run_config.run_step_input_filter = input_filter_str_output + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert result == "This is a string output" + + # list of dicts with "type" + def input_filter_dict_output(*args, **kwargs): + return [ + {"type": "message", "role": "user", "content": "This is a user message"}, + {"type": "message", "role": "system", "content": "This is a system message"}, + ] + + run_config.run_step_input_filter = input_filter_dict_output + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert len(result) == 2 + assert result == input_filter_dict_output() + + +@pytest.mark.asyncio +async def test_run_input_step_filter_error(response_input_items, run_config, span): + def input_filter(*args, **kwargs): + raise Exception("This is an error") + + run_config.run_step_input_filter = input_filter + + # returns input by default + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert result == response_input_items + + # raises error if run_step_input_filter_raise_error is True + run_config.run_step_input_filter_raise_error = True + with pytest.raises(ModelBehaviorError): + await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + + +@pytest.mark.asyncio +async def test_run_input_step_filter(response_input_items, run_config, span): + # test sync function + def input_filter(input_items): + return [item for item in input_items if item.get("role", "") == "user"] + + run_config.run_step_input_filter = input_filter + + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert all(item["role"] == "user" for item in result) + assert len(result) == 2 + + # test async function + async def input_filter_async(input_items): + return await asyncio.to_thread( + lambda: [item for item in input_items if item.get("role", "") == "user"] + ) + + run_config.run_step_input_filter = input_filter_async + result = await Runner._run_step_input_filter( + original_input=response_input_items, + run_config=run_config, + span=span, + ) + assert all(item["role"] == "user" for item in result) + assert len(result) == 2