Skip to content

Commit a10e6f4

Browse files
author
Ian Schweer
committed
Use squeeze_output
1 parent f4b5a79 commit a10e6f4

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

pytensor/link/pytorch/dispatch/blockwise.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
1111
batched_dims = op.batch_ndim(node)
1212
core_node = op._create_dummy_core_node(node.inputs)
1313
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
14-
inner_func = pytorch_funcify(core_fgraph)
14+
inner_func = pytorch_funcify(core_fgraph, squeeze_output=len(node.outputs) == 1)
1515

1616
for _ in range(batched_dims):
1717
inner_func = torch.vmap(inner_func)
@@ -27,9 +27,6 @@ def batcher(*inputs):
2727
for i in inputs
2828
]
2929
res = inner_func(*broadcast_inputs)
30-
if len(node.outputs) == 1:
31-
return res[0]
32-
else:
33-
return res
30+
return res
3431

3532
return batcher

0 commit comments

Comments
 (0)