4
4
from IPython .display import display
5
5
6
6
from pytensor .graph import FunctionGraph , Variable , rewrite_graph
7
- from pytensor .graph .features import FullHistory
7
+ from pytensor .graph .features import AlreadyThere , FullHistory
8
8
9
9
10
10
class CodeBlockWidget (anywidget .AnyWidget ):
@@ -45,29 +45,41 @@ class CodeBlockWidget(anywidget.AnyWidget):
45
45
46
46
class InteractiveRewrite :
47
47
"""
48
- A class that wraps a graph history object with interactive widgets
49
- to navigate through history and display the graph at each step.
50
-
51
- Includes an option to display the reason for the last change.
48
+ Visualize a graph history through a series of rewrites.
52
49
"""
53
50
54
- def __init__ (self , fg , display_reason = True ):
51
+ def __init__ (
52
+ self ,
53
+ fg ,
54
+ display_reason = True ,
55
+ rewrite_options : dict | None = None ,
56
+ dprint_options : dict | None = None ,
57
+ ):
55
58
"""
56
- Initialize with a history object that has a goto method
57
- and tracks a FunctionGraph.
58
-
59
59
Parameters:
60
60
-----------
61
61
fg : FunctionGraph (or Variables)
62
62
The function graph to track
63
63
display_reason : bool, optional
64
64
Whether to display the reason for each rewrite
65
+ rewrite_options : dict, optional
66
+ Options for rewriting the graph. Defaults to {'include': ('fast_run',), 'exclude': ('inplace',)}
67
+ print_options : dict, optional
68
+ Print options passed to `debugprint` used to generate the text representation of the graph.
69
+ Useful options are {'print_shape': True, 'print_op_info': True}
65
70
"""
71
+ self .dprint_options = dprint_options or {}
72
+ self .rewrite_options = rewrite_options or dict (
73
+ include = ("fast_run" ,), exclude = ("inplace" ,)
74
+ )
66
75
self .history = FullHistory (callback = self ._history_callback )
67
76
if not isinstance (fg , FunctionGraph ):
68
77
outs = [fg ] if isinstance (fg , Variable ) else fg
69
78
fg = FunctionGraph (outputs = outs )
70
- fg .attach_feature (self .history )
79
+ try :
80
+ fg .attach_feature (self .history )
81
+ except AlreadyThere :
82
+ self .history .end ()
71
83
72
84
self .updating_from_callback = False # Flag to prevent recursion
73
85
self .code_widget = CodeBlockWidget (content = "" )
@@ -163,7 +175,7 @@ def _update_display(self):
163
175
reason = ""
164
176
else :
165
177
reason = self .history .fw [self .history .pointer ].reason
166
- reason = getattr (reason , "name" , str (reason ) )
178
+ reason = getattr (reason , "name" , None ) or str (reason )
167
179
168
180
self .reason_label .value = f"""
169
181
<div style='padding: 5px; margin-bottom: 10px; background-color: #e6f7ff; border-left: 4px solid #1890ff;'>
@@ -172,7 +184,9 @@ def _update_display(self):
172
184
"""
173
185
174
186
# Update the graph display
175
- self .code_widget .content = self .history .fg .dprint (file = "str" )
187
+ self .code_widget .content = self .history .fg .dprint (
188
+ file = "str" , ** self .dprint_options
189
+ )
176
190
177
191
# Update slider range if history length has changed
178
192
history_len = len (self .history .fw ) + 1
@@ -189,14 +203,13 @@ def _update_display(self):
189
203
f"History: { self .history .pointer + 1 } /{ history_len - 1 } "
190
204
)
191
205
192
- def rewrite (self , * args , include = ( "fast_run" ,), exclude = ( "inplace" ,), ** kwargs ):
206
+ def rewrite (self , * args , ** kwargs ):
193
207
"""Apply rewrites to the current graph"""
194
208
rewrite_graph (
195
209
self .history .fg ,
196
210
* args ,
197
- include = include ,
198
- exclude = exclude ,
199
211
** kwargs ,
212
+ ** self .rewrite_options ,
200
213
clone = False ,
201
214
)
202
215
self ._update_display ()
0 commit comments