1
- from typing import Optional
1
+ from __future__ import annotations
2
2
3
3
import graphviz # type: ignore
4
4
@@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str:
31
31
return "" .join (parts )
32
32
33
33
34
- def get_all_nodes (agent : Agent , parent : Optional [Agent ] = None ) -> str :
34
+ def get_all_nodes (
35
+ agent : Agent , parent : Agent | None = None , visited : set [str ] | None = None
36
+ ) -> str :
35
37
"""
36
38
Recursively generates the nodes for the given agent and its handoffs in DOT format.
37
39
@@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
41
43
Returns:
42
44
str: The DOT format string representing the nodes.
43
45
"""
46
+ if visited is None :
47
+ visited = set ()
48
+ if agent .name in visited :
49
+ return ""
50
+ visited .add (agent .name )
51
+
44
52
parts = []
45
53
46
54
# Start and end the graph
47
- parts .append (
48
- '"__start__" [label="__start__", shape=ellipse, style=filled, '
49
- "fillcolor=lightblue, width=0.5, height=0.3];"
50
- '"__end__" [label="__end__", shape=ellipse, style=filled, '
51
- "fillcolor=lightblue, width=0.5, height=0.3];"
52
- )
53
- # Ensure parent agent node is colored
54
55
if not parent :
56
+ parts .append (
57
+ '"__start__" [label="__start__", shape=ellipse, style=filled, '
58
+ "fillcolor=lightblue, width=0.5, height=0.3];"
59
+ '"__end__" [label="__end__", shape=ellipse, style=filled, '
60
+ "fillcolor=lightblue, width=0.5, height=0.3];"
61
+ )
62
+ # Ensure parent agent node is colored
55
63
parts .append (
56
64
f'"{ agent .name } " [label="{ agent .name } ", shape=box, style=filled, '
57
65
"fillcolor=lightyellow, width=1.5, height=0.8];"
@@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
71
79
f"fillcolor=lightyellow, width=1.5, height=0.8];"
72
80
)
73
81
if isinstance (handoff , Agent ):
74
- parts .append (
75
- f'"{ handoff .name } " [label="{ handoff .name } ", '
76
- f"shape=box, style=filled, style=rounded, "
77
- f"fillcolor=lightyellow, width=1.5, height=0.8];"
78
- )
79
- parts .append (get_all_nodes (handoff ))
82
+ if handoff .name not in visited :
83
+ parts .append (
84
+ f'"{ handoff .name } " [label="{ handoff .name } ", '
85
+ f"shape=box, style=filled, style=rounded, "
86
+ f"fillcolor=lightyellow, width=1.5, height=0.8];"
87
+ )
88
+ parts .append (get_all_nodes (handoff , agent , visited ))
80
89
81
90
return "" .join (parts )
82
91
83
92
84
- def get_all_edges (agent : Agent , parent : Optional [Agent ] = None ) -> str :
93
+ def get_all_edges (
94
+ agent : Agent , parent : Agent | None = None , visited : set [str ] | None = None
95
+ ) -> str :
85
96
"""
86
97
Recursively generates the edges for the given agent and its handoffs in DOT format.
87
98
@@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
92
103
Returns:
93
104
str: The DOT format string representing the edges.
94
105
"""
106
+ if visited is None :
107
+ visited = set ()
108
+ if agent .name in visited :
109
+ return ""
110
+ visited .add (agent .name )
111
+
95
112
parts = []
96
113
97
114
if not parent :
@@ -109,15 +126,15 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
109
126
if isinstance (handoff , Agent ):
110
127
parts .append (f"""
111
128
"{ agent .name } " -> "{ handoff .name } ";""" )
112
- parts .append (get_all_edges (handoff , agent ))
129
+ parts .append (get_all_edges (handoff , agent , visited ))
113
130
114
131
if not agent .handoffs and not isinstance (agent , Tool ): # type: ignore
115
132
parts .append (f'"{ agent .name } " -> "__end__";' )
116
133
117
134
return "" .join (parts )
118
135
119
136
120
- def draw_graph (agent : Agent , filename : Optional [ str ] = None ) -> graphviz .Source :
137
+ def draw_graph (agent : Agent , filename : str | None = None ) -> graphviz .Source :
121
138
"""
122
139
Draws the graph for the given agent and optionally saves it as a PNG file.
123
140
0 commit comments