Skip to content

Commit cca20eb

Browse files
committed
QoL improvements to InteractiveRewrite widget
1 parent 3876e73 commit cca20eb

File tree

1 file changed

+28
-15
lines changed

1 file changed

+28
-15
lines changed

pytensor/ipython.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from IPython.display import display
55

66
from pytensor.graph import FunctionGraph, Variable, rewrite_graph
7-
from pytensor.graph.features import FullHistory
7+
from pytensor.graph.features import AlreadyThere, FullHistory
88

99

1010
class CodeBlockWidget(anywidget.AnyWidget):
@@ -45,29 +45,41 @@ class CodeBlockWidget(anywidget.AnyWidget):
4545

4646
class InteractiveRewrite:
4747
"""
48-
A class that wraps a graph history object with interactive widgets
49-
to navigate through history and display the graph at each step.
50-
51-
Includes an option to display the reason for the last change.
48+
Visualize a graph history through a series of rewrites.
5249
"""
5350

54-
def __init__(self, fg, display_reason=True):
51+
def __init__(
52+
self,
53+
fg,
54+
display_reason=True,
55+
rewrite_options: dict | None = None,
56+
dprint_options: dict | None = None,
57+
):
5558
"""
56-
Initialize with a history object that has a goto method
57-
and tracks a FunctionGraph.
58-
5959
Parameters:
6060
-----------
6161
fg : FunctionGraph (or Variables)
6262
The function graph to track
6363
display_reason : bool, optional
6464
Whether to display the reason for each rewrite
65+
rewrite_options : dict, optional
66+
Options for rewriting the graph. Defaults to {'include': ('fast_run',), 'exclude': ('inplace',)}
67+
print_options : dict, optional
68+
Print options passed to `debugprint` used to generate the text representation of the graph.
69+
Useful options are {'print_shape': True, 'print_op_info': True}
6570
"""
71+
self.dprint_options = dprint_options or {}
72+
self.rewrite_options = rewrite_options or dict(
73+
include=("fast_run",), exclude=("inplace",)
74+
)
6675
self.history = FullHistory(callback=self._history_callback)
6776
if not isinstance(fg, FunctionGraph):
6877
outs = [fg] if isinstance(fg, Variable) else fg
6978
fg = FunctionGraph(outputs=outs)
70-
fg.attach_feature(self.history)
79+
try:
80+
fg.attach_feature(self.history)
81+
except AlreadyThere:
82+
self.history.end()
7183

7284
self.updating_from_callback = False # Flag to prevent recursion
7385
self.code_widget = CodeBlockWidget(content="")
@@ -163,7 +175,7 @@ def _update_display(self):
163175
reason = ""
164176
else:
165177
reason = self.history.fw[self.history.pointer].reason
166-
reason = getattr(reason, "name", str(reason))
178+
reason = getattr(reason, "name", None) or str(reason)
167179

168180
self.reason_label.value = f"""
169181
<div style='padding: 5px; margin-bottom: 10px; background-color: #e6f7ff; border-left: 4px solid #1890ff;'>
@@ -172,7 +184,9 @@ def _update_display(self):
172184
"""
173185

174186
# Update the graph display
175-
self.code_widget.content = self.history.fg.dprint(file="str")
187+
self.code_widget.content = self.history.fg.dprint(
188+
file="str", **self.dprint_options
189+
)
176190

177191
# Update slider range if history length has changed
178192
history_len = len(self.history.fw) + 1
@@ -189,14 +203,13 @@ def _update_display(self):
189203
f"History: {self.history.pointer + 1}/{history_len - 1}"
190204
)
191205

192-
def rewrite(self, *args, include=("fast_run",), exclude=("inplace",), **kwargs):
206+
def rewrite(self, *args, **kwargs):
193207
"""Apply rewrites to the current graph"""
194208
rewrite_graph(
195209
self.history.fg,
196210
*args,
197-
include=include,
198-
exclude=exclude,
199211
**kwargs,
212+
**self.rewrite_options,
200213
clone=False,
201214
)
202215
self._update_display()

0 commit comments

Comments
 (0)