We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b7d15ea commit fe4fc3bCopy full SHA for fe4fc3b
pytensor/link/pytorch/dispatch/elemwise.py
@@ -22,11 +22,9 @@ def elemwise_fn(*inputs):
22
def elemwise_fn(*inputs):
23
Elemwise._check_runtime_broadcast(node, inputs)
24
shaped_inputs = torch.broadcast_tensors(*inputs)
25
- if shaped_inputs[0].dim() == 1:
26
- ufunc = torch.vmap(base_fn)
27
- else:
28
- dims = (tuple(range(shaped_inputs[0].dim())),)
29
- ufunc = torch.vmap(base_fn, in_dims=dims)
+ ufunc = base_fn
+ for _ in range(shaped_inputs[0].dim()):
+ ufunc = torch.vmap(ufunc)
30
# @todo: This will fail for anything that calls
31
# `.item()`
32
return ufunc(*shaped_inputs)
0 commit comments