-
Notifications
You must be signed in to change notification settings - Fork 132
Add initial support for PyTorch backend #764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
27e2526
Add pytorch support for some basic Ops
HarshvirSandhu 629d00b
update variable names, docstrings
HarshvirSandhu 3eceb56
Avoid numpy conversion of torch Tensors
HarshvirSandhu 3cde964
Fix typify and CheckAndRaise
HarshvirSandhu c003aa5
Fix Elemwise Ops
HarshvirSandhu 8dc406e
Fix Scalar Ops
HarshvirSandhu a8f6ddb
Fix ruff-format
HarshvirSandhu 9d535f5
Initial setup for pytorch tests
HarshvirSandhu c5600da
Fix mode parameters for pytorch
HarshvirSandhu 54b6248
Prevent conversion of scalars to numpy
HarshvirSandhu 19454b3
Update TensorConstantSignature and map dtypes to Tensor types
HarshvirSandhu 92d7114
Add tests for basic ops
HarshvirSandhu 5aae0e5
Remove torch from user facing API
HarshvirSandhu 8c174dd
Add function to convert numpy arrays to pytorch tensors
HarshvirSandhu 0977c3a
Avoid copy when converting to tensor
HarshvirSandhu 1c23825
Fix tests
HarshvirSandhu c9195a8
Remove dispatches that are not tested
HarshvirSandhu b07805c
set path for pytorch tests
HarshvirSandhu 9e8d3fc
Remove tensorflow probability from yml
HarshvirSandhu a2d3afa
Add checks for runtime broadcasting
HarshvirSandhu a577a80
Remove IfElse
HarshvirSandhu 499a174
Remove dev notebook
HarshvirSandhu 2826613
Fix check and raise
HarshvirSandhu 62ffcec
Fix compare_pytorch_and_py
HarshvirSandhu acdbba1
Fix DimShuffle
HarshvirSandhu 2519c65
Add tests for Elemwise operations
HarshvirSandhu eb6d5c2
Fix test for CheckAndRaise
HarshvirSandhu 9f02a4f
Remove duplicate function
HarshvirSandhu caf2965
Remove device from pytorch_typify
HarshvirSandhu bf87eb9
Merge branch 'main' of https://github.com/HarshvirSandhu/pytensor int…
HarshvirSandhu 2c27683
Solve merge conflict
HarshvirSandhu c603c6b
Use micromamba for pytorch install
HarshvirSandhu 3f17107
Fix pytorch linker
HarshvirSandhu e850d8d
Fix typify and deepcopy
HarshvirSandhu e682fc4
Parametrize device in all tests
HarshvirSandhu bf4cf92
Install torch with cuda
HarshvirSandhu 899e7f9
Fix test_pytorch_FunctionGraph_once
HarshvirSandhu 04d2935
Remove device argument from test
HarshvirSandhu 8ec7661
remove device from elemwise tests and add assertions
HarshvirSandhu bb7df41
skip tests if cuda is not available
HarshvirSandhu 0441cf2
Fix tests
HarshvirSandhu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# isort: off | ||
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify | ||
|
||
# # Load dispatch specializations | ||
import pytensor.link.pytorch.dispatch.scalar | ||
import pytensor.link.pytorch.dispatch.elemwise | ||
# isort: on |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from functools import singledispatch | ||
|
||
import torch | ||
|
||
from pytensor.compile.ops import DeepCopyOp | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.link.utils import fgraph_to_python | ||
from pytensor.raise_op import CheckAndRaise | ||
|
||
|
||
@singledispatch | ||
def pytorch_typify(data, dtype=None, **kwargs): | ||
r"""Convert instances of PyTensor `Type`\s to PyTorch types.""" | ||
return torch.as_tensor(data, dtype=dtype) | ||
|
||
|
||
@singledispatch | ||
def pytorch_funcify(op, node=None, storage_map=None, **kwargs): | ||
"""Create a PyTorch compatible function from an PyTensor `Op`.""" | ||
raise NotImplementedError( | ||
f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation" | ||
) | ||
|
||
|
||
@pytorch_funcify.register(FunctionGraph) | ||
def pytorch_funcify_FunctionGraph( | ||
fgraph, | ||
node=None, | ||
fgraph_name="pytorch_funcified_fgraph", | ||
**kwargs, | ||
): | ||
return fgraph_to_python( | ||
fgraph, | ||
pytorch_funcify, | ||
type_conversion_fn=pytorch_typify, | ||
fgraph_name=fgraph_name, | ||
**kwargs, | ||
) | ||
|
||
|
||
@pytorch_funcify.register(CheckAndRaise) | ||
def pytorch_funcify_CheckAndRaise(op, **kwargs): | ||
error = op.exc_type | ||
msg = op.msg | ||
|
||
def assert_fn(x, *conditions): | ||
for cond in conditions: | ||
if not cond.item(): | ||
raise error(msg) | ||
return x | ||
|
||
return assert_fn | ||
|
||
|
||
@pytorch_funcify.register(DeepCopyOp) | ||
def pytorch_funcify_DeepCopyOp(op, **kwargs): | ||
def deepcopyop(x): | ||
return x.clone() | ||
|
||
return deepcopyop |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
|
||
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify | ||
from pytensor.tensor.elemwise import DimShuffle, Elemwise | ||
|
||
|
||
@pytorch_funcify.register(Elemwise) | ||
def pytorch_funcify_Elemwise(op, node, **kwargs): | ||
scalar_op = op.scalar_op | ||
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) | ||
|
||
def elemwise_fn(*inputs): | ||
Elemwise._check_runtime_broadcast(node, inputs) | ||
return base_fn(*inputs) | ||
|
||
return elemwise_fn | ||
|
||
|
||
@pytorch_funcify.register(DimShuffle) | ||
def pytorch_funcify_DimShuffle(op, **kwargs): | ||
def dimshuffle(x): | ||
res = torch.permute(x, op.transposition) | ||
|
||
shape = list(res.shape[: len(op.shuffle)]) | ||
|
||
for augm in op.augment: | ||
shape.insert(augm, 1) | ||
|
||
res = torch.reshape(res, shape) | ||
|
||
if not op.inplace: | ||
res = res.clone() | ||
|
||
return res | ||
|
||
return dimshuffle |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch | ||
|
||
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify | ||
from pytensor.scalar.basic import ( | ||
ScalarOp, | ||
) | ||
|
||
|
||
@pytorch_funcify.register(ScalarOp) | ||
def pytorch_funcify_ScalarOp(op, node, **kwargs): | ||
"""Return pytorch function that implements the same computation as the Scalar Op. | ||
|
||
This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does, | ||
even though it's dispatched on the Scalar Op. | ||
""" | ||
|
||
nfunc_spec = getattr(op, "nfunc_spec", None) | ||
if nfunc_spec is None: | ||
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}") | ||
|
||
func_name = nfunc_spec[0] | ||
|
||
pytorch_func = getattr(torch, func_name) | ||
|
||
if len(node.inputs) > op.nfunc_spec[1]: | ||
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function, | ||
# even though the base Op from `func_name` is specified as a binary Op. | ||
# This happens with `Add`, which can work as a `Sum` for multiple scalars. | ||
pytorch_variadic_func = getattr(torch, op.nfunc_variadic, None) | ||
if not pytorch_variadic_func: | ||
raise NotImplementedError( | ||
f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs" | ||
) | ||
|
||
def pytorch_func(*args): | ||
return pytorch_variadic_func( | ||
torch.stack(torch.broadcast_tensors(*args), axis=0), axis=0 | ||
) | ||
|
||
return pytorch_func |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
from typing import Any | ||
|
||
from pytensor.graph.basic import Variable | ||
from pytensor.link.basic import JITLinker | ||
|
||
|
||
class PytorchLinker(JITLinker): | ||
"""A `Linker` that compiles NumPy-based operations using torch.compile.""" | ||
|
||
def input_filter(self, inp: Any) -> Any: | ||
from pytensor.link.pytorch.dispatch import pytorch_typify | ||
|
||
return pytorch_typify(inp) | ||
|
||
def output_filter(self, var: Variable, out: Any) -> Any: | ||
return out.cpu() | ||
|
||
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): | ||
from pytensor.link.pytorch.dispatch import pytorch_funcify | ||
|
||
return pytorch_funcify( | ||
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs | ||
) | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def jit_compile(self, fn): | ||
import torch | ||
|
||
return torch.compile(fn) | ||
|
||
def create_thunk_inputs(self, storage_map): | ||
thunk_inputs = [] | ||
for n in self.fgraph.inputs: | ||
sinput = storage_map[n] | ||
thunk_inputs.append(sinput) | ||
|
||
return thunk_inputs |
Empty file.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.