Skip to content

Commit f4b5a79

Browse files
committed
Address some PR comments
1 parent d199188 commit f4b5a79

File tree

2 files changed

+16
-21
lines changed

2 files changed

+16
-21
lines changed
Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,35 @@
11
import torch
2+
import torch.compiler
23

34
from pytensor.graph import FunctionGraph
45
from pytensor.link.pytorch.dispatch import pytorch_funcify
56
from pytensor.tensor.blockwise import Blockwise
6-
from pytensor.tensor.random.utils import params_broadcast_shapes
77

88

99
@pytorch_funcify.register(Blockwise)
1010
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
1111
batched_dims = op.batch_ndim(node)
1212
core_node = op._create_dummy_core_node(node.inputs)
1313
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
14-
core_func = pytorch_funcify(core_fgraph)
15-
if len(node.outputs) == 1:
16-
17-
def inner_func(*inputs):
18-
return core_func(*inputs)[0]
19-
else:
20-
inner_func = core_func
14+
inner_func = pytorch_funcify(core_fgraph)
2115

2216
for _ in range(batched_dims):
2317
inner_func = torch.vmap(inner_func)
2418

19+
@torch.compiler.disable(recursive=False)
2520
def batcher(*inputs):
2621
op._check_runtime_broadcast(node, inputs)
2722
# broadcast on batched_dims
28-
all_batched_dims = tuple(tuple(t.shape) for t in inputs)
29-
new_shapes = params_broadcast_shapes(
30-
all_batched_dims,
31-
ndims_params=[batched_dims] * len(inputs),
32-
use_pytensor=False,
33-
)
23+
all_batched_dims = tuple(t.shape[:batched_dims] for t in inputs)
24+
batched_shape = torch.broadcast_shapes(*all_batched_dims)
3425
broadcast_inputs = [
35-
torch.broadcast_to(i, s) for i, s in zip(inputs, new_shapes)
26+
torch.broadcast_to(i, batched_shape + i.shape[batched_dims:])
27+
for i in inputs
3628
]
37-
return inner_func(*broadcast_inputs)
29+
res = inner_func(*broadcast_inputs)
30+
if len(node.outputs) == 1:
31+
return res[0]
32+
else:
33+
return res
3834

3935
return batcher

tests/link/pytorch/test_blockwise.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import pytensor
55
import pytensor.tensor as pt
6-
from pytensor.tensor.shape import specify_broadcastable
76

87

98
torch = pytest.importorskip("torch")
@@ -13,10 +12,10 @@ def test_blockwise_broadcast():
1312
_x = np.random.rand(5, 1, 2, 3)
1413
_y = np.random.rand(3, 3, 2)
1514

16-
x = specify_broadcastable(pt.tensor4("x"), 1)
17-
y = pt.tensor3("y")
15+
x = pt.tensor4("x", shape=(5, 1, 2, 3))
16+
y = pt.tensor3("y", shape=(3, 3, 2))
1817

19-
f = pytensor.function([x, y], [x @ y], mode="PYTORCH")
20-
[res] = f(_x, _y)
18+
f = pytensor.function([x, y], x @ y, mode="PYTORCH")
19+
res = f(_x, _y)
2120
assert tuple(res.shape) == (5, 3, 2, 2)
2221
np.testing.assert_allclose(res, _x @ _y)

0 commit comments

Comments
 (0)