From b2115b7395d2a517f94337a303a9b0491412931b Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 11 Mar 2024 10:36:11 +0000 Subject: [PATCH 1/3] Return type hint `solve_args()` --- array_api_tests/test_linalg.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index 465f54be..c09fced7 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -20,11 +20,12 @@ from ndindex import iter_indices import itertools +from typing import Tuple from .array_helpers import assert_exactly_equal, asarray from .hypothesis_helpers import (arrays, all_floating_dtypes, xps, shapes, kwargs, matrix_shapes, square_matrix_shapes, - symmetric_matrices, + symmetric_matrices, SearchStrategy, positive_definite_matrices, MAX_ARRAY_SIZE, invertible_matrices, two_mutual_arrays, mutually_promotable_dtypes, one_d_shapes, @@ -35,6 +36,7 @@ from . import dtype_helpers as dh from . import pytest_helpers as ph from . import shape_helpers as sh +from .typing import Array from . import _array_module from . import _array_module as xp @@ -589,7 +591,7 @@ def test_slogdet(x): # TODO: Test this when we have tests for floating-point values. # assert all(abs(linalg.det(x) - sign*exp(logabsdet)) < eps) -def solve_args(): +def solve_args() -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]: """ Strategy for the x1 and x2 arguments to test_solve() From 762f1e77861281cdfdd9f449e19df7096e2c875e Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Mon, 11 Mar 2024 10:45:13 +0000 Subject: [PATCH 2/3] Draw end shape relative to base shape --- array_api_tests/test_linalg.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/array_api_tests/test_linalg.py b/array_api_tests/test_linalg.py index c09fced7..a20792ae 100644 --- a/array_api_tests/test_linalg.py +++ b/array_api_tests/test_linalg.py @@ -19,6 +19,7 @@ data) from ndindex import iter_indices +import math import itertools from typing import Tuple @@ -610,8 +611,9 @@ def solve_args() -> Tuple[SearchStrategy[Array], SearchStrategy[Array]]: @composite def _x2_shapes(draw): - end = draw(integers(0, SQRT_MAX_ARRAY_SIZE)) - return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,) + base_shape = draw(stack_shapes)[1] + draw(x1).shape[-1:] + end = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(base_shape), 1))) + return base_shape + (end,) x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes()) x2 = arrays(shape=x2_shapes, dtype=mutual_dtypes.map(lambda pair: pair[1])) From 695c67e3f98217bb1b85141ab09929a879d3a2f0 Mon Sep 17 00:00:00 2001 From: Matthew Barber Date: Tue, 12 Mar 2024 10:37:59 +0000 Subject: [PATCH 3/3] Final shape not exceed max size in `hh.invertible_matrices()` --- array_api_tests/hypothesis_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/array_api_tests/hypothesis_helpers.py b/array_api_tests/hypothesis_helpers.py index 3033dac3..60c683c4 100644 --- a/array_api_tests/hypothesis_helpers.py +++ b/array_api_tests/hypothesis_helpers.py @@ -312,8 +312,8 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()): @composite def invertible_matrices(draw, dtypes=xps.floating_dtypes(), stack_shapes=shapes()): # For now, just generate stacks of diagonal matrices. - n = draw(integers(0, SQRT_MAX_ARRAY_SIZE),) stack_shape = draw(stack_shapes) + n = draw(integers(0, SQRT_MAX_ARRAY_SIZE // max(math.prod(stack_shape), 1)),) dtype = draw(dtypes) elements = one_of( from_dtype(dtype, min_value=0.5, allow_nan=False, allow_infinity=False),