Skip to content

Commit fdd5d5c

Browse files
committed
Disable torch dynamo
1 parent 9c64320 commit fdd5d5c

File tree

1 file changed

+17
-6
lines changed

1 file changed

+17
-6
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from functools import singledispatch
2-
from operator import itemgetter
32
from types import NoneType
43

54
import torch
@@ -140,9 +139,21 @@ def makevector(*x):
140139
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs):
141140
_ = kwargs.pop("storage_map", None)
142141

142+
# @todo: Torch compile doesn't capture the scope accounting
143+
# for op.fgraph, leading to an import error. Disable the
144+
# dynamo compile for these graphs
145+
import torch._dynamo.config
146+
147+
torch._dynamo.config.suppress_errors = True
148+
143149
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))
144-
return (
145-
fgraph_fn
146-
if len(op.fgraph.outputs) > 1
147-
else lambda *args: itemgetter(0)(fgraph_fn(*args))
148-
)
150+
if len(op.fgraph.outputs) > 1:
151+
152+
def inner(*args):
153+
return fgraph_fn(*args)
154+
else:
155+
156+
def inner(*args):
157+
return fgraph_fn(*args)[0]
158+
159+
return inner

0 commit comments

Comments
 (0)