Skip to content

Attach partial run data on errors #747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .computer import AsyncComputer, Button, Computer, Environment
from .exceptions import (
AgentsException,
ErrorRunData,
InputGuardrailTripwireTriggered,
MaxTurnsExceeded,
ModelBehaviorError,
Expand Down Expand Up @@ -173,6 +174,7 @@ def enable_verbose_stdout_logging():
"Environment",
"Button",
"AgentsException",
"ErrorRunData",
"InputGuardrailTripwireTriggered",
"OutputGuardrailTripwireTriggered",
"MaxTurnsExceeded",
Expand Down
28 changes: 27 additions & 1 deletion src/agents/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
from typing import TYPE_CHECKING
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from .agent import Agent
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ModelResponse, RunItem, TResponseInputItem
from .run_context import RunContextWrapper


@dataclass
class ErrorRunData:
"""Data collected from an agent run when an exception occurs."""

input: str | list[TResponseInputItem]
new_items: list[RunItem]
raw_responses: list[ModelResponse]
last_agent: "Agent[Any]"
context_wrapper: "RunContextWrapper[Any]"
input_guardrail_results: list[InputGuardrailResult]
output_guardrail_results: list[OutputGuardrailResult]


class AgentsException(Exception):
"""Base class for all exceptions in the Agents SDK."""

run_data: ErrorRunData | None

def __init__(self, *args: object) -> None:
super().__init__(*args)
self.run_data = None


class MaxTurnsExceeded(AgentsException):
"""Exception raised when the maximum number of turns is exceeded."""
Expand All @@ -15,6 +38,7 @@ class MaxTurnsExceeded(AgentsException):

def __init__(self, message: str):
self.message = message
super().__init__(message)


class ModelBehaviorError(AgentsException):
Expand All @@ -26,6 +50,7 @@ class ModelBehaviorError(AgentsException):

def __init__(self, message: str):
self.message = message
super().__init__(message)


class UserError(AgentsException):
Expand All @@ -35,6 +60,7 @@ class UserError(AgentsException):

def __init__(self, message: str):
self.message = message
super().__init__(message)


class InputGuardrailTripwireTriggered(AgentsException):
Expand Down
61 changes: 58 additions & 3 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from ._run_impl import QueueCompleteSentinel
from .agent import Agent
from .agent_output import AgentOutputSchemaBase
from .exceptions import InputGuardrailTripwireTriggered, MaxTurnsExceeded
from .exceptions import (
AgentsException,
ErrorRunData,
InputGuardrailTripwireTriggered,
MaxTurnsExceeded,
)
from .guardrail import InputGuardrailResult, OutputGuardrailResult
from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
from .logger import logger
Expand Down Expand Up @@ -208,28 +213,78 @@ async def stream_events(self) -> AsyncIterator[StreamEvent]:

def _check_errors(self):
if self.current_turn > self.max_turns:
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
exc = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
exc.run_data = ErrorRunData(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = exc

# Fetch all the completed guardrail results from the queue and raise if needed
while not self._input_guardrail_queue.empty():
guardrail_result = self._input_guardrail_queue.get_nowait()
if guardrail_result.output.tripwire_triggered:
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
exc = InputGuardrailTripwireTriggered(guardrail_result)
exc.run_data = ErrorRunData(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = exc

# Check the tasks for any exceptions
if self._run_impl_task and self._run_impl_task.done():
exc = self._run_impl_task.exception()
if exc and isinstance(exc, Exception):
if isinstance(exc, AgentsException) and exc.run_data is None:
exc.run_data = ErrorRunData(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = exc

if self._input_guardrails_task and self._input_guardrails_task.done():
exc = self._input_guardrails_task.exception()
if exc and isinstance(exc, Exception):
if isinstance(exc, AgentsException) and exc.run_data is None:
exc.run_data = ErrorRunData(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = exc

if self._output_guardrails_task and self._output_guardrails_task.done():
exc = self._output_guardrails_task.exception()
if exc and isinstance(exc, Exception):
if isinstance(exc, AgentsException) and exc.run_data is None:
exc.run_data = ErrorRunData(
input=self.input,
new_items=self.new_items,
raw_responses=self.raw_responses,
last_agent=self.current_agent,
context_wrapper=self.context_wrapper,
input_guardrail_results=self.input_guardrail_results,
output_guardrail_results=self.output_guardrail_results,
)
self._stored_exception = exc

def _cleanup_tasks(self):
Expand Down
25 changes: 25 additions & 0 deletions src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .agent_output import AgentOutputSchema, AgentOutputSchemaBase
from .exceptions import (
AgentsException,
ErrorRunData,
InputGuardrailTripwireTriggered,
MaxTurnsExceeded,
ModelBehaviorError,
Expand Down Expand Up @@ -283,6 +284,17 @@ async def run(
raise AgentsException(
f"Unknown next step type: {type(turn_result.next_step)}"
)
except AgentsException as exc:
exc.run_data = ErrorRunData(
input=original_input,
new_items=generated_items,
raw_responses=model_responses,
last_agent=current_agent,
context_wrapper=context_wrapper,
input_guardrail_results=input_guardrail_results,
output_guardrail_results=[],
)
raise
finally:
if current_span:
current_span.finish(reset_current=True)
Expand Down Expand Up @@ -609,6 +621,19 @@ async def _run_streamed_impl(
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
elif isinstance(turn_result.next_step, NextStepRunAgain):
pass
except AgentsException as exc:
streamed_result.is_complete = True
streamed_result._event_queue.put_nowait(QueueCompleteSentinel())
exc.run_data = ErrorRunData(
input=streamed_result.input,
new_items=streamed_result.new_items,
raw_responses=streamed_result.raw_responses,
last_agent=current_agent,
context_wrapper=context_wrapper,
input_guardrail_results=streamed_result.input_guardrail_results,
output_guardrail_results=streamed_result.output_guardrail_results,
)
raise
except Exception as e:
if current_span:
_error_tracing.attach_error_to_span(
Expand Down
42 changes: 42 additions & 0 deletions tests/test_error_run_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import json
import pytest

from agents import Agent, Runner, MaxTurnsExceeded, ErrorRunData
from .fake_model import FakeModel
from .test_responses import get_text_message, get_function_tool, get_function_tool_call


@pytest.mark.asyncio
async def test_run_error_includes_data():
model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
])
with pytest.raises(MaxTurnsExceeded) as exc:
await Runner.run(agent, input="hello", max_turns=1)
data = exc.value.run_data
assert isinstance(data, ErrorRunData)
assert data.last_agent == agent
assert len(data.raw_responses) == 1
assert len(data.new_items) > 0


@pytest.mark.asyncio
async def test_streamed_run_error_includes_data():
model = FakeModel()
agent = Agent(name="test", model=model, tools=[get_function_tool("foo", "res")])
model.add_multiple_turn_outputs([
[get_text_message("1"), get_function_tool_call("foo", json.dumps({"a": "b"}))],
[get_text_message("done")],
])
result = Runner.run_streamed(agent, input="hello", max_turns=1)
with pytest.raises(MaxTurnsExceeded) as exc:
async for _ in result.stream_events():
pass
data = exc.value.run_data
assert isinstance(data, ErrorRunData)
assert data.last_agent == agent
assert len(data.raw_responses) == 1
assert len(data.new_items) > 0
Loading