Skip to content

Commit 503d939

Browse files
committed
Vectorize ExtractDiag
Also adds better static shapes
1 parent d469a61 commit 503d939

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

pytensor/tensor/basic.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pytensor.graph.basic import Apply, Constant, Variable
2727
from pytensor.graph.fg import FunctionGraph
2828
from pytensor.graph.op import Op
29+
from pytensor.graph.replace import _vectorize_node
2930
from pytensor.graph.rewriting.db import EquilibriumDB
3031
from pytensor.graph.type import HasShape, Type
3132
from pytensor.link.c.op import COp
@@ -3497,10 +3498,17 @@ def make_node(self, x):
34973498

34983499
if x.ndim < 2:
34993500
raise ValueError("ExtractDiag needs an input with 2 or more dimensions", x)
3501+
3502+
out_shape = [
3503+
st_dim
3504+
for i, st_dim in enumerate(x.type.shape)
3505+
if i not in (self.axis1, self.axis2)
3506+
] + [None]
3507+
35003508
return Apply(
35013509
self,
35023510
[x],
3503-
[x.type.clone(dtype=x.dtype, shape=(None,) * (x.ndim - 1))()],
3511+
[x.type.clone(dtype=x.dtype, shape=tuple(out_shape))()],
35043512
)
35053513

35063514
def perform(self, node, inputs, outputs):
@@ -3601,6 +3609,17 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
36013609
return ExtractDiag(offset, axis1, axis2)(a)
36023610

36033611

3612+
@_vectorize_node.register(ExtractDiag)
3613+
def vectorize_extract_diag(op: ExtractDiag, node, batched_x):
3614+
batched_ndims = batched_x.type.ndim - node.inputs[0].type.ndim
3615+
return diagonal(
3616+
batched_x,
3617+
offset=op.offset,
3618+
axis1=op.axis1 + batched_ndims,
3619+
axis2=op.axis2 + batched_ndims,
3620+
).owner
3621+
3622+
36043623
def trace(a, offset=0, axis1=0, axis2=1):
36053624
"""
36063625
Returns the sum along diagonals of the array.

tests/tensor/test_basic.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytensor.misc.safe_asarray import _asarray
2121
from pytensor.raise_op import Assert
2222
from pytensor.scalar import autocast_float, autocast_float_as
23-
from pytensor.tensor import NoneConst
23+
from pytensor.tensor import NoneConst, vectorize
2424
from pytensor.tensor.basic import (
2525
Alloc,
2626
AllocEmpty,
@@ -88,6 +88,7 @@
8888
vertical_stack,
8989
zeros_like,
9090
)
91+
from pytensor.tensor.blockwise import Blockwise
9192
from pytensor.tensor.elemwise import DimShuffle
9293
from pytensor.tensor.exceptions import NotScalarConstantError
9394
from pytensor.tensor.math import dense_dot
@@ -4517,3 +4518,26 @@ def test_trace():
45174518
trace(x, offset=-1, axis1=0, axis2=-1).eval(),
45184519
np.trace(x_val, offset=-1, axis1=0, axis2=-1),
45194520
)
4521+
4522+
4523+
def test_vectorize_extract_diag():
4524+
signature = "(a1,b,a2)->(b,a)"
4525+
4526+
def core_pt(x):
4527+
return at.diagonal(x, offset=1, axis1=0, axis2=2)
4528+
4529+
def core_np(x):
4530+
return np.diagonal(x, offset=1, axis1=0, axis2=2)
4531+
4532+
x = tensor(shape=(5, 5, 5, 5))
4533+
vectorize_pt = function([x], vectorize(core_pt, signature=signature)(x))
4534+
assert not any(
4535+
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
4536+
)
4537+
4538+
x_test = np.random.normal(size=x.type.shape).astype(x.type.dtype)
4539+
vectorize_np = np.vectorize(core_np, signature=signature)
4540+
np.testing.assert_allclose(
4541+
vectorize_pt(x_test),
4542+
vectorize_np(x_test),
4543+
)

0 commit comments

Comments
 (0)