@@ -22,23 +22,13 @@ def matrix_test():
22
22
23
23
@pytest .mark .parametrize (
24
24
"func" ,
25
- (
26
- pt_nla .eig ,
27
- pt_nla .eigh ,
28
- pt_nla .slogdet ,
29
- pytest .param (
30
- pt_nla .inv , marks = pytest .mark .xfail (reason = "Blockwise not implemented" )
31
- ),
32
- pytest .param (
33
- pt_nla .det , marks = pytest .mark .xfail (reason = "Blockwise not implemented" )
34
- ),
35
- ),
25
+ (pt_nla .eig , pt_nla .eigh , pt_nla .slogdet , pt_nla .MatrixInverse (), pt_nla .Det ()),
36
26
)
37
27
def test_lin_alg_no_params (func , matrix_test ):
38
28
x , test_value = matrix_test
39
29
40
- outs = func (x )
41
- out_fg = FunctionGraph ([x ], outs )
30
+ out = func (x )
31
+ out_fg = FunctionGraph ([x ], out if isinstance ( out , list ) else [ out ] )
42
32
43
33
def assert_fn (x , y ):
44
34
np .testing .assert_allclose (x , y , rtol = 1e-3 )
@@ -58,18 +48,17 @@ def assert_fn(x, y):
58
48
def test_qr (mode , matrix_test ):
59
49
x , test_value = matrix_test
60
50
outs = pt_nla .qr (x , mode = mode )
61
- out_fg = FunctionGraph ([x ], [ outs ] if mode == "r" else outs )
51
+ out_fg = FunctionGraph ([x ], outs if isinstance ( outs , list ) else [ outs ] )
62
52
compare_pytorch_and_py (out_fg , [test_value ])
63
53
64
54
65
- @pytest .mark .xfail (reason = "Blockwise not implemented" )
66
- @pytest .mark .parametrize ("compute_uv" , [False , True ])
67
- @pytest .mark .parametrize ("full_matrices" , [False , True ])
55
+ @pytest .mark .parametrize ("compute_uv" , [True , False ])
56
+ @pytest .mark .parametrize ("full_matrices" , [True , False ])
68
57
def test_svd (compute_uv , full_matrices , matrix_test ):
69
58
x , test_value = matrix_test
70
59
71
- outs = pt_nla .svd ( x , full_matrices = full_matrices , compute_uv = compute_uv )
72
- out_fg = FunctionGraph ([x ], outs )
60
+ out = pt_nla .SVD ( full_matrices = full_matrices , compute_uv = compute_uv )( x )
61
+ out_fg = FunctionGraph ([x ], out if isinstance ( out , list ) else [ out ] )
73
62
74
63
def assert_fn (x , y ):
75
64
np .testing .assert_allclose (x , y , rtol = 1e-3 )
0 commit comments