Skip to content

Commit 81e6d66

Browse files
committed
Forbid runtime broadcasting in Alloc
1 parent 8f5d5fa commit 81e6d66

File tree

6 files changed

+109
-9
lines changed

6 files changed

+109
-9
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ def allocempty(*shape):
5151

5252

5353
@jax_funcify.register(Alloc)
54-
def jax_funcify_Alloc(op, **kwargs):
54+
def jax_funcify_Alloc(op, node, **kwargs):
5555
def alloc(x, *shape):
5656
res = jnp.broadcast_to(x, shape)
57+
Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
5758
return res
5859

5960
return alloc

pytensor/link/numba/dispatch/tensor_basic.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,16 +78,24 @@ def numba_funcify_Alloc(op, node, **kwargs):
7878
" " * 4,
7979
)
8080

81+
check_runtime_broadcast = []
82+
for i, val_static_dim in enumerate(node.inputs[0].type.shape[::-1]):
83+
if val_static_dim is None:
84+
check_runtime_broadcast.append(
85+
f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")'
86+
)
87+
check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4)
88+
8189
alloc_def_src = f"""
8290
def alloc(val, {", ".join(shape_var_names)}):
8391
val_np = np.asarray(val)
8492
{shapes_to_items_src}
8593
scalar_shape = {create_tuple_string(shape_var_item_names)}
94+
{check_runtime_broadcast_src}
8695
res = np.empty(scalar_shape, dtype=val_np.dtype)
8796
res[...] = val_np
8897
return res
8998
"""
90-
9199
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})
92100

93101
return numba_basic.numba_njit(alloc_fn)

pytensor/tensor/basic.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,6 +1431,12 @@ class Alloc(COp):
14311431

14321432
__props__ = ()
14331433

1434+
_runtime_broadcast_error_msg = (
1435+
"Runtime broadcasting not allowed. "
1436+
"The output of ALloc requires broadcasting a dimension of the input value, which was not marked as broadcastable. "
1437+
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
1438+
)
1439+
14341440
def make_node(self, value, *shape):
14351441
value = as_tensor_variable(value)
14361442
shape, static_shape = infer_static_shape(shape)
@@ -1468,10 +1474,21 @@ def make_node(self, value, *shape):
14681474
otype = TensorType(dtype=value.dtype, shape=combined_static_shape)
14691475
return Apply(self, [value] + shape, [otype()])
14701476

1477+
@staticmethod
1478+
def _check_runtime_broadcast(node, value, shape):
1479+
value_static_shape = node.inputs[0].type.shape
1480+
for v_static_dim, value_dim, out_dim in zip(
1481+
value_static_shape[::-1], value.shape[::-1], shape[::-1]
1482+
):
1483+
if v_static_dim is None and value_dim == 1 and out_dim != 1:
1484+
raise ValueError(Alloc._runtime_broadcast_error_msg)
1485+
14711486
def perform(self, node, inputs, out_):
14721487
(out,) = out_
14731488
v = inputs[0]
14741489
sh = tuple([int(i) for i in inputs[1:]])
1490+
self._check_runtime_broadcast(node, v, sh)
1491+
14751492
if out[0] is None or out[0].shape != sh:
14761493
if v.size == 1 and v.item() == 0:
14771494
out[0] = np.zeros(sh, dtype=v.dtype)
@@ -1484,12 +1501,19 @@ def perform(self, node, inputs, out_):
14841501

14851502
def c_code(self, node, name, inp, out, sub):
14861503
vv = inp[0]
1487-
ndim = len(inp[1:])
14881504
(zz,) = out
14891505
fail = sub["fail"]
14901506

1507+
v_static_shape = node.inputs[0].type.shape
1508+
o_static_shape = node.outputs[0].type.shape
1509+
v_ndim = len(v_static_shape)
1510+
o_ndim = len(o_static_shape)
1511+
assert o_ndim == len(inp[1:])
1512+
1513+
# Declare variables
14911514
code = f"""
1492-
npy_intp shape[{ndim}];
1515+
npy_intp shape[{o_ndim}];
1516+
int need_new_out;
14931517
"""
14941518

14951519
# Initialize shape
@@ -1498,15 +1522,28 @@ def c_code(self, node, name, inp, out, sub):
14981522
shape[{i}] = ((dtype_{shp_i}*) PyArray_DATA({shp_i}))[0];
14991523
"""
15001524

1525+
# Add checks for runtime broadcasting
1526+
for i, (v_static_dim, out_static_dim) in enumerate(
1527+
zip(v_static_shape[::-1], o_static_shape[::-1])
1528+
):
1529+
if v_static_dim is None:
1530+
code += f"""
1531+
if (PyArray_DIMS({vv})[{v_ndim - i - 1}] == 1 && shape[{o_ndim - i - 1}] != 1)
1532+
{{
1533+
PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
1534+
{fail}
1535+
}}
1536+
"""
1537+
15011538
code += f"""
1502-
int need_new_out = (NULL == {zz});
1503-
for (int i = 0; i < {ndim}; i++)
1539+
need_new_out = (NULL == {zz});
1540+
for (int i = 0; i < {o_ndim}; i++)
15041541
need_new_out = (need_new_out || (PyArray_DIMS({zz})[i] != shape[i]));
15051542
15061543
if (need_new_out)
15071544
{{
15081545
Py_XDECREF({zz});
1509-
{zz} = (PyArrayObject*) PyArray_SimpleNew({ndim}, shape, PyArray_TYPE({vv}));
1546+
{zz} = (PyArrayObject*) PyArray_SimpleNew({o_ndim}, shape, PyArray_TYPE({vv}));
15101547
if (!{zz})
15111548
{{
15121549
PyErr_SetString(PyExc_MemoryError, "alloc failed");
@@ -1522,7 +1559,7 @@ def c_code(self, node, name, inp, out, sub):
15221559
return code
15231560

15241561
def c_code_cache_version(self):
1525-
return (3,)
1562+
return (4,)
15261563

15271564
def infer_shape(self, fgraph, node, input_shapes):
15281565
return [node.inputs[1:]]

tests/link/jax/test_tensor_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import pytest
33

4+
from pytensor.compile import get_mode
5+
46

57
jax = pytest.importorskip("jax")
68
import jax.errors
@@ -12,6 +14,7 @@
1214
from pytensor.graph.op import get_test_value
1315
from pytensor.tensor.type import iscalar, matrix, scalar, vector
1416
from tests.link.jax.test_basic import compare_jax_and_py
17+
from tests.tensor.test_basic import TestAlloc
1518

1619

1720
def test_jax_Alloc():
@@ -50,6 +53,10 @@ def compare_shape_dtype(x, y):
5053
compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)])
5154

5255

56+
def test_alloc_runtime_broadcast():
57+
TestAlloc.check_runtime_broadcast(get_mode("JAX"))
58+
59+
5360
def test_jax_MakeVector():
5461
x = at.make_vector(1, 2, 3)
5562
x_fg = FunctionGraph([], [x])

tests/link/numba/test_tensor_basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import pytensor.tensor as at
66
import pytensor.tensor.basic as atb
77
from pytensor import config, function
8+
from pytensor.compile import get_mode
89
from pytensor.compile.sharedvalue import SharedVariable
910
from pytensor.graph.basic import Constant
1011
from pytensor.graph.fg import FunctionGraph
@@ -15,6 +16,7 @@
1516
compare_shape_dtype,
1617
set_test_value,
1718
)
19+
from tests.tensor.test_basic import TestAlloc
1820

1921

2022
rng = np.random.default_rng(42849)
@@ -45,6 +47,10 @@ def test_Alloc(v, shape):
4547
assert numba_res.shape == shape
4648

4749

50+
def test_alloc_runtime_broadcast():
51+
TestAlloc.check_runtime_broadcast(get_mode("NUMBA"))
52+
53+
4854
def test_AllocEmpty():
4955
x = at.empty((2, 3), dtype="float32")
5056
x_fg = FunctionGraph([], [x])

tests/tensor/test_basic.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,38 @@ class TestAlloc:
720720
shared = staticmethod(pytensor.shared)
721721
allocs = [Alloc()] * 3
722722

723+
@staticmethod
724+
def check_allocs_in_fgraph(fgraph, n):
725+
assert (
726+
len([node for node in fgraph.apply_nodes if isinstance(node.op, Alloc)])
727+
== n
728+
)
729+
730+
@staticmethod
731+
def check_runtime_broadcast(mode):
732+
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
733+
floatX = config.floatX
734+
x_v = vector("x", shape=(None,))
735+
736+
out = alloc(x_v, 5, 3)
737+
f = pytensor.function([x_v], out, mode=mode)
738+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
739+
740+
np.testing.assert_array_equal(
741+
f(x=np.zeros((3,), dtype=floatX)),
742+
np.zeros((5, 3), dtype=floatX),
743+
)
744+
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
745+
f(x=np.zeros((1,), dtype=floatX))
746+
747+
out = alloc(specify_shape(x_v, (1,)), 5, 3)
748+
f = pytensor.function([x_v], out, mode=mode)
749+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
750+
751+
np.testing.assert_array_equal(
752+
f(x=np.zeros((1,), dtype=floatX)),
753+
np.zeros((5, 3), dtype=floatX),
754+
)
723755

724756
def setup_method(self):
725757
self.rng = np.random.default_rng(seed=utt.fetch_seed())
@@ -854,6 +886,8 @@ def test_static_shape(self):
854886

855887
def test_alloc_of_view_linker(self):
856888
"""Check we can allocate a new array properly in the C linker when input is a view."""
889+
floatX = config.floatX
890+
857891
x_v = vector("x", shape=(None,))
858892
dim_len = scalar("dim_len", dtype=int)
859893
out = alloc(specify_shape(x_v, (1,)), 5, dim_len)
@@ -863,7 +897,14 @@ def test_alloc_of_view_linker(self):
863897
f.maker.fgraph.outputs, [alloc(specify_shape(x_v, (1,)), 5, dim_len)]
864898
)
865899

866-
np.testing.assert_array_equal(f(x=np.zeros((1,)), dim_len=3), np.zeros((5, 3)))
900+
np.testing.assert_array_equal(
901+
f(x=np.zeros((1,), dtype=floatX), dim_len=3),
902+
np.zeros((5, 3), dtype=floatX),
903+
)
904+
905+
@pytest.mark.parametrize("mode", (Mode("py"), Mode("c")))
906+
def test_runtime_broadcast(self, mode):
907+
self.check_runtime_broadcast(mode)
867908

868909

869910
def test_infer_shape():

0 commit comments

Comments
 (0)