13
13
14
14
"""
15
15
16
+ import pytest
16
17
from hypothesis import assume , given
17
18
from hypothesis .strategies import (booleans , composite , none , tuples , integers ,
18
19
shared , sampled_from )
33
34
from . import _array_module
34
35
from ._array_module import linalg
35
36
37
+
36
38
# Standin strategy for not yet implemented tests
37
39
todo = none ()
38
40
@@ -74,6 +76,7 @@ def _test_namedtuple(res, fields, func_name):
74
76
assert hasattr (res , field ), f"{ func_name } () result namedtuple doesn't have the '{ field } ' field"
75
77
assert res [i ] is getattr (res , field ), f"{ func_name } () result namedtuple '{ field } ' field is not in position { i } "
76
78
79
+ @pytest .mark .xp_extension ('linalg' )
77
80
@given (
78
81
x = positive_definite_matrices (),
79
82
kw = kwargs (upper = booleans ())
@@ -121,6 +124,7 @@ def cross_args(draw, dtype_objects=dh.numeric_dtypes):
121
124
)
122
125
return draw (arrays1 ), draw (arrays2 ), kw
123
126
127
+ @pytest .mark .xp_extension ('linalg' )
124
128
@given (
125
129
cross_args ()
126
130
)
@@ -159,6 +163,7 @@ def test_cross(x1_x2_kw):
159
163
], dtype = res .dtype )
160
164
assert_exactly_equal (res_stack , exact_cross )
161
165
166
+ @pytest .mark .xp_extension ('linalg' )
162
167
@given (
163
168
x = xps .arrays (dtype = xps .floating_dtypes (), shape = square_matrix_shapes ),
164
169
)
@@ -172,6 +177,7 @@ def test_det(x):
172
177
173
178
# TODO: Test that res actually corresponds to the determinant of x
174
179
180
+ @pytest .mark .xp_extension ('linalg' )
175
181
@given (
176
182
x = xps .arrays (dtype = dtypes , shape = matrix_shapes ),
177
183
# offset may produce an overflow if it is too large. Supporting offsets
@@ -206,6 +212,7 @@ def true_diag(x_stack):
206
212
207
213
_test_stacks (linalg .diagonal , x , ** kw , res = res , dims = 1 , true_val = true_diag )
208
214
215
+ @pytest .mark .xp_extension ('linalg' )
209
216
@given (x = symmetric_matrices (finite = True ))
210
217
def test_eigh (x ):
211
218
res = linalg .eigh (x )
@@ -229,6 +236,7 @@ def test_eigh(x):
229
236
# TODO: Test that res actually corresponds to the eigenvalues and
230
237
# eigenvectors of x
231
238
239
+ @pytest .mark .xp_extension ('linalg' )
232
240
@given (x = symmetric_matrices (finite = True ))
233
241
def test_eigvalsh (x ):
234
242
res = linalg .eigvalsh (x )
@@ -242,6 +250,7 @@ def test_eigvalsh(x):
242
250
243
251
# TODO: Test that res actually corresponds to the eigenvalues of x
244
252
253
+ @pytest .mark .xp_extension ('linalg' )
245
254
@given (x = invertible_matrices ())
246
255
def test_inv (x ):
247
256
res = linalg .inv (x )
@@ -286,6 +295,7 @@ def test_matmul(x1, x2):
286
295
assert res .shape == stack_shape + (x1 .shape [- 2 ], x2 .shape [- 1 ])
287
296
_test_stacks (_array_module .matmul , x1 , x2 , res = res )
288
297
298
+ @pytest .mark .xp_extension ('linalg' )
289
299
@given (
290
300
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
291
301
kw = kwargs (axis = todo , keepdims = todo , ord = todo )
@@ -295,6 +305,7 @@ def test_matrix_norm(x, kw):
295
305
pass
296
306
297
307
matrix_power_n = shared (integers (- 1000 , 1000 ), key = 'matrix_power n' )
308
+ @pytest .mark .xp_extension ('linalg' )
298
309
@given (
299
310
# Generate any square matrix if n >= 0 but only invertible matrices if n < 0
300
311
x = matrix_power_n .flatmap (lambda n : invertible_matrices () if n < 0 else
@@ -316,6 +327,7 @@ def test_matrix_power(x, n):
316
327
func = lambda x : linalg .matrix_power (x , n )
317
328
_test_stacks (func , x , res = res , true_val = true_val )
318
329
330
+ @pytest .mark .xp_extension ('linalg' )
319
331
@given (
320
332
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
321
333
kw = kwargs (rtol = todo )
@@ -341,6 +353,7 @@ def test_matrix_transpose(x):
341
353
342
354
_test_stacks (_array_module .matrix_transpose , x , res = res , true_val = true_val )
343
355
356
+ @pytest .mark .xp_extension ('linalg' )
344
357
@given (
345
358
* two_mutual_arrays (dtype_objs = dh .numeric_dtypes ,
346
359
two_shapes = tuples (one_d_shapes , one_d_shapes ))
@@ -364,6 +377,7 @@ def test_outer(x1, x2):
364
377
365
378
assert_exactly_equal (res , true_res )
366
379
380
+ @pytest .mark .xp_extension ('linalg' )
367
381
@given (
368
382
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
369
383
kw = kwargs (rtol = todo )
@@ -372,6 +386,7 @@ def test_pinv(x, kw):
372
386
# res = linalg.pinv(x, **kw)
373
387
pass
374
388
389
+ @pytest .mark .xp_extension ('linalg' )
375
390
@given (
376
391
x = xps .arrays (dtype = xps .floating_dtypes (), shape = matrix_shapes ),
377
392
kw = kwargs (mode = sampled_from (['reduced' , 'complete' ]))
@@ -407,6 +422,7 @@ def test_qr(x, kw):
407
422
# Check that r is upper-triangular.
408
423
assert_exactly_equal (r , _array_module .triu (r ))
409
424
425
+ @pytest .mark .xp_extension ('linalg' )
410
426
@given (
411
427
x = xps .arrays (dtype = xps .floating_dtypes (), shape = square_matrix_shapes ),
412
428
)
@@ -464,6 +480,7 @@ def x2_shapes(draw):
464
480
x2 = xps .arrays (dtype = xps .floating_dtypes (), shape = x2_shapes ())
465
481
return x1 , x2
466
482
483
+ @pytest .mark .xp_extension ('linalg' )
467
484
@given (* solve_args ())
468
485
def test_solve (x1 , x2 ):
469
486
# TODO: solve() is currently ambiguous, in that some inputs can be
@@ -476,6 +493,7 @@ def test_solve(x1, x2):
476
493
# res = linalg.solve(x1, x2)
477
494
pass
478
495
496
+ @pytest .mark .xp_extension ('linalg' )
479
497
@given (
480
498
x = finite_matrices ,
481
499
kw = kwargs (full_matrices = booleans ())
@@ -503,6 +521,7 @@ def test_svd(x, kw):
503
521
assert u .shape == (* stack , M , K )
504
522
assert vh .shape == (* stack , K , N )
505
523
524
+ @pytest .mark .xp_extension ('linalg' )
506
525
@given (
507
526
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
508
527
)
@@ -519,6 +538,7 @@ def test_tensordot(x1, x2, kw):
519
538
# res = _array_module.tensordot(x1, x2, **kw)
520
539
pass
521
540
541
+ @pytest .mark .xp_extension ('linalg' )
522
542
@given (
523
543
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
524
544
kw = kwargs (offset = todo )
@@ -536,6 +556,7 @@ def test_vecdot(x1, x2, kw):
536
556
# res = _array_module.vecdot(x1, x2, **kw)
537
557
pass
538
558
559
+ @pytest .mark .xp_extension ('linalg' )
539
560
@given (
540
561
x = xps .arrays (dtype = xps .floating_dtypes (), shape = shapes ),
541
562
kw = kwargs (axis = todo , keepdims = todo , ord = todo )
0 commit comments