From ccca97d3581e4fdaea254afe82b9d216686b87e4 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 9 Dec 2022 18:46:38 -0600 Subject: [PATCH 1/3] Create new rewrites for elemwise There is no need for an Elemwise Op if all inputs have rank 0. And we don't need to use scalar constants as inputs of the Elemwise, they can be inputs for the scalar_op. --- pytensor/_version.py | 2 +- pytensor/compile/mode.py | 15 ++- pytensor/tensor/rewriting/elemwise.py | 130 +++++++++++++++++++------ pytensor/tensor/rewriting/math.py | 5 +- pytensor/tensor/rewriting/subtensor.py | 1 - 5 files changed, 119 insertions(+), 34 deletions(-) diff --git a/pytensor/_version.py b/pytensor/_version.py index a61f4f2658..0942d41cf8 100644 --- a/pytensor/_version.py +++ b/pytensor/_version.py @@ -92,7 +92,7 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env= env=env, stdout=subprocess.PIPE, stderr=(subprocess.PIPE if hide_stderr else None), - **popen_kwargs + **popen_kwargs, ) break except OSError: diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 8aecf1a902..2add4a321d 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -255,6 +255,16 @@ def apply(self, fgraph): "specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6 ) # must be after gpu stuff at 48.5 +# Must be before add_destroy_handler +optdb.register( + "elemwise_fusion", + SequenceDB(), + "fast_run", + "fusion", + "local_elemwise_fusion", + position=49, +) + # especially constant merge optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49) @@ -453,7 +463,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): ) NUMBA = Mode( NumbaLinker(), - RewriteDatabaseQuery(include=["fast_run"], exclude=["cxx_only", "BlasOpt"]), + RewriteDatabaseQuery( + include=["fast_run", "fast_run_numba", "fast_compile_numba"], + exclude=["cxx_only", "BlasOpt"], + ), ) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index e9952a3908..7697b2e533 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -18,8 +18,8 @@ in2out, node_rewriter, ) -from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined, TestValueError +from pytensor.tensor import as_tensor_variable from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError @@ -380,6 +380,99 @@ def is_dimshuffle_useless(new_order, input): return is_useless +@node_rewriter([Elemwise]) +def local_elemwise_lift_scalars(fgraph, node): + op = node.op + + if not isinstance(op, Elemwise): + return False + + if not all(input.ndim == 0 for input in node.inputs): + return False + + scalars = [aes.as_scalar(input) for input in node.inputs] + + # TODO Something like + # copy_stack_trace(node.outputs[0], new_res) + return [as_tensor_variable(out) for out in op.scalar_op.make_node(*scalars).outputs] + + +compile.optdb["specialize"].register( + "local_elemwise_lift_scalars", + local_elemwise_lift_scalars, + "fast_run_numba", + "fast_compile_numba", +) + + +@node_rewriter([Elemwise]) +def push_elemwise_constants(fgraph, node): + """Push constant scalars from inputs to elemwise to inputs of the + contained scalar op. + """ + op = node.op + + if not isinstance(op, Elemwise): + return False + + if any(op.inplace_pattern): + return False + + if not isinstance(node.op.scalar_op, aes.Composite): + return False + + def is_constant_scalar(x): + return isinstance(x, TensorConstant) and all(x.broadcastable) + + push_idxs = [] + push_values = [] + keep_values = [] + for i, input in enumerate(node.inputs): + if is_constant_scalar(input): + push_idxs.append(i) + val = input.value + push_values.append(aes.constant(val.item(), dtype=val.dtype)) + elif ( + input.owner + and isinstance(input.owner.op, DimShuffle) + and is_constant_scalar(input.owner.inputs[0]) + ): + push_idxs.append(i) + val = input.owner.inputs[0].value + push_values.append(aes.constant(val.item(), dtype=val.dtype)) + else: + keep_values.append(input) + + if not push_values: + return False + + inner_graph = node.op.scalar_op.fgraph + to_replace = [input for i, input in enumerate(inner_graph.inputs) if i in push_idxs] + + # Clone the inner graph, it might be used somewhere else + inner_graph, mapping = inner_graph.clone_get_equiv() + inner_graph.replace_all( + (mapping[old], new) for old, new in zip(to_replace, push_values) + ) + + new_inputs = [ + input for i, input in enumerate(inner_graph.inputs) if i not in push_idxs + ] + return ( + Elemwise(scalar_op=aes.Composite(new_inputs, inner_graph.outputs)) + .make_node(*keep_values) + .outputs + ) + + +compile.optdb["specialize"].register( + "push_elemwise_constants", + push_elemwise_constants, + "fast_run_numba", + "fast_compile_numba", +) + + @register_canonicalize @register_specialize @node_rewriter([DimShuffle]) @@ -898,34 +991,13 @@ def print_profile(cls, stream, prof, level=0): print(blanc, " time_toposort", prof[7], file=stream) -if config.tensor__local_elemwise_fusion: - # Must be after gpu(48.5) and before AddDestroyHandler(49.5) - fuse_seqopt = SequenceDB() - fuse_seqopt.register( - "composite_elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), - "fast_run", - "fusion", - position=1, - ) - compile.optdb.register( # type: ignore - "elemwise_fusion", - fuse_seqopt, - "fast_run", - "fusion", - "local_elemwise_fusion", - "FusionOptimizer", - position=49, - ) -else: - compile.optdb.register( # type: ignore - "elemwise_fusion", - FusionOptimizer(local_elemwise_fusion), - "fusion", - "local_elemwise_fusion", - "FusionOptimizer", - position=49, - ) +compile.optdb["elemwise_fusion"].register( # type: ignore + "composite_elemwise_fusion", + FusionOptimizer(local_elemwise_fusion), + "fast_run", + "fusion", + position=1, +) @register_canonicalize diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 644b9f56c0..cb14f51752 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -8,6 +8,7 @@ import pytensor.scalar.basic as aes import pytensor.scalar.math as aes_math +from pytensor import compile from pytensor.graph.basic import Constant, Variable from pytensor.graph.rewriting.basic import ( NodeRewriter, @@ -91,7 +92,7 @@ register_uncanonicalize, register_useless, ) -from pytensor.tensor.rewriting.elemwise import FusionOptimizer, fuse_seqopt +from pytensor.tensor.rewriting.elemwise import FusionOptimizer from pytensor.tensor.shape import Shape, Shape_i from pytensor.tensor.subtensor import Subtensor from pytensor.tensor.type import ( @@ -2922,7 +2923,7 @@ def local_add_mul_fusion(fgraph, node): return [output] -fuse_seqopt.register( +compile.optdb["elemwise_fusion"].register( "local_add_mul_fusion", FusionOptimizer(local_add_mul_fusion), "fast_run", diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index a252d0f446..9b3717b8f9 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -469,7 +469,6 @@ def local_subtensor_lift(fgraph, node): return [rbcast_subt_x] -@register_canonicalize @register_specialize @node_rewriter([Subtensor]) def local_subtensor_merge(fgraph, node): From 178496542209ab21f46e5ea37d8ca0688341d5ea Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Sun, 11 Dec 2022 16:21:08 -0600 Subject: [PATCH 2/3] Move push_elemwise_constants to post_fusion pass --- pytensor/compile/mode.py | 10 +++++++++- pytensor/tensor/rewriting/elemwise.py | 3 +-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index 2add4a321d..a027317a81 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -262,7 +262,15 @@ def apply(self, fgraph): "fast_run", "fusion", "local_elemwise_fusion", - position=49, + position=48.7, +) + +optdb.register( + "post_fusion", + EquilibriumDB(), + "fast_run", + "fast_compile", + position=48.8, ) # especially constant merge diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 7697b2e533..8db39eb02d 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -411,7 +411,6 @@ def push_elemwise_constants(fgraph, node): contained scalar op. """ op = node.op - if not isinstance(op, Elemwise): return False @@ -465,7 +464,7 @@ def is_constant_scalar(x): ) -compile.optdb["specialize"].register( +compile.optdb["post_fusion"].register( "push_elemwise_constants", push_elemwise_constants, "fast_run_numba", From b0c4462a6a41d4d89f83db19d62e8e27e1056709 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 12 Dec 2022 18:46:00 -0600 Subject: [PATCH 3/3] Create scalarize rewrite pass --- pytensor/compile/mode.py | 37 ++++++++++++++++--- pytensor/tensor/rewriting/basic.py | 51 +++++++++++++++++++------- pytensor/tensor/rewriting/elemwise.py | 23 ++++-------- pytensor/tensor/rewriting/math.py | 13 +++++++ pytensor/tensor/rewriting/subtensor.py | 3 +- tests/link/numba/test_scan.py | 7 +--- 6 files changed, 93 insertions(+), 41 deletions(-) diff --git a/pytensor/compile/mode.py b/pytensor/compile/mode.py index a027317a81..4ce03f253d 100644 --- a/pytensor/compile/mode.py +++ b/pytensor/compile/mode.py @@ -250,6 +250,11 @@ def apply(self, fgraph): # misc special cases for speed that break canonicalization optdb.register("uncanonicalize", EquilibriumDB(), "fast_run", position=3) +# Turn tensor operations to scalar operations where possible. +# This is currently marked as numba-only, but this could be changed +# in the future. +optdb.register("scalarize", EquilibriumDB(), "numba_only", position=3.1) + # misc special cases for speed that are dependent on the device. optdb.register( "specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6 @@ -459,20 +464,42 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs): # FunctionMaker, the Mode will be taken from this dictionary using the # string as the key # Use VM_linker to allow lazy evaluation by default. -FAST_COMPILE = Mode(VMLinker(use_cloop=False, c_thunks=False), "fast_compile") +FAST_COMPILE = Mode( + VMLinker(use_cloop=False, c_thunks=False), + RewriteDatabaseQuery( + include=["fast_compile"], + exclude=["numba_only"], + ), +) if config.cxx: - FAST_RUN = Mode("cvm", "fast_run") + FAST_RUN = Mode( + "cvm", + RewriteDatabaseQuery( + include=["fast_run"], + exclude=["numba_only"], + ), + ) else: - FAST_RUN = Mode("vm", "fast_run") + FAST_RUN = Mode( + "vm", + RewriteDatabaseQuery( + include=["fast_run"], + exclude=["numba_only"], + ), + ) JAX = Mode( JAXLinker(), - RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]), + RewriteDatabaseQuery( + include=["fast_run", "jax"], + exclude=["cxx_only", "BlasOpt", "numba_only"], + ), ) + NUMBA = Mode( NumbaLinker(), RewriteDatabaseQuery( - include=["fast_run", "fast_run_numba", "fast_compile_numba"], + include=["fast_run", "numba_only"], exclude=["cxx_only", "BlasOpt"], ), ) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 87639afccc..1e9855c8e1 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -186,6 +186,23 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): return node_rewriter +def register_scalarize( + node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs +): + if isinstance(node_rewriter, str): + + def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): + return register_specialize(inner_rewriter, node_rewriter, *tags, **kwargs) + + return register + else: + name = kwargs.pop("name", None) or node_rewriter.__name__ + compile.optdb["scalarize"].register( + name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs + ) + return node_rewriter + + def register_uncanonicalize( node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs ): @@ -226,30 +243,36 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]): @register_canonicalize @register_specialize +@register_scalarize @node_rewriter([TensorFromScalar]) def local_tensor_scalar_tensor(fgraph, node): """tensor_from_scalar(scalar_from_tensor(x)) -> x""" - if isinstance(node.op, TensorFromScalar): - s = node.inputs[0] - if s.owner and isinstance(s.owner.op, ScalarFromTensor): - t = s.owner.inputs[0] + s = node.inputs[0] + if s.owner and isinstance(s.owner.op, ScalarFromTensor): + t = s.owner.inputs[0] - # We don't need to copy over any stack traces here - return [t] + # We don't need to copy over any stack traces here + return [t] @register_canonicalize @register_specialize +@register_scalarize @node_rewriter([ScalarFromTensor]) def local_scalar_tensor_scalar(fgraph, node): - """scalar_from_tensor(tensor_from_scalar(x)) -> x""" - if isinstance(node.op, ScalarFromTensor): - t = node.inputs[0] - if t.owner and isinstance(t.owner.op, TensorFromScalar): - s = t.owner.inputs[0] - - # We don't need to copy over any stack traces here - return [s] + """scalar_from_tensor(tensor_from_scalar(x)) -> x + + and scalar_from_tensor(TensorConstant(x)) -> x + """ + t = node.inputs[0] + if t.owner and isinstance(t.owner.op, TensorFromScalar): + s = t.owner.inputs[0] + + # We don't need to copy over any stack traces here + return [s] + if isinstance(t, TensorConstant): + assert t.ndim == 0 + return [aes.constant(t.value.item(), t.name, t.dtype)] @register_specialize("local_alloc_elemwise") diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 8db39eb02d..813bfac2b4 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -381,12 +381,8 @@ def is_dimshuffle_useless(new_order, input): @node_rewriter([Elemwise]) -def local_elemwise_lift_scalars(fgraph, node): +def elemwise_to_scalar(fgraph, node): op = node.op - - if not isinstance(op, Elemwise): - return False - if not all(input.ndim == 0 for input in node.inputs): return False @@ -397,11 +393,12 @@ def local_elemwise_lift_scalars(fgraph, node): return [as_tensor_variable(out) for out in op.scalar_op.make_node(*scalars).outputs] -compile.optdb["specialize"].register( - "local_elemwise_lift_scalars", - local_elemwise_lift_scalars, - "fast_run_numba", - "fast_compile_numba", +compile.optdb["scalarize"].register( + "local_elemwise_to_scalar", + elemwise_to_scalar, + "fast_run", + "fast_compile", + "numba_only", ) @@ -411,9 +408,6 @@ def push_elemwise_constants(fgraph, node): contained scalar op. """ op = node.op - if not isinstance(op, Elemwise): - return False - if any(op.inplace_pattern): return False @@ -467,8 +461,7 @@ def is_constant_scalar(x): compile.optdb["post_fusion"].register( "push_elemwise_constants", push_elemwise_constants, - "fast_run_numba", - "fast_compile_numba", + "numba_only", ) diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index cb14f51752..ca4c6a4fcf 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -86,6 +86,7 @@ encompasses_broadcastable, local_fill_sink, register_canonicalize, + register_scalarize, register_specialize, register_specialize_device, register_stabilize, @@ -1568,6 +1569,18 @@ def local_op_of_op(fgraph, node): return [combined(node_inps.owner.inputs[0])] +@register_scalarize +@node_rewriter([Sum]) +def local_sum_of_makevector(fgraph, node): + (array,) = node.inputs + if not array.owner or not isinstance(array.owner.op, MakeVector): + return False + + values = array.owner.inputs + summed = aes.add(*values) + return [as_tensor_variable(summed)] + + ALL_REDUCE = ( [ CAReduce, diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 9b3717b8f9..8b75e7bff5 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -469,7 +469,8 @@ def local_subtensor_lift(fgraph, node): return [rbcast_subt_x] -@register_specialize +@register_stabilize("cxx_only") +@register_canonicalize("cxx_only") @node_rewriter([Subtensor]) def local_subtensor_merge(fgraph, node): """ diff --git a/tests/link/numba/test_scan.py b/tests/link/numba/test_scan.py index 04bb3aefd8..4e0ad9850b 100644 --- a/tests/link/numba/test_scan.py +++ b/tests/link/numba/test_scan.py @@ -10,7 +10,6 @@ from pytensor.scan.op import Scan from pytensor.scan.utils import until from pytensor.tensor import log, vector -from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.random.utils import RandomStream from tests import unittest_tools as utt from tests.link.numba.test_basic import compare_numba_and_py @@ -437,8 +436,4 @@ def test_inner_graph_optimized(): node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan) ] inner_scan_nodes = scan_node.op.fgraph.apply_nodes - assert len(inner_scan_nodes) == 1 - (inner_scan_node,) = scan_node.op.fgraph.apply_nodes - assert isinstance(inner_scan_node.op, Elemwise) and isinstance( - inner_scan_node.op.scalar_op, Log1p - ) + assert any(isinstance(node.op, Log1p) for node in inner_scan_nodes)