Skip to content

Commit acf8cb5

Browse files
committed
Ensure inputs are shaped
1 parent 60d6350 commit acf8cb5

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

pytensor/link/pytorch/dispatch/blockwise.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pytensor.graph import FunctionGraph
44
from pytensor.link.pytorch.dispatch import pytorch_funcify
55
from pytensor.tensor.blockwise import Blockwise
6+
from pytensor.tensor.random.utils import params_broadcast_shapes
67

78

89
@pytorch_funcify.register(Blockwise)
@@ -23,6 +24,16 @@ def inner_func(*inputs):
2324

2425
def batcher(*inputs):
2526
op._check_runtime_broadcast(node, inputs)
26-
return inner_func(*inputs)
27+
# broadcast on batched_dims
28+
all_batched_dims = tuple(tuple(t.shape) for t in inputs)
29+
new_shapes = params_broadcast_shapes(
30+
all_batched_dims,
31+
ndims_params=[batched_dims] * len(inputs),
32+
use_pytensor=False,
33+
)
34+
broadcast_inputs = [
35+
torch.broadcast_to(i, s) for i, s in zip(inputs, new_shapes)
36+
]
37+
return inner_func(*broadcast_inputs)
2738

2839
return batcher

tests/link/pytorch/test_blockwise.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
1+
import numpy as np
12
import pytest
23

4+
import pytensor
5+
import pytensor.tensor as pt
36
from pytensor.graph.replace import vectorize_node
47
from pytensor.tensor import tensor
58
from pytensor.tensor.blockwise import Blockwise
69
from pytensor.tensor.nlinalg import MatrixInverse
10+
from pytensor.tensor.shape import specify_broadcastable
711

812

913
torch = pytest.importorskip("torch")
@@ -30,3 +34,16 @@ def test_vectorize_blockwise():
3034
new_vect_node.op.core_op, MatrixInverse
3135
)
3236
assert new_vect_node.inputs[0] is tns4
37+
38+
39+
def test_blockwise_broadcast():
40+
_x = np.random.rand(5, 1, 2, 3)
41+
_y = np.random.rand(3, 3, 2)
42+
43+
x = specify_broadcastable(pt.tensor4("x"), 1)
44+
y = pt.tensor3("y")
45+
46+
f = pytensor.function([x, y], [x @ y], mode="PYTORCH")
47+
[res] = f(_x, _y)
48+
assert tuple(res.shape) == (5, 3, 2, 2)
49+
np.testing.assert_allclose(res, _x @ _y)

0 commit comments

Comments
 (0)