Skip to content

Commit c0833b9

Browse files
committed
Smoke testing for xp.asarray()
1 parent ca3ef30 commit c0833b9

File tree

4 files changed

+47
-35
lines changed

4 files changed

+47
-35
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
99
integers, just, lists, none, one_of,
1010
sampled_from, shared)
11-
1211
from . import _array_module as xp
1312
from . import dtype_helpers as dh
1413
from . import shape_helpers as sh
@@ -19,7 +18,7 @@
1918
from .algos import broadcast_shapes
2019
from .function_stubs import elementwise_functions
2120
from .pytest_helpers import nargs
22-
from .typing import Array, DataType, Shape
21+
from .typing import Array, DataType, Scalar, Shape
2322

2423
# Set this to True to not fail tests just because a dtype isn't implemented.
2524
# If no compatible dtype is implemented for a given test, the test will fail
@@ -431,3 +430,11 @@ def axes(ndim: int) -> SearchStrategy[Optional[Union[int, Shape]]]:
431430
axes_strats.append(integers(-ndim, ndim - 1))
432431
axes_strats.append(xps.valid_tuple_axes(ndim))
433432
return one_of(axes_strats)
433+
434+
435+
def scalar_objects(dtype: DataType, shape: Shape) -> SearchStrategy[List[Scalar]]:
436+
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
437+
size = math.prod(shape)
438+
return lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
439+
lambda l: sh.reshape(l, shape)
440+
)

array_api_tests/shape_helpers.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import math
12
from itertools import product
23
from typing import Iterator, List, Optional, Tuple, Union
34

4-
from .typing import Shape
5+
from .typing import Scalar, Shape
56

6-
__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex"]
7+
__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"]
78

89

910
def normalise_axis(
@@ -57,3 +58,20 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
5758
idx = tuple(idx)
5859
indices.append(idx)
5960
yield list(indices)
61+
62+
63+
def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]:
64+
"""Reshape a flat sequence"""
65+
if any(s == 0 for s in shape):
66+
raise ValueError(
67+
f"{shape=} contains 0-sided dimensions, "
68+
f"but that's not representable in lists"
69+
)
70+
if len(shape) == 0:
71+
assert len(flat_seq) == 1 # sanity check
72+
return flat_seq[0]
73+
elif len(shape) == 1:
74+
return flat_seq
75+
size = len(flat_seq)
76+
n = math.prod(shape[1:])
77+
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]

array_api_tests/test_array_object.py

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import product
3-
from typing import Sequence, Union, get_args
3+
from typing import get_args
44

55
import pytest
66
from hypothesis import assume, given, note
@@ -12,33 +12,13 @@
1212
from . import pytest_helpers as ph
1313
from . import shape_helpers as sh
1414
from . import xps
15-
from .typing import DataType, Param, Scalar, ScalarType, Shape
16-
17-
18-
def reshape(
19-
flat_seq: Sequence[Scalar], shape: Shape
20-
) -> Union[Scalar, Sequence[Scalar]]:
21-
"""Reshape a flat sequence"""
22-
if len(shape) == 0:
23-
assert len(flat_seq) == 1 # sanity check
24-
return flat_seq[0]
25-
elif len(shape) == 1:
26-
return flat_seq
27-
size = len(flat_seq)
28-
n = math.prod(shape[1:])
29-
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
15+
from .typing import DataType, Param, Scalar, ScalarType
3016

3117

3218
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
3319
def test_getitem(shape, data):
34-
size = math.prod(shape)
3520
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
36-
obj = data.draw(
37-
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
38-
lambda l: reshape(l, shape)
39-
),
40-
label="obj",
41-
)
21+
obj = data.draw(hh.scalar_objects(dtype, shape), label="obj")
4222
x = xp.asarray(obj, dtype=dtype)
4323
note(f"{x=}")
4424
key = data.draw(xps.indices(shape=shape), label="key")
@@ -71,21 +51,15 @@ def test_getitem(shape, data):
7151
for i in idx:
7252
val = val[i]
7353
out_obj.append(val)
74-
out_obj = reshape(out_obj, out_shape)
54+
out_obj = sh.reshape(out_obj, out_shape)
7555
expected = xp.asarray(out_obj, dtype=dtype)
7656
ph.assert_array("__getitem__", out, expected)
7757

7858

7959
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
8060
def test_setitem(shape, data):
81-
size = math.prod(shape)
8261
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
83-
obj = data.draw(
84-
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
85-
lambda l: reshape(l, shape)
86-
),
87-
label="obj",
88-
)
62+
obj = data.draw(hh.scalar_objects(dtype, shape), label="obj")
8963
x = xp.asarray(obj, dtype=dtype)
9064
note(f"{x=}")
9165
# TODO: test setting non-0d arrays

array_api_tests/test_creation_functions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,19 @@ def test_arange(dtype, data):
186186
), f"out[0]={out[0]}, but should be {_start} {f_func}"
187187

188188

189+
@given(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1), data=st.data())
190+
def test_asarray_scalars(dtype, shape, data):
191+
obj = data.draw(hh.scalar_objects(dtype, shape), label="obj")
192+
kw = data.draw(
193+
hh.kwargs(dtype=st.sampled_from([None, dtype]), copy=st.none()), label="kw"
194+
)
195+
196+
xp.asarray(obj, **kw)
197+
198+
199+
# TODO: test asarray with arrays and copy (in a seperate method)
200+
201+
189202
@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes))
190203
def test_empty(shape, kw):
191204
out = xp.empty(shape, **kw)

0 commit comments

Comments
 (0)