From 51cee3c6bd6ea598279f7f362333caf8a7516937 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 21 May 2025 18:32:36 -0400 Subject: [PATCH 1/3] Add cycle detection to visualization --- src/agents/extensions/visualization.py | 49 +++++++++++++++++--------- tests/test_visualization.py | 15 ++++++++ 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 888e262c..6db076d8 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str: return "".join(parts) -def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_nodes( + agent: Agent, parent: Optional[Agent] = None, visited: set[str] | None = None +) -> str: """ Recursively generates the nodes for the given agent and its handoffs in DOT format. @@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: Returns: str: The DOT format string representing the nodes. """ + if visited is None: + visited = set() + if agent.name in visited: + return "" + visited.add(agent.name) + parts = [] # Start and end the graph - parts.append( - '"__start__" [label="__start__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" - '"__end__" [label="__end__", shape=ellipse, style=filled, ' - "fillcolor=lightblue, width=0.5, height=0.3];" - ) - # Ensure parent agent node is colored if not parent: + parts.append( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" + ) + # Ensure parent agent node is colored parts.append( f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, ' "fillcolor=lightyellow, width=1.5, height=0.8];" @@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: f"fillcolor=lightyellow, width=1.5, height=0.8];" ) if isinstance(handoff, Agent): - parts.append( - f'"{handoff.name}" [label="{handoff.name}", ' - f"shape=box, style=filled, style=rounded, " - f"fillcolor=lightyellow, width=1.5, height=0.8];" - ) - parts.append(get_all_nodes(handoff)) + if handoff.name not in visited: + parts.append( + f'"{handoff.name}" [label="{handoff.name}", ' + f"shape=box, style=filled, style=rounded, " + f"fillcolor=lightyellow, width=1.5, height=0.8];" + ) + parts.append(get_all_nodes(handoff, agent, visited)) return "".join(parts) -def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: +def get_all_edges( + agent: Agent, parent: Optional[Agent] = None, visited: set[str] | None = None +) -> str: """ Recursively generates the edges for the given agent and its handoffs in DOT format. @@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: Returns: str: The DOT format string representing the edges. """ + if visited is None: + visited = set() + if agent.name in visited: + return "" + visited.add(agent.name) + parts = [] if not parent: @@ -109,7 +126,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: if isinstance(handoff, Agent): parts.append(f""" "{agent.name}" -> "{handoff.name}";""") - parts.append(get_all_edges(handoff, agent)) + parts.append(get_all_edges(handoff, agent, visited)) if not agent.handoffs and not isinstance(agent, Tool): # type: ignore parts.append(f'"{agent.name}" -> "__end__";') diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 6aa86774..8bce897e 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -134,3 +134,18 @@ def test_draw_graph(mock_agent): '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source ) + + +def test_cycle_detection(): + agent_a = Agent(name="A") + agent_b = Agent(name="B") + agent_a.handoffs.append(agent_b) + agent_b.handoffs.append(agent_a) + + nodes = get_all_nodes(agent_a) + edges = get_all_edges(agent_a) + + assert nodes.count('"A" [label="A"') == 1 + assert nodes.count('"B" [label="B"') == 1 + assert '"A" -> "B"' in edges + assert '"B" -> "A"' in edges From 8a991434154946b06f68d5fd66ae664d000f91e5 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 21 May 2025 19:11:19 -0400 Subject: [PATCH 2/3] Update visualization.py --- src/agents/extensions/visualization.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 6db076d8..b6a93803 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Optional import graphviz # type: ignore From b3cd7263509ca9068d9fd7fb93269e9ccb642b4d Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Wed, 21 May 2025 19:17:07 -0400 Subject: [PATCH 3/3] Update visualization.py --- src/agents/extensions/visualization.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index b6a93803..be762a33 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -1,5 +1,4 @@ from __future__ import annotations -from typing import Optional import graphviz # type: ignore @@ -33,7 +32,7 @@ def get_main_graph(agent: Agent) -> str: def get_all_nodes( - agent: Agent, parent: Optional[Agent] = None, visited: set[str] | None = None + agent: Agent, parent: Agent | None = None, visited: set[str] | None = None ) -> str: """ Recursively generates the nodes for the given agent and its handoffs in DOT format. @@ -92,7 +91,7 @@ def get_all_nodes( def get_all_edges( - agent: Agent, parent: Optional[Agent] = None, visited: set[str] | None = None + agent: Agent, parent: Agent | None = None, visited: set[str] | None = None ) -> str: """ Recursively generates the edges for the given agent and its handoffs in DOT format. @@ -135,7 +134,7 @@ def get_all_edges( return "".join(parts) -def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source: +def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source: """ Draws the graph for the given agent and optionally saves it as a PNG file.