-
Notifications
You must be signed in to change notification settings - Fork 132
Implement OpFromGraph in PyTorch backend #956
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
c4b20ec
9c64320
fdd5d5c
d98e68e
10a841f
b29be45
cefec02
0f18d8d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ | |
OPT_O3, | ||
OPT_STABILIZE, | ||
OPT_UNSAFE, | ||
PYTORCH, | ||
AddDestroyHandler, | ||
AddFeatureOptimizer, | ||
Mode, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
from functools import singledispatch | ||
from operator import itemgetter | ||
from types import NoneType | ||
|
||
import torch | ||
|
||
from pytensor.compile.builders import OpFromGraph | ||
from pytensor.compile.ops import DeepCopyOp | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.link.utils import fgraph_to_python | ||
|
@@ -132,3 +134,15 @@ def makevector(*x): | |
return torch.tensor(x, dtype=torch_dtype) | ||
|
||
return makevector | ||
|
||
|
||
@pytorch_funcify.register(OpFromGraph) | ||
def pytorch_funcify_OpFromGraph(op, node=None, **kwargs): | ||
_ = kwargs.pop("storage_map", None) | ||
|
||
fgraph_fn = torch.compile(pytorch_funcify(op.fgraph, **kwargs)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need to compile the inner function? Is that a thing in PyTorch? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was following what numba does where it jits the inner function - we could remove the inner torch.compile and just return op.fgraph if that seems more reasonable. That will still lead to some c-linker issues fwiw. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed the inner function, you only need to do indexing if the number of return values is more than 1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Numba can only have inner compiled functions, I don't know if that's a requirement in pytorch, and whether it has any advantages. We don't do it for JAX There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do not see / know of any requirement to have an inner compiled function. |
||
return ( | ||
fgraph_fn | ||
if len(op.fgraph.outputs) > 1 | ||
else lambda *args: itemgetter(0)(fgraph_fn(*args)) | ||
) | ||
Ch0ronomato marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.