We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 14816dc commit 331ffe7Copy full SHA for 331ffe7
pytensor/link/pytorch/dispatch/elemwise.py
@@ -22,11 +22,9 @@ def elemwise_fn(*inputs):
22
def elemwise_fn(*inputs):
23
Elemwise._check_runtime_broadcast(node, inputs)
24
shaped_inputs = torch.broadcast_tensors(*inputs)
25
- if shaped_inputs[0].dim() == 1:
26
- ufunc = torch.vmap(base_fn)
27
- else:
28
- dims = (tuple(range(shaped_inputs[0].dim())),)
29
- ufunc = torch.vmap(base_fn, in_dims=dims)
+ ufunc = base_fn
+ for _ in range(shaped_inputs[0].dim()):
+ ufunc = torch.vmap(ufunc)
30
# @todo: This will fail for anything that calls
31
# `.item()`
32
return ufunc(*shaped_inputs)
0 commit comments