File tree 1 file changed +6
-8
lines changed
pytensor/link/pytorch/dispatch 1 file changed +6
-8
lines changed Original file line number Diff line number Diff line change 1
1
from functools import singledispatch
2
+ from operator import itemgetter
2
3
from types import NoneType
3
4
4
5
import torch
5
6
6
- from pytensor .compile import PYTORCH
7
7
from pytensor .compile .builders import OpFromGraph
8
8
from pytensor .compile .ops import DeepCopyOp
9
9
from pytensor .graph .fg import FunctionGraph
@@ -140,11 +140,9 @@ def makevector(*x):
140
140
def pytorch_funcify_OpFromGraph (op , node = None , ** kwargs ):
141
141
_ = kwargs .pop ("storage_map" , None )
142
142
143
- PYTORCH .optimizer (op .fgraph )
144
143
fgraph_fn = torch .compile (pytorch_funcify (op .fgraph , ** kwargs ))
145
-
146
- def opfromgraph (* inputs , dim = op .fgraph .outputs ):
147
- res = fgraph_fn (* inputs )
148
- return res [0 ]
149
-
150
- return opfromgraph
144
+ return (
145
+ fgraph_fn
146
+ if len (op .fgraph .outputs ) > 1
147
+ else lambda * args : itemgetter (0 )(fgraph_fn (* args ))
148
+ )
You can’t perform that action at this time.
0 commit comments