Skip to content

Commit b3e2f9b

Browse files
committed
Add interactive optimization mode
1 parent b8831aa commit b3e2f9b

File tree

3 files changed

+56
-3
lines changed

3 files changed

+56
-3
lines changed

pytensor/configdefaults.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,22 @@ def add_compile_configvars():
515515
in_c_key=False,
516516
)
517517

518+
config.add(
519+
"optimizer_interactive",
520+
"If True, we interrupt after every optimization being applied and display how the graph changed",
521+
BoolParam(False),
522+
in_c_key=False,
523+
)
524+
525+
config.add(
526+
"optimizer_interactive_skip_rewrites",
527+
(
528+
"Do not interrupt after changes from optimizers with these names. Separate names with ',"
529+
),
530+
StrParam(""),
531+
in_c_key=False,
532+
)
533+
518534
config.add(
519535
"on_opt_error",
520536
(

pytensor/graph/features.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
import warnings
55
from collections import OrderedDict
6+
from difflib import Differ
67
from functools import partial
78
from io import StringIO
89

@@ -563,8 +564,19 @@ def replace_all_validate(
563564
):
564565
chk = fgraph.checkpoint()
565566

567+
interactive = config.optimizer_interactive
568+
566569
if verbose is None:
567-
verbose = config.optimizer_verbose
570+
verbose = config.optimizer_verbose or interactive
571+
572+
if interactive:
573+
differ = Differ()
574+
bef = pytensor.dprint(
575+
fgraph, file="str", print_type=True, id_type="", print_topo_order=False
576+
)
577+
skip_rewrites = config.optimizer_interactive_skip_rewrites.replace(
578+
" ", ""
579+
).split(",")
568580

569581
for r, new_r in replacements:
570582
try:
@@ -611,6 +623,22 @@ def replace_all_validate(
611623
print(
612624
f"rewriting: rewrite {reason} replaces {r} of {r.owner} with {new_r} of {new_r.owner}"
613625
)
626+
if interactive and str(reason) not in skip_rewrites:
627+
aft = pytensor.dprint(
628+
fgraph,
629+
file="str",
630+
print_type=True,
631+
id_type="",
632+
print_topo_order=False,
633+
)
634+
if bef != aft:
635+
diff = list(
636+
differ.compare(
637+
bef.splitlines(keepends=True), aft.splitlines(keepends=True)
638+
)
639+
)
640+
sys.stdout.writelines(diff)
641+
input("Press any key to continue")
614642

615643
# The return is needed by replace_all_validate_remove
616644
return chk

pytensor/printing.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def debugprint(
119119
print_destroy_map: bool = False,
120120
print_view_map: bool = False,
121121
print_fgraph_inputs: bool = False,
122+
print_topo_order: bool = True,
122123
) -> Union[str, TextIO]:
123124
r"""Print a graph as text.
124125
@@ -175,6 +176,8 @@ def debugprint(
175176
Whether to print the `view_map`\s of printed objects
176177
print_fgraph_inputs
177178
Print the inputs of `FunctionGraph`\s.
179+
print_topo_order
180+
Whether to print the toposort ordering of nodes
178181
179182
Returns
180183
-------
@@ -231,7 +234,10 @@ def debugprint(
231234
else:
232235
storage_maps.extend([None for item in obj.maker.fgraph.outputs])
233236
topo = obj.maker.fgraph.toposort()
234-
topo_orders.extend([topo for item in obj.maker.fgraph.outputs])
237+
if print_topo_order:
238+
topo_orders.extend([topo for item in obj.maker.fgraph.outputs])
239+
else:
240+
topo_orders.extend([None for item in obj.maker.fgraph.outputs])
235241
elif isinstance(obj, FunctionGraph):
236242
if print_fgraph_inputs:
237243
inputs_to_print.extend(obj.inputs)
@@ -241,7 +247,10 @@ def debugprint(
241247
[getattr(obj, "storage_map", None) for item in obj.outputs]
242248
)
243249
topo = obj.toposort()
244-
topo_orders.extend([topo for item in obj.outputs])
250+
if print_topo_order:
251+
topo_orders.extend([topo for item in obj.outputs])
252+
else:
253+
topo_orders.extend([None for item in obj.outputs])
245254
elif isinstance(obj, (int, float, np.ndarray)):
246255
print(obj, file=_file)
247256
elif isinstance(obj, (In, Out)):

0 commit comments

Comments
 (0)