diff --git a/.github/workflows/numpy.yml b/.github/workflows/numpy.yml index aed716f1..3581edee 100644 --- a/.github/workflows/numpy.yml +++ b/.github/workflows/numpy.yml @@ -34,8 +34,9 @@ jobs: array_api_tests/test_creation_functions.py::test_linspace # https://github.com/numpy/numpy/issues/20870 array_api_tests/test_data_type_functions.py::test_can_cast - # linalg tests generally need more mulling over - array_api_tests/test_linalg.py + # The return dtype for trace is not consistent in the spec + # (https://github.com/data-apis/array-api/issues/202#issuecomment-952529197) + array_api_tests/test_linalg.py::test_trace # waiting on NumPy to allow/revert distinct NaNs for np.unique # https://github.com/numpy/numpy/issues/20326#issuecomment-1012380448 array_api_tests/test_set_functions.py diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 2d8289d7..62c93562 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -17,8 +17,9 @@ from hypothesis import assume, given from hypothesis.strategies import (booleans, composite, none, tuples, integers, shared, sampled_from, data, just) +from ndindex import iter_indices -from .array_helpers import assert_exactly_equal, asarray, equal, zero, infinity +from .array_helpers import assert_exactly_equal, asarray from .hypothesis_helpers import (xps, dtypes, shapes, kwargs, matrix_shapes, square_matrix_shapes, symmetric_matrices, positive_definite_matrices, MAX_ARRAY_SIZE, @@ -43,29 +44,41 @@ # Standin strategy for not yet implemented tests todo = none() -def _test_stacks(f, *args, res=None, dims=2, true_val=None, **kw): +def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1), + assert_equal=assert_exactly_equal, **kw): """ Test that f(*args, **kw) maps across stacks of matrices - dims is the number of dimensions f should have for a single n x m matrix - stack. + dims is the number of dimensions f(*args) should have for a single n x m + matrix stack. + + matrix_axes are the axes along which matrices (or vectors) are stacked in + the input. + + true_val may be a function such that true_val(*x_stacks, **kw) gives the + true value for f on a stack. + + res should be the result of f(*args, **kw). It is computed if not passed + in. - true_val may be a function such that true_val(*x_stacks) gives the true - value for f on a stack """ if res is None: res = f(*args, **kw) - shape = args[0].shape if len(args) == 1 else broadcast_shapes(*[x.shape - for x in args]) - for _idx in sh.ndindex(shape[:-2]): - idx = _idx + (slice(None),)*dims - res_stack = res[idx] - x_stacks = [x[_idx + (...,)] for x in args] + shapes = [x.shape for x in args] + + for (x_idxes, (res_idx,)) in zip( + iter_indices(*shapes, skip_axes=matrix_axes), + iter_indices(res.shape, skip_axes=tuple(range(-dims, 0)))): + x_idxes = [x_idx.raw for x_idx in x_idxes] + res_idx = res_idx.raw + + res_stack = res[res_idx] + x_stacks = [x[x_idx] for x, x_idx in zip(args, x_idxes)] decomp_res_stack = f(*x_stacks, **kw) - assert_exactly_equal(res_stack, decomp_res_stack) + assert_equal(res_stack, decomp_res_stack) if true_val: - assert_exactly_equal(decomp_res_stack, true_val(*x_stacks)) + assert_equal(decomp_res_stack, true_val(*x_stacks)) def _test_namedtuple(res, fields, func_name): """ @@ -452,10 +465,12 @@ def test_slogdet(x): # Check that when the determinant is 0, the sign and logabsdet are (0, # -inf). - d = linalg.det(x) - zero_det = equal(d, zero(d.shape, d.dtype)) - assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype)) - assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype)) + # TODO: This test does not necessarily hold exactly. Update it to test it + # approximately. + # d = linalg.det(x) + # zero_det = equal(d, zero(d.shape, d.dtype)) + # assert_exactly_equal(sign[zero_det], zero(sign[zero_det].shape, x.dtype)) + # assert_exactly_equal(logabsdet[zero_det], -infinity(logabsdet[zero_det].shape, x.dtype)) # More generally, det(x) should equal sign*exp(logabsdet), but this does # not hold exactly due to floating-point loss of precision. @@ -614,7 +629,7 @@ def true_trace(x_stack): @given( dtypes=mutually_promotable_dtypes(dtypes=dh.numeric_dtypes), - shape=shapes(), + shape=shapes(min_dims=1), data=data(), ) def test_vecdot(dtypes, shape, data): diff --git a/requirements.txt b/requirements.txt index 8773b0e9..95a49cfa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ pytest hypothesis>=6.31.1 +ndindex>=1.6 regex removestar