Skip to content

Commit 14816dc

Browse files
committed
Use broadcast tensor
1 parent a570dbf commit 14816dc

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,25 @@ def pytorch_funcify_Elemwise(op, node, **kwargs):
1111
scalar_op = op.scalar_op
1212
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
1313

14-
def elemwise_fn(*inputs):
15-
Elemwise._check_runtime_broadcast(node, inputs)
16-
return base_fn(*inputs)
14+
if hasattr(scalar_op, "nfunc_spec") and hasattr(torch, scalar_op.nfunc_spec[0]):
15+
# torch can handle this scalar
16+
# broadcast, we'll let it.
17+
def elemwise_fn(*inputs):
18+
Elemwise._check_runtime_broadcast(node, inputs)
19+
return base_fn(*inputs)
20+
else:
21+
22+
def elemwise_fn(*inputs):
23+
Elemwise._check_runtime_broadcast(node, inputs)
24+
shaped_inputs = torch.broadcast_tensors(*inputs)
25+
if shaped_inputs[0].dim() == 1:
26+
ufunc = torch.vmap(base_fn)
27+
else:
28+
dims = (tuple(range(shaped_inputs[0].dim())),)
29+
ufunc = torch.vmap(base_fn, in_dims=dims)
30+
# @todo: This will fail for anything that calls
31+
# `.item()`
32+
return ufunc(*shaped_inputs)
1733

1834
return elemwise_fn
1935

0 commit comments

Comments
 (0)