diff --git a/pytensor/graph/rewriting/utils.py b/pytensor/graph/rewriting/utils.py index e1234c19e8..d9c1f39bfd 100644 --- a/pytensor/graph/rewriting/utils.py +++ b/pytensor/graph/rewriting/utils.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Optional, Union, cast import pytensor +from pytensor import compile from pytensor.graph.basic import ( Apply, Variable, @@ -11,7 +12,8 @@ vars_between, ) from pytensor.graph.fg import FunctionGraph -from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.graph.rewriting.basic import NodeRewriter +from pytensor.graph.rewriting.db import RewriteDatabase, RewriteDatabaseQuery if TYPE_CHECKING: @@ -238,3 +240,21 @@ def get_clients_at_depth( else: assert var.owner is not None yield var.owner + + +def register_canonicalize( + node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: Union[RewriteDatabase, NodeRewriter]): + return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs) + + return register + else: + name = kwargs.pop("name", None) or getattr(node_rewriter, "__name__", None) + + compile.optdb["canonicalize"].register( + name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs + ) + return node_rewriter diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index a2a4ccc2f7..c9be35b1d0 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -41,6 +41,7 @@ node_rewriter, ) from pytensor.graph.rewriting.db import RewriteDatabase +from pytensor.graph.rewriting.utils import register_canonicalize from pytensor.raise_op import Assert, CheckAndRaise, assert_op from pytensor.tensor.basic import ( Alloc, @@ -153,23 +154,6 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): return node_rewriter -def register_canonicalize( - node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs -): - if isinstance(node_rewriter, str): - - def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): - return register_canonicalize(inner_rewriter, node_rewriter, *tags, **kwargs) - - return register - else: - name = kwargs.pop("name", None) or node_rewriter.__name__ - compile.optdb["canonicalize"].register( - name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs - ) - return node_rewriter - - def register_stabilize( node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs ):