diff --git a/pytensor/_version.py b/pytensor/_version.py index a61f4f2658..0942d41cf8 100644 --- a/pytensor/_version.py +++ b/pytensor/_version.py @@ -92,7 +92,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None), - **popen_kwargs + **popen_kwargs, ) break except OSError: diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 8aecf1a902..4ce03f253d 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -250,11 +250,34 @@ def apply(self, fgraph): # misc special cases for speed that break canonicalization optdb.register("uncanonicalize", EquilibriumDB(), "fast_run", position=3) +# Turn tensor operations to scalar operations where possible. +# This is currently marked as numba-only, but this could be changed +# in the future. +optdb.register("scalarize", EquilibriumDB(), "numba_only", position=3.1) + # misc special cases for speed that are dependent on the device. optdb.register( "specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6 ) # must be after gpu stuff at 48.5 +# Must be before add_destroy_handler +optdb.register( + "elemwise_fusion", + SequenceDB(), + "fast_run", + "fusion", + "local_elemwise_fusion", + position=48.7, +) + +optdb.register( + "post_fusion", + EquilibriumDB(), + "fast_run", + "fast_compile", + position=48.8, +) + # especially constant merge optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49) @@ -441,19 +464,44 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): # FunctionMaker, the Mode will be taken from this dictionary using the # string as the key # Use VM_linker to allow lazy evaluation by default. -FAST_COMPILE = Mode(VMLinker(use_cloop=False, c_thunks=False), "fast_compile") +FAST_COMPILE = Mode( + VMLinker(use_cloop=False, c_thunks=False), + RewriteDatabaseQuery( + include=["fast_compile"], + exclude=["numba_only"], + ), +) if config.cxx: - FAST_RUN = Mode("cvm", "fast_run") + FAST_RUN = Mode( + "cvm", + RewriteDatabaseQuery( + include=["fast_run"], + exclude=["numba_only"], + ), + ) else: - FAST_RUN = Mode("vm", "fast_run") + FAST_RUN = Mode( + "vm", + RewriteDatabaseQuery( + include=["fast_run"], + exclude=["numba_only"], + ), + ) JAX = Mode( JAXLinker(), - RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]), + RewriteDatabaseQuery( + include=["fast_run", "jax"], + exclude=["cxx_only", "BlasOpt", "numba_only"], + ), ) + NUMBA = Mode( NumbaLinker(), - RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]), + RewriteDatabaseQuery( + include=["fast_run", "numba_only"], + exclude=["cxx_only", "BlasOpt"], + ), ) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 87639afccc..1e9855c8e1 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -186,6 +186,23 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): return node_rewriter +def register_scalarize( + node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): + return register_specialize(inner_rewriter, node_rewriter, *tags, **kwargs) + + return register + else: + name = kwargs.pop("name", None) or node_rewriter.__name__ + compile.optdb["scalarize"].register( + name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs + ) + return node_rewriter + + def register_uncanonicalize( node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs ): @@ -226,30 +243,36 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): @register_canonicalize @register_specialize +@register_scalarize @node_rewriter([TensorFromScalar]) def local_tensor_scalar_tensor(fgraph, node): """tensor_from_scalar(scalar_from_tensor(x)) -> x""" - if isinstance(node.op, TensorFromScalar): - s = node.inputs[0] - if s.owner and isinstance(s.owner.op, ScalarFromTensor): - t = s.owner.inputs[0] + s = node.inputs[0] + if s.owner and isinstance(s.owner.op, ScalarFromTensor): + t = s.owner.inputs[0] - # We don't need to copy over any stack traces here - return [t] + # We don't need to copy over any stack traces here + return [t] @register_canonicalize @register_specialize +@register_scalarize @node_rewriter([ScalarFromTensor]) def local_scalar_tensor_scalar(fgraph, node): - """scalar_from_tensor(tensor_from_scalar(x)) -> x""" - if isinstance(node.op, ScalarFromTensor): - t = node.inputs[0] - if t.owner and isinstance(t.owner.op, TensorFromScalar): - s = t.owner.inputs[0] - - # We don't need to copy over any stack traces here - return [s] + """scalar_from_tensor(tensor_from_scalar(x)) -> x + + and scalar_from_tensor(TensorConstant(x)) -> x + """ + t = node.inputs[0] + if t.owner and isinstance(t.owner.op, TensorFromScalar): + s = t.owner.inputs[0] + + # We don't need to copy over any stack traces here + return [s] + if isinstance(t, TensorConstant): + assert t.ndim == 0 + return [aes.constant(t.value.item(), t.name, t.dtype)] @register_specialize("local_alloc_elemwise") diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index e9952a3908..813bfac2b4 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -18,8 +18,8 @@ in2out, node_rewriter, ) -from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError +from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError @@ -380,6 +380,91 @@ def is_dimshuffle_useless(new_order, input): return is_useless +@node_rewriter([Elemwise]) +def elemwise_to_scalar(fgraph, node): + op = node.op + if not all(input.ndim == 0 for input in node.inputs): + return False + + scalars = [aes.as_scalar(input) for input in node.inputs] + + # TODO Something like + # copy_stack_trace(node.outputs[0], new_res) + return [as_tensor_variable(out) for out in op.scalar_op.make_node(*scalars).outputs] + + +compile.optdb["scalarize"].register( + "local_elemwise_to_scalar", + elemwise_to_scalar, + "fast_run", + "fast_compile", + "numba_only", +) + + +@node_rewriter([Elemwise]) +def push_elemwise_constants(fgraph, node): + """Push constant scalars from inputs to elemwise to inputs of the + contained scalar op. + """ + op = node.op + if any(op.inplace_pattern): + return False + + if not isinstance(node.op.scalar_op, aes.Composite): + return False + + def is_constant_scalar(x): + return isinstance(x, TensorConstant) and all(x.broadcastable) + + push_idxs = [] + push_values = [] + keep_values = [] + for i, input in enumerate(node.inputs): + if is_constant_scalar(input): + push_idxs.append(i) + val = input.value + push_values.append(aes.constant(val.item(), dtype=val.dtype)) + elif ( + input.owner + and isinstance(input.owner.op, DimShuffle) + and is_constant_scalar(input.owner.inputs[0]) + ): + push_idxs.append(i) + val = input.owner.inputs[0].value + push_values.append(aes.constant(val.item(), dtype=val.dtype)) + else: + keep_values.append(input) + + if not push_values: + return False + + inner_graph = node.op.scalar_op.fgraph + to_replace = [input for i, input in enumerate(inner_graph.inputs) if i in push_idxs] + + # Clone the inner graph, it might be used somewhere else + inner_graph, mapping = inner_graph.clone_get_equiv() + inner_graph.replace_all( + (mapping[old], new) for old, new in zip(to_replace, push_values) + ) + + new_inputs = [ + input for i, input in enumerate(inner_graph.inputs) if i not in push_idxs + ] + return ( + Elemwise(scalar_op=aes.Composite(new_inputs, inner_graph.outputs)) + .make_node(*keep_values) + .outputs + ) + + +compile.optdb["post_fusion"].register( + "push_elemwise_constants", + push_elemwise_constants, + "numba_only", +) + + @register_canonicalize @register_specialize @node_rewriter([DimShuffle]) @@ -898,34 +983,13 @@ def print_profile(cls, stream, prof, level=0): print(blanc, " time_toposort", prof[7], file=stream) -if config.tensor__local_elemwise_fusion: - # Must be after gpu(48.5) and before AddDestroyHandler(49.5) - fuse_seqopt = SequenceDB() - fuse_seqopt.register( - "composite_elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), - "fast_run", - "fusion", - position=1, - ) - compile.optdb.register( # type: ignore - "elemwise_fusion", - fuse_seqopt, - "fast_run", - "fusion", - "local_elemwise_fusion", - "FusionOptimizer", - position=49, - ) -else: - compile.optdb.register( # type: ignore - "elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), - "fusion", - "local_elemwise_fusion", - "FusionOptimizer", - position=49, - ) +compile.optdb["elemwise_fusion"].register( # type: ignore + "composite_elemwise_fusion", + FusionOptimizer(local_elemwise_fusion), + "fast_run", + "fusion", + position=1, +) @register_canonicalize diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 644b9f56c0..ca4c6a4fcf 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -8,6 +8,7 @@ import pytensor.scalar.basic as aes import pytensor.scalar.math as aes_math +from pytensor import compile from pytensor.graph.basic import Constant, Variable from pytensor.graph.rewriting.basic import ( NodeRewriter, @@ -85,13 +86,14 @@ encompasses_broadcastable, local_fill_sink, register_canonicalize, + register_scalarize, register_specialize, register_specialize_device, register_stabilize, register_uncanonicalize, register_useless, ) -from pytensor.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt +from pytensor.tensor.rewriting.elemwise import FusionOptimizer from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( @@ -1567,6 +1569,18 @@ def local_op_of_op(fgraph, node): return [combined(node_inps.owner.inputs[0])] +@register_scalarize +@node_rewriter([Sum]) +def local_sum_of_makevector(fgraph, node): + (array,) = node.inputs + if not array.owner or not isinstance(array.owner.op, MakeVector): + return False + + values = array.owner.inputs + summed = aes.add(*values) + return [as_tensor_variable(summed)] + + ALL_REDUCE = ( [ CAReduce, @@ -2922,7 +2936,7 @@ def local_add_mul_fusion(fgraph, node): return [output] -fuse_seqopt.register( +compile.optdb["elemwise_fusion"].register( "local_add_mul_fusion", FusionOptimizer(local_add_mul_fusion), "fast_run", diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index a252d0f446..8b75e7bff5 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -469,8 +469,8 @@ def local_subtensor_lift(fgraph, node): return [rbcast_subt_x] -@register_canonicalize -@register_specialize +@register_stabilize("cxx_only") +@register_canonicalize("cxx_only") @node_rewriter([Subtensor]) def local_subtensor_merge(fgraph, node): """ diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 04bb3aefd8..4e0ad9850b 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -10,7 +10,6 @@ from pytensor.scan.op import Scan from pytensor.scan.utils import until from pytensor.tensor import log, vector -from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.utils import RandomStream from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py @@ -437,8 +436,4 @@ def test_inner_graph_optimized(): node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan) ] inner_scan_nodes = scan_node.op.fgraph.apply_nodes - assert len(inner_scan_nodes) == 1 - (inner_scan_node,) = scan_node.op.fgraph.apply_nodes - assert isinstance(inner_scan_node.op, Elemwise) and isinstance( - inner_scan_node.op.scalar_op, Log1p - ) + assert any(isinstance(node.op, Log1p) for node in inner_scan_nodes)