diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 3033dac3..60c683c4 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -312,8 +312,8 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()): @composite def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes()): # For now, just generate stacks of diagonal matrices. - n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),) stack_shape = draw(stack_shapes) + n = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(stack_shape), 1)),) dtype = draw(dtypes) elements = one_of( from_dtype(dtype, min_value=0.5, allow_nan=False, allow_infinity=False), diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 465f54be..a20792ae 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -19,12 +19,14 @@ data) from ndindex import iter_indices +import math import itertools +from typing import Tuple from .array_helpers import assert_exactly_equal, asarray from .hypothesis_helpers import (arrays, all_floating_dtypes, xps, shapes, kwargs, matrix_shapes, square_matrix_shapes, - symmetric_matrices, + symmetric_matrices, SearchStrategy, positive_definite_matrices, MAX_ARRAY_SIZE, invertible_matrices, two_mutual_arrays, mutually_promotable_dtypes, one_d_shapes, @@ -35,6 +37,7 @@ from . import dtype_helpers as dh from . import pytest_helpers as ph from . import shape_helpers as sh +from .typing import Array from . import _array_module from . import _array_module as xp @@ -589,7 +592,7 @@ def test_slogdet(x): # TODO: Test this when we have tests for floating-point values. # assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps) -def solve_args(): +def solve_args() -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]: """ Strategy for the x1 and x2 arguments to test_solve() @@ -608,8 +611,9 @@ def solve_args(): @composite def _x2_shapes(draw): - end = draw(integers(0, SQRT_MAX_ARRAY_SIZE)) - return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,) + base_shape = draw(stack_shapes)[1] + draw(x1).shape[-1:] + end = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(base_shape), 1))) + return base_shape + (end,) x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes()) x2 = arrays(shape=x2_shapes, dtype=mutual_dtypes.map(lambda pair: pair[1]))