Skip to content

Commit 5ceb81d

Browse files
committed
Update linalg tests to test complex dtypes
Also diagonal and matrix_transpose now test against all dtypes, since they have no dtype restrictions in the spec.
1 parent 012ca19 commit 5ceb81d

File tree

2 files changed

+34
-19
lines changed

2 files changed

+34
-19
lines changed

array_api_tests/dtype_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,16 @@ class MinMax(NamedTuple):
231231
{"complex64": xp.float32, "complex128": xp.float64}
232232
)
233233

234+
def as_real_dtype(dtype):
235+
"""
236+
Return the corresponding real dtype for a given floating-point dtype.
237+
"""
238+
if dtype in real_float_dtypes:
239+
return dtype
240+
elif dtype_to_name[dtype] in complex_names:
241+
return dtype_components[dtype]
242+
else:
243+
raise ValueError("as_real_dtype requires a floating-point dtype")
234244

235245
if not hasattr(xp, "asarray"):
236246
default_int = xp.int32

array_api_tests/test_linalg.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
required, but we don't yet have a clean way to disable only those tests (see https://github.com/data-apis/array-api-tests/issues/25).
1313
1414
"""
15-
# TODO: test with complex dtypes where appropriate
16-
1715
import pytest
1816
from hypothesis import assume, given
1917
from hypothesis.strategies import (booleans, composite, tuples, floats,
@@ -24,8 +22,9 @@
2422
import itertools
2523

2624
from .array_helpers import assert_exactly_equal, asarray
27-
from .hypothesis_helpers import (arrays, xps, shapes, kwargs, matrix_shapes,
28-
square_matrix_shapes, symmetric_matrices,
25+
from .hypothesis_helpers import (arrays, all_floating_dtypes, xps, shapes,
26+
kwargs, matrix_shapes, square_matrix_shapes,
27+
symmetric_matrices,
2928
positive_definite_matrices, MAX_ARRAY_SIZE,
3029
invertible_matrices, two_mutual_arrays,
3130
mutually_promotable_dtypes, one_d_shapes,
@@ -210,7 +209,7 @@ def exact_cross(a, b):
210209

211210
@pytest.mark.xp_extension('linalg')
212211
@given(
213-
x=arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes),
212+
x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes),
214213
)
215214
def test_det(x):
216215
res = linalg.det(x)
@@ -224,7 +223,7 @@ def test_det(x):
224223

225224
@pytest.mark.xp_extension('linalg')
226225
@given(
227-
x=arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()),
226+
x=arrays(dtype=xps.scalar_dtypes(), shape=matrix_shapes()),
228227
# offset may produce an overflow if it is too large. Supporting offsets
229228
# that are way larger than the array shape isn't very important.
230229
kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE))
@@ -382,7 +381,7 @@ def test_matrix_norm(x, kw):
382381
@given(
383382
# Generate any square matrix if n >= 0 but only invertible matrices if n < 0
384383
x=matrix_power_n.flatmap(lambda n: invertible_matrices() if n < 0 else
385-
arrays(dtype=xps.floating_dtypes(),
384+
arrays(dtype=all_floating_dtypes(),
386385
shape=square_matrix_shapes)),
387386
n=matrix_power_n,
388387
)
@@ -409,7 +408,7 @@ def test_matrix_rank(x, kw):
409408
linalg.matrix_rank(x, **kw)
410409

411410
@given(
412-
x=arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()),
411+
x=arrays(dtype=xps.scalar_dtypes(), shape=matrix_shapes()),
413412
)
414413
def test_matrix_transpose(x):
415414
res = _array_module.matrix_transpose(x)
@@ -459,7 +458,7 @@ def test_pinv(x, kw):
459458

460459
@pytest.mark.xp_extension('linalg')
461460
@given(
462-
x=arrays(dtype=xps.floating_dtypes(), shape=matrix_shapes()),
461+
x=arrays(dtype=all_floating_dtypes(), shape=matrix_shapes()),
463462
kw=kwargs(mode=sampled_from(['reduced', 'complete']))
464463
)
465464
def test_qr(x, kw):
@@ -495,7 +494,7 @@ def test_qr(x, kw):
495494

496495
@pytest.mark.xp_extension('linalg')
497496
@given(
498-
x=arrays(dtype=xps.floating_dtypes(), shape=square_matrix_shapes),
497+
x=arrays(dtype=all_floating_dtypes(), shape=square_matrix_shapes),
499498
)
500499
def test_slogdet(x):
501500
res = linalg.slogdet(x)
@@ -504,11 +503,16 @@ def test_slogdet(x):
504503

505504
sign, logabsdet = res
506505

507-
assert sign.dtype == x.dtype, "slogdet().sign did not return the correct dtype"
508-
assert sign.shape == x.shape[:-2], "slogdet().sign did not return the correct shape"
509-
assert logabsdet.dtype == x.dtype, "slogdet().logabsdet did not return the correct dtype"
510-
assert logabsdet.shape == x.shape[:-2], "slogdet().logabsdet did not return the correct shape"
511-
506+
ph.assert_dtype("slogdet", in_dtype=x.dtype, out_dtype=sign.dtype,
507+
expected=x.dtype, repr_name="sign.dtype")
508+
ph.assert_shape("slogdet", out_shape=sign.shape, expected=x.shape[:-2],
509+
repr_name="sign.shape")
510+
expected_dtype = dh.as_real_dtype(x.dtype)
511+
ph.assert_dtype("slogdet", in_dtype=x.dtype, out_dtype=logabsdet.dtype,
512+
expected=expected_dtype, repr_name="logabsdet.dtype")
513+
ph.assert_shape("slogdet", out_shape=logabsdet.shape,
514+
expected=x.shape[:-2],
515+
repr_name="logabsdet.shape")
512516

513517
_test_stacks(lambda x: linalg.slogdet(x).sign, x,
514518
res=sign, dims=0)
@@ -550,7 +554,7 @@ def _x2_shapes(draw):
550554
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,)
551555

552556
x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes())
553-
x2 = arrays(dtype=xps.floating_dtypes(), shape=x2_shapes)
557+
x2 = arrays(dtype=all_floating_dtypes(), shape=x2_shapes)
554558
return x1, x2
555559

556560
@pytest.mark.xp_extension('linalg')
@@ -734,7 +738,7 @@ def test_tensordot(x1, x2, kw):
734738

735739
@pytest.mark.xp_extension('linalg')
736740
@given(
737-
x=arrays(dtype=xps.real_dtypes(), shape=matrix_shapes()),
741+
x=arrays(dtype=xps.numeric_dtypes(), shape=matrix_shapes()),
738742
# offset may produce an overflow if it is too large. Supporting offsets
739743
# that are way larger than the array shape isn't very important.
740744
kw=kwargs(offset=integers(-MAX_ARRAY_SIZE, MAX_ARRAY_SIZE))
@@ -812,7 +816,7 @@ def true_val(x, y, axis=-1):
812816

813817
@pytest.mark.xp_extension('linalg')
814818
@given(
815-
x=arrays(dtype=xps.floating_dtypes(), shape=shapes(min_side=1)),
819+
x=arrays(dtype=all_floating_dtypes(), shape=shapes(min_side=1)),
816820
data=data(),
817821
)
818822
def test_vector_norm(x, data):
@@ -838,8 +842,9 @@ def test_vector_norm(x, data):
838842
ph.assert_keepdimable_shape('linalg.vector_norm', out_shape=res.shape,
839843
in_shape=x.shape, axes=_axes,
840844
keepdims=keepdims, kw=kw)
845+
expected_dtype = dh.as_real_dtype(x.dtype)
841846
ph.assert_dtype('linalg.vector_norm', in_dtype=x.dtype,
842-
out_dtype=res.dtype)
847+
out_dtype=res.dtype, expected=expected_dtype)
843848

844849
_kw = kw.copy()
845850
_kw.pop('axis', None)

0 commit comments

Comments
 (0)