Skip to content

Commit 5264fdd

Browse files
committed
Smoke testing for xp.asarray()
1 parent 3626af3 commit 5264fdd

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

array_api_tests/shape_helpers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
6262

6363
def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]:
6464
"""Reshape a flat sequence"""
65+
if any(s == 0 for s in shape):
66+
raise ValueError(
67+
f"{shape=} contains 0-sided dimensions, "
68+
f"but that's not representable in lists"
69+
)
6570
if len(shape) == 0:
6671
assert len(flat_seq) == 1 # sanity check
6772
return flat_seq[0]

array_api_tests/test_creation_functions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,19 @@ def test_arange(dtype, data):
186186
), f"out[0]={out[0]}, but should be {_start} {f_func}"
187187

188188

189+
@given(dtype=xps.scalar_dtypes(), shape=hh.shapes(min_side=1), data=st.data())
190+
def test_asarray_scalars(dtype, shape, data):
191+
obj = data.draw(hh.scalar_objects(dtype, shape), label="obj")
192+
kw = data.draw(
193+
hh.kwargs(dtype=st.sampled_from([None, dtype]), copy=st.none()), label="kw"
194+
)
195+
196+
xp.asarray(obj, **kw)
197+
198+
199+
# TODO: test asarray with arrays and copy (in a seperate method)
200+
201+
189202
@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes))
190203
def test_empty(shape, kw):
191204
out = xp.empty(shape, **kw)

0 commit comments

Comments
 (0)