|
4 | 4 | import torch
|
5 | 5 | import torch.compiler
|
6 | 6 |
|
| 7 | +from pytensor.compile import PYTORCH |
7 | 8 | from pytensor.compile.builders import OpFromGraph
|
8 | 9 | from pytensor.compile.ops import DeepCopyOp
|
9 | 10 | from pytensor.graph.fg import FunctionGraph
|
@@ -137,25 +138,13 @@ def makevector(*x):
|
137 | 138 |
|
138 | 139 |
|
139 | 140 | @pytorch_funcify.register(OpFromGraph)
|
140 |
| -def pytorch_funcify_OpFromGraph(op, node=None, **kwargs): |
141 |
| - _ = kwargs.pop("storage_map", None) |
| 141 | +def pytorch_funcify_OpFromGraph(op, node, **kwargs): |
| 142 | + kwargs.pop("storage_map", None) |
142 | 143 |
|
143 |
| - fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) |
| 144 | + # Apply inner rewrites |
| 145 | + PYTORCH.optimizer(op.fgraph) |
144 | 146 |
|
145 |
| - if len(op.fgraph.outputs) > 1: |
146 |
| - |
147 |
| - def inner(*args): |
148 |
| - return fgraph_fn(*args) |
149 |
| - else: |
150 |
| - |
151 |
| - def inner(*args): |
152 |
| - return fgraph_fn(*args)[0] |
153 |
| - |
154 |
| - # Don't compile the inner function |
155 |
| - # This is due torch failing to create |
156 |
| - # guards when parent scoped closure variables |
157 |
| - # are used in conditional statements. |
158 |
| - # Instead of rewriting many portions of code |
159 |
| - # this will allow for only this small section to |
160 |
| - # not be compiled by the outer graph |
161 |
| - return torch.compiler.disable(inner) |
| 147 | + fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True) |
| 148 | + # Disable one step inlining to prevent torch from trying to import local functions |
| 149 | + # defined in `pytorch_funcify` |
| 150 | + return torch.compiler.disable(fgraph_fn, recursive=False) |
0 commit comments