2
2
from types import NoneType
3
3
4
4
import torch
5
+ import torch .compiler
5
6
6
7
from pytensor .compile .builders import OpFromGraph
7
8
from pytensor .compile .ops import DeepCopyOp
@@ -139,14 +140,8 @@ def makevector(*x):
139
140
def pytorch_funcify_OpFromGraph (op , node = None , ** kwargs ):
140
141
_ = kwargs .pop ("storage_map" , None )
141
142
142
- # @todo: Torch compile doesn't capture the scope accounting
143
- # for op.fgraph, leading to an import error. Disable the
144
- # dynamo compile for these graphs
145
- import torch ._dynamo .config
146
-
147
- torch ._dynamo .config .suppress_errors = True
148
-
149
143
fgraph_fn = torch .compile (pytorch_funcify (op .fgraph , ** kwargs ))
144
+
150
145
if len (op .fgraph .outputs ) > 1 :
151
146
152
147
def inner (* args ):
@@ -156,4 +151,11 @@ def inner(*args):
156
151
def inner (* args ):
157
152
return fgraph_fn (* args )[0 ]
158
153
159
- return inner
154
+ # Don't compile the inner function
155
+ # This is due torch failing to create
156
+ # guards when parent scoped closure variables
157
+ # are used in conditional statements.
158
+ # Instead of rewriting many portions of code
159
+ # this will allow for only this small section to
160
+ # not be compiled by the outer graph
161
+ return torch .compiler .disable (inner )
0 commit comments