Skip to content

Commit a21ae05

Browse files
committed
Forbid runtime broadcasting in Elemwise
1 parent 5c87d74 commit a21ae05

File tree

12 files changed

+222
-217
lines changed

12 files changed

+222
-217
lines changed

pytensor/link/jax/dispatch/elemwise.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,17 @@
77

88

99
@jax_funcify.register(Elemwise)
10-
def jax_funcify_Elemwise(op, **kwargs):
10+
def jax_funcify_Elemwise(op, node, **kwargs):
1111
scalar_op = op.scalar_op
12-
return jax_funcify(scalar_op, **kwargs)
12+
base_fn = jax_funcify(scalar_op, node=node, **kwargs)
13+
14+
def elemwise_fn(*inputs):
15+
# ScalarVariables in JAX are passed as int/float.
16+
# We wrap them in arrays just for the broadcast check
17+
Elemwise._check_runtime_broadcast(node, tuple(map(jnp.asarray, inputs)))
18+
return base_fn(*inputs)
19+
20+
return elemwise_fn
1321

1422

1523
@jax_funcify.register(CAReduce)

pytensor/link/numba/dispatch/elemwise_codegen.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,9 @@ def compute_itershape(
4141
):
4242
with then:
4343
msg = (
44-
f"Incompatible shapes for input {j} and axis {i} of "
45-
f"elemwise. Input {j} has shape 1, but is not statically "
46-
"known to have shape 1, and thus not broadcastable."
44+
"Runtime broadcasting not allowed. "
45+
f"Input {j} had a distinct dimension length of 1 at axis {i}, but was not marked as broadcastable.\n"
46+
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
4747
)
4848
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
4949
with otherwise:

pytensor/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
132132
shape_padaxis,
133133
shape_padleft,
134134
shape_padright,
135+
specify_broadcastable,
135136
specify_shape,
136137
)
137138

pytensor/tensor/elemwise.py

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import pytensor.tensor.basic
77
from pytensor.configdefaults import config
88
from pytensor.gradient import DisconnectedType
9-
from pytensor.graph.basic import Apply
9+
from pytensor.graph.basic import Apply, Constant
1010
from pytensor.graph.null_type import NullType
1111
from pytensor.graph.utils import MethodNotDefined
1212
from pytensor.link.c.basic import failure_code
@@ -19,9 +19,9 @@
1919
from pytensor.scalar.basic import bool as scalar_bool
2020
from pytensor.scalar.basic import identity as scalar_identity
2121
from pytensor.scalar.basic import transfer_type, upcast
22-
from pytensor.tensor import _get_vector_length, as_tensor_variable
2322
from pytensor.tensor import elemwise_cgen as cgen
2423
from pytensor.tensor import get_vector_length
24+
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
2525
from pytensor.tensor.type import (
2626
TensorType,
2727
continuous_dtypes,
@@ -740,9 +740,7 @@ def perform(self, node, inputs, output_storage):
740740
# FIXME: This no longer calls the C implementation!
741741
super().perform(node, inputs, output_storage)
742742

743-
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
744-
if len(set(dim_shapes) - {1}) > 1:
745-
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
743+
self._check_runtime_broadcast(node, inputs)
746744

747745
ufunc_args = inputs
748746
ufunc_kwargs = {}
@@ -818,18 +816,40 @@ def perform(self, node, inputs, output_storage):
818816
else:
819817
storage[0] = variable
820818

821-
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
822-
if len(node.outputs) > 1:
823-
from pytensor.tensor.exceptions import ShapeError
824-
825-
raise ShapeError(
826-
"Multiple outputs are not supported by the default `Elemwise.infer_shape`"
827-
)
819+
@staticmethod
820+
def _check_runtime_broadcast(node, inputs):
821+
for dims_and_bcast in zip(
822+
*[
823+
zip(input.shape, sinput.type.broadcastable)
824+
for input, sinput in zip(inputs, node.inputs)
825+
]
826+
):
827+
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
828+
raise ValueError(
829+
"Runtime broadcasting not allowed. "
830+
"At least one input has a distinct dimension length of 1, but was not marked as broadcastable.\n"
831+
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
832+
)
828833

829-
out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
834+
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
835+
one = pytensor.tensor.basic.constant(1, dtype="int64")
836+
output_shape = []
837+
for dim, broadcastable in enumerate(node.outputs[0].type.broadcastable):
838+
out_dim_length = one
839+
if not broadcastable:
840+
# There must be some input that is not broadcastable in this dim
841+
for inp_shape, inp_var in zip(i_shapes, node.inputs):
842+
if not inp_var.type.broadcastable[dim]:
843+
# Give preference to constant dims
844+
if isinstance(inp_shape[dim], Constant):
845+
out_dim_length = inp_shape[dim]
846+
break
847+
# If we haven't yet seen a non-broadcastable dim, use this one
848+
if out_dim_length is one:
849+
out_dim_length = inp_shape[dim]
850+
output_shape.append(as_tensor_variable(out_dim_length, dtype="int64"))
830851

831-
# The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s
832-
return [tuple(as_tensor_variable(s) for s in out_shape)]
852+
return [tuple(output_shape)] * len(node.outputs)
833853

834854
def _c_all(self, node, nodename, inames, onames, sub):
835855
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
@@ -1193,7 +1213,7 @@ def c_support_code_apply(self, node, nodename):
11931213
return support_code
11941214

11951215
def c_code_cache_version_apply(self, node):
1196-
version = [14] # the version corresponding to the c code in this Op
1216+
version = [15] # the version corresponding to the c code in this Op
11971217

11981218
# now we insert versions for the ops on which we depend...
11991219
scalar_node = Apply(

pytensor/tensor/elemwise_cgen.py

Lines changed: 41 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub):
6666
if index != "x":
6767
# Initialize the variables associated to the jth loop
6868
# jump = stride - adjust
69-
# If the variable has size 1 in that dim, we set the stride to zero to
70-
# emulate broadcasting
7169
jump = f"({var}_stride{index}) - ({adjust})"
7270
init += f"""
7371
{var}_n{index} = PyArray_DIMS({var})[{index}];
74-
{var}_stride{index} = ({var}_n{index} == 1)? 0 : PyArray_STRIDES({var})[{index}] / sizeof({dtype});
72+
{var}_stride{index} = PyArray_STRIDES({var})[{index}] / sizeof({dtype});
7573
{var}_jump{index}_{j} = {jump};
7674
"""
7775
adjust = f"{var}_n{index}*{var}_stride{index}"
@@ -86,88 +84,72 @@ def make_checks(loop_orders, dtypes, sub):
8684
# This loop builds multiple if conditions to verify that the
8785
# dimensions of the inputs match, and the first one that is true
8886
# raises an informative error message
87+
88+
runtime_broadcast_error_msg = (
89+
"Runtime broadcasting not allowed. "
90+
"Input %%i had a distinct dimension length of 1 at axis %%i, but was not marked as broadcastable. "
91+
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
92+
)
93+
8994
for matches in zip(*loop_orders):
9095
to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"]
9196

9297
# elements of to_compare are pairs ( input_variable_idx, input_variable_dim_idx )
9398
if len(to_compare) < 2:
9499
continue
95100

96-
# Find first dimension size that is != 1
97-
jl, xl = to_compare[-1]
98-
non1size_dim_check = f"""
99-
npy_intp non1size_dim{xl};
100-
non1size_dim{xl} = """
101-
for j, x in to_compare[:-1]:
102-
non1size_dim_check += f"(%(lv{j})s_n{x} != 1) ? %(lv{j})s_n{x} : "
103-
non1size_dim_check += f"%(lv{jl})s_n{xl};"
104-
check += non1size_dim_check
105-
106-
# Check the nonsize1 dims match
107-
# TODO: This is a bit inefficient because we are comparing one dimension against itself
108-
check += f"""
109-
if (non1size_dim{xl} != 1)
110-
{{
111-
"""
112-
for j, x in to_compare:
101+
j0, x0 = to_compare[0]
102+
for j, x in to_compare[1:]:
113103
check += f"""
114-
if ((%(lv{j})s_n{x} != non1size_dim{x}) && (%(lv{j})s_n{x} != 1))
104+
if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x})
105+
{{
106+
if (%(lv{j0})s_n{x0} == 1 || %(lv{j})s_n{x} == 1)
115107
{{
116-
PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.",
117-
{x},
118-
(long long int) non1size_dim{x},
108+
PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}",
109+
{j0},
110+
{x0},
111+
(long long int) %(lv{j0})s_n{x0},
112+
{j},
113+
{x},
114+
(long long int) %(lv{j})s_n{x}
115+
);
116+
}} else {{
117+
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
118+
{j0},
119+
{x0},
120+
(long long int) %(lv{j0})s_n{x0},
119121
{j},
120122
{x},
121123
(long long int) %(lv{j})s_n{x}
122124
);
123-
%(fail)s
124125
}}
125-
"""
126-
check += """
127-
}
126+
%(fail)s
127+
}}
128128
"""
129129

130130
return init % sub + check % sub
131131

132132

133-
def compute_broadcast_dimensions(array_name: str, loop_orders, sub) -> str:
134-
"""Create c_code to compute broadcasted dimensions of multiple arrays, arising from
135-
Elemwise operations.
133+
def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
134+
"""Create c_code to compute the output dimensions of an Elemwise operation.
136135
137136
The code returned by this function populates the array `array_name`, but does not
138137
initialize it.
139138
140-
TODO: We can decide to either specialize C code even further given the input types
141-
or make it general, regardless of whether static broadcastable information is given
139+
Note: We could specialize C code even further with the known static output shapes
142140
"""
143141
dims_c_code = ""
144142
for i, candidates in enumerate(zip(*loop_orders)):
145-
# TODO: Are candidates always either "x" or "i"? If that's the case we can
146-
# simplify some logic here (e.g., we don't need to track the `idx`).
147-
nonx_candidates = tuple(
148-
(idx, c) for idx, c in enumerate(candidates) if c != "x"
149-
)
150-
151-
# All inputs are known to be broadcastable
152-
if not nonx_candidates:
143+
# Borrow the length of the first non-broadcastable input dimension
144+
for j, candidate in enumerate(candidates):
145+
if candidate != "x":
146+
var = sub[f"lv{int(j)}"]
147+
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
148+
break
149+
# If none is non-broadcastable, the output dimension has a length of 1
150+
else: # no-break
153151
dims_c_code += f"{array_name}[{i}] = 1;\n"
154-
continue
155-
156-
# There is only one informative source of size
157-
if len(nonx_candidates) == 1:
158-
idx, candidate = nonx_candidates[0]
159-
var = sub[f"lv{int(idx)}"]
160-
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
161-
continue
162152

163-
# In this case any non-size 1 variable will define the right size
164-
dims_c_code += f"{array_name}[{i}] = "
165-
for idx, candidate in nonx_candidates[:-1]:
166-
var = sub[f"lv{int(idx)}"]
167-
dims_c_code += f"({var}_n{candidate} != 1)? {var}_n{candidate}: "
168-
idx, candidate = nonx_candidates[-1]
169-
var = sub[f"lv{idx}"]
170-
dims_c_code += f"{var}_n{candidate};\n"
171153
return dims_c_code
172154

173155

@@ -186,7 +168,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
186168
if type.startswith("PYTENSOR_COMPLEX"):
187169
type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX")
188170
nd = len(loop_orders[0])
189-
init_dims = compute_broadcast_dimensions("dims", loop_orders, sub)
171+
init_dims = compute_output_dims_lengths("dims", loop_orders, sub)
190172

191173
# TODO: it would be interesting to allocate the output in such a
192174
# way that its contiguous dimensions match one of the input's
@@ -359,7 +341,7 @@ def make_reordered_loop(
359341

360342
# Get the (sorted) total number of iterations of each loop
361343
declare_totals = f"int init_totals[{nnested}];\n"
362-
declare_totals += compute_broadcast_dimensions("init_totals", init_loop_orders, sub)
344+
declare_totals += compute_output_dims_lengths("init_totals", init_loop_orders, sub)
363345

364346
# Sort totals to match the new order that was computed by sorting
365347
# the loop vector. One integer variable per loop is declared.

pytensor/tensor/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1439,7 +1439,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
14391439

14401440
_broadcast_assert = Assert(
14411441
"Could not broadcast dimensions. Broadcasting is only allowed along "
1442-
"axes that have a statically known length 1. Use `specify_shape` to "
1442+
"axes that have a statically known length 1. Use `specify_broadcastable` to "
14431443
"inform PyTensor of a known shape."
14441444
)
14451445

tests/link/jax/test_elemwise.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytensor
66
import pytensor.tensor as at
7+
from pytensor.compile import get_mode
78
from pytensor.configdefaults import config
89
from pytensor.graph.fg import FunctionGraph
910
from pytensor.graph.op import get_test_value
@@ -14,6 +15,11 @@
1415
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
1516
from pytensor.tensor.type import matrix, tensor, vector
1617
from tests.link.jax.test_basic import compare_jax_and_py
18+
from tests.tensor.test_elemwise import TestElemwise
19+
20+
21+
def test_elemwise_runtime_shape_error():
22+
TestElemwise.check_runtime_shapes_error(get_mode("JAX"))
1723

1824

1925
def test_jax_Dimshuffle():

tests/link/numba/test_elemwise.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import pytensor.tensor.inplace as ati
1010
import pytensor.tensor.math as aem
1111
from pytensor import config, function
12+
from pytensor.compile import get_mode
1213
from pytensor.compile.ops import deep_copy_op
1314
from pytensor.compile.sharedvalue import SharedVariable
1415
from pytensor.gradient import grad
@@ -17,6 +18,7 @@
1718
from pytensor.tensor import elemwise as at_elemwise
1819
from pytensor.tensor.math import All, Any, Max, Mean, Min, Prod, ProdWithoutZeros, Sum
1920
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
21+
from tensor.test_elemwise import TestElemwise
2022
from tests.link.numba.test_basic import (
2123
compare_numba_and_py,
2224
scalar_my_multi_out,
@@ -119,6 +121,11 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
119121
compare_numba_and_py(out_fg, input_vals)
120122

121123

124+
@pytest.mark.xfail(reason="Error message not triggered")
125+
def test_elemwise_runtime_shape_error():
126+
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA"))
127+
128+
122129
def test_elemwise_speed(benchmark):
123130
x = at.dmatrix("y")
124131
y = at.dvector("z")

tests/tensor/rewriting/test_basic.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,12 +1671,7 @@ def verify_op_count(f, count, cls):
16711671
(),
16721672
(),
16731673
),
1674-
pytest.param(
1675-
lambda x, y: at.mul(y, at.alloc(1, x)),
1676-
(),
1677-
(),
1678-
marks=pytest.mark.xfail(reason="Not implemented"),
1679-
),
1674+
(lambda x, y: at.mul(y, at.alloc(1, x)), (), ()),
16801675
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)),
16811676
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)),
16821677
(

tests/tensor/rewriting/test_math.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,11 +607,10 @@ def test_mul_div_cases(self):
607607
((fx / fy) / fx, [fx, fy], [fxv, fyv], 1, "float32"),
608608
((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"),
609609
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
610-
# must broadcast as their is a dimshuffle in the computation
611-
# The broadcast leads to an extra elemwise to check compatibility
612-
((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"),
610+
# must broadcast as there is a dimshuffle in the computation
611+
((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"),
613612
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
614-
((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"),
613+
((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"),
615614
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float32, row)>), Alloc]
616615
]
617616
):

0 commit comments

Comments
 (0)