diff --git a/pytensor/link/__init__.py b/pytensor/link/__init__.py index e69de29bb2..c8c236a854 100644 --- a/pytensor/link/__init__.py +++ b/pytensor/link/__init__.py @@ -0,0 +1 @@ +from pytensor.link.pytorch.linker import PytorchLinker diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index 017e57df64..fa47908d74 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -2,9 +2,12 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify # # Load dispatch specializations +import pytensor.link.pytorch.dispatch.blas import pytensor.link.pytorch.dispatch.scalar import pytensor.link.pytorch.dispatch.elemwise +import pytensor.link.pytorch.dispatch.math import pytensor.link.pytorch.dispatch.extra_ops -import pytensor.link.pytorch.dispatch.sort import pytensor.link.pytorch.dispatch.shape +import pytensor.link.pytorch.dispatch.sort + # isort: on diff --git a/pytensor/link/pytorch/dispatch/blas.py b/pytensor/link/pytorch/dispatch/blas.py new file mode 100644 index 0000000000..5691551998 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/blas.py @@ -0,0 +1,14 @@ +import torch + +from pytensor.link.pytorch.dispatch import pytorch_funcify +from pytensor.tensor.blas import BatchedDot + + +@pytorch_funcify.register(BatchedDot) +def pytorch_funcify_BatchedDot(op, **kwargs): + def batched_dot(a, b): + if a.shape[0] != b.shape[0]: + raise TypeError("Shapes must match in the 0-th dimension") + return torch.bmm(a, b) + + return batched_dot diff --git a/pytensor/link/pytorch/dispatch/math.py b/pytensor/link/pytorch/dispatch/math.py new file mode 100644 index 0000000000..4275424f0a --- /dev/null +++ b/pytensor/link/pytorch/dispatch/math.py @@ -0,0 +1,12 @@ +import torch + +from pytensor.link.pytorch.dispatch import pytorch_funcify +from pytensor.tensor.math import Dot + + +@pytorch_funcify.register(Dot) +def pytorch_funcify_Dot(op, **kwargs): + def dot(x, y): + return torch.matmul(x, y) + + return dot diff --git a/tests/link/pytorch/test_blas.py b/tests/link/pytorch/test_blas.py new file mode 100644 index 0000000000..35f7dd7b6a --- /dev/null +++ b/tests/link/pytorch/test_blas.py @@ -0,0 +1,24 @@ +import numpy as np +import pytest + +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor import blas as pt_blas +from pytensor.tensor.type import tensor3 +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_BatchedDot(): + # tensor3 . tensor3 + a = tensor3("a") + a_test = np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) + b = tensor3("b") + b_test = np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) + out = pt_blas.BatchedDot()(a, b) + fgraph = FunctionGraph([a, b], [out]) + pytensor_pytorch_fn, _ = compare_pytorch_and_py(fgraph, [a_test, b_test]) + + # A dimension mismatch should raise a TypeError for compatibility + inputs = [a_test[:-1], b_test] + with pytest.raises(TypeError): + pytensor_pytorch_fn(*inputs) diff --git a/tests/link/pytorch/test_math.py b/tests/link/pytorch/test_math.py new file mode 100644 index 0000000000..affca4ad32 --- /dev/null +++ b/tests/link/pytorch/test_math.py @@ -0,0 +1,29 @@ +import numpy as np + +from pytensor.configdefaults import config +from pytensor.graph.fg import FunctionGraph +from pytensor.tensor.type import matrix, scalar, vector +from tests.link.pytorch.test_basic import compare_pytorch_and_py + + +def test_pytorch_dot(): + y = vector("y") + y_test = np.r_[1.0, 2.0].astype(config.floatX) + x = vector("x") + x_test = np.r_[3.0, 4.0].astype(config.floatX) + A = matrix("A") + A_test = np.array([[6, 3], [3, 0]], dtype=config.floatX) + alpha = scalar("alpha") + alpha_test = np.array(3.0, dtype=config.floatX) + beta = scalar("beta") + beta_test = np.array(5.0, dtype=config.floatX) + + # 2D * 2D + out = A.dot(A * alpha) + beta * A + fgraph = FunctionGraph([A, alpha, beta], [out]) + compare_pytorch_and_py(fgraph, [A_test, alpha_test, beta_test]) + + # 1D * 2D and 1D * 1D + out = y.dot(alpha * A).dot(x) + beta * y + fgraph = FunctionGraph([y, x, A, alpha, beta], [out]) + compare_pytorch_and_py(fgraph, [y_test, x_test, A_test, alpha_test, beta_test])