Skip to content

Commit 10a841f

Browse files
author
Ian Schweer
committed
Disable the opfromgraph inner function from compiling
1 parent d98e68e commit 10a841f

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from types import NoneType
33

44
import torch
5+
import torch.compiler
56

67
from pytensor.compile.builders import OpFromGraph
78
from pytensor.compile.ops import DeepCopyOp
@@ -139,14 +140,8 @@ def makevector(*x):
139140
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs):
140141
_ = kwargs.pop("storage_map", None)
141142

142-
# @todo: Torch compile doesn't capture the scope accounting
143-
# for op.fgraph, leading to an import error. Disable the
144-
# dynamo compile for these graphs
145-
import torch._dynamo.config
146-
147-
torch._dynamo.config.suppress_errors = True
148-
149143
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))
144+
150145
if len(op.fgraph.outputs) > 1:
151146

152147
def inner(*args):
@@ -156,4 +151,11 @@ def inner(*args):
156151
def inner(*args):
157152
return fgraph_fn(*args)[0]
158153

159-
return inner
154+
# Don't compile the inner function
155+
# This is due torch failing to create
156+
# guards when parent scoped closure variables
157+
# are used in conditional statements.
158+
# Instead of rewriting many portions of code
159+
# this will allow for only this small section to
160+
# not be compiled by the outer graph
161+
return torch.compiler.disable(inner)

0 commit comments

Comments
 (0)