Skip to content

Commit dfca243

Browse files
authored
Merge pull request #79 from honno/asarray
`xp.asarray()` testing
2 parents ca3ef30 + 4750977 commit dfca243

File tree

4 files changed

+131
-32
lines changed

4 files changed

+131
-32
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def assert_fill(
231231

232232

233233
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
234-
assert_dtype(func_name, out.dtype, expected.dtype, **kw)
234+
assert_dtype(func_name, out.dtype, expected.dtype)
235235
assert_shape(func_name, out.shape, expected.shape, **kw)
236236
msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}"
237237
if dh.is_float_dtype(out.dtype):

array_api_tests/shape_helpers.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import math
12
from itertools import product
23
from typing import Iterator, List, Optional, Tuple, Union
34

4-
from .typing import Shape
5+
from .typing import Scalar, Shape
56

6-
__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex"]
7+
__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"]
78

89

910
def normalise_axis(
@@ -57,3 +58,20 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
5758
idx = tuple(idx)
5859
indices.append(idx)
5960
yield list(indices)
61+
62+
63+
def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]:
64+
"""Reshape a flat sequence"""
65+
if any(s == 0 for s in shape):
66+
raise ValueError(
67+
f"{shape=} contains 0-sided dimensions, "
68+
f"but that's not representable in lists"
69+
)
70+
if len(shape) == 0:
71+
assert len(flat_seq) == 1 # sanity check
72+
return flat_seq[0]
73+
elif len(shape) == 1:
74+
return flat_seq
75+
size = len(flat_seq)
76+
n = math.prod(shape[1:])
77+
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]

array_api_tests/test_array_object.py

Lines changed: 10 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import product
3-
from typing import Sequence, Union, get_args
3+
from typing import List, get_args
44

55
import pytest
66
from hypothesis import assume, given, note
@@ -15,30 +15,18 @@
1515
from .typing import DataType, Param, Scalar, ScalarType, Shape
1616

1717

18-
def reshape(
19-
flat_seq: Sequence[Scalar], shape: Shape
20-
) -> Union[Scalar, Sequence[Scalar]]:
21-
"""Reshape a flat sequence"""
22-
if len(shape) == 0:
23-
assert len(flat_seq) == 1 # sanity check
24-
return flat_seq[0]
25-
elif len(shape) == 1:
26-
return flat_seq
27-
size = len(flat_seq)
28-
n = math.prod(shape[1:])
29-
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
18+
def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scalar]]:
19+
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
20+
size = math.prod(shape)
21+
return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
22+
lambda l: sh.reshape(l, shape)
23+
)
3024

3125

3226
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
3327
def test_getitem(shape, data):
34-
size = math.prod(shape)
3528
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
36-
obj = data.draw(
37-
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
38-
lambda l: reshape(l, shape)
39-
),
40-
label="obj",
41-
)
29+
obj = data.draw(scalar_objects(dtype, shape), label="obj")
4230
x = xp.asarray(obj, dtype=dtype)
4331
note(f"{x=}")
4432
key = data.draw(xps.indices(shape=shape), label="key")
@@ -71,21 +59,15 @@ def test_getitem(shape, data):
7159
for i in idx:
7260
val = val[i]
7361
out_obj.append(val)
74-
out_obj = reshape(out_obj, out_shape)
62+
out_obj = sh.reshape(out_obj, out_shape)
7563
expected = xp.asarray(out_obj, dtype=dtype)
7664
ph.assert_array("__getitem__", out, expected)
7765

7866

7967
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
8068
def test_setitem(shape, data):
81-
size = math.prod(shape)
8269
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
83-
obj = data.draw(
84-
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
85-
lambda l: reshape(l, shape)
86-
),
87-
label="obj",
88-
)
70+
obj = data.draw(scalar_objects(dtype, shape), label="obj")
8971
x = xp.asarray(obj, dtype=dtype)
9072
note(f"{x=}")
9173
# TODO: test setting non-0d arrays

array_api_tests/test_creation_functions.py

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
from itertools import count
33
from typing import Iterator, NamedTuple, Union
44

5-
from hypothesis import assume, given
5+
from hypothesis import assume, given, note
66
from hypothesis import strategies as st
77

88
from . import _array_module as xp
99
from . import array_helpers as ah
1010
from . import dtype_helpers as dh
1111
from . import hypothesis_helpers as hh
1212
from . import pytest_helpers as ph
13+
from . import shape_helpers as sh
1314
from . import xps
1415
from .typing import DataType, Scalar
1516

@@ -186,6 +187,104 @@ def test_arange(dtype, data):
186187
), f"out[0]={out[0]}, but should be {_start} {f_func}"
187188

188189

190+
@given(
191+
shape=hh.shapes(min_side=1),
192+
data=st.data(),
193+
)
194+
def test_asarray_scalars(shape, data):
195+
kw = data.draw(
196+
hh.kwargs(dtype=st.none() | xps.scalar_dtypes(), copy=st.none()), label="kw"
197+
)
198+
dtype = kw.get("dtype", None)
199+
if dtype is None:
200+
dtype_family = data.draw(
201+
st.sampled_from(
202+
[(xp.bool,), (xp.int32, xp.int64), (xp.float32, xp.float64)]
203+
),
204+
label="expected out dtypes",
205+
)
206+
_dtype = dtype_family[0]
207+
else:
208+
_dtype = dtype
209+
if dh.is_float_dtype(_dtype):
210+
elements_strat = xps.from_dtype(_dtype) | xps.from_dtype(xp.int32)
211+
elif dh.is_int_dtype(_dtype):
212+
elements_strat = xps.from_dtype(_dtype) | st.booleans()
213+
else:
214+
elements_strat = xps.from_dtype(_dtype)
215+
size = math.prod(shape)
216+
obj_strat = st.lists(elements_strat, min_size=size, max_size=size)
217+
scalar_type = dh.get_scalar_type(_dtype)
218+
if dtype is None:
219+
# For asarray to infer the dtype we're testing, obj requires at least
220+
# one element to be the scalar equivalent of the inferred dtype, and so
221+
# we filter out invalid examples. Note we use type() as Python booleans
222+
# instance check with ints e.g. isinstance(False, int) == True.
223+
obj_strat = obj_strat.filter(lambda l: any(type(e) == scalar_type for e in l))
224+
_obj = data.draw(obj_strat, label="_obj")
225+
obj = sh.reshape(_obj, shape)
226+
note(f"{obj=}")
227+
228+
out = xp.asarray(obj, **kw)
229+
230+
if dtype is None:
231+
msg = f"out.dtype={dh.dtype_to_name[out.dtype]}, should be "
232+
if dtype_family == (xp.float32, xp.float64):
233+
msg += "default floating-point dtype (float32 or float64)"
234+
elif dtype_family == (xp.int32, xp.int64):
235+
msg += "default integer dtype (int32 or int64)"
236+
else:
237+
msg += "boolean dtype"
238+
msg += " [asarray()]"
239+
assert out.dtype in dtype_family, msg
240+
else:
241+
assert kw["dtype"] == _dtype # sanity check
242+
ph.assert_kw_dtype("asarray", _dtype, out.dtype)
243+
ph.assert_shape("asarray", out.shape, shape)
244+
for idx, v_expect in zip(sh.ndindex(out.shape), _obj):
245+
v = scalar_type(out[idx])
246+
ph.assert_scalar_equals("asarray", scalar_type, idx, v, v_expect, **kw)
247+
248+
249+
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), st.data())
250+
def test_asarray_arrays(x, data):
251+
# TODO: test other valid dtypes
252+
kw = data.draw(
253+
hh.kwargs(dtype=st.none() | st.just(x.dtype), copy=st.none() | st.booleans()),
254+
label="kw",
255+
)
256+
257+
out = xp.asarray(x, **kw)
258+
259+
dtype = kw.get("dtype", None)
260+
if dtype is None:
261+
ph.assert_dtype("asarray", x.dtype, out.dtype)
262+
else:
263+
ph.assert_kw_dtype("asarray", dtype, out.dtype)
264+
ph.assert_shape("asarray", out.shape, x.shape)
265+
if dtype is None or dtype == x.dtype:
266+
ph.assert_array("asarray", out, x, **kw)
267+
else:
268+
pass # TODO
269+
copy = kw.get("copy", None)
270+
if copy is not None:
271+
idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx")
272+
_dtype = x.dtype if dtype is None else dtype
273+
old_value = x[idx]
274+
value = data.draw(
275+
xps.arrays(dtype=_dtype, shape=()).filter(lambda y: y != old_value),
276+
label="mutating value",
277+
)
278+
x[idx] = value
279+
note(f"mutated {x=}")
280+
if copy:
281+
assert not xp.all(
282+
out == x
283+
), "xp.all(out == x)=True, but should be False after x was mutated\n{out=}"
284+
elif copy is False:
285+
pass # TODO
286+
287+
189288
@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes))
190289
def test_empty(shape, kw):
191290
out = xp.empty(shape, **kw)

0 commit comments

Comments
 (0)