Skip to content

Commit f075b50

Browse files
committed
Implemented nlinalg in PyTorch
Implemented Ops: - Argmax - Max - Dot - SVD - Det - SLogDet - Eig - Eigh - KroneckerProduct - MatrixInverse - MatrixPinv - QRFul
1 parent a6b9585 commit f075b50

File tree

3 files changed

+344
-0
lines changed

3 files changed

+344
-0
lines changed

pytensor/link/pytorch/dispatch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@
66
import pytensor.link.pytorch.dispatch.elemwise
77
import pytensor.link.pytorch.dispatch.extra_ops
88
import pytensor.link.pytorch.dispatch.sort
9+
import pytensor.link.pytorch.dispatch.nlinalg
910
# isort: on
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import torch
2+
3+
from pytensor.link.pytorch.dispatch import pytorch_funcify
4+
from pytensor.tensor.blas import BatchedDot
5+
from pytensor.tensor.math import Argmax, Dot, Max
6+
from pytensor.tensor.nlinalg import (
7+
SVD,
8+
Det,
9+
Eig,
10+
Eigh,
11+
KroneckerProduct,
12+
MatrixInverse,
13+
MatrixPinv,
14+
QRFull,
15+
SLogDet,
16+
)
17+
18+
19+
@pytorch_funcify.register(SVD)
20+
def pytorch_funcify_SVD(op, **kwargs):
21+
full_matrices = op.full_matrices
22+
compute_uv = op.compute_uv
23+
24+
def svd(x):
25+
U, S, V = torch.linalg.svd(x, full_matrices=full_matrices)
26+
return U, S, V if compute_uv else S
27+
28+
return svd
29+
30+
31+
@pytorch_funcify.register(Det)
32+
def pytorch_funcify_Det(op, **kwargs):
33+
def det(x):
34+
return torch.linalg.det(x)
35+
36+
return det
37+
38+
39+
@pytorch_funcify.register(SLogDet)
40+
def pytorch_funcify_SLogDet(op, **kwargs):
41+
def slogdet(x):
42+
return torch.linalg.slogdet(x)
43+
44+
return slogdet
45+
46+
47+
@pytorch_funcify.register(Eig)
48+
def pytorch_funcify_Eig(op, **kwargs):
49+
def eig(x):
50+
return torch.linalg.eig(x)
51+
52+
return eig
53+
54+
55+
@pytorch_funcify.register(Eigh)
56+
def pytorch_funcify_Eigh(op, **kwargs):
57+
uplo = op.UPLO
58+
59+
def eigh(x, uplo=uplo):
60+
return torch.linalg.eigh(x, UPLO=uplo)
61+
62+
return eigh
63+
64+
65+
@pytorch_funcify.register(MatrixInverse)
66+
def pytorch_funcify_MatrixInverse(op, **kwargs):
67+
def matrix_inverse(x):
68+
return torch.linalg.inv(x)
69+
70+
return matrix_inverse
71+
72+
73+
@pytorch_funcify.register(QRFull)
74+
def pytorch_funcify_QRFull(op, **kwargs):
75+
mode = op.mode
76+
if mode == "raw":
77+
raise NotImplementedError("raw mode not implemented in PyTorch")
78+
79+
def qr_full(x):
80+
Q, R = torch.linalg.qr(x, mode=mode)
81+
if mode == "r":
82+
return R
83+
return Q, R
84+
85+
return qr_full
86+
87+
88+
@pytorch_funcify.register(Dot)
89+
def pytorch_funcify_Dot(op, **kwargs):
90+
def dot(x, y):
91+
return torch.dot(x, y)
92+
93+
return dot
94+
95+
96+
@pytorch_funcify.register(MatrixPinv)
97+
def pytorch_funcify_Pinv(op, **kwargs):
98+
hermitian = op.hermitian
99+
100+
def pinv(x):
101+
return torch.linalg.pinv(x, hermitian=hermitian)
102+
103+
return pinv
104+
105+
106+
@pytorch_funcify.register(BatchedDot)
107+
def pytorch_funcify_BatchedDot(op, **kwargs):
108+
def batched_dot(a, b):
109+
if a.shape[0] != b.shape[0]:
110+
raise TypeError("Shapes must match in the 0-th dimension")
111+
return torch.matmul(a, b)
112+
113+
return batched_dot
114+
115+
116+
@pytorch_funcify.register(KroneckerProduct)
117+
def pytorch_funcify_KroneckerProduct(op, **kwargs):
118+
def _kron(x, y):
119+
return torch.kron(x, y)
120+
121+
return _kron
122+
123+
124+
@pytorch_funcify.register(Max)
125+
def pytorch_funcify_Max(op, **kwargs):
126+
axis = op.axis
127+
128+
def max(x):
129+
if axis is None:
130+
max_res = torch.max(x.flatten())
131+
return max_res
132+
133+
# PyTorch doesn't support multiple axes for max;
134+
# this is a work-around
135+
axes = [int(ax) for ax in axis]
136+
137+
new_dim = torch.prod(torch.tensor([x.size(ax) for ax in axes])).item()
138+
keep_axes = [i for i in range(x.ndim) if i not in axes]
139+
permute_order = keep_axes + axes
140+
permuted_x = x.permute(*permute_order)
141+
kept_shape = permuted_x.shape[: len(keep_axes)]
142+
143+
new_shape = (*kept_shape, new_dim)
144+
reshaped_x = permuted_x.reshape(new_shape)
145+
max_res, _ = torch.max(reshaped_x, dim=-1)
146+
return max_res
147+
148+
return max
149+
150+
151+
@pytorch_funcify.register(Argmax)
152+
def pytorch_funcify_Argmax(op, **kwargs):
153+
axis = op.axis
154+
155+
def argmax(x):
156+
if axis is None:
157+
return torch.argmax(x.view(-1))
158+
159+
# PyTorch doesn't support multiple axes for argmax;
160+
# this is a work-around
161+
axes = [int(ax) for ax in axis]
162+
163+
new_dim = torch.prod(torch.tensor([x.size(ax) for ax in axes])).item()
164+
keep_axes = [i for i in range(x.ndim) if i not in axes]
165+
permute_order = keep_axes + axes
166+
permuted_x = x.permute(*permute_order)
167+
kept_shape = permuted_x.shape[: len(keep_axes)]
168+
169+
new_shape = (*kept_shape, new_dim)
170+
reshaped_x = permuted_x.reshape(new_shape)
171+
return torch.argmax(reshaped_x, dim=-1)
172+
173+
return argmax

tests/link/pytorch/test_nlinalg.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pytensor.compile.function import function
5+
from pytensor.configdefaults import config
6+
from pytensor.graph.fg import FunctionGraph
7+
from pytensor.graph.op import get_test_value
8+
from pytensor.tensor import blas as pt_blas
9+
from pytensor.tensor import nlinalg as pt_nla
10+
from pytensor.tensor.math import argmax, dot, max
11+
from pytensor.tensor.type import matrix, tensor3, vector
12+
from tests.link.pytorch.test_basic import compare_pytorch_and_py
13+
14+
15+
@pytest.fixture
16+
def matrix_test():
17+
rng = np.random.default_rng(213234)
18+
19+
M = rng.normal(size=(3, 3))
20+
test_value = M.dot(M.T).astype(config.floatX)
21+
22+
x = matrix("x")
23+
return (x, test_value)
24+
25+
26+
def test_BatchedDot():
27+
# tensor3 . tensor3
28+
a = tensor3("a")
29+
a.tag.test_value = (
30+
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
31+
)
32+
b = tensor3("b")
33+
b.tag.test_value = (
34+
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
35+
)
36+
out = pt_blas.BatchedDot()(a, b)
37+
fgraph = FunctionGraph([a, b], [out])
38+
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
39+
40+
# A dimension mismatch should raise a TypeError for compatibility
41+
inputs = [get_test_value(a)[:-1], get_test_value(b)]
42+
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode="PYTORCH")
43+
with pytest.raises(TypeError):
44+
pytensor_jax_fn(*inputs)
45+
46+
47+
@pytest.mark.parametrize(
48+
"func",
49+
(
50+
pt_nla.eig,
51+
pt_nla.eigh,
52+
pt_nla.slogdet,
53+
pytest.param(
54+
pt_nla.inv, marks=pytest.mark.xfail(reason="Blockwise not implemented")
55+
),
56+
pytest.param(
57+
pt_nla.det, marks=pytest.mark.xfail(reason="Blockwise not implemented")
58+
),
59+
),
60+
)
61+
def test_lin_alg_no_params(func, matrix_test):
62+
x, test_value = matrix_test
63+
64+
outs = func(x)
65+
out_fg = FunctionGraph([x], outs)
66+
67+
def assert_fn(x, y):
68+
np.testing.assert_allclose(x, y, rtol=1e-3)
69+
70+
compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn)
71+
72+
73+
@pytest.mark.parametrize(
74+
"mode",
75+
(
76+
"complete",
77+
"reduced",
78+
"r",
79+
pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)),
80+
),
81+
)
82+
def test_qr(mode, matrix_test):
83+
x, test_value = matrix_test
84+
outs = pt_nla.qr(x, mode=mode)
85+
out_fg = FunctionGraph([x], [outs] if mode == "r" else outs)
86+
compare_pytorch_and_py(out_fg, [test_value])
87+
88+
89+
@pytest.mark.xfail(reason="Blockwise not implemented")
90+
@pytest.mark.parametrize("compute_uv", [False, True])
91+
@pytest.mark.parametrize("full_matrices", [False, True])
92+
def test_svd(compute_uv, full_matrices, matrix_test):
93+
x, test_value = matrix_test
94+
95+
outs = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
96+
out_fg = FunctionGraph([x], outs)
97+
98+
def assert_fn(x, y):
99+
np.testing.assert_allclose(x, y, rtol=1e-3)
100+
101+
compare_pytorch_and_py(out_fg, [test_value], assert_fn=assert_fn)
102+
103+
104+
def test_pinv():
105+
x = matrix("x")
106+
x_inv = pt_nla.pinv(x)
107+
108+
fgraph = FunctionGraph([x], [x_inv])
109+
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
110+
compare_pytorch_and_py(fgraph, [x_np])
111+
112+
113+
@pytest.mark.parametrize("hermitian", [False, True])
114+
def test_pinv_hermitian(hermitian):
115+
A = matrix("A", dtype="complex128")
116+
A_h_test = np.c_[[3, 3 + 2j], [3 - 2j, 2]]
117+
A_not_h_test = A_h_test + 0 + 1j
118+
119+
A_inv = pt_nla.pinv(A, hermitian=hermitian)
120+
torch_fn = function([A], A_inv, mode="PYTORCH")
121+
122+
assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=False))
123+
assert np.allclose(torch_fn(A_h_test), np.linalg.pinv(A_h_test, hermitian=True))
124+
125+
assert (
126+
np.allclose(
127+
torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=False)
128+
)
129+
is not hermitian
130+
)
131+
132+
assert (
133+
np.allclose(
134+
torch_fn(A_not_h_test), np.linalg.pinv(A_not_h_test, hermitian=True)
135+
)
136+
is hermitian
137+
)
138+
139+
140+
def test_kron():
141+
x = matrix("x")
142+
y = matrix("y")
143+
z = pt_nla.kron(x, y)
144+
145+
fgraph = FunctionGraph([x, y], [z])
146+
x_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
147+
y_np = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
148+
149+
compare_pytorch_and_py(fgraph, [x_np, y_np])
150+
151+
152+
@pytest.mark.parametrize("func", (max, argmax))
153+
@pytest.mark.parametrize("axis", [None, [0], [0, 1], [0, 2], [0, 1, 2]])
154+
def test_max_and_argmax(func, axis):
155+
x = tensor3("x")
156+
np.random.seed(42)
157+
test_value = np.random.randint(0, 20, (4, 3, 2))
158+
159+
out = func(x, axis=axis)
160+
out_fg = FunctionGraph([x], [out])
161+
compare_pytorch_and_py(out_fg, [test_value])
162+
163+
164+
def test_dot():
165+
x = vector("x")
166+
test_value = np.array([1, 2, 3])
167+
168+
out = dot(x, x)
169+
out_fg = FunctionGraph([x], [out])
170+
compare_pytorch_and_py(out_fg, [test_value])

0 commit comments

Comments
 (0)