Skip to content

Add initial support for PyTorch backend #764

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
27e2526
Add pytorch support for some basic Ops
HarshvirSandhu May 13, 2024
629d00b
update variable names, docstrings
HarshvirSandhu May 13, 2024
3eceb56
Avoid numpy conversion of torch Tensors
HarshvirSandhu May 17, 2024
3cde964
Fix typify and CheckAndRaise
HarshvirSandhu May 17, 2024
c003aa5
Fix Elemwise Ops
HarshvirSandhu May 17, 2024
8dc406e
Fix Scalar Ops
HarshvirSandhu May 17, 2024
a8f6ddb
Fix ruff-format
HarshvirSandhu May 17, 2024
9d535f5
Initial setup for pytorch tests
HarshvirSandhu May 23, 2024
c5600da
Fix mode parameters for pytorch
HarshvirSandhu May 23, 2024
54b6248
Prevent conversion of scalars to numpy
HarshvirSandhu May 23, 2024
19454b3
Update TensorConstantSignature and map dtypes to Tensor types
HarshvirSandhu May 23, 2024
92d7114
Add tests for basic ops
HarshvirSandhu May 23, 2024
5aae0e5
Remove torch from user facing API
HarshvirSandhu May 29, 2024
8c174dd
Add function to convert numpy arrays to pytorch tensors
HarshvirSandhu May 29, 2024
0977c3a
Avoid copy when converting to tensor
HarshvirSandhu May 29, 2024
1c23825
Fix tests
HarshvirSandhu May 29, 2024
c9195a8
Remove dispatches that are not tested
HarshvirSandhu May 31, 2024
b07805c
set path for pytorch tests
HarshvirSandhu May 31, 2024
9e8d3fc
Remove tensorflow probability from yml
HarshvirSandhu Jun 4, 2024
a2d3afa
Add checks for runtime broadcasting
HarshvirSandhu Jun 4, 2024
a577a80
Remove IfElse
HarshvirSandhu Jun 4, 2024
499a174
Remove dev notebook
HarshvirSandhu Jun 12, 2024
2826613
Fix check and raise
HarshvirSandhu Jun 12, 2024
62ffcec
Fix compare_pytorch_and_py
HarshvirSandhu Jun 12, 2024
acdbba1
Fix DimShuffle
HarshvirSandhu Jun 12, 2024
2519c65
Add tests for Elemwise operations
HarshvirSandhu Jun 12, 2024
eb6d5c2
Fix test for CheckAndRaise
HarshvirSandhu Jun 14, 2024
9f02a4f
Remove duplicate function
HarshvirSandhu Jun 14, 2024
caf2965
Remove device from pytorch_typify
HarshvirSandhu Jun 15, 2024
bf87eb9
Merge branch 'main' of https://github.com/HarshvirSandhu/pytensor int…
HarshvirSandhu Jun 15, 2024
2c27683
Solve merge conflict
HarshvirSandhu Jun 15, 2024
c603c6b
Use micromamba for pytorch install
HarshvirSandhu Jun 15, 2024
3f17107
Fix pytorch linker
HarshvirSandhu Jun 16, 2024
e850d8d
Fix typify and deepcopy
HarshvirSandhu Jun 16, 2024
e682fc4
Parametrize device in all tests
HarshvirSandhu Jun 16, 2024
bf4cf92
Install torch with cuda
HarshvirSandhu Jun 16, 2024
899e7f9
Fix test_pytorch_FunctionGraph_once
HarshvirSandhu Jun 16, 2024
04d2935
Remove device argument from test
HarshvirSandhu Jun 16, 2024
8ec7661
remove device from elemwise tests and add assertions
HarshvirSandhu Jun 17, 2024
bb7df41
skip tests if cuda is not available
HarshvirSandhu Jun 17, 2024
0441cf2
Fix tests
HarshvirSandhu Jun 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 4 additions & 15 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from functools import singledispatch

import torch
Expand All @@ -18,7 +17,9 @@ def pytorch_typify(data, dtype=None, **kwargs):
@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
raise NotImplementedError(f"No PyTorch conversion for the given `Op`: {op}")
raise NotImplementedError(
f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation"
)


@pytorch_funcify.register(FunctionGraph)
Expand Down Expand Up @@ -51,21 +52,9 @@ def assert_fn(x, *conditions):
return assert_fn


def pytorch_safe_copy(x):
# Cannot use try-except due to: https://github.com/pytorch/pytorch/issues/93720

if hasattr(x, "clone"):
res = torch.clone(x)
else:
warnings.warn(f"Object has no `clone` method: {x}")
res = x

return res


@pytorch_funcify.register(DeepCopyOp)
def pytorch_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return pytorch_safe_copy(x)
return x.clone()

return deepcopyop
68 changes: 4 additions & 64 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,22 @@
from typing import Any

from numpy.random import Generator, RandomState

from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.graph.basic import Variable
from pytensor.link.basic import JITLinker


class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""

def input_filter(self, inp: Any) -> Any:
import torch

from pytensor.link.pytorch.dispatch import pytorch_typify

if isinstance(inp, torch.Tensor):
return inp
return pytorch_typify(inp)

def output_filter(self, var: Variable, out: Any) -> Any:
return out.cpu()

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.random.type import RandomType

shared_rng_inputs = [
inp
for inp in fgraph.inputs
if (isinstance(inp, SharedVariable) and isinstance(inp.type, RandomType))
]

# Replace any shared RNG inputs so that their values can be updated in place
# without affecting the original RNG container. This is necessary because
# JAX does not accept RandomState/Generators as inputs, and they will have to
# be tipyfied
if shared_rng_inputs:
# warnings.warn(
# f"The RandomType SharedVariables {shared_rng_inputs} will not be used "
# f"in the compiled JAX graph. Instead a copy will be used.",
# UserWarning,
# )
new_shared_rng_inputs = [
shared(inp.get_value(borrow=False)) for inp in shared_rng_inputs
]

fgraph.replace_all(
zip(shared_rng_inputs, new_shared_rng_inputs),
import_missing=True,
reason="PytorchLinker.fgraph_convert",
)

for old_inp, new_inp in zip(shared_rng_inputs, new_shared_rng_inputs):
new_inp_storage = [new_inp.get_value(borrow=True)]
storage_map[new_inp] = new_inp_storage
old_inp_storage = storage_map.pop(old_inp)
# Find index of old_inp_storage in input_storage
for input_storage_idx, input_storage_item in enumerate(input_storage):
# We have to establish equality based on identity because input_storage may contain numpy arrays
if input_storage_item is old_inp_storage:
break
else: # no break
raise ValueError()
input_storage[input_storage_idx] = new_inp_storage
# We need to change the order of the inputs of the FunctionGraph
# so that the new input is in the same position as to old one,
# to align with the storage_map. We hope this is safe!
old_inp_fgrap_index = fgraph.inputs.index(old_inp)
fgraph.remove_input(
old_inp_fgrap_index,
reason="PytorchLinker.fgraph_convert",
)
fgraph.inputs.remove(new_inp)
fgraph.inputs.insert(old_inp_fgrap_index, new_inp)

return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
Expand All @@ -81,16 +28,9 @@ def jit_compile(self, fn):
return torch.compile(fn)

def create_thunk_inputs(self, storage_map):
from pytensor.link.pytorch.dispatch import pytorch_typify

thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], RandomState | Generator):
new_value = pytorch_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
)
sinput[0] = new_value
thunk_inputs.append(sinput)

return thunk_inputs
152 changes: 82 additions & 70 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,104 +72,116 @@ def compare_pytorch_and_py(
return pytensor_torch_fn, pytorch_res


def test_pytorch_FunctionGraph_once():
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_pytorch_FunctionGraph_once(device):
"""Make sure that an output is only computed once when it's referenced multiple times."""
from pytensor.link.pytorch.dispatch import pytorch_funcify

x = vector("x")
y = vector("y")
with torch.device(device):
x = vector("x")
y = vector("y")

class TestOp(Op):
def __init__(self):
self.called = 0
class TestOp(Op):
def __init__(self):
self.called = 0

def make_node(self, *args):
return Apply(self, list(args), [x.type() for x in args])
def make_node(self, *args):
return Apply(self, list(args), [x.type() for x in args])

def perform(self, inputs, outputs):
for i, inp in enumerate(inputs):
outputs[i][0] = inp[0]
def perform(self, inputs, outputs):
for i, inp in enumerate(inputs):
outputs[i][0] = inp[0]

@pytorch_funcify.register(TestOp)
def pytorch_funcify_TestOp(op, **kwargs):
def func(*args, op=op):
op.called += 1
return list(args)
@pytorch_funcify.register(TestOp)
def pytorch_funcify_TestOp(op, **kwargs):
def func(*args, op=op):
op.called += 1
return list(args)

return func
return func

op1 = TestOp()
op2 = TestOp()
op1 = TestOp()
op2 = TestOp()

q, r = op1(x, y)
outs = op2(q + r, q + r)
q, r = op1(x, y)
outs = op2(q + r, q + r)

out_fg = FunctionGraph([x, y], outs, clone=False)
assert len(out_fg.outputs) == 2
out_fg = FunctionGraph([x, y], outs, clone=False)
assert len(out_fg.outputs) == 2

out_torch = pytorch_funcify(out_fg)
out_torch = pytorch_funcify(out_fg)

x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX))
y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX))
x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX))
y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX))

res = out_torch(x_val, y_val)
assert len(res) == 2
assert op1.called == 1
assert op2.called == 1
res = out_torch(x_val, y_val)
assert len(res) == 2
assert op1.called == 1
assert op2.called == 1

res = out_torch(x_val, y_val)
assert len(res) == 2
assert op1.called == 2
assert op2.called == 2
res = out_torch(x_val, y_val)
assert len(res) == 2
assert op1.called == 2
assert op2.called == 2


def test_shared():
a = shared(np.array([1, 2, 3], dtype=config.floatX))
pytensor_torch_fn = function([], a, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_shared(device):
with torch.device(device):
a = shared(np.array([1, 2, 3], dtype=config.floatX))
pytensor_torch_fn = function([], a, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()

assert isinstance(pytorch_res, torch.Tensor)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value())
assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value())

pytensor_torch_fn = function([], a * 2, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()
pytensor_torch_fn = function([], a * 2, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()

assert isinstance(pytorch_res, torch.Tensor)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2)
assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2)

new_a_value = np.array([3, 4, 5], dtype=config.floatX)
a.set_value(new_a_value)
new_a_value = np.array([3, 4, 5], dtype=config.floatX)
a.set_value(new_a_value)

pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2)
pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2)


@pytest.mark.xfail(reason="Shared variables will be handled in later PRs")
def test_shared_updates():
a = shared(0)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_shared_updates(device):
with torch.device(device):
a = shared(0)

pytensor_torch_fn = function([], a, updates={a: a + 1}, mode="PYTORCH")
res1, res2 = pytensor_torch_fn(), pytensor_torch_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2
pytensor_torch_fn = function([], a, updates={a: a + 1}, mode="PYTORCH")
res1, res2 = pytensor_torch_fn(), pytensor_torch_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2
assert isinstance(a.get_value(), np.ndarray)

a.set_value(5)
res1, res2 = pytensor_torch_fn(), pytensor_torch_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7
a.set_value(5)
res1, res2 = pytensor_torch_fn(), pytensor_torch_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7
assert isinstance(a.get_value(), np.ndarray)


def test_pytorch_checkandraise():
check_and_raise = CheckAndRaise(AssertionError, "testing")
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_pytorch_checkandraise(device):
with torch.device(device):
check_and_raise = CheckAndRaise(AssertionError, "testing")

x = scalar("x")
conds = (x > 0, x > 3)
y = check_and_raise(x, *conds)
x = scalar("x")
conds = (x > 0, x > 3)
y = check_and_raise(x, *conds)

y_fn = function([x], y, mode="PYTORCH")
y_fn = function([x], y, mode="PYTORCH")

with pytest.raises(AssertionError, match="testing"):
y_fn(0.0)
with pytest.raises(AssertionError, match="testing"):
y_fn(0.0)
assert y_fn(4).item() == 4
Loading
Loading