Skip to content

Commit db462e3

Browse files
authored
Fix visualization recursion with cycle detection (#737)
## Summary - avoid infinite recursion in visualization by tracking visited agents - test cycle detection in graph utility ## Testing - `make mypy` - `make tests` Resolves #668
1 parent 1364f44 commit db462e3

File tree

2 files changed

+50
-18
lines changed

2 files changed

+50
-18
lines changed

src/agents/extensions/visualization.py

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from __future__ import annotations
22

33
import graphviz # type: ignore
44

@@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str:
3131
return "".join(parts)
3232

3333

34-
def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
34+
def get_all_nodes(
35+
agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
36+
) -> str:
3537
"""
3638
Recursively generates the nodes for the given agent and its handoffs in DOT format.
3739
@@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
4143
Returns:
4244
str: The DOT format string representing the nodes.
4345
"""
46+
if visited is None:
47+
visited = set()
48+
if agent.name in visited:
49+
return ""
50+
visited.add(agent.name)
51+
4452
parts = []
4553

4654
# Start and end the graph
47-
parts.append(
48-
'"__start__" [label="__start__", shape=ellipse, style=filled, '
49-
"fillcolor=lightblue, width=0.5, height=0.3];"
50-
'"__end__" [label="__end__", shape=ellipse, style=filled, '
51-
"fillcolor=lightblue, width=0.5, height=0.3];"
52-
)
53-
# Ensure parent agent node is colored
5455
if not parent:
56+
parts.append(
57+
'"__start__" [label="__start__", shape=ellipse, style=filled, '
58+
"fillcolor=lightblue, width=0.5, height=0.3];"
59+
'"__end__" [label="__end__", shape=ellipse, style=filled, '
60+
"fillcolor=lightblue, width=0.5, height=0.3];"
61+
)
62+
# Ensure parent agent node is colored
5563
parts.append(
5664
f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
5765
"fillcolor=lightyellow, width=1.5, height=0.8];"
@@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
7179
f"fillcolor=lightyellow, width=1.5, height=0.8];"
7280
)
7381
if isinstance(handoff, Agent):
74-
parts.append(
75-
f'"{handoff.name}" [label="{handoff.name}", '
76-
f"shape=box, style=filled, style=rounded, "
77-
f"fillcolor=lightyellow, width=1.5, height=0.8];"
78-
)
79-
parts.append(get_all_nodes(handoff))
82+
if handoff.name not in visited:
83+
parts.append(
84+
f'"{handoff.name}" [label="{handoff.name}", '
85+
f"shape=box, style=filled, style=rounded, "
86+
f"fillcolor=lightyellow, width=1.5, height=0.8];"
87+
)
88+
parts.append(get_all_nodes(handoff, agent, visited))
8089

8190
return "".join(parts)
8291

8392

84-
def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
93+
def get_all_edges(
94+
agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
95+
) -> str:
8596
"""
8697
Recursively generates the edges for the given agent and its handoffs in DOT format.
8798
@@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
92103
Returns:
93104
str: The DOT format string representing the edges.
94105
"""
106+
if visited is None:
107+
visited = set()
108+
if agent.name in visited:
109+
return ""
110+
visited.add(agent.name)
111+
95112
parts = []
96113

97114
if not parent:
@@ -109,15 +126,15 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
109126
if isinstance(handoff, Agent):
110127
parts.append(f"""
111128
"{agent.name}" -> "{handoff.name}";""")
112-
parts.append(get_all_edges(handoff, agent))
129+
parts.append(get_all_edges(handoff, agent, visited))
113130

114131
if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
115132
parts.append(f'"{agent.name}" -> "__end__";')
116133

117134
return "".join(parts)
118135

119136

120-
def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
137+
def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source:
121138
"""
122139
Draws the graph for the given agent and optionally saves it as a PNG file.
123140

tests/test_visualization.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,18 @@ def test_draw_graph(mock_agent):
134134
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
135135
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
136136
)
137+
138+
139+
def test_cycle_detection():
140+
agent_a = Agent(name="A")
141+
agent_b = Agent(name="B")
142+
agent_a.handoffs.append(agent_b)
143+
agent_b.handoffs.append(agent_a)
144+
145+
nodes = get_all_nodes(agent_a)
146+
edges = get_all_edges(agent_a)
147+
148+
assert nodes.count('"A" [label="A"') == 1
149+
assert nodes.count('"B" [label="B"') == 1
150+
assert '"A" -> "B"' in edges
151+
assert '"B" -> "A"' in edges

0 commit comments

Comments
 (0)