diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index a0adc8c9..9a7735c4 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -159,10 +159,21 @@ def matrix_shapes(draw, stack_shapes=shapes()): square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2]) -finite_matrices = xps.arrays(dtype=xps.floating_dtypes(), - shape=matrix_shapes(), - elements=dict(allow_nan=False, - allow_infinity=False)) +@composite +def finite_matrices(draw, shape=matrix_shapes()): + return draw(xps.arrays(dtype=xps.floating_dtypes(), + shape=shape, + elements=dict(allow_nan=False, + allow_infinity=False))) + +rtol_shared_matrix_shapes = shared(matrix_shapes()) +# Should we set a max_value here? +_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0) +rtols = one_of(floats(**_rtol_float_kw), + xps.arrays(dtype=xps.floating_dtypes(), + shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]), + elements=_rtol_float_kw)) + def mutually_broadcastable_shapes( num_shapes: int, diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 62c93562..cdf77ce7 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -16,7 +16,7 @@ import pytest from hypothesis import assume, given from hypothesis.strategies import (booleans, composite, none, tuples, integers, - shared, sampled_from, data, just) + shared, sampled_from, one_of, data, just) from ndindex import iter_indices from .array_helpers import assert_exactly_equal, asarray @@ -26,10 +26,10 @@ invertible_matrices, two_mutual_arrays, mutually_promotable_dtypes, one_d_shapes, two_mutually_broadcastable_shapes, - SQRT_MAX_ARRAY_SIZE, finite_matrices) + SQRT_MAX_ARRAY_SIZE, finite_matrices, + rtol_shared_matrix_shapes, rtols) from . import dtype_helpers as dh from . import pytest_helpers as ph -from . import shape_helpers as sh from .algos import broadcast_shapes @@ -39,18 +39,17 @@ pytestmark = pytest.mark.ci - - # Standin strategy for not yet implemented tests todo = none() -def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1), +def _test_stacks(f, *args, res=None, dims=2, true_val=None, + matrix_axes=(-2, -1), assert_equal=assert_exactly_equal, **kw): """ Test that f(*args, **kw) maps across stacks of matrices - dims is the number of dimensions f(*args) should have for a single n x m - matrix stack. + dims is the number of dimensions f(*args, *kw) should have for a single n + x m matrix stack. matrix_axes are the axes along which matrices (or vectors) are stacked in the input. @@ -67,9 +66,13 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1) shapes = [x.shape for x in args] + # Assume the result is stacked along the last 'dims' axes of matrix_axes. + # This holds for all the functions tested in this file + res_axes = matrix_axes[::-1][:dims] + for (x_idxes, (res_idx,)) in zip( iter_indices(*shapes, skip_axes=matrix_axes), - iter_indices(res.shape, skip_axes=tuple(range(-dims, 0)))): + iter_indices(res.shape, skip_axes=res_axes)): x_idxes = [x_idx.raw for x_idx in x_idxes] res_idx = res_idx.raw @@ -161,26 +164,18 @@ def test_cross(x1_x2_kw): assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype" assert res.shape == shape, "cross() did not return the correct shape" - # cross is too different from other functions to use _test_stacks, and it - # is the only function that works the way it does, so it's not really - # worth generalizing _test_stacks to handle it. - a = axis if axis >= 0 else axis + len(shape) - for _idx in sh.ndindex(shape[:a] + shape[a+1:]): - idx = _idx[:a] + (slice(None),) + _idx[a:] - assert len(idx) == len(shape), "Invalid index. This indicates a bug in the test suite." - res_stack = res[idx] - x1_stack = x1[idx] - x2_stack = x2[idx] - assert x1_stack.shape == x2_stack.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite." - decomp_res_stack = linalg.cross(x1_stack, x2_stack) - assert_exactly_equal(res_stack, decomp_res_stack) - - exact_cross = asarray([ - x1_stack[1]*x2_stack[2] - x1_stack[2]*x2_stack[1], - x1_stack[2]*x2_stack[0] - x1_stack[0]*x2_stack[2], - x1_stack[0]*x2_stack[1] - x1_stack[1]*x2_stack[0], - ], dtype=res.dtype) - assert_exactly_equal(res_stack, exact_cross) + def exact_cross(a, b): + assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite." + return asarray([ + a[1]*b[2] - a[2]*b[1], + a[2]*b[0] - a[0]*b[2], + a[0]*b[1] - a[1]*b[0], + ], dtype=res.dtype) + + # We don't want to pass in **kw here because that would pass axis to + # cross() on a single stack, but the axis is not meaningful on unstacked + # vectors. + _test_stacks(linalg.cross, x1, x2, dims=1, matrix_axes=(axis,), res=res, true_val=exact_cross) @pytest.mark.xp_extension('linalg') @given( @@ -314,14 +309,30 @@ def test_matmul(x1, x2): assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1]) _test_stacks(_array_module.matmul, x1, x2, res=res) +matrix_norm_shapes = shared(matrix_shapes()) + @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()), - kw=kwargs(axis=todo, keepdims=todo, ord=todo) + x=finite_matrices(), + kw=kwargs(keepdims=booleans(), + ord=sampled_from([-float('inf'), -2, -2, 1, 2, float('inf'), 'fro', 'nuc'])) ) def test_matrix_norm(x, kw): - # res = linalg.matrix_norm(x, **kw) - pass + res = linalg.matrix_norm(x, **kw) + + keepdims = kw.get('keepdims', False) + # TODO: Check that the ord values give the correct norms. + # ord = kw.get('ord', 'fro') + + if keepdims: + expected_shape = x.shape[:-2] + (1, 1) + else: + expected_shape = x.shape[:-2] + assert res.shape == expected_shape, f"matrix_norm({keepdims=}) did not return the correct shape" + assert res.dtype == x.dtype, "matrix_norm() did not return the correct dtype" + + _test_stacks(linalg.matrix_norm, x, **kw, dims=2 if keepdims else 0, + res=res) matrix_power_n = shared(integers(-1000, 1000), key='matrix_power n') @pytest.mark.xp_extension('linalg') @@ -348,12 +359,11 @@ def test_matrix_power(x, n): @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()), - kw=kwargs(rtol=todo) + x=finite_matrices(shape=rtol_shared_matrix_shapes), + kw=kwargs(rtol=rtols) ) def test_matrix_rank(x, kw): - # res = linalg.matrix_rank(x, **kw) - pass + res = linalg.matrix_rank(x, **kw) @given( x=xps.arrays(dtype=dtypes, shape=matrix_shapes()), @@ -398,12 +408,11 @@ def test_outer(x1, x2): @pytest.mark.xp_extension('linalg') @given( - x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()), - kw=kwargs(rtol=todo) + x=finite_matrices(shape=rtol_shared_matrix_shapes), + kw=kwargs(rtol=rtols) ) def test_pinv(x, kw): - # res = linalg.pinv(x, **kw) - pass + res = linalg.pinv(x, **kw) @pytest.mark.xp_extension('linalg') @given( @@ -483,7 +492,7 @@ def solve_args(): Strategy for the x1 and x2 arguments to test_solve() solve() takes x1, x2, where x1 is any stack of square invertible matrices - of shape (..., M, M), and x2 is either shape (..., M) or (..., M, K), + of shape (..., M, M), and x2 is either shape (M,) or (..., M, K), where the ... parts of x1 and x2 are broadcast compatible. """ stack_shapes = shared(two_mutually_broadcastable_shapes) @@ -493,30 +502,22 @@ def solve_args(): pair[0]))) @composite - def x2_shapes(draw): - end = draw(xps.array_shapes(min_dims=0, max_dims=1, min_side=0, - max_side=SQRT_MAX_ARRAY_SIZE)) - return draw(stack_shapes)[1] + draw(x1).shape[-1:] + end + def _x2_shapes(draw): + end = draw(integers(0, SQRT_MAX_ARRAY_SIZE)) + return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,) - x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes()) + x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes()) + x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes) return x1, x2 @pytest.mark.xp_extension('linalg') @given(*solve_args()) def test_solve(x1, x2): - # TODO: solve() is currently ambiguous, in that some inputs can be - # interpreted in two different ways. For example, if x1 is shape (2, 2, 2) - # and x2 is shape (2, 2), should this be interpreted as x2 is (2,) stack - # of a (2,) vector, i.e., the result would be (2, 2, 2, 1) after - # broadcasting, or as a single stack of a 2x2 matrix, i.e., resulting in - # (2, 2, 2, 2). - - # res = linalg.solve(x1, x2) - pass + res = linalg.solve(x1, x2) @pytest.mark.xp_extension('linalg') @given( - x=finite_matrices, + x=finite_matrices(), kw=kwargs(full_matrices=booleans()) ) def test_svd(x, kw): @@ -552,7 +553,7 @@ def test_svd(x, kw): @pytest.mark.xp_extension('linalg') @given( - x=finite_matrices, + x=finite_matrices(), ) def test_svdvals(x): res = linalg.svdvals(x)