Skip to content

Commit d2f4ea9

Browse files
committed
Test shape and dtype for xp.asarray()
1 parent c0833b9 commit d2f4ea9

File tree

3 files changed

+64
-18
lines changed

3 files changed

+64
-18
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from .algos import broadcast_shapes
1919
from .function_stubs import elementwise_functions
2020
from .pytest_helpers import nargs
21-
from .typing import Array, DataType, Scalar, Shape
21+
from .typing import Array, DataType, Shape
2222

2323
# Set this to True to not fail tests just because a dtype isn't implemented.
2424
# If no compatible dtype is implemented for a given test, the test will fail
@@ -430,11 +430,3 @@ def axes(ndim: int) -> SearchStrategy[Optional[Union[int, Shape]]]:
430430
axes_strats.append(integers(-ndim, ndim - 1))
431431
axes_strats.append(xps.valid_tuple_axes(ndim))
432432
return one_of(axes_strats)
433-
434-
435-
def scalar_objects(dtype: DataType, shape: Shape) -> SearchStrategy[List[Scalar]]:
436-
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
437-
size = math.prod(shape)
438-
return lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
439-
lambda l: sh.reshape(l, shape)
440-
)

array_api_tests/test_array_object.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import product
3-
from typing import get_args
3+
from typing import List, get_args
44

55
import pytest
66
from hypothesis import assume, given, note
@@ -12,13 +12,21 @@
1212
from . import pytest_helpers as ph
1313
from . import shape_helpers as sh
1414
from . import xps
15-
from .typing import DataType, Param, Scalar, ScalarType
15+
from .typing import DataType, Param, Scalar, ScalarType, Shape
16+
17+
18+
def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scalar]]:
19+
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
20+
size = math.prod(shape)
21+
return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
22+
lambda l: sh.reshape(l, shape)
23+
)
1624

1725

1826
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
1927
def test_getitem(shape, data):
2028
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
21-
obj = data.draw(hh.scalar_objects(dtype, shape), label="obj")
29+
obj = data.draw(scalar_objects(dtype, shape), label="obj")
2230
x = xp.asarray(obj, dtype=dtype)
2331
note(f"{x=}")
2432
key = data.draw(xps.indices(shape=shape), label="key")
@@ -59,7 +67,7 @@ def test_getitem(shape, data):
5967
@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
6068
def test_setitem(shape, data):
6169
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
62-
obj = data.draw(hh.scalar_objects(dtype, shape), label="obj")
70+
obj = data.draw(scalar_objects(dtype, shape), label="obj")
6371
x = xp.asarray(obj, dtype=dtype)
6472
note(f"{x=}")
6573
# TODO: test setting non-0d arrays

array_api_tests/test_creation_functions.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from . import dtype_helpers as dh
1111
from . import hypothesis_helpers as hh
1212
from . import pytest_helpers as ph
13+
from . import shape_helpers as sh
1314
from . import xps
1415
from .typing import DataType, Scalar
1516

@@ -186,14 +187,59 @@ def test_arange(dtype, data):
186187
), f"out[0]={out[0]}, but should be {_start} {f_func}"
187188

188189

189-
@given(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1), data=st.data())
190-
def test_asarray_scalars(dtype, shape, data):
191-
obj = data.draw(hh.scalar_objects(dtype, shape), label="obj")
190+
@given(
191+
shape=hh.shapes(min_side=1),
192+
data=st.data(),
193+
)
194+
def test_asarray_scalars(shape, data):
192195
kw = data.draw(
193-
hh.kwargs(dtype=st.sampled_from([None, dtype]), copy=st.none()), label="kw"
196+
hh.kwargs(dtype=st.none() | xps.scalar_dtypes(), copy=st.none()), label="kw"
194197
)
198+
dtype = kw.get("dtype", None)
199+
if dtype is None:
200+
dtype_family = data.draw(
201+
st.sampled_from(
202+
[(xp.bool,), (xp.int32, xp.int64), (xp.float32, xp.float64)]
203+
),
204+
label="expected out dtypes",
205+
)
206+
_dtype = dtype_family[0]
207+
else:
208+
_dtype = dtype
209+
if dh.is_float_dtype(_dtype):
210+
elements_strat = xps.from_dtype(_dtype) | xps.from_dtype(xp.int32)
211+
elif dh.is_int_dtype(_dtype):
212+
elements_strat = xps.from_dtype(_dtype) | st.booleans()
213+
else:
214+
elements_strat = xps.from_dtype(_dtype)
215+
size = math.prod(shape)
216+
obj_strat = st.lists(elements_strat, min_size=size, max_size=size)
217+
if dtype is None:
218+
# For asarray to infer the dtype we're testing, obj requires at least
219+
# one element to be the scalar equivalent of the inferred dtype, and so
220+
# we filter out invalid examples. Note we use type() as Python booleans
221+
# instance check with ints e.g. isinstance(False, int) == True.
222+
scalar_type = dh.get_scalar_type(_dtype)
223+
obj_strat = obj_strat.filter(lambda l: any(type(e) == scalar_type for e in l))
224+
obj_strat = obj_strat.map(lambda l: sh.reshape(l, shape))
225+
obj = data.draw(obj_strat, label="obj")
226+
227+
out = xp.asarray(obj, **kw)
195228

196-
xp.asarray(obj, **kw)
229+
if dtype is None:
230+
msg = f"out.dtype={dh.dtype_to_name[out.dtype]}, should be "
231+
if dtype_family == (xp.float32, xp.float64):
232+
msg += "default floating-point dtype (float32 or float64)"
233+
elif dtype_family == (xp.int32, xp.int64):
234+
msg += "default integer dtype (int32 or int64)"
235+
else:
236+
msg += "boolean dtype"
237+
msg += " [asarray()]"
238+
assert out.dtype in dtype_family, msg
239+
else:
240+
assert kw["dtype"] == _dtype # sanity check
241+
ph.assert_kw_dtype("asarray", _dtype, out.dtype)
242+
ph.assert_shape("asarray", out.shape, shape)
197243

198244

199245
# TODO: test asarray with arrays and copy (in a seperate method)

0 commit comments

Comments
 (0)