Skip to content

Commit 3626af3

Browse files
committed
hh.scalar_objects() strategy to refactor indexing test logic
1 parent ca3ef30 commit 3626af3

File tree

3 files changed

+44
-39
lines changed

3 files changed

+44
-39
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
import itertools
2+
import math
23
from functools import reduce
3-
from math import sqrt
44
from operator import mul
55
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union
66

77
from hypothesis import assume
8-
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
9-
integers, just, lists, none, one_of,
10-
sampled_from, shared)
8+
from hypothesis.strategies import (
9+
SearchStrategy,
10+
booleans,
11+
composite,
12+
floats,
13+
integers,
14+
just,
15+
lists,
16+
none,
17+
one_of,
18+
sampled_from,
19+
shared,
20+
)
1121

1222
from . import _array_module as xp
1323
from . import dtype_helpers as dh
@@ -19,7 +29,7 @@
1929
from .algos import broadcast_shapes
2030
from .function_stubs import elementwise_functions
2131
from .pytest_helpers import nargs
22-
from .typing import Array, DataType, Shape
32+
from .typing import Array, DataType, Scalar, Shape
2333

2434
# Set this to True to not fail tests just because a dtype isn't implemented.
2535
# If no compatible dtype is implemented for a given test, the test will fail
@@ -125,7 +135,7 @@ def mutually_promotable_dtypes(
125135
# Limit the total size of an array shape
126136
MAX_ARRAY_SIZE = 10000
127137
# Size to use for 2-dim arrays
128-
SQRT_MAX_ARRAY_SIZE = int(sqrt(MAX_ARRAY_SIZE))
138+
SQRT_MAX_ARRAY_SIZE = int(math.sqrt(MAX_ARRAY_SIZE))
129139

130140
# np.prod and others have overflow and math.prod is Python 3.8+ only
131141
def prod(seq):
@@ -431,3 +441,11 @@ def axes(ndim: int) -> SearchStrategy[Optional[Union[int, Shape]]]:
431441
axes_strats.append(integers(-ndim, ndim - 1))
432442
axes_strats.append(xps.valid_tuple_axes(ndim))
433443
return one_of(axes_strats)
444+
445+
446+
def scalar_objects(dtype: DataType, shape: Shape) -> SearchStrategy[List[Scalar]]:
447+
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
448+
size = math.prod(shape)
449+
return lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
450+
lambda l: sh.reshape(l, shape)
451+
)

array_api_tests/shape_helpers.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import math
12
from itertools import product
23
from typing import Iterator, List, Optional, Tuple, Union
34

4-
from .typing import Shape
5+
from .typing import Scalar, Shape
56

6-
__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex"]
7+
__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"]
78

89

910
def normalise_axis(
@@ -57,3 +58,15 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
5758
idx = tuple(idx)
5859
indices.append(idx)
5960
yield list(indices)
61+
62+
63+
def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]:
64+
"""Reshape a flat sequence"""
65+
if len(shape) == 0:
66+
assert len(flat_seq) == 1 # sanity check
67+
return flat_seq[0]
68+
elif len(shape) == 1:
69+
return flat_seq
70+
size = len(flat_seq)
71+
n = math.prod(shape[1:])
72+
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]

array_api_tests/test_array_object.py

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import product
3-
from typing import Sequence, Union, get_args
3+
from typing import get_args
44

55
import pytest
66
from hypothesis import assume, given, note
@@ -12,33 +12,13 @@
1212
from . import pytest_helpers as ph
1313
from . import shape_helpers as sh
1414
from . import xps
15-
from .typing import DataType, Param, Scalar, ScalarType, Shape
16-
17-
18-
def reshape(
19-
flat_seq: Sequence[Scalar], shape: Shape
20-
) -> Union[Scalar, Sequence[Scalar]]:
21-
"""Reshape a flat sequence"""
22-
if len(shape) == 0:
23-
assert len(flat_seq) == 1 # sanity check
24-
return flat_seq[0]
25-
elif len(shape) == 1:
26-
return flat_seq
27-
size = len(flat_seq)
28-
n = math.prod(shape[1:])
29-
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
15+
from .typing import DataType, Param, Scalar, ScalarType
3016

3117

3218
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
3319
def test_getitem(shape, data):
34-
size = math.prod(shape)
3520
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
36-
obj = data.draw(
37-
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
38-
lambda l: reshape(l, shape)
39-
),
40-
label="obj",
41-
)
21+
obj = data.draw(hh.scalar_objects(dtype, shape), label="obj")
4222
x = xp.asarray(obj, dtype=dtype)
4323
note(f"{x=}")
4424
key = data.draw(xps.indices(shape=shape), label="key")
@@ -71,21 +51,15 @@ def test_getitem(shape, data):
7151
for i in idx:
7252
val = val[i]
7353
out_obj.append(val)
74-
out_obj = reshape(out_obj, out_shape)
54+
out_obj = sh.reshape(out_obj, out_shape)
7555
expected = xp.asarray(out_obj, dtype=dtype)
7656
ph.assert_array("__getitem__", out, expected)
7757

7858

7959
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
8060
def test_setitem(shape, data):
81-
size = math.prod(shape)
8261
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
83-
obj = data.draw(
84-
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
85-
lambda l: reshape(l, shape)
86-
),
87-
label="obj",
88-
)
62+
obj = data.draw(hh.scalar_objects(dtype, shape), label="obj")
8963
x = xp.asarray(obj, dtype=dtype)
9064
note(f"{x=}")
9165
# TODO: test setting non-0d arrays

0 commit comments

Comments
 (0)