-
Notifications
You must be signed in to change notification settings - Fork 134
Add shape_unsafe
tag to rewrites that can hide shape errors
#381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
70033c9
28037cd
add8d5f
dd8462a
bd918db
548c14a
2ac8774
84c46f1
9f8ed94
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ | |
from pytensor.scalar import upcast | ||
from pytensor.tensor import as_tensor_variable | ||
from pytensor.tensor import basic as at | ||
from pytensor.tensor import get_vector_length | ||
from pytensor.tensor.basic import alloc, second | ||
from pytensor.tensor.exceptions import NotScalarConstantError | ||
from pytensor.tensor.math import abs as pt_abs | ||
from pytensor.tensor.math import all as pt_all | ||
|
@@ -1584,141 +1584,6 @@ def broadcast_shape_iter( | |
return tuple(result_dims) | ||
|
||
|
||
class BroadcastTo(COp): | ||
"""An `Op` for `numpy.broadcast_to`.""" | ||
|
||
_output_type_depends_on_input_value = True | ||
|
||
__props__ = () | ||
|
||
view_map = {0: [0]} | ||
|
||
def __call__(self, a, shape, **kwargs): | ||
return super().__call__(a, *shape, **kwargs) | ||
|
||
def make_node(self, a, *shape): | ||
a = at.as_tensor_variable(a) | ||
|
||
shape, static_shape = at.infer_static_shape(shape) | ||
|
||
if len(shape) < a.ndim: | ||
raise ValueError( | ||
f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims" | ||
) | ||
|
||
out = TensorType(dtype=a.type.dtype, shape=static_shape)() | ||
|
||
# Attempt to prevent in-place operations on this view-based output | ||
out.tag.indestructible = True | ||
|
||
return Apply(self, [a] + shape, [out]) | ||
|
||
def perform(self, node, inputs, output_storage): | ||
a, *shape = inputs | ||
z = output_storage[0] | ||
z[0] = np.broadcast_to(a, shape) | ||
|
||
def grad(self, inputs, outputs_gradients): | ||
a, *shape = inputs | ||
(dout,) = outputs_gradients | ||
|
||
# Determine the dimensions that were added by broadcasting | ||
new_dims = list(range(dout.ndim - a.ndim)) | ||
|
||
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims) | ||
|
||
# Determine the dimensions that were broadcast | ||
_, static_shape = at.infer_static_shape(shape) | ||
|
||
# TODO: This needs to be performed at run-time when static shape | ||
# information isn't available. | ||
bcast_sums = [ | ||
i | ||
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :])) | ||
if a_s == 1 and s_s != 1 | ||
] | ||
|
||
if bcast_sums: | ||
d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True) | ||
|
||
return [d_wrt_a] + [ | ||
grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1) | ||
] | ||
|
||
def infer_shape(self, fgraph, node, ins_shapes): | ||
return [node.inputs[1:]] | ||
|
||
def c_code(self, node, name, inputs, outputs, sub): | ||
inp_dims = node.inputs[0].ndim | ||
out_dims = node.outputs[0].ndim | ||
new_dims = out_dims - inp_dims | ||
|
||
(x, *shape) = inputs | ||
(out,) = outputs | ||
fail = sub["fail"] | ||
|
||
# TODO: Could just use `PyArray_Return`, no? | ||
dims_array = ", ".join( | ||
[ | ||
f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]" | ||
for i, shape in enumerate(shape) | ||
] | ||
) | ||
|
||
src = ( | ||
""" | ||
npy_intp itershape[%(out_dims)s] = {%(dims_array)s}; | ||
|
||
NpyIter *iter; | ||
PyArrayObject *ops[1] = {%(x)s}; | ||
npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK; | ||
npy_uint32 op_flags[1] = {NPY_ITER_READONLY}; | ||
PyArray_Descr *op_dtypes[1] = {NULL}; | ||
int oa_ndim = %(out_dims)s; | ||
int* op_axes[1] = {NULL}; | ||
npy_intp buffersize = 0; | ||
|
||
for(int i = 0; i < %(inp_dims)s; i++) | ||
{ | ||
if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s])) | ||
{ | ||
PyErr_Format(PyExc_ValueError, | ||
"Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.", | ||
i, | ||
(long long int) itershape[i + %(new_dims)s], | ||
(long long int) PyArray_DIMS(%(x)s)[i] | ||
); | ||
%(fail)s | ||
} | ||
} | ||
|
||
iter = NpyIter_AdvancedNew( | ||
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize | ||
); | ||
%(out)s = NpyIter_GetIterView(iter, 0); | ||
|
||
if(%(out)s == NULL){ | ||
NpyIter_Deallocate(iter); | ||
%(fail)s; | ||
} | ||
|
||
if (NpyIter_Deallocate(iter) != NPY_SUCCEED) { | ||
%(fail)s; | ||
} | ||
|
||
""" | ||
% locals() | ||
) | ||
|
||
return src | ||
|
||
def c_code_cache_version(self): | ||
return (2,) | ||
|
||
|
||
broadcast_to_ = BroadcastTo() | ||
|
||
|
||
def geomspace(start, end, steps, base=10.0): | ||
from pytensor.tensor.math import log | ||
|
||
|
@@ -1762,13 +1627,7 @@ def broadcast_to( | |
broadcasted array may refer to a single memory location. | ||
|
||
""" | ||
x = at.as_tensor(x) | ||
shape_len = get_vector_length(shape) | ||
|
||
if x.ndim == 0 and shape_len == 0: | ||
return x | ||
|
||
return broadcast_to_(x, shape) | ||
return alloc(x, *shape) | ||
|
||
|
||
def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: | ||
|
@@ -1780,7 +1639,19 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: | |
The arrays to broadcast. | ||
|
||
""" | ||
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args) | ||
|
||
def broadcast_with_others(a, others): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We discussed with @aseyboldt that it may make sense to generalize |
||
for other in others: | ||
a = second(other, a) | ||
return a | ||
|
||
brodacasted_vars = [] | ||
for i, a in enumerate(args): | ||
# We use indexing and not identity in case there are duplicated variables | ||
others = [a for j, a in enumerate(args) if j != i] | ||
brodacasted_vars.append(broadcast_with_others(a, others)) | ||
|
||
return brodacasted_vars | ||
|
||
|
||
__all__ = [ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BroadcastTo
is imported in pymc a couple of times. Maybe we should leave an empty Op here, that is deprecated and doesn't do anything?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of the removed rewrites are also directly imported.
This shouldn't be a problem however. I marked this PR as a major release so we will bump the version above the upper-bound pinned by PyMC. When we update the pin on PyMC I'll address the changes. They require some manual review anyway to see if the logic that depended on BroadcastTo was valid per our new rules and can be transferred to Alloc.
This was all on the logprob inference module AFAICT so impact should be pretty contained.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good