Skip to content

Added RunErrorDetails object for MaxTurnsExceeded exception #743

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
MaxTurnsExceeded,
ModelBehaviorError,
OutputGuardrailTripwireTriggered,
RunErrorDetails,
UserError,
)
from .guardrail import (
Expand Down Expand Up @@ -204,6 +205,7 @@ def enable_verbose_stdout_logging():
"AgentHooks",
"RunContextWrapper",
"TContext",
"RunErrorDetails",
"RunResult",
"RunResultStreaming",
"RunConfig",
Expand Down
41 changes: 36 additions & 5 deletions src/agents/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,39 @@
from typing import TYPE_CHECKING
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from .agent import Agent
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ModelResponse, RunItem, TResponseInputItem
from .run_context import RunContextWrapper

from .util._pretty_print import pretty_print_run_error_details


@dataclass
class RunErrorDetails:
"""Data collected from an agent run when an exception occurs."""
input: str | list[TResponseInputItem]
new_items: list[RunItem]
raw_responses: list[ModelResponse]
last_agent: Agent[Any]
context_wrapper: RunContextWrapper[Any]
input_guardrail_results: list[InputGuardrailResult]
output_guardrail_results: list[OutputGuardrailResult]

def __str__(self) -> str:
return pretty_print_run_error_details(self)


class AgentsException(Exception):
"""Base class for all exceptions in the Agents SDK."""
run_data: RunErrorDetails | None

def __init__(self, *args: object) -> None:
super().__init__(*args)
self.run_data = None


class MaxTurnsExceeded(AgentsException):
Expand All @@ -15,6 +43,7 @@ class MaxTurnsExceeded(AgentsException):

def __init__(self, message: str):
self.message = message
super().__init__(message)


class ModelBehaviorError(AgentsException):
Expand All @@ -26,6 +55,7 @@ class ModelBehaviorError(AgentsException):

def __init__(self, message: str):
self.message = message
super().__init__(message)


class UserError(AgentsException):
Expand All @@ -35,15 +65,16 @@ class UserError(AgentsException):

def __init__(self, message: str):
self.message = message
super().__init__(message)


class InputGuardrailTripwireTriggered(AgentsException):
"""Exception raised when a guardrail tripwire is triggered."""

guardrail_result: "InputGuardrailResult"
guardrail_result: InputGuardrailResult
"""The result data of the guardrail that was triggered."""

def __init__(self, guardrail_result: "InputGuardrailResult"):
def __init__(self, guardrail_result: InputGuardrailResult):
self.guardrail_result = guardrail_result
super().__init__(
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
Expand All @@ -53,10 +84,10 @@ def __init__(self, guardrail_result: "InputGuardrailResult"):
class OutputGuardrailTripwireTriggered(AgentsException):
"""Exception raised when a guardrail tripwire is triggered."""

guardrail_result: "OutputGuardrailResult"
guardrail_result: OutputGuardrailResult
"""The result data of the guardrail that was triggered."""

def __init__(self, guardrail_result: "OutputGuardrailResult"):
def __init__(self, guardrail_result: OutputGuardrailResult):
self.guardrail_result = guardrail_result
super().__init__(
f"Guardrail {guardrail_result.guardrail.__class__.__name__} triggered tripwire"
Expand Down
85 changes: 72 additions & 13 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,22 @@
from ._run_impl import QueueCompleteSentinel
from .agent import Agent
from .agent_output import AgentOutputSchemaBase
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
from .exceptions import (
AgentsException,
InputGuardrailTripwireTriggered,
MaxTurnsExceeded,
RunErrorDetails,
)
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .logger import logger
from .run_context import RunContextWrapper
from .stream_events import StreamEvent
from .tracing import Trace
from .util._pretty_print import pretty_print_result, pretty_print_run_result_streaming
from .util._pretty_print import (
pretty_print_result,
pretty_print_run_result_streaming,
)

if TYPE_CHECKING:
from ._run_impl import QueueCompleteSentinel
Expand Down Expand Up @@ -208,29 +216,79 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:

def _check_errors(self):
if self.current_turn > self.max_turns:
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
max_turns_exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
max_turns_exc.run_data = RunErrorDetails(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = max_turns_exc

# Fetch all the completed guardrail results from the queue and raise if needed
while not self._input_guardrail_queue.empty():
guardrail_result = self._input_guardrail_queue.get_nowait()
if guardrail_result.output.tripwire_triggered:
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
tripwire_exc = InputGuardrailTripwireTriggered(guardrail_result)
tripwire_exc.run_data = RunErrorDetails(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = tripwire_exc

# Check the tasks for any exceptions
if self._run_impl_task and self._run_impl_task.done():
exc = self._run_impl_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
run_impl_exc = self._run_impl_task.exception()
if run_impl_exc and isinstance(run_impl_exc, Exception):
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None:
run_impl_exc.run_data = RunErrorDetails(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = run_impl_exc

if self._input_guardrails_task and self._input_guardrails_task.done():
exc = self._input_guardrails_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
in_guard_exc = self._input_guardrails_task.exception()
if in_guard_exc and isinstance(in_guard_exc, Exception):
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None:
in_guard_exc.run_data = RunErrorDetails(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = in_guard_exc

if self._output_guardrails_task and self._output_guardrails_task.done():
exc = self._output_guardrails_task.exception()
if exc and isinstance(exc, Exception):
self._stored_exception = exc
out_guard_exc = self._output_guardrails_task.exception()
if out_guard_exc and isinstance(out_guard_exc, Exception):
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None:
out_guard_exc.run_data = RunErrorDetails(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = out_guard_exc

def _cleanup_tasks(self):
if self._run_impl_task and not self._run_impl_task.done():
Expand All @@ -244,3 +302,4 @@ def _cleanup_tasks(self):

def __str__(self) -> str:
return pretty_print_run_result_streaming(self)

30 changes: 29 additions & 1 deletion src/agents/run.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from __future__ import annotations

import asyncio
Expand Down Expand Up @@ -26,6 +27,7 @@
MaxTurnsExceeded,
ModelBehaviorError,
OutputGuardrailTripwireTriggered,
RunErrorDetails,
)
from .guardrail import InputGuardrail, InputGuardrailResult, OutputGuardrail, OutputGuardrailResult
from .handoffs import Handoff, HandoffInputFilter, handoff
Expand Down Expand Up @@ -208,7 +210,9 @@ async def run(
data={"max_turns": max_turns},
),
)
raise MaxTurnsExceeded(f"Max turns ({max_turns}) exceeded")
raise MaxTurnsExceeded(
f"Max turns ({max_turns}) exceeded"
)

logger.debug(
f"Running agent {current_agent.name} (turn {current_turn})",
Expand Down Expand Up @@ -283,6 +287,17 @@ async def run(
raise AgentsException(
f"Unknown next step type: {type(turn_result.next_step)}"
)
except AgentsException as exc:
exc.run_data = RunErrorDetails(
input=original_input,
new_items=generated_items,
raw_responses=model_responses,
last_agent=current_agent,
context_wrapper=context_wrapper,
input_guardrail_results=input_guardrail_results,
output_guardrail_results=[]
)
raise
finally:
if current_span:
current_span.finish(reset_current=True)
Expand Down Expand Up @@ -609,6 +624,19 @@ async def _run_streamed_impl(
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
elif isinstance(turn_result.next_step, NextStepRunAgain):
pass
except AgentsException as exc:
streamed_result.is_complete = True
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
exc.run_data = RunErrorDetails(
input=streamed_result.input,
new_items=streamed_result.new_items,
raw_responses=streamed_result.raw_responses,
last_agent=current_agent,
context_wrapper=context_wrapper,
input_guardrail_results=streamed_result.input_guardrail_results,
output_guardrail_results=streamed_result.output_guardrail_results,
)
raise
except Exception as e:
if current_span:
_error_tracing.attach_error_to_span(
Expand Down
12 changes: 12 additions & 0 deletions src/agents/util/_pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel

if TYPE_CHECKING:
from ..exceptions import RunErrorDetails
from ..result import RunResult, RunResultBase, RunResultStreaming


Expand Down Expand Up @@ -38,6 +39,17 @@ def pretty_print_result(result: "RunResult") -> str:
return output


def pretty_print_run_error_details(result: "RunErrorDetails") -> str:
output = "RunErrorDetails:"
output += f'\n- Last agent: Agent(name="{result.last_agent.name}", ...)'
output += f"\n- {len(result.new_items)} new item(s)"
output += f"\n- {len(result.raw_responses)} raw response(s)"
output += f"\n- {len(result.input_guardrail_results)} input guardrail result(s)"
output += "\n(See `RunErrorDetails` for more details)"

return output


def pretty_print_run_result_streaming(result: "RunResultStreaming") -> str:
output = "RunResultStreaming:"
output += f'\n- Current agent: Agent(name="{result.current_agent.name}", ...)'
Expand Down
44 changes: 44 additions & 0 deletions tests/test_run_error_details.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import json

import pytest

from agents import Agent, MaxTurnsExceeded, RunErrorDetails, Runner

from .fake_model import FakeModel
from .test_responses import get_function_tool, get_function_tool_call, get_text_message


@pytest.mark.asyncio
async def test_run_error_includes_data():
model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
])
with pytest.raises(MaxTurnsExceeded) as exc:
await Runner.run(agent, input="hello", max_turns=1)
data = exc.value.run_data
assert isinstance(data, RunErrorDetails)
assert data.last_agent == agent
assert len(data.raw_responses) == 1
assert len(data.new_items) > 0


@pytest.mark.asyncio
async def test_streamed_run_error_includes_data():
model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
])
result = Runner.run_streamed(agent, input="hello", max_turns=1)
with pytest.raises(MaxTurnsExceeded) as exc:
async for _ in result.stream_events():
pass
data = exc.value.run_data
assert isinstance(data, RunErrorDetails)
assert data.last_agent == agent
assert len(data.raw_responses) == 1
assert len(data.new_items) > 0
4 changes: 0 additions & 4 deletions tests/test_tracing_errors_streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ async def test_tool_call_error():
"children": [
{
"type": "agent",
"error": {
"message": "Error in agent run",
"data": {"error": "Invalid JSON input for tool foo: bad_json"},
},
"data": {
"name": "test_agent",
"handoffs": [],
Expand Down