Skip to content

Commit 34eaaa5

Browse files
committed
Forbid runtime broadcasting in Alloc
1 parent 6941820 commit 34eaaa5

File tree

6 files changed

+107
-9
lines changed

6 files changed

+107
-9
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ def allocempty(*shape):
4141

4242

4343
@jax_funcify.register(Alloc)
44-
def jax_funcify_Alloc(op, **kwargs):
44+
def jax_funcify_Alloc(op, node, **kwargs):
4545
def alloc(x, *shape):
4646
res = jnp.broadcast_to(x, shape)
47+
Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
4748
return res
4849

4950
return alloc

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,24 @@ def numba_funcify_Alloc(op, node, **kwargs):
7777
" " * 4,
7878
)
7979

80+
check_runtime_broadcast = []
81+
for i, val_static_dim in enumerate(node.inputs[0].type.shape[::-1]):
82+
if val_static_dim is None:
83+
check_runtime_broadcast.append(
84+
f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")'
85+
)
86+
check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4)
87+
8088
alloc_def_src = f"""
8189
def alloc(val, {", ".join(shape_var_names)}):
8290
val_np = np.asarray(val)
8391
{shapes_to_items_src}
8492
scalar_shape = {create_tuple_string(shape_var_item_names)}
93+
{check_runtime_broadcast_src}
8594
res = np.empty(scalar_shape, dtype=val_np.dtype)
8695
res[...] = val_np
8796
return res
8897
"""
89-
9098
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})
9199

92100
return numba_basic.numba_njit(alloc_fn)

pytensor/tensor/basic.py

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,12 @@ class Alloc(COp):
14311431

14321432
__props__ = ()
14331433

1434+
_runtime_broadcast_error_msg = (
1435+
"Runtime broadcasting not allowed. "
1436+
"The output of Alloc requires broadcasting a dimension of the input value, which was not marked as broadcastable. "
1437+
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
1438+
)
1439+
14341440
def make_node(self, value, *shape):
14351441
value = as_tensor_variable(value)
14361442
shape, static_shape = infer_static_shape(shape)
@@ -1468,10 +1474,21 @@ def make_node(self, value, *shape):
14681474
otype = TensorType(dtype=value.dtype, shape=combined_static_shape)
14691475
return Apply(self, [value] + shape, [otype()])
14701476

1477+
@staticmethod
1478+
def _check_runtime_broadcast(node, value, shape):
1479+
value_static_shape = node.inputs[0].type.shape
1480+
for v_static_dim, value_dim, out_dim in zip(
1481+
value_static_shape[::-1], value.shape[::-1], shape[::-1]
1482+
):
1483+
if v_static_dim is None and value_dim == 1 and out_dim != 1:
1484+
raise ValueError(Alloc._runtime_broadcast_error_msg)
1485+
14711486
def perform(self, node, inputs, out_):
14721487
(out,) = out_
14731488
v = inputs[0]
14741489
sh = tuple([int(i) for i in inputs[1:]])
1490+
self._check_runtime_broadcast(node, v, sh)
1491+
14751492
if out[0] is None or out[0].shape != sh:
14761493
if v.size == 1 and v.item() == 0:
14771494
out[0] = np.zeros(sh, dtype=v.dtype)
@@ -1484,12 +1501,19 @@ def perform(self, node, inputs, out_):
14841501

14851502
def c_code(self, node, name, inp, out, sub):
14861503
vv = inp[0]
1487-
ndim = len(inp[1:])
14881504
(zz,) = out
14891505
fail = sub["fail"]
14901506

1507+
v_static_shape = node.inputs[0].type.shape
1508+
o_static_shape = node.outputs[0].type.shape
1509+
v_ndim = len(v_static_shape)
1510+
o_ndim = len(o_static_shape)
1511+
assert o_ndim == len(inp[1:])
1512+
1513+
# Declare variables
14911514
code = f"""
1492-
npy_intp shape[{ndim}];
1515+
npy_intp shape[{o_ndim}];
1516+
int need_new_out;
14931517
"""
14941518

14951519
# Initialize shape
@@ -1498,15 +1522,26 @@ def c_code(self, node, name, inp, out, sub):
14981522
shape[{i}] = ((dtype_{shp_i}*) PyArray_DATA({shp_i}))[0];
14991523
"""
15001524

1525+
# Add checks for runtime broadcasting
1526+
for i, v_static_dim in enumerate(v_static_shape[::-1]):
1527+
if v_static_dim is None:
1528+
code += f"""
1529+
if (PyArray_DIMS({vv})[{v_ndim - i - 1}] == 1 && shape[{o_ndim - i - 1}] != 1)
1530+
{{
1531+
PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
1532+
{fail}
1533+
}}
1534+
"""
1535+
15011536
code += f"""
1502-
int need_new_out = (NULL == {zz});
1503-
for (int i = 0; i < {ndim}; i++)
1537+
need_new_out = (NULL == {zz});
1538+
for (int i = 0; i < {o_ndim}; i++)
15041539
need_new_out = (need_new_out || (PyArray_DIMS({zz})[i] != shape[i]));
15051540
15061541
if (need_new_out)
15071542
{{
15081543
Py_XDECREF({zz});
1509-
{zz} = (PyArrayObject*) PyArray_SimpleNew({ndim}, shape, PyArray_TYPE({vv}));
1544+
{zz} = (PyArrayObject*) PyArray_SimpleNew({o_ndim}, shape, PyArray_TYPE({vv}));
15101545
if (!{zz})
15111546
{{
15121547
PyErr_SetString(PyExc_MemoryError, "alloc failed");
@@ -1522,7 +1557,7 @@ def c_code(self, node, name, inp, out, sub):
15221557
return code
15231558

15241559
def c_code_cache_version(self):
1525-
return (3,)
1560+
return (4,)
15261561

15271562
def infer_shape(self, fgraph, node, input_shapes):
15281563
return [node.inputs[1:]]

tests/link/jax/test_tensor_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import pytest
33

4+
from pytensor.compile import get_mode
5+
46

57
jax = pytest.importorskip("jax")
68
import jax.errors
@@ -12,6 +14,7 @@
1214
from pytensor.graph.op import get_test_value
1315
from pytensor.tensor.type import iscalar, matrix, scalar, vector
1416
from tests.link.jax.test_basic import compare_jax_and_py
17+
from tests.tensor.test_basic import TestAlloc
1518

1619

1720
def test_jax_Alloc():
@@ -50,6 +53,10 @@ def compare_shape_dtype(x, y):
5053
compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)])
5154

5255

56+
def test_alloc_runtime_broadcast():
57+
TestAlloc.check_runtime_broadcast(get_mode("JAX"))
58+
59+
5360
def test_jax_MakeVector():
5461
x = at.make_vector(1, 2, 3)
5562
x_fg = FunctionGraph([], [x])

tests/link/numba/test_tensor_basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytensor.tensor as at
66
import pytensor.tensor.basic as atb
77
from pytensor import config, function
8+
from pytensor.compile import get_mode
89
from pytensor.compile.sharedvalue import SharedVariable
910
from pytensor.graph.basic import Constant
1011
from pytensor.graph.fg import FunctionGraph
@@ -15,6 +16,7 @@
1516
compare_shape_dtype,
1617
set_test_value,
1718
)
19+
from tests.tensor.test_basic import TestAlloc
1820

1921

2022
pytest.importorskip("numba")
@@ -49,6 +51,10 @@ def test_Alloc(v, shape):
4951
assert numba_res.shape == shape
5052

5153

54+
def test_alloc_runtime_broadcast():
55+
TestAlloc.check_runtime_broadcast(get_mode("NUMBA"))
56+
57+
5258
def test_AllocEmpty():
5359
x = at.empty((2, 3), dtype="float32")
5460
x_fg = FunctionGraph([], [x])

tests/tensor/test_basic.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -719,6 +719,38 @@ class TestAlloc:
719719
shared = staticmethod(pytensor.shared)
720720
allocs = [Alloc()] * 3
721721

722+
@staticmethod
723+
def check_allocs_in_fgraph(fgraph, n):
724+
assert (
725+
len([node for node in fgraph.apply_nodes if isinstance(node.op, Alloc)])
726+
== n
727+
)
728+
729+
@staticmethod
730+
def check_runtime_broadcast(mode):
731+
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
732+
floatX = config.floatX
733+
x_v = vector("x", shape=(None,))
734+
735+
out = alloc(x_v, 5, 3)
736+
f = pytensor.function([x_v], out, mode=mode)
737+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
738+
739+
np.testing.assert_array_equal(
740+
f(x=np.zeros((3,), dtype=floatX)),
741+
np.zeros((5, 3), dtype=floatX),
742+
)
743+
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
744+
f(x=np.zeros((1,), dtype=floatX))
745+
746+
out = alloc(specify_shape(x_v, (1,)), 5, 3)
747+
f = pytensor.function([x_v], out, mode=mode)
748+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
749+
750+
np.testing.assert_array_equal(
751+
f(x=np.zeros((1,), dtype=floatX)),
752+
np.zeros((5, 3), dtype=floatX),
753+
)
722754

723755
def setup_method(self):
724756
self.rng = np.random.default_rng(seed=utt.fetch_seed())
@@ -853,6 +885,8 @@ def test_static_shape(self):
853885

854886
def test_alloc_of_view_linker(self):
855887
"""Check we can allocate a new array properly in the C linker when input is a view."""
888+
floatX = config.floatX
889+
856890
x_v = vector("x", shape=(None,))
857891
dim_len = scalar("dim_len", dtype=int)
858892
out = alloc(specify_shape(x_v, (1,)), 5, dim_len)
@@ -862,7 +896,14 @@ def test_alloc_of_view_linker(self):
862896
f.maker.fgraph.outputs, [alloc(specify_shape(x_v, (1,)), 5, dim_len)]
863897
)
864898

865-
np.testing.assert_array_equal(f(x=np.zeros((1,)), dim_len=3), np.zeros((5, 3)))
899+
np.testing.assert_array_equal(
900+
f(x=np.zeros((1,), dtype=floatX), dim_len=3),
901+
np.zeros((5, 3), dtype=floatX),
902+
)
903+
904+
@pytest.mark.parametrize("mode", (Mode("py"), Mode("c")))
905+
def test_runtime_broadcast(self, mode):
906+
self.check_runtime_broadcast(mode)
866907

867908

868909
def test_infer_shape():

0 commit comments

Comments
 (0)