From 70033c99dfbab9a47814a859a3f76455051ce062 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jul 2023 12:09:46 +0200 Subject: [PATCH 1/9] Clarify behavior of Elemwise second --- pytensor/tensor/basic.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 8090d6d6a8..762e145c8b 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -765,7 +765,12 @@ def switch(cond, ift, iff): @scalar_elemwise def second(a, b): - """Create a matrix by filling the shape of a with b""" + """Create a matrix by filling the broadcasted shapes of a and b with the values of b + + Equivalent to `np.broadcast_arrays(a, b)[1]` + Equivalent to `np.array(a).fill(b)` when b is a scalar value. + + """ fill = second From 28037cda884fe41961bd1145972b317b28182dc7 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jul 2023 12:28:14 +0200 Subject: [PATCH 2/9] Use second for broadcast_arrays and remove fill_chain helper --- pytensor/tensor/extra_ops.py | 16 ++++++++++++++-- pytensor/tensor/rewriting/math.py | 31 +++++++++++++++---------------- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index c13e514d53..09e8bf5551 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -23,7 +23,7 @@ from pytensor.scalar import upcast from pytensor.tensor import as_tensor_variable from pytensor.tensor import basic as at -from pytensor.tensor import get_vector_length +from pytensor.tensor.basic import get_vector_length, second from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import all as pt_all @@ -1780,7 +1780,19 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: The arrays to broadcast. """ - return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args) + + def broadcast_with_others(a, others): + for other in others: + a = second(other, a) + return a + + brodacasted_vars = [] + for i, a in enumerate(args): + # We use indexing and not identity in case there are duplicated variables + others = [a for j, a in enumerate(args) if j != i] + brodacasted_vars.append(broadcast_with_others(a, others)) + + return brodacasted_vars __all__ = [ diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index c585b70096..3c86d1ba6e 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -38,6 +38,7 @@ ) from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.math import ( All, Any, @@ -148,12 +149,6 @@ def get_constant(v): return v -def fill_chain(new_out, orig_inputs): - for i in orig_inputs: - new_out = fill(i, new_out) - return [new_out] - - @register_canonicalize @register_stabilize @node_rewriter([Dot]) @@ -1136,7 +1131,7 @@ def same(x, y): new = cast(new, out.type.dtype) if new.type.broadcastable != out.type.broadcastable: - new = fill_chain(new, node.inputs)[0] + new = broadcast_arrays(new, *node.inputs)[0] if (new.type.dtype == out.type.dtype) and ( new.type.broadcastable == out.type.broadcastable @@ -1961,7 +1956,9 @@ def local_mul_zero(fgraph, node): # print 'MUL by value', value, node.inputs if value == 0: # print '... returning zeros' - return fill_chain(_asarray(0, dtype=otype.dtype), node.inputs) + return [ + broadcast_arrays(_asarray(0, dtype=otype.dtype), *node.inputs)[0] + ] # TODO: Add this to the canonicalization to reduce redundancy. @@ -2260,12 +2257,12 @@ def local_add_specialize(fgraph, node): # Reuse call to constant for cache() cst = constant(np.zeros((1,) * ndim, dtype=dtype)) assert cst.type.broadcastable == (True,) * ndim - return fill_chain(cst, node.inputs) + return [broadcast_arrays(cst, *node.inputs)[0]] if len(new_inputs) == 1: - ret = fill_chain(new_inputs[0], node.inputs) + ret = [broadcast_arrays(new_inputs[0], *node.inputs)[0]] else: - ret = fill_chain(add(*new_inputs), node.inputs) + ret = [broadcast_arrays(add(*new_inputs), *node.inputs)[0]] # The dtype should not be changed. It can happen if the input # that was forcing upcasting was equal to 0. @@ -2383,7 +2380,7 @@ def local_log1p(fgraph, node): ninp = nonconsts[0] if ninp.dtype != log_arg.type.dtype: ninp = ninp.astype(node.outputs[0].dtype) - return fill_chain(log1p(ninp), scalar_inputs) + return [broadcast_arrays(log1p(ninp), *scalar_inputs)[0]] elif log_arg.owner and log_arg.owner.op == sub: one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) @@ -3578,10 +3575,12 @@ def local_reciprocal_1_plus_exp(fgraph, node): if len(nonconsts) == 1: if nonconsts[0].owner and nonconsts[0].owner.op == exp: if scalars_ and np.allclose(np.sum(scalars_), 1): - out = fill_chain( - sigmoid(neg(nonconsts[0].owner.inputs[0])), - scalar_inputs, - ) + out = [ + broadcast_arrays( + sigmoid(neg(nonconsts[0].owner.inputs[0])), + *scalar_inputs, + )[0] + ] # keep combined stack traces of # exp(x): nonconsts[0], # 1 + exp(x): reciprocal_arg, From add8d5f904562c3f528efe4640ea582f07c5bd88 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jul 2023 12:53:04 +0200 Subject: [PATCH 3/9] Refactor encompasses_broadcastable to broadcasted_by --- pytensor/tensor/rewriting/basic.py | 29 ++++++++++++++--------------- pytensor/tensor/rewriting/math.py | 10 +++------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index ff103e9fc1..aeac3b8351 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -49,7 +49,7 @@ from pytensor.tensor.shape import Shape_i from pytensor.tensor.sort import TopKOp from pytensor.tensor.type import DenseTensorType, TensorType -from pytensor.tensor.var import TensorConstant +from pytensor.tensor.var import TensorConstant, TensorVariable from pytensor.utils import NoDuplicateOptWarningFilter @@ -61,27 +61,26 @@ _logger.addFilter(NoDuplicateOptWarningFilter()) -def encompasses_broadcastable(b1, b2): - """ +def broadcasted_by(x: TensorVariable, y: TensorVariable) -> bool: + """Check whether x would be broadcasted by y in an Elemwise operation Parameters ---------- - b1 - The broadcastable attribute of a tensor type. - b2 - The broadcastable attribute of a tensor type. + x: TensorVariable + The variable that may be broadcasted by y + y: TensorVariable + The variable that may broadcast x Returns ------- - bool - True if the broadcastable patterns b1 and b2 are such that b2 is - broadcasted to b1's shape and not the opposite. - + broadcasted_by: bool """ - if len(b1) < len(b2): - return False - b1 = b1[-len(b2) :] - return not any(v1 and not v2 for v1, v2 in zip(b1, b2)) + bx = x.type.broadcastable + by = y.type.broadcastable + if len(bx) < len(by): + return True + bx = bx[-len(by) :] + return any(bx_dim and not by_dim for bx_dim, by_dim in zip(bx, by)) def merge_broadcastables(broadcastables): diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 3c86d1ba6e..29f64c86d0 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -85,7 +85,7 @@ from pytensor.tensor.math import true_div from pytensor.tensor.rewriting.basic import ( broadcast_like, - encompasses_broadcastable, + broadcasted_by, local_fill_sink, register_canonicalize, register_specialize, @@ -2049,9 +2049,7 @@ def local_pow_specialize(fgraph, node): xsym = node.inputs[0] ysym = node.inputs[1] y = get_constant(ysym) - if (y is not None) and encompasses_broadcastable( - xsym.type.broadcastable, ysym.type.broadcastable - ): + if (y is not None) and not broadcasted_by(xsym, ysym): rval = None if np.all(y == 2): @@ -2107,9 +2105,7 @@ def local_pow_to_nested_squaring(fgraph, node): y = y[0] except IndexError: pass - if (y is not None) and encompasses_broadcastable( - xsym.type.broadcastable, ysym.type.broadcastable - ): + if (y is not None) and not broadcasted_by(xsym, ysym): rval = None # 512 is too small for the cpu and too big for some gpu! if abs(y) == int(abs(y)) and abs(y) <= 512: From dd8462a4d0bd891185e05dac8bda163c6e3e32a8 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 14 Jul 2023 10:18:00 +0200 Subject: [PATCH 4/9] Rename broadcast_like to alloc_like --- pytensor/tensor/rewriting/basic.py | 17 +++++++++-------- pytensor/tensor/rewriting/elemwise.py | 4 ++-- pytensor/tensor/rewriting/math.py | 18 +++++++++--------- 3 files changed, 20 insertions(+), 19 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index aeac3b8351..93a022dec5 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -8,6 +8,7 @@ import pytensor.scalar.basic as aes from pytensor import compile from pytensor.compile.ops import ViewOp +from pytensor.graph import FunctionGraph from pytensor.graph.basic import Constant, Variable from pytensor.graph.rewriting.basic import ( NodeRewriter, @@ -87,13 +88,13 @@ def merge_broadcastables(broadcastables): return [all(bcast) for bcast in zip(*broadcastables)] -def broadcast_like(value, template, fgraph, dtype=None): - """ - Return a Variable with the same shape and dtype as the template, - filled by broadcasting value through it. `value` will be cast as - necessary. - - """ +def alloc_like( + value: TensorVariable, + template: TensorVariable, + fgraph: FunctionGraph, + dtype=None, +) -> TensorVariable: + """Fill value to the same shape and dtype as the template via alloc.""" value = as_tensor_variable(value) if value.type.is_super(template.type): return value @@ -438,7 +439,7 @@ def local_fill_to_alloc(fgraph, node): # In this case, we assume that some broadcasting is needed (otherwise # the condition above would've been true), so we replace the `fill` # with an `Alloc`. - o = broadcast_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype) + o = alloc_like(values_ref, shape_ref, fgraph, dtype=values_ref.dtype) copy_stack_trace(node.outputs[0], o) return [o] diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 64a37cb340..f6f87f3a8f 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -34,7 +34,7 @@ from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import exp from pytensor.tensor.rewriting.basic import ( - broadcast_like, + alloc_like, register_canonicalize, register_specialize, ) @@ -1242,7 +1242,7 @@ def local_inline_composite_constants(fgraph, node): # Some of the inlined constants were broadcasting the output shape if node.outputs[0].type.broadcastable != new_outputs[0].type.broadcastable: new_outputs = [ - broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph) + alloc_like(new_out, template=node.outputs[0], fgraph=fgraph) for new_out in new_outputs ] diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 29f64c86d0..13f51b2d99 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -84,7 +84,7 @@ from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import true_div from pytensor.tensor.rewriting.basic import ( - broadcast_like, + alloc_like, broadcasted_by, local_fill_sink, register_canonicalize, @@ -1973,7 +1973,7 @@ def local_div_to_reciprocal(fgraph, node): new_out = cast(new_out, dtype=out.dtype) # The ones could have forced a specific length if not out.type.is_super(new_out.type): - new_out = broadcast_like(new_out, out, fgraph) + new_out = alloc_like(new_out, out, fgraph) return [new_out] else: return False @@ -1994,9 +1994,9 @@ def local_pow_canonicalize(fgraph, node): if node.op == at_pow: cst = get_constant(node.inputs[1]) if cst == 0: - return [broadcast_like(1, node.outputs[0], fgraph)] + return [alloc_like(1, node.outputs[0], fgraph)] if cst == 1: - return [broadcast_like(node.inputs[0], node.outputs[0], fgraph)] + return [alloc_like(node.inputs[0], node.outputs[0], fgraph)] else: return False @@ -2033,7 +2033,7 @@ def local_zero_div(fgraph, node): node.op.scalar_op, (aes.IntDiv, aes.TrueDiv) ): if get_constant(node.inputs[0]) == 0: - ret = broadcast_like(0, node.outputs[0], fgraph) + ret = alloc_like(0, node.outputs[0], fgraph) ret.tag.values_eq_approx = values_eq_approx_remove_nan return [ret] @@ -2184,7 +2184,7 @@ def local_mul_specialize(fgraph, node): has_neg ^= True # toggles elif y == 0.0: # if we find any zero, we just return right away - return [broadcast_like(0, node.outputs[0], fgraph)] + return [alloc_like(0, node.outputs[0], fgraph)] else: new_inputs.append(inp) @@ -2209,14 +2209,14 @@ def local_mul_specialize(fgraph, node): new_inputs = [m1] + new_inputs rval = mul(*new_inputs) - return [broadcast_like(rval, node.outputs[0], fgraph)] + return [alloc_like(rval, node.outputs[0], fgraph)] else: # there are no variable inputs to mul # N.B. this could have been constant-folded... if has_neg: - return [broadcast_like(-1, node.outputs[0], fgraph)] + return [alloc_like(-1, node.outputs[0], fgraph)] else: - return [broadcast_like(1, node.outputs[0], fgraph)] + return [alloc_like(1, node.outputs[0], fgraph)] @register_specialize From bd918db29a2f3e93dd4811c22ab72aa027f62eb3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jul 2023 12:50:19 +0200 Subject: [PATCH 5/9] Be consistent about second vs alloc in rewrites --- pytensor/tensor/rewriting/basic.py | 24 +++++++++++++++++++- pytensor/tensor/rewriting/math.py | 35 ++++++++++-------------------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 93a022dec5..acb3ebc84c 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -1,4 +1,26 @@ -""" Tensor optimizations addressing the ops in basic.py.""" +""" Tensor optimizations addressing the ops in basic.py. + +Notes +----- +There are two ways of broadcasting arrays: +second(x, y) == alloc(y, broadcast_shapes(x.shape, y.shape)) + +The second can be more efficient because x doesn't usually need to be computed when we only want its shape. +It may also allow other rewrites that don't try to modify x when it has multiple clients (for fear of duplicating computation). + +However, the first one is easier to reason about. +Knowing we have such a graph allows to do certain rewrites such as "sinking" broadcasting operations below Elemwise. +The same rewrites with alloc would be more complicated as we would need to symbolically combine the shapes of each one. + +As an example contrast rewriting the following two equivalent graphs + +alloc(x, broadcast_shapes(x.shape, y.shape)) + alloc(y, broadcast_shapes(x.shape, y.shape)) -> x + y +second(y, x) + second(x, y) -> x + y + +Theano developers (mostly) preferred to use the first form during canonicalization and introduce the second form later, +via rewrites like `local_fill_to_alloc`, and using the `alloc_like` helper inside rewrites. +Many stabilize and stabilization rewrites refuse to be applied when a variable has multiple clients, so this is important. +""" import logging from typing import TYPE_CHECKING, Optional, Union diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 13f51b2d99..ddfccff912 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -30,7 +30,6 @@ cast, constant, extract_constant, - fill, get_underlying_scalar_constant_value, ones_like, switch, @@ -2041,8 +2040,6 @@ def local_zero_div(fgraph, node): @register_specialize @node_rewriter([at_pow]) def local_pow_specialize(fgraph, node): - # here, we are past the point of canonicalization, so we don't want - # to put in un-necessary fills. if node.op == at_pow: # the idea here is that we have pow(x, y) odtype = node.outputs[0].dtype @@ -2057,7 +2054,7 @@ def local_pow_specialize(fgraph, node): if np.all(y == 1): rval = [xsym] if np.all(y == 0): - rval = [fill(xsym, np.asarray(1, dtype=odtype))] + rval = [alloc_like(1, xsym, fgraph)] if np.all(y == 0.5): rval = [sqrt(xsym)] if np.all(y == -0.5): @@ -2158,9 +2155,7 @@ def local_mul_specialize(fgraph, node): mul(-1, x, y) -/-> neg(mul(x, y)) """ - # here, we are past the point of canonicalization, so we don't - # want to put in un-necessary fills. - # + # at this point [post canonicalize], mul() may have many inputs. if node.op == mul: # the idea here is that we have pow(x, y) @@ -2221,16 +2216,7 @@ def local_mul_specialize(fgraph, node): @register_specialize @node_rewriter([add]) -def local_add_specialize(fgraph, node): - """Remove zeros from ``add``s. - - TODO: This should be a canonicalization, no? - """ - # here, we are past the point of canonicalization, so we don't want - # to put in un-necessary fills. - if node.op != add: - return False - +def local_add_remove_zeros(fgraph, node): new_inputs = [] for inp in node.inputs: try: @@ -2253,12 +2239,12 @@ def local_add_specialize(fgraph, node): # Reuse call to constant for cache() cst = constant(np.zeros((1,) * ndim, dtype=dtype)) assert cst.type.broadcastable == (True,) * ndim - return [broadcast_arrays(cst, *node.inputs)[0]] + return [alloc_like(cst, node_output, fgraph)] if len(new_inputs) == 1: - ret = [broadcast_arrays(new_inputs[0], *node.inputs)[0]] + ret = [alloc_like(new_inputs[0], node_output, fgraph)] else: - ret = [broadcast_arrays(add(*new_inputs), *node.inputs)[0]] + ret = [alloc_like(add(*new_inputs), node_output, fgraph)] # The dtype should not be changed. It can happen if the input # that was forcing upcasting was equal to 0. @@ -2376,7 +2362,7 @@ def local_log1p(fgraph, node): ninp = nonconsts[0] if ninp.dtype != log_arg.type.dtype: ninp = ninp.astype(node.outputs[0].dtype) - return [broadcast_arrays(log1p(ninp), *scalar_inputs)[0]] + return [alloc_like(log1p(ninp), node.outputs[0], fgraph)] elif log_arg.owner and log_arg.owner.op == sub: one = extract_constant(log_arg.owner.inputs[0], only_process_constants=True) @@ -3572,10 +3558,11 @@ def local_reciprocal_1_plus_exp(fgraph, node): if nonconsts[0].owner and nonconsts[0].owner.op == exp: if scalars_ and np.allclose(np.sum(scalars_), 1): out = [ - broadcast_arrays( + alloc_like( sigmoid(neg(nonconsts[0].owner.inputs[0])), - *scalar_inputs, - )[0] + node.outputs[0], + fgraph, + ) ] # keep combined stack traces of # exp(x): nonconsts[0], From 548c14a2fbad01be92634ffd1ecc984c9077a4da Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jul 2023 15:31:54 +0200 Subject: [PATCH 6/9] Tag rewrites that make shape assumptions --- pytensor/configdefaults.py | 20 +------------------- pytensor/tensor/rewriting/basic.py | 20 ++++++++++---------- pytensor/tensor/rewriting/math.py | 4 ++-- tests/tensor/rewriting/test_basic.py | 14 ++++++++++++++ 4 files changed, 27 insertions(+), 31 deletions(-) diff --git a/pytensor/configdefaults.py b/pytensor/configdefaults.py index a5958d7f4f..58f2f2faa9 100644 --- a/pytensor/configdefaults.py +++ b/pytensor/configdefaults.py @@ -682,25 +682,7 @@ def add_traceback_configvars(): def add_experimental_configvars(): - config.add( - "experimental__local_alloc_elemwise", - "DEPRECATED: If True, enable the experimental" - " optimization local_alloc_elemwise." - " Generates error if not True. Use" - " optimizer_excluding=local_alloc_elemwise" - " to disable.", - BoolParam(True), - in_c_key=False, - ) - - # False could make the graph faster but not as safe. - config.add( - "experimental__local_alloc_elemwise_assert", - "When the local_alloc_elemwise is applied, add" - " an assert to highlight shape errors.", - BoolParam(True), - in_c_key=False, - ) + return def add_error_and_warning_configvars(): diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index acb3ebc84c..23cb429e37 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -256,7 +256,7 @@ def local_scalar_tensor_scalar(fgraph, node): return [s] -@register_specialize("local_alloc_elemwise") +@register_specialize("shape_unsafe") @node_rewriter([Elemwise]) def local_elemwise_alloc(fgraph, node): r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s. @@ -377,7 +377,7 @@ def dimshuffled_alloc(i): return ret -@register_canonicalize +@register_canonicalize("shape_unsafe") @node_rewriter([Elemwise]) def local_fill_sink(fgraph, node): """ @@ -428,8 +428,8 @@ def local_fill_sink(fgraph, node): return replacements -@register_specialize -@register_stabilize +@register_specialize("shape_unsafe") +@register_stabilize("shape_unsafe") @node_rewriter([fill]) def local_fill_to_alloc(fgraph, node): r"""Remove `fill`\s or replace them with `Alloc`\s. @@ -479,8 +479,8 @@ def local_fill_to_alloc(fgraph, node): ) -@register_canonicalize("fast_compile") -@register_useless +@register_canonicalize("fast_compile", "shape_unsafe") +@register_useless("shape_unsafe") @node_rewriter([fill]) def local_useless_fill(fgraph, node): """fill(s,v) -> v @@ -500,10 +500,10 @@ def local_useless_fill(fgraph, node): return [v] -@register_specialize -@register_stabilize -@register_canonicalize -@register_useless +@register_specialize("shape_unsafe") +@register_stabilize("shape_unsafe") +@register_canonicalize("shape_unsafe") +@register_useless("shape_unsafe") @node_rewriter([Alloc]) def local_useless_alloc(fgraph, node): """ diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index ddfccff912..a5b04d54c8 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1176,7 +1176,7 @@ def mul_calculate(num, denum, aslist=False, out_type=None): local_mul_canonizer = AlgebraicCanonizer( mul, true_div, reciprocal, mul_calculate, False ) -register_canonicalize(local_mul_canonizer, name="local_mul_canonizer") +register_canonicalize(local_mul_canonizer, "shape_unsafe", name="local_mul_canonizer") @register_canonicalize @@ -2493,7 +2493,7 @@ def add_calculate(num, denum, aslist=False, out_type=None): ) -register_canonicalize(local_add_canonizer, name="local_add_canonizer") +register_canonicalize(local_add_canonizer, "shape_unsafe", name="local_add_canonizer") def distribute_greedy(pos_pairs, neg_pairs, num, denum, out_type, minscore=0): diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index af001d6346..eb2aa8eeb1 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1933,3 +1933,17 @@ def test_misc(self): x_val = np.random.random((1, 5)).astype(self.dtype) exp_res = np.broadcast_to(x_val, (5, 5))[..., None] + y_val assert np.array_equal(func(y_val, x_val), exp_res) + + +def test_shape_unsafe_tag(): + mode = get_mode("FAST_RUN") + x = vector("x") + y = vector("y") + out = x * y / y + + fn = function([x, y], out, mode=mode) + np.testing.assert_equal(fn([0, 1], [2, 3, 4]), [0, 1]) + + fn = function([x, y], out, mode=mode.excluding("shape_unsafe")) + with pytest.raises(ValueError): + fn([0, 1], [2, 3, 4]), [0, 1] From 2ac87749e6625c744f3a295477cd210e3a6dc968 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jul 2023 14:56:58 +0200 Subject: [PATCH 7/9] Simplify rewrites by assuming Elemwise / Alloc shapes are correct --- pytensor/tensor/rewriting/basic.py | 165 ++++++++------------------- tests/tensor/rewriting/test_basic.py | 71 ++++++------ 2 files changed, 86 insertions(+), 150 deletions(-) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 23cb429e37..b190a4ea80 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -23,7 +23,7 @@ """ import logging -from typing import TYPE_CHECKING, Optional, Union +from typing import Union import numpy as np @@ -65,21 +65,17 @@ ) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError -from pytensor.tensor.extra_ops import broadcast_shape, broadcast_to +from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.math import Sum, add from pytensor.tensor.math import all as at_all from pytensor.tensor.math import eq -from pytensor.tensor.shape import Shape_i +from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.sort import TopKOp from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.var import TensorConstant, TensorVariable from pytensor.utils import NoDuplicateOptWarningFilter -if TYPE_CHECKING: - from pytensor.tensor.rewriting.shape import ShapeFeature - - _logger = logging.getLogger("pytensor.tensor.rewriting.basic") _logger.addFilter(NoDuplicateOptWarningFilter()) @@ -261,31 +257,16 @@ def local_scalar_tensor_scalar(fgraph, node): def local_elemwise_alloc(fgraph, node): r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s. - `Alloc`\s are effectively a type of `Elemwise` operation - (e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so - this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to - `Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it - broadcasts). - - In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant - `Alloc`\s. - The rewrite essentially performs the following replacement: - ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``, - when ``y.shape`` for some input ``y`` (or the combined shapes of the - non-`Alloc`\s) is sufficient to maintain the same/correct output shape. + ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)`` - In it's current form, it also explicitly accounts for `DimShuffle`\s of + In its current form, it also explicitly accounts for `DimShuffle`\s of `Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which introduces them as a canonicalization of `Alloc`'s with leading broadcastable dimensions. """ - # Rewrite is only applicable when there are at least two inputs if len(node.inputs) == 1: - return False - - if len(node.outputs) > 1: - return False + return None def dimshuffled_alloc(i): return ( @@ -305,76 +286,40 @@ def dimshuffled_alloc(i): if len(alloc_idxs) == 0: return False - # Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a - # baseline for the dimensions. - ref_var_idx = None - for idx, i in enumerate(node.inputs): - if i.type.broadcastable == node.outputs[0].type.broadcastable: - # Prefer an input that is not an `Alloc` nor a `DimShuffle` of an - # `Alloc`, so that all `Alloc`s can be rewritten. - if idx not in alloc_idxs: - ref_var_idx = idx - break - - # If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one - if ref_var_idx is None: - for idx, i in enumerate(node.inputs): - # XXX: This broadcastable comparison doesn't work - if ( - i.type.broadcastable == node.outputs[0].type.broadcastable - ) and idx in alloc_idxs: - ref_var_idx = idx - break - - if not hasattr(fgraph, "shape_feature"): - return False - - input_shapes = [ - tuple(fgraph.shape_feature.get_shape(i, j) for j in range(i.type.ndim)) - for i in node.inputs - ] - bcasted_shape = broadcast_shape( - *input_shapes, - arrays_are_shapes=True, - ) - new_inputs = list(node.inputs) for idx in alloc_idxs: i = node.inputs[idx] - # Remove `Alloc` + # Remove simple `Alloc` if isinstance(i.owner.op, Alloc): - new_alloc = broadcast_to(i.owner.inputs[0], bcasted_shape) + new_inp = i.owner.inputs[0] - # TODO FIXME: This shouldn't be handled here. - # `DimShuffle`s should be lifted through `Alloc`s - # by other, more general rewrites. - # Remove `Alloc` in `DimShuffle` + # Remove `Dimshuffle(Alloc)` elif isinstance(i.owner.op, DimShuffle): old_alloc = i.owner.inputs[0] - new_alloc = old_alloc.owner.inputs[0] + old_alloc_inp = old_alloc.owner.inputs[0] + missing_ndims = old_alloc.type.ndim - old_alloc_inp.type.ndim + if missing_ndims > 0: + # The `Alloc` added new dimensions to the left. + # We replace those cases with a `DimShuffle` here. + # Nested dimshuffles will be merged later by other rewrites. + old_alloc_inp = shape_padleft(old_alloc_inp, missing_ndims) # We need to keep the old `DimShuffle`. It could swap axes or # add dimensions anywhere. - if new_alloc.ndim != old_alloc.ndim: - # The `Alloc` can add dimensions to the value. - # We replace those cases with a `DimShuffle` here. - nb_dim_to_add = old_alloc.ndim - new_alloc.ndim - new_alloc = new_alloc.dimshuffle( - ["x"] * nb_dim_to_add + list(range(new_alloc.ndim)) - ) - new_alloc = broadcast_to(i.owner.op(new_alloc), bcasted_shape) + new_inp = i.owner.op(old_alloc_inp) - copy_stack_trace(i, new_alloc) - new_inputs[idx] = new_alloc + copy_stack_trace(i, new_inp) + new_inputs[idx] = new_inp - # If this assert is triggered, it means we are recreating an equivalent graph - # which would result in cyclical merge rewrites. - if all(new is old for new, old in zip(new_inputs, node.inputs)): - return + new_outs = node.op(*new_inputs, return_list=True) - ret = node.op(*new_inputs, return_list=True) - copy_stack_trace(node.outputs, ret) - return ret + if new_outs[0].type.broadcastable != node.outputs[0].type.broadcastable: + new_outs = [ + alloc_like(new_out, node.outputs[0], fgraph) for new_out in new_outs + ] + + copy_stack_trace(node.outputs, new_outs) + return new_outs @register_canonicalize("shape_unsafe") @@ -406,6 +351,7 @@ def local_fill_sink(fgraph, node): # The newly created node c doesn't has 'clients', # so this iteration is took place with node.outputs[0] + # TODO: This should just be a WalkingGraphRewrite! replacements = {node.outputs[0]: c} for client, cl_idx in fgraph.clients[node.outputs[0]]: if ( @@ -438,9 +384,8 @@ def local_fill_to_alloc(fgraph, node): with their dependencies on those tensors' shapes, and sometimes those shapes can be computed without needing to compute the tensors themselves. - XXX: This rewrite can produce inconsistent results, so do *not* consider - making it a canonicalization until those inconsistencies are - resolved/justified. + Like `local_fill_sink` this rewrites assumes non-broadcastable shapes are equivalent, + which could mask shape errors. """ shape_ref, values_ref = node.inputs out_type = node.outputs[0].type @@ -448,13 +393,6 @@ def local_fill_to_alloc(fgraph, node): if values_ref.type.broadcastable == out_type.broadcastable: # The assumption here is that `values_ref` already has the same shape # as `shape_ref`, so a `fill`/`Alloc` is unnecessary. - - # XXX FIXME TODO: The only way this can be determined is if one - # absolutely knows that the shapes of `shape_ref` and `values_ref` are - # equal. - # This is an old rewrite, and it's only a - # "specialization/stabilization", so we're going to leave it be for - # now. return [values_ref] if shape_ref.type.broadcastable == out_type.broadcastable: @@ -465,6 +403,9 @@ def local_fill_to_alloc(fgraph, node): copy_stack_trace(node.outputs[0], o) return [o] + # The case that is not covered is when `shape_ref` is broadcasted by `values_ref` + # TODO: Return broadcast_to(values_ref, broadcast_shapes(values_ref.shape, shape_ref.shape)) + return @@ -1014,36 +955,30 @@ def local_sum_make_vector(fgraph, node): return [element_sum] -@register_useless("local_remove_switch_const_cond") -@register_canonicalize("fast_compile", "local_remove_switch_const_cond") -@register_specialize -@node_rewriter([Elemwise]) +@register_useless("shape_unsafe") +@register_canonicalize("fast_compile", "shape_unsafe") +@register_specialize("shape_unsafe") +@node_rewriter([switch]) def local_useless_switch(fgraph, node): """ This rewrite makes the following changes in a graph: - at.switch(cond, left, right) -> - if cond is constant and cond == 0: right - if cond is constant and cond != 0: left - if left is right -> left + switch(cond, left, right) -> + if cond is constant and cond == 0: right + if cond is constant and cond != 0: left + if left is right -> left and - at.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) + switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X) """ - if not isinstance(node.op.scalar_op, aes.Switch): - return False - - shape_feature: Optional["ShapeFeature"] = getattr(fgraph, "shape_feature", None) - - if shape_feature is None: - return False left = node.inputs[1] right = node.inputs[2] cond_var = node.inputs[0] cond = extract_constant(cond_var, only_process_constants=True) + out_bcast = node.outputs[0].type.broadcastable if (isinstance(cond, np.ndarray) and cond.ndim == 0) or isinstance( cond, (np.number, np.bool_) @@ -1058,14 +993,8 @@ def local_useless_switch(fgraph, node): else: out = correct_out - input_shapes = [ - tuple(shape_feature.get_shape(inp, i) for i in range(inp.type.ndim)) - for inp in node.inputs - ] - - out_shape = broadcast_shape(*input_shapes, arrays_are_shapes=True) - - out = alloc(out, *out_shape) + if out.type.broadcastable != out_bcast: + out = broadcast_arrays(out, *node.inputs)[0] # Copy over stacktrace from selected output to new output copy_stack_trace(node.outputs + correct_out, out) @@ -1075,10 +1004,10 @@ def local_useless_switch(fgraph, node): if left == right: # Note: No need to copy over stacktrace, because the input node # already has its own stacktrace - if cond.type.is_super(left.type): + if left.type.broadcastable == out_bcast: return [left] - ret = fill(cond, left) + ret = broadcast_arrays(left, cond)[0] # Copy over stacktrace from switch output and correct branch copy_stack_trace(node.outputs + left, ret) diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index eb2aa8eeb1..5d364da6fd 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -1013,7 +1013,7 @@ def test_broadcasting_1(self): z = at.switch(1, x, y) f = function([x, y], z, mode=self.mode) - start_var = f.maker.fgraph.outputs[0].owner.inputs[0] + start_var = f.maker.fgraph.outputs[0] assert isinstance(start_var.owner.op, Elemwise) assert isinstance(start_var.owner.op.scalar_op, aes.basic.Cast) assert not any(node.op == at.switch for node in f.maker.fgraph.toposort()) @@ -1698,45 +1698,50 @@ def verify_op_count(f, count, cls): ) @pytest.mark.parametrize( - "expr, x_shape, y_shape", + "expr, x_shape, y_shape, needs_alloc", [ - (lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 2), (3, 2)), - (lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 1), (1, 1)), - (lambda x, y: at.mul(x, at.alloc(y, 2, 3)), (1, 3), (2, 3)), + (lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 2), (3, 2), True), + (lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 1), (1, 1), False), + (lambda x, y: at.mul(x, at.alloc(y, 2, 3)), (1, 3), (2, 3), False), ( lambda x, y: at.mul( at.alloc(x, 3).dimshuffle("x", 0), y.dimshuffle("x", "x") ), (), (), + True, ), - (lambda x, y: at.mul(y, at.alloc(1, x)), (), ()), - (lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)), - (lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)), + (lambda x, y: at.mul(y, at.alloc(1, x)), (), (), True), + (lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1), False), + (lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2), False), ( lambda x, y: at.mul(at.alloc(x, 15, 1), at.alloc(y, 15, 1)), (15, 1), (15, 1), + False, ), ( lambda x, y: at.mul(at.alloc(x, 15, 2), at.alloc(y, 15, 2)), (15, 2), (15, 2), + False, ), ( lambda x, y: at.mul(at.alloc(x, 15, 2).dimshuffle(1, 0), y), (15, 2), (2, 15), + False, ), - (lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2)), + (lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2), False), ( lambda x, y: at.mul(at.alloc(x, 1, 15, 2).dimshuffle(0, 2, 1), y), (15, 2), (2, 15), + False, ), ], ) - def test_basic(self, expr, x_shape, y_shape): + def test_basic(self, expr, x_shape, y_shape, needs_alloc): x = at.tensor( dtype="int64", shape=(1 if val == 1 else None for val in x_shape), name="x" ) @@ -1752,10 +1757,16 @@ def test_basic(self, expr, x_shape, y_shape): on_unused_input="ignore", ) - assert not any( - isinstance(node.op, Alloc) for node in z_opt.maker.fgraph.toposort() - ) + nodes = z_opt.maker.fgraph.toposort() + if needs_alloc: + # When the final result needs an Alloc, this should be the last node + # x = scalar; y = vector; mul(x, ones_like(y)) -> alloc(x, y.shape) + assert isinstance(nodes[-1].op, Alloc) + nodes = nodes[:-1] + + assert not any(isinstance(node.op, Alloc) for node in nodes) + # Check results are the same without the optimization z_no_opt = pytensor.function( [x, y], z, @@ -1799,7 +1810,7 @@ def test_remove_alloc_wo_dimshuffle(self): [self.vec, self.mat], self.alloc_wo_dep + self.mat, mode=self.fast_run_mode ) self.verify_op_count(func, 0, Alloc) - self.verify_op_count(func, 2, Assert) + self.verify_op_count(func, 1, SpecifyShape) func = function( [self.vec, self.mat], @@ -1807,7 +1818,7 @@ def test_remove_alloc_wo_dimshuffle(self): mode=self.fast_run_mode, ) self.verify_op_count(func, 0, Alloc) - self.verify_op_count(func, 1, Assert) + self.verify_op_count(func, 1, SpecifyShape) # No optimization on alloc without assert func = function( @@ -1839,7 +1850,10 @@ def test_remove_alloc_wo_dimshuffle(self): self.alloc_w_dep_broad2 + self.mat, mode=self.fast_run_mode, ) - self.verify_op_count(func, 0, Alloc) + # This graph requires one outer Alloc and an Assert + # To make sure `mat` is square since we end up doing + # broadcast_to(x, mat[..., None].shape) + mat[None, ...] + self.verify_op_count(func, 1, Alloc) self.verify_op_count(func, 1, Assert) def test_remove_alloc_w_dimshuffle(self): @@ -1851,16 +1865,13 @@ def test_remove_alloc_w_dimshuffle(self): self.verify_op_count(func, 1, Alloc) self.verify_op_count(func, 0, Assert) - # TODO FIXME: The `BroadcastTo` shapes should use the constants - # provided by the first/`Alloc` term, and not the unknown values from - # the `tens` term. func = function( [self.vec, self.tens], self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens, mode=self.fast_run_mode, ) self.verify_op_count(func, 0, Alloc) - self.verify_op_count(func, 2, Assert) + self.verify_op_count(func, 1, SpecifyShape) func = function( [self.vec, self.tens], @@ -1888,16 +1899,13 @@ def test_multi_input_single_alloc(self): self.verify_op_count(func, 2, Alloc) self.verify_op_count(func, 0, Assert) - # Optimization on dimshuffle with assert - # TODO: When we support static shape constraints like `shape[i] != 1`, - # reproduce this with such a constraint on `mat` and make sure the - # `BroadcastTo` is removed. func = function( [self.vec, self.mat], self.tv_wo_dep + self.tm_wo_dep, mode=self.fast_run_mode, ) - self.verify_op_count(func, 0, Alloc) + # It still needs an outer alloc to broadcast final shape + self.verify_op_count(func, 1, Alloc) self.verify_op_count(func, 0, Assert) # No optimization on dimshuffle without assert @@ -1909,25 +1917,24 @@ def test_multi_input_single_alloc(self): self.verify_op_count(func, 2, Alloc) self.verify_op_count(func, 0, Assert) - # Optimization on dimshuffle without assert func = function( [self.vec, self.mat, self.s], self.tv_w_dep + self.tm_w_dep, mode=self.fast_run_mode, ) - self.verify_op_count(func, 0, Alloc) - # The second assert is from the shape check... - self.verify_op_count(func, 2, Assert) + # It still needs an outer alloc to broadcast final shape + self.verify_op_count(func, 1, Alloc) + self.verify_op_count(func, 0, Assert) def test_misc(self): - x = row(dtype=self.dtype) - y = tensor(dtype=self.dtype, shape=(None, None, 1)) + x = row("x", dtype=self.dtype) + y = tensor("y", dtype=self.dtype, shape=(None, None, 1)) out = at.alloc(x, 5, 5).dimshuffle(0, 1, "x") + y func = function([y, x], out, mode=self.fast_run_mode) self.verify_op_count(func, 0, Alloc) - self.verify_op_count(func, 2, Assert) + self.verify_op_count(func, 1, SpecifyShape) y_val = np.random.random((5, 5, 1)).astype(self.dtype) x_val = np.random.random((1, 5)).astype(self.dtype) From 84c46f1c77ee58a918ce60dad6742d11f11d2113 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jul 2023 13:23:07 +0200 Subject: [PATCH 8/9] Remove BroadcastTo --- pytensor/link/jax/dispatch/extra_ops.py | 14 -- pytensor/link/numba/dispatch/extra_ops.py | 25 --- pytensor/tensor/extra_ops.py | 145 +---------------- pytensor/tensor/rewriting/extra_ops.py | 48 +----- tests/link/jax/test_extra_ops.py | 25 +-- tests/link/numba/test_extra_ops.py | 35 ---- tests/tensor/rewriting/test_extra_ops.py | 73 +-------- tests/tensor/test_extra_ops.py | 190 +--------------------- 8 files changed, 14 insertions(+), 541 deletions(-) diff --git a/pytensor/link/jax/dispatch/extra_ops.py b/pytensor/link/jax/dispatch/extra_ops.py index bfce752434..a9e36667ef 100644 --- a/pytensor/link/jax/dispatch/extra_ops.py +++ b/pytensor/link/jax/dispatch/extra_ops.py @@ -3,10 +3,8 @@ import jax.numpy as jnp from pytensor.link.jax.dispatch.basic import jax_funcify -from pytensor.tensor.basic import infer_static_shape from pytensor.tensor.extra_ops import ( Bartlett, - BroadcastTo, CumOp, FillDiagonal, FillDiagonalOffset, @@ -102,18 +100,6 @@ def ravelmultiindex(*inp, mode=mode, order=order): return ravelmultiindex -@jax_funcify.register(BroadcastTo) -def jax_funcify_BroadcastTo(op, node, **kwargs): - shape = node.inputs[1:] - static_shape = infer_static_shape(shape)[1] - - def broadcast_to(x, *shape): - shape = tuple(st if st is not None else s for s, st in zip(shape, static_shape)) - return jnp.broadcast_to(x, shape) - - return broadcast_to - - @jax_funcify.register(FillDiagonal) def jax_funcify_FillDiagonal(op, **kwargs): def filldiagonal(value, diagonal): diff --git a/pytensor/link/numba/dispatch/extra_ops.py b/pytensor/link/numba/dispatch/extra_ops.py index ce275fd031..a3a489deaa 100644 --- a/pytensor/link/numba/dispatch/extra_ops.py +++ b/pytensor/link/numba/dispatch/extra_ops.py @@ -2,7 +2,6 @@ import numba import numpy as np -from numba.misc.special import literal_unroll from pytensor import config from pytensor.link.numba.dispatch import basic as numba_basic @@ -10,7 +9,6 @@ from pytensor.raise_op import CheckAndRaise from pytensor.tensor.extra_ops import ( Bartlett, - BroadcastTo, CumOp, FillDiagonal, FillDiagonalOffset, @@ -353,29 +351,6 @@ def searchsorted(a, v): return searchsorted -@numba_funcify.register(BroadcastTo) -def numba_funcify_BroadcastTo(op, node, **kwargs): - create_zeros_tuple = numba_basic.create_tuple_creator( - lambda _: 0, len(node.inputs) - 1 - ) - - # TODO broadcastable checks - @numba_basic.numba_njit - def broadcast_to(x, *shape): - scalars_shape = create_zeros_tuple() - - i = 0 - for s_i in literal_unroll(shape): - scalars_shape = numba_basic.tuple_setitem( - scalars_shape, i, numba_basic.to_scalar(s_i) - ) - i += 1 - - return np.broadcast_to(x, scalars_shape) - - return broadcast_to - - @numba_funcify.register(CheckAndRaise) def numba_funcify_CheckAndRaise(op, node, **kwargs): error = op.exc_type diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 09e8bf5551..6a7e8d38bc 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -23,7 +23,7 @@ from pytensor.scalar import upcast from pytensor.tensor import as_tensor_variable from pytensor.tensor import basic as at -from pytensor.tensor.basic import get_vector_length, second +from pytensor.tensor.basic import alloc, second from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import abs as pt_abs from pytensor.tensor.math import all as pt_all @@ -1584,141 +1584,6 @@ def broadcast_shape_iter( return tuple(result_dims) -class BroadcastTo(COp): - """An `Op` for `numpy.broadcast_to`.""" - - _output_type_depends_on_input_value = True - - __props__ = () - - view_map = {0: [0]} - - def __call__(self, a, shape, **kwargs): - return super().__call__(a, *shape, **kwargs) - - def make_node(self, a, *shape): - a = at.as_tensor_variable(a) - - shape, static_shape = at.infer_static_shape(shape) - - if len(shape) < a.ndim: - raise ValueError( - f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims" - ) - - out = TensorType(dtype=a.type.dtype, shape=static_shape)() - - # Attempt to prevent in-place operations on this view-based output - out.tag.indestructible = True - - return Apply(self, [a] + shape, [out]) - - def perform(self, node, inputs, output_storage): - a, *shape = inputs - z = output_storage[0] - z[0] = np.broadcast_to(a, shape) - - def grad(self, inputs, outputs_gradients): - a, *shape = inputs - (dout,) = outputs_gradients - - # Determine the dimensions that were added by broadcasting - new_dims = list(range(dout.ndim - a.ndim)) - - d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims) - - # Determine the dimensions that were broadcast - _, static_shape = at.infer_static_shape(shape) - - # TODO: This needs to be performed at run-time when static shape - # information isn't available. - bcast_sums = [ - i - for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :])) - if a_s == 1 and s_s != 1 - ] - - if bcast_sums: - d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True) - - return [d_wrt_a] + [ - grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1) - ] - - def infer_shape(self, fgraph, node, ins_shapes): - return [node.inputs[1:]] - - def c_code(self, node, name, inputs, outputs, sub): - inp_dims = node.inputs[0].ndim - out_dims = node.outputs[0].ndim - new_dims = out_dims - inp_dims - - (x, *shape) = inputs - (out,) = outputs - fail = sub["fail"] - - # TODO: Could just use `PyArray_Return`, no? - dims_array = ", ".join( - [ - f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]" - for i, shape in enumerate(shape) - ] - ) - - src = ( - """ - npy_intp itershape[%(out_dims)s] = {%(dims_array)s}; - - NpyIter *iter; - PyArrayObject *ops[1] = {%(x)s}; - npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK; - npy_uint32 op_flags[1] = {NPY_ITER_READONLY}; - PyArray_Descr *op_dtypes[1] = {NULL}; - int oa_ndim = %(out_dims)s; - int* op_axes[1] = {NULL}; - npy_intp buffersize = 0; - - for(int i = 0; i < %(inp_dims)s; i++) - { - if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s])) - { - PyErr_Format(PyExc_ValueError, - "Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.", - i, - (long long int) itershape[i + %(new_dims)s], - (long long int) PyArray_DIMS(%(x)s)[i] - ); - %(fail)s - } - } - - iter = NpyIter_AdvancedNew( - 1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize - ); - %(out)s = NpyIter_GetIterView(iter, 0); - - if(%(out)s == NULL){ - NpyIter_Deallocate(iter); - %(fail)s; - } - - if (NpyIter_Deallocate(iter) != NPY_SUCCEED) { - %(fail)s; - } - - """ - % locals() - ) - - return src - - def c_code_cache_version(self): - return (2,) - - -broadcast_to_ = BroadcastTo() - - def geomspace(start, end, steps, base=10.0): from pytensor.tensor.math import log @@ -1762,13 +1627,7 @@ def broadcast_to( broadcasted array may refer to a single memory location. """ - x = at.as_tensor(x) - shape_len = get_vector_length(shape) - - if x.ndim == 0 and shape_len == 0: - return x - - return broadcast_to_(x, shape) + return alloc(x, *shape) def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: diff --git a/pytensor/tensor/rewriting/extra_ops.py b/pytensor/tensor/rewriting/extra_ops.py index aa20334abc..945433f2a4 100644 --- a/pytensor/tensor/rewriting/extra_ops.py +++ b/pytensor/tensor/rewriting/extra_ops.py @@ -2,7 +2,7 @@ from pytensor.graph.rewriting.basic import node_rewriter from pytensor.tensor.basic import Alloc, as_tensor_variable from pytensor.tensor.elemwise import Elemwise -from pytensor.tensor.extra_ops import BroadcastTo, Repeat, Unique +from pytensor.tensor.extra_ops import Repeat, Unique from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless @@ -60,39 +60,6 @@ def local_Unique_Alloc_lift(fgraph, node): return [new_x] -@register_useless -@register_canonicalize -@node_rewriter([Unique]) -def local_Unique_BroadcastTo_lift(fgraph, node): - """Convert ``unique(broadcast_to(x, ...), axis=None)`` to ``unique(x, axis=None)``. - - This isn't really so much a lift as a "reduction/consumption". - """ - if not isinstance(node.op, Unique): - return False - - if ( - node.op.return_index - or node.op.return_inverse - or node.op.return_counts - or node.op.axis is not None - ): - return False - - bcast_var = node.inputs[0] - - if not (bcast_var.owner and isinstance(bcast_var.owner.op, BroadcastTo)): - return False - - bcasted_var, *bcast_shape = bcast_var.owner.inputs - - new_unique, *_ = node.op.make_node(bcasted_var).outputs - - old_out = node.outputs[0] - new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) - return [new_x] - - @register_useless @register_canonicalize @node_rewriter([Unique]) @@ -161,16 +128,3 @@ def local_Unique_second(fgraph, node): old_out = node.outputs[0] new_x = as_tensor_variable(new_unique, ndim=old_out.ndim, dtype=old_out.dtype) return [new_x] - - -@register_useless -@register_canonicalize -@node_rewriter([BroadcastTo]) -def local_remove_scalar_BroadcastTo(fgraph, node): - bcast_shape = node.inputs[1:] - - if not bcast_shape: - bcasted_var = node.inputs[0] - # If this isn't true, the graph is invalid - assert bcasted_var.ndim == 0 - return [bcasted_var] diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index 73c6e4249c..78abd671b8 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -7,7 +7,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value from pytensor.tensor import extra_ops as at_extra_ops -from pytensor.tensor.type import matrix, vector +from pytensor.tensor.type import matrix from tests.link.jax.test_basic import compare_jax_and_py @@ -63,29 +63,6 @@ def test_extra_ops(): ) -@pytest.mark.parametrize( - "x, shape", - [ - ( - set_test_value( - vector("x"), np.random.random(size=(2,)).astype(config.floatX) - ), - [at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)], - ), - ( - set_test_value( - vector("x"), np.random.random(size=(2,)).astype(config.floatX) - ), - [at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)], - ), - ], -) -def test_BroadcastTo(x, shape): - out = at_extra_ops.broadcast_to(x, shape) - fgraph = FunctionGraph(outputs=[out]) - compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) - - @pytest.mark.xfail( version_parse(jax.__version__) >= version_parse("0.2.12"), reason="Omnistaging cannot be disabled", diff --git a/tests/link/numba/test_extra_ops.py b/tests/link/numba/test_extra_ops.py index 30b62ba225..36a67cfff0 100644 --- a/tests/link/numba/test_extra_ops.py +++ b/tests/link/numba/test_extra_ops.py @@ -36,41 +36,6 @@ def test_Bartlett(val): ) -@pytest.mark.parametrize( - "x, shape", - [ - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - [set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]], - ), - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - [at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)], - ), - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - at.as_tensor([set_test_value(at.lscalar(), np.array(v)) for v in [3, 2]]), - ), - ( - set_test_value(at.vector(), rng.random(size=(2,)).astype(config.floatX)), - [at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)], - ), - ], -) -def test_BroadcastTo(x, shape): - g = extra_ops.BroadcastTo()(x, shape) - g_fg = FunctionGraph(outputs=[g]) - - compare_numba_and_py( - g_fg, - [ - i.tag.test_value - for i in g_fg.inputs - if not isinstance(i, (SharedVariable, Constant)) - ], - ) - - @pytest.mark.parametrize( "val, axis, mode", [ diff --git a/tests/tensor/rewriting/test_extra_ops.py b/tests/tensor/rewriting/test_extra_ops.py index d0aac80249..15f5870e5b 100644 --- a/tests/tensor/rewriting/test_extra_ops.py +++ b/tests/tensor/rewriting/test_extra_ops.py @@ -8,7 +8,7 @@ from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.tensor.basic import Alloc, alloc, as_tensor_variable, second from pytensor.tensor.elemwise import DimShuffle, Elemwise -from pytensor.tensor.extra_ops import BroadcastTo, Repeat, Unique, repeat, unique +from pytensor.tensor.extra_ops import Repeat, Unique, repeat, unique from pytensor.tensor.type import dscalar @@ -103,64 +103,6 @@ def test_local_Unique_Alloc_lift( assert np.array_equal(y_exp_val, y_val) -@pytest.mark.parametrize( - "x_val, axis, new_shape", - [ - (np.array(-10, dtype=np.int64), None, (2, 3)), - (np.array([[-10, -3], [-10, 2], [-10, 2]], dtype=np.int64), None, (2, 3, 2)), - ], -) -@pytest.mark.parametrize("return_index", [False]) -@pytest.mark.parametrize("return_counts", [False]) -@pytest.mark.parametrize("return_inverse", [False]) -def test_local_Unique_BroadcastTo( - x_val, axis, new_shape, return_index, return_counts, return_inverse -): - x = as_tensor_variable(x_val).type() - y = unique( - BroadcastTo()(x, tuple(new_shape)), - return_index=return_index, - return_counts=return_counts, - return_inverse=return_inverse, - axis=axis, - ) - - if isinstance(y, list): - y, *_ = y - - # This approach allows us to directly confirm that `x` is in the result. - y_fg = FunctionGraph(outputs=[y], copy_inputs=False) - y_rewritten_fg = rewrite_graph( - y_fg, - clone=False, - include=["canonicalize", "local_Unique_BroadcastTo_lift"], - exclude=["local_Unique_scalar"], - ) - y_rewritten = y_rewritten_fg.outputs[0] - y_rewritten_start = y_rewritten - - assert isinstance(y_rewritten_start.owner.op, Unique) - assert y_rewritten_start.owner.inputs[0] == x - assert not any( - isinstance(node.op, BroadcastTo) for node in y_rewritten_fg.apply_nodes - ) - - default_mode = get_default_mode() - # The rewrite has already been applied to `y_rewritten`, so we can--and - # should--exclude it from the compilation of both our reference, `y`, and - # the rewritten result, `y_rewritten`. - rewrite_mode = default_mode.excluding("local_Unique_BroadcastTo_lift") - y_fn = function([x], [y, y_rewritten], mode=rewrite_mode) - # Make sure that the original `BroadcastTo` is used to compute the - # reference `y` result - assert any( - isinstance(node.op, BroadcastTo) for node in y_fn.maker.fgraph.apply_nodes - ) - - y_exp_val, y_val = y_fn(x_val) - assert np.array_equal(y_exp_val, y_val) - - @pytest.mark.parametrize( "x_val, unique_axis, repeats, repeat_axis", [ @@ -287,16 +229,3 @@ def test_local_Unique_second( y_exp_val, y_val = y_fn(x_val) assert np.array_equal(y_exp_val, y_val) - - -def test_local_remove_scalar_BroadcastTo(): - x = dscalar() - y = BroadcastTo()(x, ()) - - assert isinstance(y.owner.op, BroadcastTo) - - res = rewrite_graph( - y, clone=False, include=["canonicalize", "local_remove_scalar_BroadcastTo"] - ) - - assert res is x diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 4d2c3fec9e..e103567564 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -8,14 +8,12 @@ from pytensor import tensor as at from pytensor.compile.mode import Mode from pytensor.configdefaults import config -from pytensor.graph.basic import Constant, applys_between -from pytensor.graph.replace import clone_replace -from pytensor.graph.rewriting.db import RewriteDatabaseQuery +from pytensor.graph.basic import Constant, applys_between, equal_computations from pytensor.raise_op import Assert +from pytensor.tensor import alloc from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.extra_ops import ( Bartlett, - BroadcastTo, CpuContiguous, CumOp, FillDiagonal, @@ -47,7 +45,6 @@ to_one_hot, unravel_index, ) -from pytensor.tensor.subtensor import AdvancedIncSubtensor from pytensor.tensor.type import ( TensorType, dmatrix, @@ -61,7 +58,6 @@ lscalar, matrix, scalar, - tensor, tensor3, vector, ) @@ -1246,183 +1242,15 @@ def test_broadcast_shape_symbolic_one_symbolic(): assert res_shape[2].data == 3 -class TestBroadcastTo(utt.InferShapeTester): - def setup_method(self): - super().setup_method() - self.op_class = BroadcastTo - self.op = broadcast_to - - def test_avoid_useless_scalars(self): - x = scalar() - y = broadcast_to(x, ()) - assert y is x - - def test_avoid_useless_subtensors(self): - x = scalar() - y = broadcast_to(x, (1, 2)) - # There shouldn't be any unnecessary `Subtensor` operations - # (e.g. from `at.as_tensor((1, 2))[0]`) - assert y.owner.inputs[1].owner is None - assert y.owner.inputs[2].owner is None - - @pytest.mark.parametrize("linker", ["cvm", "py"]) - def test_perform(self, linker): - a = pytensor.shared(np.full((3, 1, 1), 5)) - s_0 = iscalar("s_0") - s_1 = iscalar("s_1") - shape = (s_0, s_1, 1) - - bcast_res = broadcast_to(a, shape) - assert bcast_res.broadcastable == (False, False, True) - - bcast_fn = pytensor.function( - [s_0, s_1], bcast_res, mode=Mode(optimizer=None, linker=linker) - ) - bcast_fn.vm.allow_gc = False - - bcast_at = bcast_fn(3, 4) - bcast_np = np.broadcast_to(5, (3, 4, 1)) - - assert np.array_equal(bcast_at, bcast_np) - - with pytest.raises(ValueError): - bcast_fn(5, 4) - - if linker != "py": - bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0] - bcast_in = bcast_fn.vm.storage_map[a] - bcast_out = bcast_fn.vm.storage_map[bcast_var] - assert np.shares_memory(bcast_out[0], bcast_in[0]) - - def test_make_node_error_handling(self): - with pytest.raises( - ValueError, - match="Broadcast target shape has 1 dims, which is shorter than input with 2 dims", - ): - broadcast_to(at.zeros((3, 4)), (5,)) +def test_broadcast_to(): + x = vector("x") + y1 = scalar(dtype="int64") + y2 = scalar(dtype="int64") - @pytest.mark.skipif( - not config.cxx, reason="G++ not available, so we need to skip this test." + assert equal_computations( + [broadcast_to(x, (y1, y2))], + [alloc(x, y1, y2)], ) - @pytest.mark.parametrize("valid", (True, False)) - def test_memory_leak(self, valid): - import gc - import tracemalloc - - from pytensor.link.c.cvm import CVM - - n = 100_000 - x = pytensor.shared(np.ones((1, n), dtype=np.float64)) - y = broadcast_to(x, (5, n)) - - f = pytensor.function([], y, mode=Mode(optimizer=None, linker="cvm")) - assert isinstance(f.vm, CVM) - - assert len(f.maker.fgraph.apply_nodes) == 2 - assert any( - isinstance(node.op, BroadcastTo) for node in f.maker.fgraph.apply_nodes - ) - - tracemalloc.start() - - blocks_last = None - block_diffs = [] - for i in range(1, 50): - if valid: - x.set_value(np.ones((1, n))) - _ = f() - else: - x.set_value(np.ones((2, n))) - try: - _ = f() - except ValueError: - pass - else: - raise RuntimeError("Should have failed") - _ = gc.collect() - blocks_i, _ = tracemalloc.get_traced_memory() - if blocks_last is not None: - blocks_diff = (blocks_i - blocks_last) // 10**3 - block_diffs.append(blocks_diff) - blocks_last = blocks_i - - tracemalloc.stop() - assert np.all(np.array(block_diffs) <= (0 + 1e-8)) - - @pytest.mark.parametrize( - "fn,input_dims", - [ - [lambda x: broadcast_to(x, (1,)), (1,)], - [lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)], - [lambda x: broadcast_to(x, (6, 2, 5, 3)), (5, 1)], - [lambda x: broadcast_to(x, (6, 2, 1, 3)), (2, 1, 3)], - ], - ) - def test_gradient(self, fn, input_dims): - rng = np.random.default_rng(43) - utt.verify_grad( - fn, - [rng.random(input_dims).astype(config.floatX)], - n_tests=1, - rng=rng, - ) - - def test_infer_shape(self): - rng = np.random.default_rng(43) - a = tensor(dtype=config.floatX, shape=(None, 1, None)) - shape = list(a.shape) - out = self.op(a, shape) - - self._compile_and_check( - [a] + shape, - [out], - [rng.random((2, 1, 3)).astype(config.floatX), 2, 1, 3], - self.op_class, - ) - - a = tensor(dtype=config.floatX, shape=(None, 1, None)) - shape = [iscalar() for i in range(4)] - self._compile_and_check( - [a] + shape, - [self.op(a, shape)], - [rng.random((2, 1, 3)).astype(config.floatX), 6, 2, 5, 3], - self.op_class, - ) - - def test_inplace(self): - """Make sure that in-place optimizations are *not* performed on the output of a ``BroadcastTo``.""" - a = at.zeros((5,)) - d = at.vector("d") - c = at.set_subtensor(a[np.r_[0, 1, 3]], d) - b = broadcast_to(c, (5,)) - q = b[np.r_[0, 1, 3]] - e = at.set_subtensor(q, np.r_[0, 0, 0]) - - opts = RewriteDatabaseQuery(include=["inplace"]) - py_mode = Mode("py", opts) - e_fn = function([d], e, mode=py_mode) - - advincsub_node = e_fn.maker.fgraph.outputs[0].owner - assert isinstance(advincsub_node.op, AdvancedIncSubtensor) - assert isinstance(advincsub_node.inputs[0].owner.op, BroadcastTo) - - assert advincsub_node.op.inplace is False - - def test_rebuild(self): - x = vector(shape=(50,)) - x_test = np.zeros((50,), dtype=config.floatX) - i = 0 - y = broadcast_to(i, x.shape) - assert y.type.shape == (50,) - assert y.shape.eval({x: x_test}) == (50,) - assert y.eval({x: x_test}).shape == (50,) - - x_new = vector(shape=(100,)) - x_new_test = np.zeros((100,), dtype=config.floatX) - y_new = clone_replace(y, {x: x_new}, rebuild_strict=False) - assert y_new.type.shape == (100,) - assert y_new.shape.eval({x_new: x_new_test}) == (100,) - assert y_new.eval({x_new: x_new_test}).shape == (100,) def test_broadcast_arrays(): From 9f8ed942db9244a0ab8508892ae18e2f23131726 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 13 Jul 2023 16:03:50 +0200 Subject: [PATCH 9/9] Incorporate static shape of Alloc input --- pytensor/tensor/basic.py | 38 +++++++++++++++++++++++----- tests/tensor/rewriting/test_basic.py | 21 --------------- tests/tensor/test_basic.py | 16 ++++++++++++ 3 files changed, 47 insertions(+), 28 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 762e145c8b..1f43d5531d 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -1432,17 +1432,41 @@ class Alloc(COp): __props__ = () def make_node(self, value, *shape): - v = as_tensor_variable(value) - sh, static_shape = infer_static_shape(shape) - if v.ndim > len(sh): + value = as_tensor_variable(value) + shape, static_shape = infer_static_shape(shape) + if value.ndim > len(shape): raise TypeError( "The Alloc value to use has more dimensions" " than the specified dimensions", - v.ndim, - len(sh), + value.ndim, + len(shape), ) - otype = TensorType(dtype=v.dtype, shape=static_shape) - return Apply(self, [v] + sh, [otype()]) + + # Combine static shape information from value and shape + combined_static_shape = list(static_shape).copy() + new_dims = len(shape) - value.type.ndim + extended_value_static_shape = (None,) * new_dims + value.type.shape + extended_value_broadcastable = (False,) * new_dims + value.type.broadcastable + for i, (v_bc, v_st, sh_st) in enumerate( + zip( + extended_value_broadcastable, + extended_value_static_shape, + static_shape, + ) + ): + # If value is not broadcastable and we don't know the target static shape: use value static shape + if (not v_bc) and (sh_st is None): + combined_static_shape[i] = v_st + # Otherwise check if static shapes are compatible + elif (v_st is not None) and (sh_st is not None): + # They must match or if not, the value must be broadcastable + if v_st != sh_st and not v_bc: + raise ValueError( + f"Alloc static input type and target shape are incompatible: {value.type} vs {static_shape}" + ) + + otype = TensorType(dtype=value.dtype, shape=combined_static_shape) + return Apply(self, [value] + shape, [otype()]) def perform(self, node, inputs, out_): (out,) = out_ diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index 5d364da6fd..0e5c618ba0 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -272,27 +272,6 @@ class TestLocalCanonicalizeAlloc: def setup_method(self): self.rng = np.random.default_rng(utt.fetch_seed()) - def test_inconsistent_constant(self): - x = at.as_tensor(self.rng.standard_normal((3, 7))) - a = at.alloc(x, 6, 7) - - assert a.owner and isinstance(a.owner.op, Alloc) - - # `local_useless_alloc` should attempt to replace the `Alloc` with an - # `Assert` and fail when the static shape information conflicts. - with pytest.raises(TypeError): - f = function([], a, mode=rewrite_mode) - - x = at.as_tensor(self.rng.standard_normal((6, 7))) - a = at.alloc(x, 6, 7) - - f = function([], a, mode=rewrite_mode) - - # The rewrite should then be applied, and remove Alloc - assert not any( - isinstance(node.op, (Alloc, Assert)) for node in f.maker.fgraph.toposort() - ) - def test_inconsistent_shared(self): # These shapes don't match! x = shared(self.rng.standard_normal((3, 7))) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 79703fb761..bb2e918278 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -835,6 +835,22 @@ def test_rebuild(self, func): assert y_new.shape.eval({x_new: x_new_test}) == (100,) assert y_new.eval({x_new: x_new_test}).shape == (100,) + def test_static_shape(self): + x = tensor(shape=(None, 1, 5)) + d0 = scalar("d0", dtype=int) + d1 = scalar("d1", dtype=int) + assert at.alloc(x, 3, 1, 5).type.shape == (3, 1, 5) + assert at.alloc(x, 3, 4, 5).type.shape == (3, 4, 5) + assert at.alloc(x, d0, d1, 5).type.shape == (None, None, 5) + assert at.alloc(x, d0, 1, d1).type.shape == (None, 1, 5) + + msg = "Alloc static input type and target shape are incompatible" + with pytest.raises(ValueError, match=msg): + at.alloc(x, 3, 1, 1) + + with pytest.raises(ValueError, match=msg): + at.alloc(x, 3, 1, 6) + def test_infer_shape(): with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):