diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 888e262c..be762a33 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -1,4 +1,4 @@ -from typing import Optional +from __future__ import annotations import graphviz # type: ignore @@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str: return "".join(parts) -def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_nodes( + agent: Agent, parent: Agent | None = None, visited: set[str] | None = None +) -> str: """ Recursively generates the nodes for the given agent and its handoffs in DOT format. @@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: Returns: str: The DOT format string representing the nodes. """ + if visited is None: + visited = set() + if agent.name in visited: + return "" + visited.add(agent.name) + parts = [] # Start and end the graph - parts.append( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" - ) - # Ensure parent agent node is colored if not parent: + parts.append( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" + ) + # Ensure parent agent node is colored parts.append( f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, ' "fillcolor=lightyellow, width=1.5, height=0.8];" @@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: f"fillcolor=lightyellow, width=1.5, height=0.8];" ) if isinstance(handoff, Agent): - parts.append( - f'"{handoff.name}" [label="{handoff.name}", ' - f"shape=box, style=filled, style=rounded, " - f"fillcolor=lightyellow, width=1.5, height=0.8];" - ) - parts.append(get_all_nodes(handoff)) + if handoff.name not in visited: + parts.append( + f'"{handoff.name}" [label="{handoff.name}", ' + f"shape=box, style=filled, style=rounded, " + f"fillcolor=lightyellow, width=1.5, height=0.8];" + ) + parts.append(get_all_nodes(handoff, agent, visited)) return "".join(parts) -def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_edges( + agent: Agent, parent: Agent | None = None, visited: set[str] | None = None +) -> str: """ Recursively generates the edges for the given agent and its handoffs in DOT format. @@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: Returns: str: The DOT format string representing the edges. """ + if visited is None: + visited = set() + if agent.name in visited: + return "" + visited.add(agent.name) + parts = [] if not parent: @@ -109,7 +126,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: if isinstance(handoff, Agent): parts.append(f""" "{agent.name}" -> "{handoff.name}";""") - parts.append(get_all_edges(handoff, agent)) + parts.append(get_all_edges(handoff, agent, visited)) if not agent.handoffs and not isinstance(agent, Tool): # type: ignore parts.append(f'"{agent.name}" -> "__end__";') @@ -117,7 +134,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: return "".join(parts) -def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source: +def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source: """ Draws the graph for the given agent and optionally saves it as a PNG file. diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 6aa86774..8bce897e 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -134,3 +134,18 @@ def test_draw_graph(mock_agent): '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source ) + + +def test_cycle_detection(): + agent_a = Agent(name="A") + agent_b = Agent(name="B") + agent_a.handoffs.append(agent_b) + agent_b.handoffs.append(agent_a) + + nodes = get_all_nodes(agent_a) + edges = get_all_edges(agent_a) + + assert nodes.count('"A" [label="A"') == 1 + assert nodes.count('"B" [label="B"') == 1 + assert '"A" -> "B"' in edges + assert '"B" -> "A"' in edges