Skip to content

Commit d1a0ff7

Browse files
committed
Forbid runtime broadcasting in Elemwise
1 parent e20dd0b commit d1a0ff7

File tree

6 files changed

+121
-95
lines changed

6 files changed

+121
-95
lines changed

pytensor/link/jax/dispatch/elemwise.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@
77

88

99
@jax_funcify.register(Elemwise)
10-
def jax_funcify_Elemwise(op, **kwargs):
10+
def jax_funcify_Elemwise(op, node, **kwargs):
1111
scalar_op = op.scalar_op
12-
return jax_funcify(scalar_op, **kwargs)
12+
base_fn = jax_funcify(scalar_op, node, **kwargs)
13+
14+
def elemwise_fn(*inputs):
15+
Elemwise._check_runtime_broadcast(node, inputs)
16+
return base_fn(*inputs)
17+
18+
return elemwise_fn
1319

1420

1521
@jax_funcify.register(CAReduce)

pytensor/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
132132
shape_padaxis,
133133
shape_padleft,
134134
shape_padright,
135+
specify_broadcastable,
135136
specify_shape,
136137
)
137138

pytensor/tensor/elemwise.py

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytensor.tensor.basic
77
from pytensor.configdefaults import config
88
from pytensor.gradient import DisconnectedType
9-
from pytensor.graph.basic import Apply
9+
from pytensor.graph.basic import Apply, Constant
1010
from pytensor.graph.null_type import NullType
1111
from pytensor.graph.utils import MethodNotDefined
1212
from pytensor.link.c.basic import failure_code
@@ -19,9 +19,9 @@
1919
from pytensor.scalar.basic import bool as scalar_bool
2020
from pytensor.scalar.basic import identity as scalar_identity
2121
from pytensor.scalar.basic import transfer_type, upcast
22-
from pytensor.tensor import _get_vector_length, as_tensor_variable
2322
from pytensor.tensor import elemwise_cgen as cgen
2423
from pytensor.tensor import get_vector_length
24+
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
2525
from pytensor.tensor.type import (
2626
TensorType,
2727
continuous_dtypes,
@@ -737,9 +737,7 @@ def perform(self, node, inputs, output_storage):
737737
# FIXME: This no longer calls the C implementation!
738738
super().perform(node, inputs, output_storage)
739739

740-
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
741-
if len(set(dim_shapes) - {1}) > 1:
742-
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
740+
self._check_runtime_broadcast(node, inputs)
743741

744742
ufunc_args = inputs
745743
ufunc_kwargs = {}
@@ -815,18 +813,41 @@ def perform(self, node, inputs, output_storage):
815813
else:
816814
storage[0] = variable
817815

818-
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
819-
if len(node.outputs) > 1:
820-
from pytensor.tensor.exceptions import ShapeError
821-
822-
raise ShapeError(
823-
"Multiple outputs are not supported by the default `Elemwise.infer_shape`"
824-
)
816+
@staticmethod
817+
def _check_runtime_broadcast(node, inputs):
818+
for dims_and_bcast in zip(
819+
*[
820+
zip(input.shape, sinput.type.broadcastable)
821+
for input, sinput in zip(inputs, node.inputs)
822+
]
823+
):
824+
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
825+
raise ValueError(
826+
"Runtime broadcasting not allowed. "
827+
"At least one input has a distinct dimension length of 1, but was not marked as broadcastable.\n"
828+
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
829+
)
830+
raise ValueError
825831

826-
out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
832+
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
833+
one = pytensor.tensor.basic.constant(1, dtype="int64")
834+
output_shape = []
835+
for dim, broadcastable in enumerate(node.outputs[0].type.broadcastable):
836+
out_dim_length = one
837+
if not broadcastable:
838+
# There must be some input that is not broadcastable in this dim
839+
for inp_shape, inp_var in zip(i_shapes, node.inputs):
840+
if not inp_var.type.broadcastable[dim]:
841+
# Give preference to constant dims
842+
if isinstance(inp_shape[dim], Constant):
843+
out_dim_length = inp_shape[dim]
844+
break
845+
# If we haven't yet seen a non-broadcastable dim, use this one
846+
if out_dim_length is one:
847+
out_dim_length = inp_shape[dim]
848+
output_shape.append(as_tensor_variable(out_dim_length, dtype="int64"))
827849

828-
# The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s
829-
return [tuple(as_tensor_variable(s) for s in out_shape)]
850+
return [tuple(output_shape)] * len(node.outputs)
830851

831852
def _c_all(self, node, nodename, inames, onames, sub):
832853
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
@@ -1190,7 +1211,7 @@ def c_support_code_apply(self, node, nodename):
11901211
return support_code
11911212

11921213
def c_code_cache_version_apply(self, node):
1193-
version = [14] # the version corresponding to the c code in this Op
1214+
version = [15] # the version corresponding to the c code in this Op
11941215

11951216
# now we insert versions for the ops on which we depend...
11961217
scalar_node = Apply(

pytensor/tensor/elemwise_cgen.py

Lines changed: 42 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub):
6666
if index != "x":
6767
# Initialize the variables associated to the jth loop
6868
# jump = stride - adjust
69-
# If the variable has size 1 in that dim, we set the stride to zero to
70-
# emulate broadcasting
7169
jump = f"({var}_stride{index}) - ({adjust})"
7270
init += f"""
7371
{var}_n{index} = PyArray_DIMS({var})[{index}];
74-
{var}_stride{index} = ({var}_n{index} == 1)? 0 : PyArray_STRIDES({var})[{index}] / sizeof({dtype});
72+
{var}_stride{index} = PyArray_STRIDES({var})[{index}] / sizeof({dtype});
7573
{var}_jump{index}_{j} = {jump};
7674
"""
7775
adjust = f"{var}_n{index}*{var}_stride{index}"
@@ -86,88 +84,73 @@ def make_checks(loop_orders, dtypes, sub):
8684
# This loop builds multiple if conditions to verify that the
8785
# dimensions of the inputs match, and the first one that is true
8886
# raises an informative error message
87+
88+
runtime_broadcast_error_msg = (
89+
"Runtime broadcasting not allowed. "
90+
"One input had a distinct dimension length of 1, but was not marked as broadcastable: "
91+
"(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). "
92+
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
93+
)
94+
8995
for matches in zip(*loop_orders):
9096
to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"]
9197

9298
# elements of to_compare are pairs ( input_variable_idx, input_variable_dim_idx )
9399
if len(to_compare) < 2:
94100
continue
95101

96-
# Find first dimension size that is != 1
97-
jl, xl = to_compare[-1]
98-
non1size_dim_check = f"""
99-
npy_intp non1size_dim{xl};
100-
non1size_dim{xl} = """
101-
for j, x in to_compare[:-1]:
102-
non1size_dim_check += f"(%(lv{j})s_n{x} != 1) ? %(lv{j})s_n{x} : "
103-
non1size_dim_check += f"%(lv{jl})s_n{xl};"
104-
check += non1size_dim_check
105-
106-
# Check the nonsize1 dims match
107-
# TODO: This is a bit inefficient because we are comparing one dimension against itself
108-
check += f"""
109-
if (non1size_dim{xl} != 1)
110-
{{
111-
"""
112-
for j, x in to_compare:
102+
j0, x0 = to_compare[0]
103+
for j, x in to_compare[1:]:
113104
check += f"""
114-
if ((%(lv{j})s_n{x} != non1size_dim{x}) && (%(lv{j})s_n{x} != 1))
105+
if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x})
106+
{{
107+
if (%(lv{j0})s_n{x0} == 1 || %(lv{j})s_n{x} == 1)
115108
{{
116-
PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.",
117-
{x},
118-
(long long int) non1size_dim{x},
109+
PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}",
110+
{j0},
111+
{x0},
112+
(long long int) %(lv{j0})s_n{x0},
113+
{j},
114+
{x},
115+
(long long int) %(lv{j})s_n{x}
116+
);
117+
}} else {{
118+
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
119+
{j0},
120+
{x0},
121+
(long long int) %(lv{j0})s_n{x0},
119122
{j},
120123
{x},
121124
(long long int) %(lv{j})s_n{x}
122125
);
123-
%(fail)s
124126
}}
125-
"""
126-
check += """
127-
}
127+
%(fail)s
128+
}}
128129
"""
129130

130131
return init % sub + check % sub
131132

132133

133-
def compute_broadcast_dimensions(array_name: str, loop_orders, sub) -> str:
134-
"""Create c_code to compute broadcasted dimensions of multiple arrays, arising from
135-
Elemwise operations.
134+
def compute_outputs_dims(array_name: str, loop_orders, sub) -> str:
135+
"""Create c_code to compute the output dimensions of an Elemwise operation.
136136
137137
The code returned by this function populates the array `array_name`, but does not
138138
initialize it.
139139
140-
TODO: We can decide to either specialize C code even further given the input types
141-
or make it general, regardless of whether static broadcastable information is given
140+
Note: We could specialize C code even further with the known static output shapes
142141
"""
143142
dims_c_code = ""
144143
for i, candidates in enumerate(zip(*loop_orders)):
145-
# TODO: Are candidates always either "x" or "i"? If that's the case we can
146-
# simplify some logic here (e.g., we don't need to track the `idx`).
147-
nonx_candidates = tuple(
148-
(idx, c) for idx, c in enumerate(candidates) if c != "x"
149-
)
150-
151-
# All inputs are known to be broadcastable
152-
if not nonx_candidates:
144+
# Borrow the length of the first non-broadcastable input dimension
145+
for j, candidate in enumerate(candidates):
146+
if candidate != "x":
147+
var = sub[f"lv{int(j)}"]
148+
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
149+
break
150+
# If none is non-broadcastable, the output dimension has a length of 1
151+
else: # no-break
153152
dims_c_code += f"{array_name}[{i}] = 1;\n"
154-
continue
155-
156-
# There is only one informative source of size
157-
if len(nonx_candidates) == 1:
158-
idx, candidate = nonx_candidates[0]
159-
var = sub[f"lv{int(idx)}"]
160-
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
161-
continue
162153

163-
# In this case any non-size 1 variable will define the right size
164-
dims_c_code += f"{array_name}[{i}] = "
165-
for idx, candidate in nonx_candidates[:-1]:
166-
var = sub[f"lv{int(idx)}"]
167-
dims_c_code += f"({var}_n{candidate} != 1)? {var}_n{candidate}: "
168-
idx, candidate = nonx_candidates[-1]
169-
var = sub[f"lv{idx}"]
170-
dims_c_code += f"{var}_n{candidate};\n"
171154
return dims_c_code
172155

173156

@@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
186169
if type.startswith("PYTENSOR_COMPLEX"):
187170
type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX")
188171
nd = len(loop_orders[0])
189-
init_dims = compute_broadcast_dimensions("dims", loop_orders, sub)
172+
init_dims = compute_outputs_dims("dims", loop_orders, sub)
190173

191174
# TODO: it would be interesting to allocate the output in such a
192175
# way that its contiguous dimensions match one of the input's
@@ -359,7 +342,7 @@ def make_reordered_loop(
359342

360343
# Get the (sorted) total number of iterations of each loop
361344
declare_totals = f"int init_totals[{nnested}];\n"
362-
declare_totals += compute_broadcast_dimensions("init_totals", init_loop_orders, sub)
345+
declare_totals += compute_outputs_dims("init_totals", init_loop_orders, sub)
363346

364347
# Sort totals to match the new order that was computed by sorting
365348
# the loop vector. One integer variable per loop is declared.

tests/link/jax/test_elemwise.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import numpy as np
22
import pytest
33
import scipy.special
4+
from tensor.test_elemwise import TestElemwise
45

56
import pytensor
67
import pytensor.tensor as at
8+
from pytensor.compile import get_mode
79
from pytensor.configdefaults import config
810
from pytensor.graph.fg import FunctionGraph
911
from pytensor.graph.op import get_test_value
@@ -16,6 +18,10 @@
1618
from tests.link.jax.test_basic import compare_jax_and_py
1719

1820

21+
def test_elemwise_runtime_broadcast_error():
22+
TestElemwise.check_runtime_broadcast_error(get_mode("JAX"))
23+
24+
1925
def test_jax_Dimshuffle():
2026
a_at = matrix("a")
2127

tests/tensor/test_elemwise.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from pytensor.tensor import as_tensor_variable
1919
from pytensor.tensor.basic import second
2020
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
21-
from pytensor.tensor.exceptions import ShapeError
2221
from pytensor.tensor.math import all as at_all
2322
from pytensor.tensor.math import any as at_any
2423
from pytensor.tensor.math import exp
@@ -769,10 +768,9 @@ def test_input_dimensions_overflow(self):
769768
g = pytensor.function([a, b, c, d, e, f], s, mode=Mode(linker="py"))
770769
g(*[np.zeros(2**11, config.floatX) for i in range(6)])
771770

772-
def check_input_dimensions_match(self, mode):
773-
"""Make sure that our input validation works correctly and doesn't
774-
throw erroneous broadcast-based errors.
775-
"""
771+
@staticmethod
772+
def check_runtime_broadcast_error(mode):
773+
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
776774
x_v = matrix("x")
777775
m_v = vector("m")
778776

@@ -782,19 +780,18 @@ def check_input_dimensions_match(self, mode):
782780
z_v = x_v - m_v
783781
f = pytensor.function([x_v, m_v], z_v, mode=mode)
784782

785-
res = f(x, m)
783+
with pytest.raises(ValueError, match="Runtime broadcasting not allowe"):
784+
f(x, m)
786785

787-
assert np.array_equal(res, x - m)
788-
789-
def test_input_dimensions_match_python(self):
790-
self.check_input_dimensions_match(Mode(linker="py"))
786+
def test_runtime_broadcast_error_python(self):
787+
self.check_runtime_broadcast_error(Mode(linker="py"))
791788

792789
@pytest.mark.skipif(
793790
not pytensor.config.cxx,
794791
reason="G++ not available, so we need to skip this test.",
795792
)
796-
def test_input_dimensions_match_c(self):
797-
self.check_input_dimensions_match(Mode(linker="c"))
793+
def test_runtime_broadcast_error_c(self):
794+
self.check_runtime_broadcast_error(Mode(linker="c"))
798795

799796
def test_str(self):
800797
op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None)
@@ -819,7 +816,7 @@ def test_partial_static_shape_info(self):
819816
assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1
820817
assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1
821818

822-
def test_multi_output(self):
819+
def test_infer_shape_multi_output(self):
823820
class CustomElemwise(Elemwise):
824821
def make_node(self, *args):
825822
res = super().make_node(*args)
@@ -833,14 +830,26 @@ def make_node(self, *args):
833830
],
834831
)
835832

836-
z_1, z_2 = CustomElemwise(aes.add)(
837-
as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(1))
838-
)
833+
custom_elemwise = CustomElemwise(aes.add)
839834

835+
z_1, z_2 = custom_elemwise(
836+
as_tensor_variable(np.eye(1)),
837+
as_tensor_variable(np.eye(1)),
838+
)
840839
in_1_shape = (aes.constant(1), aes.constant(1))
840+
outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])
841+
for out in outs:
842+
assert out[0].eval() == 1
843+
assert out[1].eval() == 1
841844

842-
with pytest.raises(ShapeError):
843-
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])
845+
z_1, z_2 = custom_elemwise(
846+
as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(3))
847+
)
848+
in_2_shape = (aes.constant(3), aes.constant(3))
849+
outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_2_shape])
850+
for out in outs:
851+
assert out[0].eval() == 3
852+
assert out[1].eval() == 3
844853

845854
def test_shape_types(self):
846855
x = tensor(dtype=np.float64, shape=(None, 1))

0 commit comments

Comments
 (0)