diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index c20d4bdcff..1347056bac 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -515,6 +515,22 @@ def add_compile_configvars(): in_c_key=False, ) + config.add( + "optimizer_interactive", + "If True, we interrupt after every optimization being applied and display how the graph changed", + BoolParam(False), + in_c_key=False, + ) + + config.add( + "optimizer_interactive_skip_rewrites", + ( + "Do not interrupt after changes from optimizers with these names. Separate names with '," + ), + StrParam(""), + in_c_key=False, + ) + config.add( "on_opt_error", ( diff --git a/pytensor/graph/features.py b/pytensor/graph/features.py index 013e127aaf..ce2fb791f6 100644 --- a/pytensor/graph/features.py +++ b/pytensor/graph/features.py @@ -3,6 +3,7 @@ import time import warnings from collections import OrderedDict +from difflib import Differ from functools import partial from io import StringIO @@ -563,8 +564,19 @@ def replace_all_validate( ): chk = fgraph.checkpoint() + interactive = config.optimizer_interactive + if verbose is None: - verbose = config.optimizer_verbose + verbose = config.optimizer_verbose or interactive + + if interactive: + differ = Differ() + bef = pytensor.dprint( + fgraph, file="str", print_type=True, id_type="", print_topo_order=False + ) + skip_rewrites = config.optimizer_interactive_skip_rewrites.replace( + " ", "" + ).split(",") for r, new_r in replacements: try: @@ -611,6 +623,22 @@ def replace_all_validate( print( f"rewriting: rewrite {reason} replaces {r} of {r.owner} with {new_r} of {new_r.owner}" ) + if interactive and str(reason) not in skip_rewrites: + aft = pytensor.dprint( + fgraph, + file="str", + print_type=True, + id_type="", + print_topo_order=False, + ) + if bef != aft: + diff = list( + differ.compare( + bef.splitlines(keepends=True), aft.splitlines(keepends=True) + ) + ) + sys.stdout.writelines(diff) + input("Press any key to continue") # The return is needed by replace_all_validate_remove return chk diff --git a/pytensor/printing.py b/pytensor/printing.py index e7f9738426..525c085613 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -119,6 +119,7 @@ def debugprint( print_destroy_map: bool = False, print_view_map: bool = False, print_fgraph_inputs: bool = False, + print_topo_order: bool = True, ) -> Union[str, TextIO]: r"""Print a graph as text. @@ -175,6 +176,8 @@ def debugprint( Whether to print the `view_map`\s of printed objects print_fgraph_inputs Print the inputs of `FunctionGraph`\s. + print_topo_order + Whether to print the toposort ordering of nodes Returns ------- @@ -231,7 +234,10 @@ def debugprint( else: storage_maps.extend([None for item in obj.maker.fgraph.outputs]) topo = obj.maker.fgraph.toposort() - topo_orders.extend([topo for item in obj.maker.fgraph.outputs]) + if print_topo_order: + topo_orders.extend([topo for item in obj.maker.fgraph.outputs]) + else: + topo_orders.extend([None for item in obj.maker.fgraph.outputs]) elif isinstance(obj, FunctionGraph): if print_fgraph_inputs: inputs_to_print.extend(obj.inputs) @@ -241,7 +247,10 @@ def debugprint( [getattr(obj, "storage_map", None) for item in obj.outputs] ) topo = obj.toposort() - topo_orders.extend([topo for item in obj.outputs]) + if print_topo_order: + topo_orders.extend([topo for item in obj.outputs]) + else: + topo_orders.extend([None for item in obj.outputs]) elif isinstance(obj, (int, float, np.ndarray)): print(obj, file=_file) elif isinstance(obj, (In, Out)):