File tree 1 file changed +17
-6
lines changed
pytensor/link/pytorch/dispatch
1 file changed +17
-6
lines changed Original file line number Diff line number Diff line change 1
1
from functools import singledispatch
2
- from operator import itemgetter
3
2
from types import NoneType
4
3
5
4
import torch
@@ -140,9 +139,21 @@ def makevector(*x):
140
139
def pytorch_funcify_OpFromGraph (op , node = None , ** kwargs ):
141
140
_ = kwargs .pop ("storage_map" , None )
142
141
142
+ # @todo: Torch compile doesn't capture the scope accounting
143
+ # for op.fgraph, leading to an import error. Disable the
144
+ # dynamo compile for these graphs
145
+ import torch ._dynamo .config
146
+
147
+ torch ._dynamo .config .suppress_errors = True
148
+
143
149
fgraph_fn = torch .compile (pytorch_funcify (op .fgraph , ** kwargs ))
144
- return (
145
- fgraph_fn
146
- if len (op .fgraph .outputs ) > 1
147
- else lambda * args : itemgetter (0 )(fgraph_fn (* args ))
148
- )
150
+ if len (op .fgraph .outputs ) > 1 :
151
+
152
+ def inner (* args ):
153
+ return fgraph_fn (* args )
154
+ else :
155
+
156
+ def inner (* args ):
157
+ return fgraph_fn (* args )[0 ]
158
+
159
+ return inner
You can’t perform that action at this time.
0 commit comments