-
Notifications
You must be signed in to change notification settings - Fork 134
Forbid runtime broadcasting by Alloc
#390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1431,6 +1431,12 @@ class Alloc(COp): | |
|
||
__props__ = () | ||
|
||
_runtime_broadcast_error_msg = ( | ||
"Runtime broadcasting not allowed. " | ||
"The output of Alloc requires broadcasting a dimension of the input value, which was not marked as broadcastable. " | ||
"If broadcasting was intended, use `specify_broadcastable` on the relevant input." | ||
) | ||
|
||
def make_node(self, value, *shape): | ||
value = as_tensor_variable(value) | ||
shape, static_shape = infer_static_shape(shape) | ||
|
@@ -1468,10 +1474,21 @@ def make_node(self, value, *shape): | |
otype = TensorType(dtype=value.dtype, shape=combined_static_shape) | ||
return Apply(self, [value] + shape, [otype()]) | ||
|
||
@staticmethod | ||
def _check_runtime_broadcast(node, value, shape): | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
value_static_shape = node.inputs[0].type.shape | ||
for v_static_dim, value_dim, out_dim in zip( | ||
value_static_shape[::-1], value.shape[::-1], shape[::-1] | ||
): | ||
if v_static_dim is None and value_dim == 1 and out_dim != 1: | ||
raise ValueError(Alloc._runtime_broadcast_error_msg) | ||
|
||
def perform(self, node, inputs, out_): | ||
(out,) = out_ | ||
v = inputs[0] | ||
sh = tuple([int(i) for i in inputs[1:]]) | ||
self._check_runtime_broadcast(node, v, sh) | ||
|
||
if out[0] is None or out[0].shape != sh: | ||
if v.size == 1 and v.item() == 0: | ||
out[0] = np.zeros(sh, dtype=v.dtype) | ||
|
@@ -1484,51 +1501,63 @@ def perform(self, node, inputs, out_): | |
|
||
def c_code(self, node, name, inp, out, sub): | ||
vv = inp[0] | ||
ndim = len(inp[1:]) | ||
(zz,) = out | ||
fail = sub["fail"] | ||
|
||
v_static_shape = node.inputs[0].type.shape | ||
o_static_shape = node.outputs[0].type.shape | ||
v_ndim = len(v_static_shape) | ||
o_ndim = len(o_static_shape) | ||
assert o_ndim == len(inp[1:]) | ||
|
||
# Declare variables | ||
code = f""" | ||
npy_intp shape[{ndim}]; | ||
npy_intp shape[{o_ndim}]; | ||
int need_new_out; | ||
""" | ||
|
||
# Initialize shape | ||
for i, shp_i in enumerate(inp[1:]): | ||
code += """ | ||
shape[%(i)s] = ((dtype_%(shp_i)s*) PyArray_DATA(%(shp_i)s))[0]; | ||
""" % dict( | ||
i=i, shp_i=shp_i | ||
) | ||
code += f""" | ||
shape[{i}] = ((dtype_{shp_i}*) PyArray_DATA({shp_i}))[0]; | ||
""" | ||
|
||
# Add checks for runtime broadcasting | ||
for i, v_static_dim in enumerate(v_static_shape[::-1]): | ||
if v_static_dim is None: | ||
code += f""" | ||
if (PyArray_DIMS({vv})[{v_ndim - i - 1}] == 1 && shape[{o_ndim - i - 1}] != 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we be sure that the arrays are long enough for the indices? Unless this is guaranteed for some reason even for invalid inputs I think an explicit check would be good. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The docs say this is guaranteed. https://pytensor.readthedocs.io/en/latest/extending/creating_a_c_op.html#simple-cop-example
This is not true if This seems to align with the pre-existing check for output having right shape (they index in a loop without checking if ndims are enough). |
||
{{ | ||
PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}"); | ||
{fail} | ||
}} | ||
""" | ||
|
||
code += """ | ||
int need_new_out = (NULL == %(zz)s); | ||
for (int i = 0; i < %(ndim)s; i++) | ||
need_new_out = (need_new_out | ||
|| (PyArray_DIMS(%(zz)s)[i] != shape[i])); | ||
code += f""" | ||
need_new_out = (NULL == {zz}); | ||
for (int i = 0; i < {o_ndim}; i++) | ||
need_new_out = (need_new_out || (PyArray_DIMS({zz})[i] != shape[i])); | ||
|
||
if (need_new_out) | ||
{ | ||
Py_XDECREF(%(zz)s); | ||
%(zz)s = (PyArrayObject*) PyArray_SimpleNew(%(ndim)s, | ||
shape, PyArray_TYPE((PyArrayObject*) py_%(vv)s)); | ||
if (!%(zz)s) | ||
{ | ||
{{ | ||
Py_XDECREF({zz}); | ||
{zz} = (PyArrayObject*) PyArray_SimpleNew({o_ndim}, shape, PyArray_TYPE({vv})); | ||
if (!{zz}) | ||
{{ | ||
PyErr_SetString(PyExc_MemoryError, "alloc failed"); | ||
%(fail)s | ||
} | ||
} | ||
{fail} | ||
}} | ||
}} | ||
|
||
// This function takes care of broadcasting | ||
if (PyArray_CopyInto(%(zz)s, %(vv)s) == -1) | ||
%(fail)s | ||
""" % dict( | ||
vv=vv, ndim=ndim, zz=zz, fail=fail | ||
) | ||
if (PyArray_CopyInto({zz}, {vv}) == -1) | ||
{fail} | ||
""" | ||
|
||
return code | ||
|
||
def c_code_cache_version(self): | ||
return (2,) | ||
return (4,) | ||
|
||
def infer_shape(self, fgraph, node, input_shapes): | ||
return [node.inputs[1:]] | ||
|
@@ -1568,7 +1597,7 @@ def grad(self, inputs, grads): | |
for idx, axis in enumerate(axis_kept): | ||
new_order[axis] = idx | ||
gx = gx.dimshuffle(new_order) | ||
# Dimshuffle to add back the broadcasted dims | ||
# Dimshuffle to add back the broadcasted dims | ||
# The *elements* of the output are not connected to | ||
# the inputs that specify the shape. If you grow the | ||
# shape by epsilon, the existing elements do not | ||
|
Uh oh!
There was an error while loading. Please reload this page.