Skip to content

Commit 9c64320

Browse files
author
Ian Schweer
committed
Clean up num args based on graph
1 parent c4b20ec commit 9c64320

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from functools import singledispatch
2+
from operator import itemgetter
23
from types import NoneType
34

45
import torch
56

6-
from pytensor.compile import PYTORCH
77
from pytensor.compile.builders import OpFromGraph
88
from pytensor.compile.ops import DeepCopyOp
99
from pytensor.graph.fg import FunctionGraph
@@ -140,11 +140,9 @@ def makevector(*x):
140140
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs):
141141
_ = kwargs.pop("storage_map", None)
142142

143-
PYTORCH.optimizer(op.fgraph)
144143
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))
145-
146-
def opfromgraph(*inputs, dim=op.fgraph.outputs):
147-
res = fgraph_fn(*inputs)
148-
return res[0]
149-
150-
return opfromgraph
144+
return (
145+
fgraph_fn
146+
if len(op.fgraph.outputs) > 1
147+
else lambda *args: itemgetter(0)(fgraph_fn(*args))
148+
)

0 commit comments

Comments
 (0)