Skip to content

Commit 84747c9

Browse files
committed
Forbid runtime broadcasting in Alloc
1 parent e6eeb0c commit 84747c9

File tree

6 files changed

+90
-4
lines changed

6 files changed

+90
-4
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ def allocempty(*shape):
5151

5252

5353
@jax_funcify.register(Alloc)
54-
def jax_funcify_Alloc(op, **kwargs):
54+
def jax_funcify_Alloc(op, node, **kwargs):
5555
def alloc(x, *shape):
5656
res = jnp.broadcast_to(x, shape)
57+
Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
5758
return res
5859

5960
return alloc

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,24 @@ def numba_funcify_Alloc(op, node, **kwargs):
7878
" " * 4,
7979
)
8080

81+
check_runtime_broadcast = []
82+
for i, val_static_dim in enumerate(node.inputs[0].type.shape[::-1]):
83+
if val_static_dim is None:
84+
check_runtime_broadcast.append(
85+
f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")'
86+
)
87+
check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4)
88+
8189
alloc_def_src = f"""
8290
def alloc(val, {", ".join(shape_var_names)}):
8391
val_np = np.asarray(val)
8492
{shapes_to_items_src}
8593
scalar_shape = {create_tuple_string(shape_var_item_names)}
94+
{check_runtime_broadcast_src}
8695
res = np.empty(scalar_shape, dtype=val_np.dtype)
8796
res[...] = val_np
8897
return res
8998
"""
90-
9199
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})
92100

93101
return numba_basic.numba_njit(alloc_fn)

pytensor/tensor/basic.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,6 +1426,12 @@ class Alloc(COp):
14261426

14271427
__props__ = ()
14281428

1429+
_runtime_broadcast_error_msg = (
1430+
"Runtime broadcasting not allowed. "
1431+
"The output of ALloc requires broadcasting a dimension of the input value, which was not marked as broadcastable. "
1432+
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
1433+
)
1434+
14291435
def make_node(self, value, *shape):
14301436
value = as_tensor_variable(value)
14311437
shape, static_shape = infer_static_shape(shape)
@@ -1455,10 +1461,21 @@ def make_node(self, value, *shape):
14551461
otype = TensorType(dtype=value.dtype, shape=combined_static_shape)
14561462
return Apply(self, [value] + shape, [otype()])
14571463

1464+
@staticmethod
1465+
def _check_runtime_broadcast(node, value, shape):
1466+
value_static_shape = node.inputs[0].type.shape
1467+
for v_static_dim, value_dim, out_dim in zip(
1468+
value_static_shape[::-1], value.shape[::-1], shape[::-1]
1469+
):
1470+
if v_static_dim is None and value_dim == 1 and out_dim != 1:
1471+
raise ValueError(Alloc._runtime_broadcast_error_msg)
1472+
14581473
def perform(self, node, inputs, out_):
14591474
(out,) = out_
14601475
v = inputs[0]
14611476
sh = tuple([int(i) for i in inputs[1:]])
1477+
self._check_runtime_broadcast(node, v, sh)
1478+
14621479
if out[0] is None or out[0].shape != sh:
14631480
if v.size == 1 and v.item() == 0:
14641481
out[0] = np.zeros(sh, dtype=v.dtype)
@@ -1477,6 +1494,7 @@ def c_code(self, node, name, inp, out, sub):
14771494

14781495
code = f"""
14791496
npy_intp shape[{ndim}];
1497+
int need_new_out;
14801498
"""
14811499

14821500
# Initialize shape
@@ -1485,8 +1503,25 @@ def c_code(self, node, name, inp, out, sub):
14851503
shape[{i}] = ((dtype_{shp_i}*) PyArray_DATA({shp_i}))[0];
14861504
"""
14871505

1506+
# Add checks for runtime broadcasting
1507+
v_static_shape = node.inputs[0].type.shape
1508+
o_static_shape = node.outputs[0].type.shape
1509+
v_ndim = len(v_static_shape)
1510+
o_ndim = len(o_static_shape)
1511+
for i, (v_static_dim, out_static_dim) in enumerate(
1512+
zip(v_static_shape[::-1], o_static_shape[::-1])
1513+
):
1514+
if v_static_dim is None:
1515+
code += f"""
1516+
if (PyArray_DIMS({vv})[{v_ndim - i - 1}] == 1 && shape[{o_ndim - i - 1}] != 1)
1517+
{{
1518+
PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
1519+
{fail}
1520+
}}
1521+
"""
1522+
14881523
code += f"""
1489-
int need_new_out = (NULL == {zz});
1524+
need_new_out = (NULL == {zz});
14901525
for (int i = 0; i < {ndim}; i++)
14911526
need_new_out = (need_new_out || (PyArray_DIMS({zz})[i] != shape[i]));
14921527
@@ -1509,7 +1544,7 @@ def c_code(self, node, name, inp, out, sub):
15091544
return code
15101545

15111546
def c_code_cache_version(self):
1512-
return (3,)
1547+
return (4,)
15131548

15141549
def infer_shape(self, fgraph, node, input_shapes):
15151550
return [node.inputs[1:]]

tests/link/jax/test_tensor_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import pytest
33

4+
from pytensor.compile import get_mode
5+
46

57
jax = pytest.importorskip("jax")
68
import jax.errors
@@ -12,6 +14,7 @@
1214
from pytensor.graph.op import get_test_value
1315
from pytensor.tensor.type import iscalar, matrix, scalar, vector
1416
from tests.link.jax.test_basic import compare_jax_and_py
17+
from tests.tensor.test_basic import TestAlloc
1518

1619

1720
def test_jax_Alloc():
@@ -50,6 +53,10 @@ def compare_shape_dtype(x, y):
5053
compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)])
5154

5255

56+
def test_alloc_runtime_broadcast():
57+
TestAlloc.check_runtime_broadcast(get_mode("JAX"))
58+
59+
5360
def test_jax_MakeVector():
5461
x = at.make_vector(1, 2, 3)
5562
x_fg = FunctionGraph([], [x])

tests/link/numba/test_tensor_basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytensor.tensor as at
66
import pytensor.tensor.basic as atb
77
from pytensor import config, function
8+
from pytensor.compile import get_mode
89
from pytensor.compile.sharedvalue import SharedVariable
910
from pytensor.graph.basic import Constant
1011
from pytensor.graph.fg import FunctionGraph
@@ -15,6 +16,7 @@
1516
compare_shape_dtype,
1617
set_test_value,
1718
)
19+
from tests.tensor.test_basic import TestAlloc
1820

1921

2022
rng = np.random.default_rng(42849)
@@ -45,6 +47,10 @@ def test_Alloc(v, shape):
4547
assert numba_res.shape == shape
4648

4749

50+
def test_alloc_runtime_broadcast():
51+
TestAlloc.check_runtime_broadcast(get_mode("NUMBA"))
52+
53+
4854
def test_AllocEmpty():
4955
x = at.empty((2, 3), dtype="float32")
5056
x_fg = FunctionGraph([], [x])

tests/tensor/test_basic.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,31 @@ class TestAlloc:
720720
shared = staticmethod(pytensor.shared)
721721
allocs = [Alloc()] * 3
722722

723+
@staticmethod
724+
def check_allocs_in_fgraph(fgraph, n):
725+
assert (
726+
len([node for node in fgraph.apply_nodes if isinstance(node.op, Alloc)])
727+
== n
728+
)
729+
730+
@staticmethod
731+
def check_runtime_broadcast(mode):
732+
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
733+
x_v = vector("x", shape=(None,))
734+
735+
out = alloc(x_v, 5, 3)
736+
f = pytensor.function([x_v], out, mode=mode)
737+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
738+
739+
np.testing.assert_array_equal(f(x=np.zeros((3,))), np.zeros((5, 3)))
740+
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
741+
f(x=np.zeros((1,)))
742+
743+
out = alloc(specify_shape(x_v, (1,)), 5, 3)
744+
f = pytensor.function([x_v], out, mode=mode)
745+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
746+
747+
np.testing.assert_array_equal(f(x=np.zeros((1,))), np.zeros((5, 3)))
723748

724749
def setup_method(self):
725750
self.rng = np.random.default_rng(seed=utt.fetch_seed())
@@ -865,6 +890,10 @@ def test_alloc_of_view_linker(self):
865890

866891
np.testing.assert_array_equal(f(x=np.zeros((1,)), dim_len=3), np.zeros((5, 3)))
867892

893+
@pytest.mark.parametrize("mode", (Mode("py"), Mode("c")))
894+
def test_runtime_broadcast(self, mode):
895+
self.check_runtime_broadcast(mode)
896+
868897

869898
def test_infer_shape():
870899
with pytest.raises(TypeError, match="^Shapes must be scalar integers.*"):

0 commit comments

Comments
 (0)