Skip to content

Blockwise improvements #532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1777,6 +1777,7 @@ def equal_computations(
ys: list[Union[np.ndarray, Variable]],
in_xs: Optional[list[Variable]] = None,
in_ys: Optional[list[Variable]] = None,
strict_dtype=True,
) -> bool:
"""Checks if PyTensor graphs represent the same computations.

Expand Down Expand Up @@ -1908,7 +1909,10 @@ def compare_nodes(nd_x, nd_y, common, different):
if dx != dy:
if isinstance(dx, Constant) and isinstance(dy, Constant):
if not dx.equals(dy):
return False
if strict_dtype:
return False
elif not np.array_equal(dx.data, dy.data):
return False
else:
return False

Expand Down
4 changes: 1 addition & 3 deletions pytensor/link/jax/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension")
if a.ndim == 2 or b.ndim == 2:
return jnp.einsum("n...j,nj...->n...", a, b)
return jnp.einsum("nij,njk->nik", a, b)
return jnp.matmul(a, b)

return batched_dot

Expand Down
2 changes: 2 additions & 0 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,8 @@ def numba_funcify_BatchedDot(op, node, **kwargs):

@numba_njit
def batched_dot(x, y):
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be prange instead of range? I don't know if we allow parallel computation in compiled numba code, or if doing so would even be useful/interesting.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know either. Also I am not sure how well that would play with further multiprocessing by the outside. We have quite some issues with Blas in PyMC because of that.

Perhaps @aseyboldt can weigh in?

Expand Down
36 changes: 31 additions & 5 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.graph.replace import _vectorize_node
from pytensor.graph.rewriting.db import EquilibriumDB
from pytensor.graph.type import HasShape, Type
from pytensor.link.c.op import COp
Expand All @@ -41,6 +42,7 @@
as_tensor_variable,
get_vector_length,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import (
Expand Down Expand Up @@ -1657,16 +1659,22 @@ def do_constant_folding(self, fgraph, node):
if not clients:
return False

for client in clients:
if client[0] == "output":
for client, idx in clients:
if client == "output":
# If the output is a constant, it will have to be deepcopied
# each time the function is called. So we do not fold.
return False
# Allow alloc to be lifted out of Elemwise before constant folding it
elif isinstance(client.op, Elemwise):
return None
# Same for Blockwise, unless it has no batch_dims
elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client):
return None
elif (
# The following ops work inplace of their input id 0.
client[1] == 0
idx == 0
and isinstance(
client[0].op,
client.op,
(
# Ops that will work inplace on the Alloc. So if they
# get constant_folded, they would copy the
Expand Down Expand Up @@ -3497,10 +3505,17 @@ def make_node(self, x):

if x.ndim < 2:
raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x)

out_shape = [
st_dim
for i, st_dim in enumerate(x.type.shape)
if i not in (self.axis1, self.axis2)
] + [None]

return Apply(
self,
[x],
[x.type.clone(dtype=x.dtype, shape=(None,) * (x.ndim - 1))()],
[x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()],
)

def perform(self, node, inputs, outputs):
Expand Down Expand Up @@ -3601,6 +3616,17 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
return ExtractDiag(offset, axis1, axis2)(a)


@_vectorize_node.register(ExtractDiag)
def vectorize_extract_diag(op: ExtractDiag, node, batched_x):
batched_ndims = batched_x.type.ndim - node.inputs[0].type.ndim
return diagonal(
batched_x,
offset=op.offset,
axis1=op.axis1 + batched_ndims,
axis2=op.axis2 + batched_ndims,
).owner


def trace(a, offset=0, axis1=0, axis2=1):
"""
Returns the sum along diagonals of the array.
Expand Down
Loading