12
12
required, but we don't yet have a clean way to disable only those tests (see https://github.com/data-apis/array-api-tests/issues/25).
13
13
14
14
"""
15
- # TODO: test with complex dtypes where appropriate
16
-
17
15
import pytest
18
16
from hypothesis import assume , given
19
17
from hypothesis .strategies import (booleans , composite , tuples , floats ,
24
22
import itertools
25
23
26
24
from .array_helpers import assert_exactly_equal , asarray
27
- from .hypothesis_helpers import (arrays , xps , shapes , kwargs , matrix_shapes ,
28
- square_matrix_shapes , symmetric_matrices ,
25
+ from .hypothesis_helpers import (arrays , all_floating_dtypes , xps , shapes ,
26
+ kwargs , matrix_shapes , square_matrix_shapes ,
27
+ symmetric_matrices ,
29
28
positive_definite_matrices , MAX_ARRAY_SIZE ,
30
29
invertible_matrices , two_mutual_arrays ,
31
30
mutually_promotable_dtypes , one_d_shapes ,
@@ -210,7 +209,7 @@ def exact_cross(a, b):
210
209
211
210
@pytest .mark .xp_extension ('linalg' )
212
211
@given (
213
- x = arrays (dtype = xps . floating_dtypes (), shape = square_matrix_shapes ),
212
+ x = arrays (dtype = all_floating_dtypes (), shape = square_matrix_shapes ),
214
213
)
215
214
def test_det (x ):
216
215
res = linalg .det (x )
@@ -224,7 +223,7 @@ def test_det(x):
224
223
225
224
@pytest .mark .xp_extension ('linalg' )
226
225
@given (
227
- x = arrays (dtype = xps .real_dtypes (), shape = matrix_shapes ()),
226
+ x = arrays (dtype = xps .scalar_dtypes (), shape = matrix_shapes ()),
228
227
# offset may produce an overflow if it is too large. Supporting offsets
229
228
# that are way larger than the array shape isn't very important.
230
229
kw = kwargs (offset = integers (- MAX_ARRAY_SIZE , MAX_ARRAY_SIZE ))
@@ -382,7 +381,7 @@ def test_matrix_norm(x, kw):
382
381
@given (
383
382
# Generate any square matrix if n >= 0 but only invertible matrices if n < 0
384
383
x = matrix_power_n .flatmap (lambda n : invertible_matrices () if n < 0 else
385
- arrays (dtype = xps . floating_dtypes (),
384
+ arrays (dtype = all_floating_dtypes (),
386
385
shape = square_matrix_shapes )),
387
386
n = matrix_power_n ,
388
387
)
@@ -409,7 +408,7 @@ def test_matrix_rank(x, kw):
409
408
linalg .matrix_rank (x , ** kw )
410
409
411
410
@given (
412
- x = arrays (dtype = xps .real_dtypes (), shape = matrix_shapes ()),
411
+ x = arrays (dtype = xps .scalar_dtypes (), shape = matrix_shapes ()),
413
412
)
414
413
def test_matrix_transpose (x ):
415
414
res = _array_module .matrix_transpose (x )
@@ -459,7 +458,7 @@ def test_pinv(x, kw):
459
458
460
459
@pytest .mark .xp_extension ('linalg' )
461
460
@given (
462
- x = arrays (dtype = xps . floating_dtypes (), shape = matrix_shapes ()),
461
+ x = arrays (dtype = all_floating_dtypes (), shape = matrix_shapes ()),
463
462
kw = kwargs (mode = sampled_from (['reduced' , 'complete' ]))
464
463
)
465
464
def test_qr (x , kw ):
@@ -495,7 +494,7 @@ def test_qr(x, kw):
495
494
496
495
@pytest .mark .xp_extension ('linalg' )
497
496
@given (
498
- x = arrays (dtype = xps . floating_dtypes (), shape = square_matrix_shapes ),
497
+ x = arrays (dtype = all_floating_dtypes (), shape = square_matrix_shapes ),
499
498
)
500
499
def test_slogdet (x ):
501
500
res = linalg .slogdet (x )
@@ -504,11 +503,16 @@ def test_slogdet(x):
504
503
505
504
sign , logabsdet = res
506
505
507
- assert sign .dtype == x .dtype , "slogdet().sign did not return the correct dtype"
508
- assert sign .shape == x .shape [:- 2 ], "slogdet().sign did not return the correct shape"
509
- assert logabsdet .dtype == x .dtype , "slogdet().logabsdet did not return the correct dtype"
510
- assert logabsdet .shape == x .shape [:- 2 ], "slogdet().logabsdet did not return the correct shape"
511
-
506
+ ph .assert_dtype ("slogdet" , in_dtype = x .dtype , out_dtype = sign .dtype ,
507
+ expected = x .dtype , repr_name = "sign.dtype" )
508
+ ph .assert_shape ("slogdet" , out_shape = sign .shape , expected = x .shape [:- 2 ],
509
+ repr_name = "sign.shape" )
510
+ expected_dtype = dh .as_real_dtype (x .dtype )
511
+ ph .assert_dtype ("slogdet" , in_dtype = x .dtype , out_dtype = logabsdet .dtype ,
512
+ expected = expected_dtype , repr_name = "logabsdet.dtype" )
513
+ ph .assert_shape ("slogdet" , out_shape = logabsdet .shape ,
514
+ expected = x .shape [:- 2 ],
515
+ repr_name = "logabsdet.shape" )
512
516
513
517
_test_stacks (lambda x : linalg .slogdet (x ).sign , x ,
514
518
res = sign , dims = 0 )
@@ -550,7 +554,7 @@ def _x2_shapes(draw):
550
554
return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + (end ,)
551
555
552
556
x2_shapes = one_of (x1 .map (lambda x : (x .shape [- 1 ],)), _x2_shapes ())
553
- x2 = arrays (dtype = xps . floating_dtypes (), shape = x2_shapes )
557
+ x2 = arrays (dtype = all_floating_dtypes (), shape = x2_shapes )
554
558
return x1 , x2
555
559
556
560
@pytest .mark .xp_extension ('linalg' )
@@ -734,7 +738,7 @@ def test_tensordot(x1, x2, kw):
734
738
735
739
@pytest .mark .xp_extension ('linalg' )
736
740
@given (
737
- x = arrays (dtype = xps .real_dtypes (), shape = matrix_shapes ()),
741
+ x = arrays (dtype = xps .numeric_dtypes (), shape = matrix_shapes ()),
738
742
# offset may produce an overflow if it is too large. Supporting offsets
739
743
# that are way larger than the array shape isn't very important.
740
744
kw = kwargs (offset = integers (- MAX_ARRAY_SIZE , MAX_ARRAY_SIZE ))
@@ -812,7 +816,7 @@ def true_val(x, y, axis=-1):
812
816
813
817
@pytest .mark .xp_extension ('linalg' )
814
818
@given (
815
- x = arrays (dtype = xps . floating_dtypes (), shape = shapes (min_side = 1 )),
819
+ x = arrays (dtype = all_floating_dtypes (), shape = shapes (min_side = 1 )),
816
820
data = data (),
817
821
)
818
822
def test_vector_norm (x , data ):
@@ -838,8 +842,9 @@ def test_vector_norm(x, data):
838
842
ph .assert_keepdimable_shape ('linalg.vector_norm' , out_shape = res .shape ,
839
843
in_shape = x .shape , axes = _axes ,
840
844
keepdims = keepdims , kw = kw )
845
+ expected_dtype = dh .as_real_dtype (x .dtype )
841
846
ph .assert_dtype ('linalg.vector_norm' , in_dtype = x .dtype ,
842
- out_dtype = res .dtype )
847
+ out_dtype = res .dtype , expected = expected_dtype )
843
848
844
849
_kw = kw .copy ()
845
850
_kw .pop ('axis' , None )
0 commit comments