From 717ba1a73e0a3f5c043e8c8e65af9b05e11390e8 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Thu, 15 Jun 2023 21:56:00 -0500 Subject: [PATCH 1/3] Add rewrite for Sum(MakeVector) --- pytensor/tensor/rewriting/basic.py | 25 +++++++++++++++++++++++++ tests/tensor/rewriting/test_basic.py | 16 ++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 58a5918c12..96796a4dc0 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -43,6 +43,7 @@ from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to +from pytensor.tensor.math import Sum, add from pytensor.tensor.math import all as at_all from pytensor.tensor.math import eq from pytensor.tensor.shape import Shape_i @@ -956,6 +957,30 @@ def local_join_make_vector(fgraph, node): return [ret] +@register_specialize +@register_canonicalize +@register_useless +@node_rewriter([Sum]) +def local_sum_make_vector(fgraph, node): + """A sum of a MakeVector node is just the sum of the elements.""" + (array,) = node.inputs + + if array.owner is None: + return + + if not isinstance(array.owner.op, MakeVector): + return + + if node.op.axis not in [None, 0, -1]: + return + + elements = array.owner.inputs + dtype = node.op.acc_dtype + element_sum = add(*[cast(value, dtype) for value in elements]) + + return [as_tensor_variable(element_sum)] + + @register_useless("local_remove_switch_const_cond") @register_canonicalize("fast_compile", "local_remove_switch_const_cond") @register_specialize diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index fe2b795907..c6b0ee8c69 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -19,6 +19,7 @@ from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.printing import debugprint, pprint from pytensor.raise_op import Assert, CheckAndRaise +from pytensor.scalar.basic import Add from pytensor.tensor.basic import ( Alloc, Join, @@ -102,6 +103,7 @@ values_eq_approx_remove_nan, vector, ) +from pytensor.tensor.var import TensorVariable from tests import unittest_tools as utt @@ -1300,6 +1302,20 @@ def test_local_join_make_vector(): assert check_stack_trace(f, ops_to_check="all") +def test_local_sum_make_vector(): + a, b, c = scalars("abc") + mv = MakeVector(config.floatX) + output = mv(a, b, c).sum() + + func = function([a, b, c], output) + + elemwise = func.maker.fgraph.outputs[0].owner + # The MakeVector op should be optimized away, so we just + # take the sum of the scalars. + assert elemwise.inputs[0].name == "a" + assert isinstance(elemwise.inputs[0], TensorVariable) + + @pytest.mark.parametrize( "dtype", [ From 287087478c566ad761e54fea8b694c70e15ccb40 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Fri, 16 Jun 2023 11:49:46 -0500 Subject: [PATCH 2/3] Improve test_local_sum_make_vector rewrite --- pytensor/tensor/rewriting/basic.py | 14 +++++++++----- tests/tensor/rewriting/test_basic.py | 24 +++++++++++++++--------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 96796a4dc0..8cd40469db 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -971,14 +971,18 @@ def local_sum_make_vector(fgraph, node): if not isinstance(array.owner.op, MakeVector): return - if node.op.axis not in [None, 0, -1]: - return + if node.op.axis == (): + return [array] + + # If this is not the case the sum is invalid + assert node.op.axis is None or node.op.axis == (0,) elements = array.owner.inputs - dtype = node.op.acc_dtype - element_sum = add(*[cast(value, dtype) for value in elements]) + acc_dtype = node.op.acc_dtype + out_dtype = node.op.dtype + element_sum = cast(add(*[cast(value, acc_dtype) for value in elements]), out_dtype) - return [as_tensor_variable(element_sum)] + return [element_sum] @register_useless("local_remove_switch_const_cond") diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index c6b0ee8c69..9b79db5f28 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -12,14 +12,13 @@ from pytensor.compile.mode import get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.configdefaults import config -from pytensor.graph.basic import equal_computations +from pytensor.graph.basic import equal_computations, vars_between from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.printing import debugprint, pprint from pytensor.raise_op import Assert, CheckAndRaise -from pytensor.scalar.basic import Add from pytensor.tensor.basic import ( Alloc, Join, @@ -32,6 +31,7 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import ( + Sum, add, bitwise_and, bitwise_or, @@ -103,7 +103,6 @@ values_eq_approx_remove_nan, vector, ) -from pytensor.tensor.var import TensorVariable from tests import unittest_tools as utt @@ -1307,13 +1306,20 @@ def test_local_sum_make_vector(): mv = MakeVector(config.floatX) output = mv(a, b, c).sum() - func = function([a, b, c], output) + output = rewrite_graph(output) + between = vars_between([a, b, c], [output]) + for var in between: + assert (var.owner is None) or (not isinstance(var.owner.op, MakeVector)) - elemwise = func.maker.fgraph.outputs[0].owner - # The MakeVector op should be optimized away, so we just - # take the sum of the scalars. - assert elemwise.inputs[0].name == "a" - assert isinstance(elemwise.inputs[0], TensorVariable) + # Check for empty sum + a, b, c = scalars("abc") + mv = MakeVector(config.floatX) + output = mv(a, b, c).sum(axis=[]) + + output = rewrite_graph(output) + between = vars_between([a, b, c], [output]) + for var in between: + assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) @pytest.mark.parametrize( From 7429e945d4eaa5761fa4792079b4dee583197380 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 11 Jul 2023 20:03:31 -0500 Subject: [PATCH 3/3] fix(rewrite): Handle sum of empty make vector --- pytensor/tensor/rewriting/basic.py | 11 +++++++++-- tests/tensor/rewriting/test_basic.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 8cd40469db..ff103e9fc1 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -975,12 +975,19 @@ def local_sum_make_vector(fgraph, node): return [array] # If this is not the case the sum is invalid - assert node.op.axis is None or node.op.axis == (0,) + assert node.op.axis is None or node.op.axis == (0,) or node.op.axis == (-1,) elements = array.owner.inputs acc_dtype = node.op.acc_dtype out_dtype = node.op.dtype - element_sum = cast(add(*[cast(value, acc_dtype) for value in elements]), out_dtype) + if len(elements) == 0: + element_sum = zeros(dtype=out_dtype, shape=()) + elif len(elements) == 1: + element_sum = cast(elements[0], out_dtype) + else: + element_sum = cast( + add(*[cast(value, acc_dtype) for value in elements]), out_dtype + ) return [element_sum] diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 9b79db5f28..3c3f917bc9 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1321,6 +1321,23 @@ def test_local_sum_make_vector(): for var in between: assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) + # Check empty MakeVector + mv = MakeVector(config.floatX) + output = mv().sum() + + output = rewrite_graph(output) + between = vars_between([a, b, c], [output]) + for var in between: + assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) + + mv = MakeVector(config.floatX) + output = mv(a).sum() + + output = rewrite_graph(output) + between = vars_between([a, b, c], [output]) + for var in between: + assert (var.owner is None) or (not isinstance(var.owner.op, Sum)) + @pytest.mark.parametrize( "dtype",