Skip to content

Implement nlinalg Ops in PyTorch #920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.shape
import pytensor.link.pytorch.dispatch.sort

import pytensor.link.pytorch.dispatch.nlinalg
# isort: on
103 changes: 103 additions & 0 deletions pytensor/link/pytorch/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import torch

from pytensor.link.pytorch.dispatch import pytorch_funcify
from pytensor.tensor.nlinalg import (
SVD,
Det,
Eig,
Eigh,
KroneckerProduct,
MatrixInverse,
MatrixPinv,
QRFull,
SLogDet,
)


@pytorch_funcify.register(SVD)
def pytorch_funcify_SVD(op, **kwargs):
full_matrices = op.full_matrices
compute_uv = op.compute_uv

def svd(x):
U, S, V = torch.linalg.svd(x, full_matrices=full_matrices)

Check warning on line 23 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L23

Added line #L23 was not covered by tests
if compute_uv:
return U, S, V
return S

Check warning on line 26 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L25-L26

Added lines #L25 - L26 were not covered by tests

return svd


@pytorch_funcify.register(Det)
def pytorch_funcify_Det(op, **kwargs):
def det(x):
return torch.linalg.det(x)

Check warning on line 34 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L34

Added line #L34 was not covered by tests

return det


@pytorch_funcify.register(SLogDet)
def pytorch_funcify_SLogDet(op, **kwargs):
def slogdet(x):
return torch.linalg.slogdet(x)

Check warning on line 42 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L42

Added line #L42 was not covered by tests

return slogdet


@pytorch_funcify.register(Eig)
def pytorch_funcify_Eig(op, **kwargs):
def eig(x):
return torch.linalg.eig(x)

Check warning on line 50 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L50

Added line #L50 was not covered by tests

return eig


@pytorch_funcify.register(Eigh)
def pytorch_funcify_Eigh(op, **kwargs):
uplo = op.UPLO

def eigh(x, uplo=uplo):
return torch.linalg.eigh(x, UPLO=uplo)

Check warning on line 60 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L60

Added line #L60 was not covered by tests

return eigh


@pytorch_funcify.register(MatrixInverse)
def pytorch_funcify_MatrixInverse(op, **kwargs):
def matrix_inverse(x):
return torch.linalg.inv(x)

Check warning on line 68 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L68

Added line #L68 was not covered by tests

return matrix_inverse


@pytorch_funcify.register(QRFull)
def pytorch_funcify_QRFull(op, **kwargs):
mode = op.mode
if mode == "raw":
raise NotImplementedError("raw mode not implemented in PyTorch")

def qr_full(x):
Q, R = torch.linalg.qr(x, mode=mode)

Check warning on line 80 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L80

Added line #L80 was not covered by tests
if mode == "r":
return R
return Q, R

Check warning on line 83 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L82-L83

Added lines #L82 - L83 were not covered by tests

return qr_full


@pytorch_funcify.register(MatrixPinv)
def pytorch_funcify_Pinv(op, **kwargs):
hermitian = op.hermitian

def pinv(x):
return torch.linalg.pinv(x, hermitian=hermitian)

Check warning on line 93 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L93

Added line #L93 was not covered by tests

return pinv


@pytorch_funcify.register(KroneckerProduct)
def pytorch_funcify_KroneckerProduct(op, **kwargs):
def _kron(x, y):
return torch.kron(x, y)

Check warning on line 101 in pytensor/link/pytorch/dispatch/nlinalg.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/nlinalg.py#L101

Added line #L101 was not covered by tests

return _kron
111 changes: 111 additions & 0 deletions tests/link/pytorch/test_nlinalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import numpy as np
import pytest

from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import nlinalg as pt_nla
from pytensor.tensor.type import matrix
from tests.link.pytorch.test_basic import compare_pytorch_and_py


@pytest.fixture
def matrix_test():
rng = np.random.default_rng(213234)

M = rng.normal(size=(3, 3))
test_value = M.dot(M.T).astype(config.floatX)

x = matrix("x")
return (x, test_value)


@pytest.mark.parametrize(
"func",
(pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det),
)
def test_lin_alg_no_params(func, matrix_test):
x, test_value = matrix_test

out = func(x)
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])

def assert_fn(x, y):
np.testing.assert_allclose(x, y, rtol=1e-3)

compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn)


@pytest.mark.parametrize(
"mode",
(
"complete",
"reduced",
"r",
pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)),
),
)
def test_qr(mode, matrix_test):
x, test_value = matrix_test
outs = pt_nla.qr(x, mode=mode)
out_fg = FunctionGraph([x], outs if isinstance(outs, list) else [outs])
compare_pytorch_and_py(out_fg, [test_value])


@pytest.mark.parametrize("compute_uv", [True, False])
@pytest.mark.parametrize("full_matrices", [True, False])
def test_svd(compute_uv, full_matrices, matrix_test):
x, test_value = matrix_test

out = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])

compare_pytorch_and_py(out_fg, [test_value])


def test_pinv():
x = matrix("x")
x_inv = pt_nla.pinv(x)

fgraph = FunctionGraph([x], [x_inv])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
compare_pytorch_and_py(fgraph, [x_np])


@pytest.mark.parametrize("hermitian", [False, True])
def test_pinv_hermitian(hermitian):
A = matrix("A", dtype="complex128")
A_h_test = np.c_[[3, 3 + 2j], [3 - 2j, 2]]
A_not_h_test = A_h_test + 0 + 1j

A_inv = pt_nla.pinv(A, hermitian=hermitian)
torch_fn = function([A], A_inv, mode="PYTORCH")

assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False))
assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True))

assert (
np.allclose(
torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False)
)
is not hermitian
)

assert (
np.allclose(
torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True)
)
is hermitian
)


def test_kron():
x = matrix("x")
y = matrix("y")
z = pt_nla.kron(x, y)

fgraph = FunctionGraph([x, y], [z])
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)

compare_pytorch_and_py(fgraph, [x_np, y_np])