Skip to content

Add shape_unsafe tag to rewrites that can hide shape errors #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 7, 2023
20 changes: 1 addition & 19 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,25 +682,7 @@ def add_traceback_configvars():


def add_experimental_configvars():
config.add(
"experimental__local_alloc_elemwise",
"DEPRECATED: If True, enable the experimental"
" optimization local_alloc_elemwise."
" Generates error if not True. Use"
" optimizer_excluding=local_alloc_elemwise"
" to disable.",
BoolParam(True),
in_c_key=False,
)

# False could make the graph faster but not as safe.
config.add(
"experimental__local_alloc_elemwise_assert",
"When the local_alloc_elemwise is applied, add"
" an assert to highlight shape errors.",
BoolParam(True),
in_c_key=False,
)
return


def add_error_and_warning_configvars():
Expand Down
14 changes: 0 additions & 14 deletions pytensor/link/jax/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import jax.numpy as jnp

from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.basic import infer_static_shape
from pytensor.tensor.extra_ops import (
Bartlett,
BroadcastTo,
CumOp,
FillDiagonal,
FillDiagonalOffset,
Expand Down Expand Up @@ -102,18 +100,6 @@ def ravelmultiindex(*inp, mode=mode, order=order):
return ravelmultiindex


@jax_funcify.register(BroadcastTo)
def jax_funcify_BroadcastTo(op, node, **kwargs):
shape = node.inputs[1:]
static_shape = infer_static_shape(shape)[1]

def broadcast_to(x, *shape):
shape = tuple(st if st is not None else s for s, st in zip(shape, static_shape))
return jnp.broadcast_to(x, shape)

return broadcast_to


@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op, **kwargs):
def filldiagonal(value, diagonal):
Expand Down
25 changes: 0 additions & 25 deletions pytensor/link/numba/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,13 @@

import numba
import numpy as np
from numba.misc.special import literal_unroll

from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import get_numba_type, numba_funcify
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.extra_ops import (
Bartlett,
BroadcastTo,
CumOp,
FillDiagonal,
FillDiagonalOffset,
Expand Down Expand Up @@ -353,29 +351,6 @@ def searchsorted(a, v):
return searchsorted


@numba_funcify.register(BroadcastTo)
def numba_funcify_BroadcastTo(op, node, **kwargs):
create_zeros_tuple = numba_basic.create_tuple_creator(
lambda _: 0, len(node.inputs) - 1
)

# TODO broadcastable checks
@numba_basic.numba_njit
def broadcast_to(x, *shape):
scalars_shape = create_zeros_tuple()

i = 0
for s_i in literal_unroll(shape):
scalars_shape = numba_basic.tuple_setitem(
scalars_shape, i, numba_basic.to_scalar(s_i)
)
i += 1

return np.broadcast_to(x, scalars_shape)

return broadcast_to


@numba_funcify.register(CheckAndRaise)
def numba_funcify_CheckAndRaise(op, node, **kwargs):
error = op.exc_type
Expand Down
45 changes: 37 additions & 8 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,12 @@ def switch(cond, ift, iff):

@scalar_elemwise
def second(a, b):
"""Create a matrix by filling the shape of a with b"""
"""Create a matrix by filling the broadcasted shapes of a and b with the values of b

Equivalent to `np.broadcast_arrays(a, b)[1]`
Equivalent to `np.array(a).fill(b)` when b is a scalar value.

"""


fill = second
Expand Down Expand Up @@ -1427,17 +1432,41 @@ class Alloc(COp):
__props__ = ()

def make_node(self, value, *shape):
v = as_tensor_variable(value)
sh, static_shape = infer_static_shape(shape)
if v.ndim > len(sh):
value = as_tensor_variable(value)
shape, static_shape = infer_static_shape(shape)
if value.ndim > len(shape):
raise TypeError(
"The Alloc value to use has more dimensions"
" than the specified dimensions",
v.ndim,
len(sh),
value.ndim,
len(shape),
)

# Combine static shape information from value and shape
combined_static_shape = list(static_shape).copy()
new_dims = len(shape) - value.type.ndim
extended_value_static_shape = (None,) * new_dims + value.type.shape
extended_value_broadcastable = (False,) * new_dims + value.type.broadcastable
for i, (v_bc, v_st, sh_st) in enumerate(
zip(
extended_value_broadcastable,
extended_value_static_shape,
static_shape,
)
otype = TensorType(dtype=v.dtype, shape=static_shape)
return Apply(self, [v] + sh, [otype()])
):
# If value is not broadcastable and we don't know the target static shape: use value static shape
if (not v_bc) and (sh_st is None):
combined_static_shape[i] = v_st
# Otherwise check if static shapes are compatible
elif (v_st is not None) and (sh_st is not None):
# They must match or if not, the value must be broadcastable
if v_st != sh_st and not v_bc:
raise ValueError(
f"Alloc static input type and target shape are incompatible: {value.type} vs {static_shape}"
)

otype = TensorType(dtype=value.dtype, shape=combined_static_shape)
return Apply(self, [value] + shape, [otype()])

def perform(self, node, inputs, out_):
(out,) = out_
Expand Down
159 changes: 15 additions & 144 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pytensor.scalar import upcast
from pytensor.tensor import as_tensor_variable
from pytensor.tensor import basic as at
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import alloc, second
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as pt_abs
from pytensor.tensor.math import all as pt_all
Expand Down Expand Up @@ -1584,141 +1584,6 @@ def broadcast_shape_iter(
return tuple(result_dims)


class BroadcastTo(COp):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BroadcastTo is imported in pymc a couple of times. Maybe we should leave an empty Op here, that is deprecated and doesn't do anything?

Copy link
Member Author

@ricardoV94 ricardoV94 Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of the removed rewrites are also directly imported.

This shouldn't be a problem however. I marked this PR as a major release so we will bump the version above the upper-bound pinned by PyMC. When we update the pin on PyMC I'll address the changes. They require some manual review anyway to see if the logic that depended on BroadcastTo was valid per our new rules and can be transferred to Alloc.

This was all on the logprob inference module AFAICT so impact should be pretty contained.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

"""An `Op` for `numpy.broadcast_to`."""

_output_type_depends_on_input_value = True

__props__ = ()

view_map = {0: [0]}

def __call__(self, a, shape, **kwargs):
return super().__call__(a, *shape, **kwargs)

def make_node(self, a, *shape):
a = at.as_tensor_variable(a)

shape, static_shape = at.infer_static_shape(shape)

if len(shape) < a.ndim:
raise ValueError(
f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims"
)

out = TensorType(dtype=a.type.dtype, shape=static_shape)()

# Attempt to prevent in-place operations on this view-based output
out.tag.indestructible = True

return Apply(self, [a] + shape, [out])

def perform(self, node, inputs, output_storage):
a, *shape = inputs
z = output_storage[0]
z[0] = np.broadcast_to(a, shape)

def grad(self, inputs, outputs_gradients):
a, *shape = inputs
(dout,) = outputs_gradients

# Determine the dimensions that were added by broadcasting
new_dims = list(range(dout.ndim - a.ndim))

d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)

# Determine the dimensions that were broadcast
_, static_shape = at.infer_static_shape(shape)

# TODO: This needs to be performed at run-time when static shape
# information isn't available.
bcast_sums = [
i
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
if a_s == 1 and s_s != 1
]

if bcast_sums:
d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True)

return [d_wrt_a] + [
grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1)
]

def infer_shape(self, fgraph, node, ins_shapes):
return [node.inputs[1:]]

def c_code(self, node, name, inputs, outputs, sub):
inp_dims = node.inputs[0].ndim
out_dims = node.outputs[0].ndim
new_dims = out_dims - inp_dims

(x, *shape) = inputs
(out,) = outputs
fail = sub["fail"]

# TODO: Could just use `PyArray_Return`, no?
dims_array = ", ".join(
[
f"((dtype_{shape}*)(PyArray_DATA({shape})))[0]"
for i, shape in enumerate(shape)
]
)

src = (
"""
npy_intp itershape[%(out_dims)s] = {%(dims_array)s};

NpyIter *iter;
PyArrayObject *ops[1] = {%(x)s};
npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK;
npy_uint32 op_flags[1] = {NPY_ITER_READONLY};
PyArray_Descr *op_dtypes[1] = {NULL};
int oa_ndim = %(out_dims)s;
int* op_axes[1] = {NULL};
npy_intp buffersize = 0;

for(int i = 0; i < %(inp_dims)s; i++)
{
if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s]))
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.",
i,
(long long int) itershape[i + %(new_dims)s],
(long long int) PyArray_DIMS(%(x)s)[i]
);
%(fail)s
}
}

iter = NpyIter_AdvancedNew(
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize
);
%(out)s = NpyIter_GetIterView(iter, 0);

if(%(out)s == NULL){
NpyIter_Deallocate(iter);
%(fail)s;
}

if (NpyIter_Deallocate(iter) != NPY_SUCCEED) {
%(fail)s;
}

"""
% locals()
)

return src

def c_code_cache_version(self):
return (2,)


broadcast_to_ = BroadcastTo()


def geomspace(start, end, steps, base=10.0):
from pytensor.tensor.math import log

Expand Down Expand Up @@ -1762,13 +1627,7 @@ def broadcast_to(
broadcasted array may refer to a single memory location.

"""
x = at.as_tensor(x)
shape_len = get_vector_length(shape)

if x.ndim == 0 and shape_len == 0:
return x

return broadcast_to_(x, shape)
return alloc(x, *shape)


def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
Expand All @@ -1780,7 +1639,19 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]:
The arrays to broadcast.

"""
return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args)

def broadcast_with_others(a, others):
Copy link
Member Author

@ricardoV94 ricardoV94 Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We discussed with @aseyboldt that it may make sense to generalize Second so that it accepts arbitrary many inputs and returns every variable as output. This would become a flat broadcast_arrays once Elemwised, and make rewrites easier to read. By overriding the __str__ we can also make it much more readable in debug_print than the current nested Second

for other in others:
a = second(other, a)
return a

brodacasted_vars = []
for i, a in enumerate(args):
# We use indexing and not identity in case there are duplicated variables
others = [a for j, a in enumerate(args) if j != i]
brodacasted_vars.append(broadcast_with_others(a, others))

return brodacasted_vars


__all__ = [
Expand Down
Loading