Skip to content

Commit b29be45

Browse files
committed
Only disable one level of inlining
1 parent 10a841f commit b29be45

File tree

2 files changed

+11
-22
lines changed

2 files changed

+11
-22
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torch.compiler
66

7+
from pytensor.compile import PYTORCH
78
from pytensor.compile.builders import OpFromGraph
89
from pytensor.compile.ops import DeepCopyOp
910
from pytensor.graph.fg import FunctionGraph
@@ -137,25 +138,13 @@ def makevector(*x):
137138

138139

139140
@pytorch_funcify.register(OpFromGraph)
140-
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs):
141-
_ = kwargs.pop("storage_map", None)
141+
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
142+
kwargs.pop("storage_map", None)
142143

143-
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs))
144+
# Apply inner rewrites
145+
PYTORCH.optimizer(op.fgraph)
144146

145-
if len(op.fgraph.outputs) > 1:
146-
147-
def inner(*args):
148-
return fgraph_fn(*args)
149-
else:
150-
151-
def inner(*args):
152-
return fgraph_fn(*args)[0]
153-
154-
# Don't compile the inner function
155-
# This is due torch failing to create
156-
# guards when parent scoped closure variables
157-
# are used in conditional statements.
158-
# Instead of rewriting many portions of code
159-
# this will allow for only this small section to
160-
# not be compiled by the outer graph
161-
return torch.compiler.disable(inner)
147+
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
148+
# Disable one step inlining to prevent torch from trying to import local functions
149+
# defined in `pytorch_funcify`
150+
return torch.compiler.disable(fgraph_fn, recursive=False)

tests/link/pytorch/test_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ def compare_pytorch_and_py(
6868

6969
if len(fgraph.outputs) > 1:
7070
for j, p in zip(pytorch_res, py_res):
71-
assert_fn(j.cpu(), p)
71+
assert_fn(j.detach().cpu().numpy(), p)
7272
else:
73-
assert_fn([pytorch_res[0].cpu()], py_res)
73+
assert_fn(pytorch_res[0].detach().cpu().numpy(), py_res[0])
7474

7575
return pytensor_torch_fn, pytorch_res
7676

0 commit comments

Comments
 (0)