Skip to content

Commit e3323b2

Browse files
committed
New graph visualization tools
1 parent 8b1d326 commit e3323b2

File tree

1 file changed

+142
-0
lines changed

1 file changed

+142
-0
lines changed

data_prototype/introspection.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
import graphlib
5+
from pprint import pformat
6+
7+
import matplotlib.pyplot as plt
8+
9+
from .conversion_edge import Edge, Graph
10+
from .description import Desc
11+
12+
13+
@dataclass
14+
class VisNode:
15+
keys: list[str]
16+
coordinates: list[str]
17+
parents: list[VisNode] = field(default_factory=list)
18+
children: list[VisNode] = field(default_factory=list)
19+
x: int = 0
20+
y: int = 0
21+
22+
def __eq__(self, other):
23+
return self.keys == other.keys and self.coordinates == other.coordinates
24+
25+
def format(self):
26+
return pformat({k: v for k, v in zip(self.keys, self.coordinates)}, width=20)
27+
28+
29+
@dataclass
30+
class VisEdge:
31+
name: str
32+
parent: VisNode
33+
child: VisNode
34+
35+
36+
def _position_subgraph(
37+
subgraph: tuple(set[str], list[Edge])
38+
) -> tuple[list[VisNode], list[VisEdge]]:
39+
# Build graph
40+
nodes: list[VisNode] = []
41+
edges: list[VisEdge] = []
42+
43+
q: list[dict[str, Desc]] = [e.input for e in subgraph[1]]
44+
explored: set[tuple[tuple[str, str], ...]] = set()
45+
explored.add(tuple(sorted(((k, v.coordinates) for k, v in q[0].items()))))
46+
47+
for e in subgraph[1]:
48+
nodes.append(
49+
VisNode(list(e.input.keys()), [x.coordinates for x in e.input.values()])
50+
)
51+
52+
while q:
53+
n = q.pop()
54+
vn = VisNode(list(n.keys()), [x.coordinates for x in n.values()])
55+
for nn in nodes:
56+
if vn == nn:
57+
vn = nn
58+
59+
for e in subgraph[1]:
60+
# Shortcut default edges appearing all over the place
61+
if e.input == {} and vn.keys != []:
62+
continue
63+
if Desc.compatible(n, e.input):
64+
w = e.output
65+
vw = VisNode(list(w.keys()), [x.coordinates for x in w.values()])
66+
for nn in nodes:
67+
if vw == nn:
68+
vw = nn
69+
70+
if vw not in nodes:
71+
nodes.append(vw)
72+
explored.add(
73+
tuple(sorted(((k, v.coordinates) for k, v in w.items())))
74+
)
75+
q.append(w)
76+
if vw != vn:
77+
edges.append(VisEdge(e.name, vn, vw))
78+
vw.parents.append(vn)
79+
vn.children.append(vw)
80+
81+
# adapt graph for total ording
82+
def hash_node(n):
83+
return (tuple(n.keys), tuple(n.coordinates))
84+
85+
to_graph = {hash_node(n): set() for n in nodes}
86+
for e in edges:
87+
to_graph[hash_node(e.child)] |= {hash_node(e.parent)}
88+
89+
# evaluate total ordering
90+
topological_sorter = graphlib.TopologicalSorter(to_graph)
91+
92+
# position horizontally by 1+ highest parent, vertically by 1+ highest sibling
93+
def get_node(n):
94+
for node in nodes:
95+
if n[0] == tuple(node.keys) and n[1] == tuple(node.coordinates):
96+
return node
97+
98+
static_order = list(topological_sorter.static_order())
99+
100+
for n in static_order:
101+
node = get_node(n)
102+
if node.parents != []:
103+
node.y = max(p.y for p in node.parents) + 1
104+
x_pos = {}
105+
for n in static_order:
106+
node = get_node(n)
107+
if node.y in x_pos:
108+
node.x = x_pos[node.y]
109+
x_pos[node.y] += 1.25
110+
else:
111+
x_pos[node.y] = 1.25
112+
113+
return nodes, edges
114+
115+
116+
def draw_graph(graph: Graph, ax=None):
117+
if ax is None:
118+
fig, ax = plt.subplots()
119+
120+
origin_y = 0
121+
122+
for sg in graph._subgraphs:
123+
nodes, edges = _position_subgraph(sg)
124+
# Draw nodes
125+
for node in nodes:
126+
ax.annotate(
127+
node.format(), (node.x, node.y + origin_y), bbox={"boxstyle": "round"}
128+
)
129+
130+
# Draw edges
131+
for edge in edges:
132+
ax.annotate(
133+
"",
134+
(edge.child.x, edge.child.y + origin_y),
135+
(edge.parent.x, edge.parent.y + origin_y),
136+
arrowprops={"arrowstyle": "->"},
137+
)
138+
mid_x = (edge.child.x + edge.parent.x) / 2
139+
mid_y = (edge.child.y + edge.parent.y) / 2
140+
ax.text(mid_x, mid_y + origin_y, edge.name)
141+
142+
origin_y += max(node.y for node in nodes) + 1

0 commit comments

Comments
 (0)