Skip to content

Commit 2854c37

Browse files
committed
Modified tests of nlinalg in pytorch implementation
Replaced instances using Blockwise by the Op constructor.
1 parent 4e1391c commit 2854c37

File tree

2 files changed

+11
-20
lines changed

2 files changed

+11
-20
lines changed

pytensor/link/pytorch/dispatch/nlinalg.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ def pytorch_funcify_SVD(op, **kwargs):
2121

2222
def svd(x):
2323
U, S, V = torch.linalg.svd(x, full_matrices=full_matrices)
24-
return U, S, V if compute_uv else S
24+
if compute_uv:
25+
return U, S, V
26+
return S
2527

2628
return svd
2729

tests/link/pytorch/test_nlinalg.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,13 @@ def matrix_test():
2222

2323
@pytest.mark.parametrize(
2424
"func",
25-
(
26-
pt_nla.eig,
27-
pt_nla.eigh,
28-
pt_nla.slogdet,
29-
pytest.param(
30-
pt_nla.inv, marks=pytest.mark.xfail(reason="Blockwise not implemented")
31-
),
32-
pytest.param(
33-
pt_nla.det, marks=pytest.mark.xfail(reason="Blockwise not implemented")
34-
),
35-
),
25+
(pt_nla.eig, pt_nla.eigh, pt_nla.slogdet, pt_nla.MatrixInverse(), pt_nla.Det()),
3626
)
3727
def test_lin_alg_no_params(func, matrix_test):
3828
x, test_value = matrix_test
3929

40-
outs = func(x)
41-
out_fg = FunctionGraph([x], outs)
30+
out = func(x)
31+
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])
4232

4333
def assert_fn(x, y):
4434
np.testing.assert_allclose(x, y, rtol=1e-3)
@@ -58,18 +48,17 @@ def assert_fn(x, y):
5848
def test_qr(mode, matrix_test):
5949
x, test_value = matrix_test
6050
outs = pt_nla.qr(x, mode=mode)
61-
out_fg = FunctionGraph([x], [outs] if mode == "r" else outs)
51+
out_fg = FunctionGraph([x], outs if isinstance(outs, list) else [outs])
6252
compare_pytorch_and_py(out_fg, [test_value])
6353

6454

65-
@pytest.mark.xfail(reason="Blockwise not implemented")
66-
@pytest.mark.parametrize("compute_uv", [False, True])
67-
@pytest.mark.parametrize("full_matrices", [False, True])
55+
@pytest.mark.parametrize("compute_uv", [True, False])
56+
@pytest.mark.parametrize("full_matrices", [True, False])
6857
def test_svd(compute_uv, full_matrices, matrix_test):
6958
x, test_value = matrix_test
7059

71-
outs = pt_nla.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
72-
out_fg = FunctionGraph([x], outs)
60+
out = pt_nla.SVD(full_matrices=full_matrices, compute_uv=compute_uv)(x)
61+
out_fg = FunctionGraph([x], out if isinstance(out, list) else [out])
7362

7463
def assert_fn(x, y):
7564
np.testing.assert_allclose(x, y, rtol=1e-3)

0 commit comments

Comments
 (0)