Skip to content

Commit fe4fc3b

Browse files
committed
Change ufunc
1 parent b7d15ea commit fe4fc3b

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

pytensor/link/pytorch/dispatch/elemwise.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@ def elemwise_fn(*inputs):
2222
def elemwise_fn(*inputs):
2323
Elemwise._check_runtime_broadcast(node, inputs)
2424
shaped_inputs = torch.broadcast_tensors(*inputs)
25-
if shaped_inputs[0].dim() == 1:
26-
ufunc = torch.vmap(base_fn)
27-
else:
28-
dims = (tuple(range(shaped_inputs[0].dim())),)
29-
ufunc = torch.vmap(base_fn, in_dims=dims)
25+
ufunc = base_fn
26+
for _ in range(shaped_inputs[0].dim()):
27+
ufunc = torch.vmap(ufunc)
3028
# @todo: This will fail for anything that calls
3129
# `.item()`
3230
return ufunc(*shaped_inputs)

0 commit comments

Comments
 (0)