Skip to content

Commit 58fec45

Browse files
authored
Implement nlinalg Ops in PyTorch (#920)
1 parent 367351f commit 58fec45

File tree

3 files changed

+215
-1
lines changed

3 files changed

+215
-1
lines changed

pytensor/link/pytorch/dispatch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
import pytensor.link.pytorch.dispatch.extra_ops
1010
import pytensor.link.pytorch.dispatch.shape
1111
import pytensor.link.pytorch.dispatch.sort
12-
12+
import pytensor.link.pytorch.dispatch.nlinalg
1313
# isort: on
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch import pytorch_funcify
4+
from pytensor.tensor.nlinalg import (
5+
SVD,
6+
Det,
7+
Eig,
8+
Eigh,
9+
KroneckerProduct,
10+
MatrixInverse,
11+
MatrixPinv,
12+
QRFull,
13+
SLogDet,
14+
)
15+
16+
17+
@pytorch_funcify.register(SVD)
18+
def pytorch_funcify_SVD(op, **kwargs):
19+
full_matrices = op.full_matrices
20+
compute_uv = op.compute_uv
21+
22+
def svd(x):
23+
U, S, V = torch.linalg.svd(x, full_matrices=full_matrices)
24+
if compute_uv:
25+
return U, S, V
26+
return S
27+
28+
return svd
29+
30+
31+
@pytorch_funcify.register(Det)
32+
def pytorch_funcify_Det(op, **kwargs):
33+
def det(x):
34+
return torch.linalg.det(x)
35+
36+
return det
37+
38+
39+
@pytorch_funcify.register(SLogDet)
40+
def pytorch_funcify_SLogDet(op, **kwargs):
41+
def slogdet(x):
42+
return torch.linalg.slogdet(x)
43+
44+
return slogdet
45+
46+
47+
@pytorch_funcify.register(Eig)
48+
def pytorch_funcify_Eig(op, **kwargs):
49+
def eig(x):
50+
return torch.linalg.eig(x)
51+
52+
return eig
53+
54+
55+
@pytorch_funcify.register(Eigh)
56+
def pytorch_funcify_Eigh(op, **kwargs):
57+
uplo = op.UPLO
58+
59+
def eigh(x, uplo=uplo):
60+
return torch.linalg.eigh(x, UPLO=uplo)
61+
62+
return eigh
63+
64+
65+
@pytorch_funcify.register(MatrixInverse)
66+
def pytorch_funcify_MatrixInverse(op, **kwargs):
67+
def matrix_inverse(x):
68+
return torch.linalg.inv(x)
69+
70+
return matrix_inverse
71+
72+
73+
@pytorch_funcify.register(QRFull)
74+
def pytorch_funcify_QRFull(op, **kwargs):
75+
mode = op.mode
76+
if mode == "raw":
77+
raise NotImplementedError("raw mode not implemented in PyTorch")
78+
79+
def qr_full(x):
80+
Q, R = torch.linalg.qr(x, mode=mode)
81+
if mode == "r":
82+
return R
83+
return Q, R
84+
85+
return qr_full
86+
87+
88+
@pytorch_funcify.register(MatrixPinv)
89+
def pytorch_funcify_Pinv(op, **kwargs):
90+
hermitian = op.hermitian
91+
92+
def pinv(x):
93+
return torch.linalg.pinv(x, hermitian=hermitian)
94+
95+
return pinv
96+
97+
98+
@pytorch_funcify.register(KroneckerProduct)
99+
def pytorch_funcify_KroneckerProduct(op, **kwargs):
100+
def _kron(x, y):
101+
return torch.kron(x, y)
102+
103+
return _kron

tests/link/pytorch/test_nlinalg.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.compile.function import function
5+
from pytensor.configdefaults import config
6+
from pytensor.graph.fg import FunctionGraph
7+
from pytensor.tensor import nlinalg as pt_nla
8+
from pytensor.tensor.type import matrix
9+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
10+
11+
12+
@pytest.fixture
13+
def matrix_test():
14+
rng = np.random.default_rng(213234)
15+
16+
M = rng.normal(size=(3, 3))
17+
test_value = M.dot(M.T).astype(config.floatX)
18+
19+
x = matrix("x")
20+
return (x, test_value)
21+
22+
23+
@pytest.mark.parametrize(
24+
"func",
25+
(pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.inv, pt_nla.det),
26+
)
27+
def test_lin_alg_no_params(func, matrix_test):
28+
x, test_value = matrix_test
29+
30+
out = func(x)
31+
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])
32+
33+
def assert_fn(x, y):
34+
np.testing.assert_allclose(x, y, rtol=1e-3)
35+
36+
compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn)
37+
38+
39+
@pytest.mark.parametrize(
40+
"mode",
41+
(
42+
"complete",
43+
"reduced",
44+
"r",
45+
pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)),
46+
),
47+
)
48+
def test_qr(mode, matrix_test):
49+
x, test_value = matrix_test
50+
outs = pt_nla.qr(x, mode=mode)
51+
out_fg = FunctionGraph([x], outs if isinstance(outs, list) else [outs])
52+
compare_pytorch_and_py(out_fg, [test_value])
53+
54+
55+
@pytest.mark.parametrize("compute_uv", [True, False])
56+
@pytest.mark.parametrize("full_matrices", [True, False])
57+
def test_svd(compute_uv, full_matrices, matrix_test):
58+
x, test_value = matrix_test
59+
60+
out = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
61+
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])
62+
63+
compare_pytorch_and_py(out_fg, [test_value])
64+
65+
66+
def test_pinv():
67+
x = matrix("x")
68+
x_inv = pt_nla.pinv(x)
69+
70+
fgraph = FunctionGraph([x], [x_inv])
71+
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
72+
compare_pytorch_and_py(fgraph, [x_np])
73+
74+
75+
@pytest.mark.parametrize("hermitian", [False, True])
76+
def test_pinv_hermitian(hermitian):
77+
A = matrix("A", dtype="complex128")
78+
A_h_test = np.c_[[3, 3 + 2j], [3 - 2j, 2]]
79+
A_not_h_test = A_h_test + 0 + 1j
80+
81+
A_inv = pt_nla.pinv(A, hermitian=hermitian)
82+
torch_fn = function([A], A_inv, mode="PYTORCH")
83+
84+
assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False))
85+
assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True))
86+
87+
assert (
88+
np.allclose(
89+
torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False)
90+
)
91+
is not hermitian
92+
)
93+
94+
assert (
95+
np.allclose(
96+
torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True)
97+
)
98+
is hermitian
99+
)
100+
101+
102+
def test_kron():
103+
x = matrix("x")
104+
y = matrix("y")
105+
z = pt_nla.kron(x, y)
106+
107+
fgraph = FunctionGraph([x, y], [z])
108+
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
109+
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
110+
111+
compare_pytorch_and_py(fgraph, [x_np, y_np])

0 commit comments

Comments
 (0)