Skip to content

Commit 4750977

Browse files
committed
Test xp.asarray() with array inputs
1 parent 8a327a3 commit 4750977

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

array_api_tests/pytest_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def assert_fill(
231231

232232

233233
def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
234-
assert_dtype(func_name, out.dtype, expected.dtype, **kw)
234+
assert_dtype(func_name, out.dtype, expected.dtype)
235235
assert_shape(func_name, out.shape, expected.shape, **kw)
236236
msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}"
237237
if dh.is_float_dtype(out.dtype):

array_api_tests/test_creation_functions.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,43 @@ def test_asarray_scalars(shape, data):
246246
ph.assert_scalar_equals("asarray", scalar_type, idx, v, v_expect, **kw)
247247

248248

249-
# TODO: test asarray with arrays and copy (in a seperate method)
249+
@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), st.data())
250+
def test_asarray_arrays(x, data):
251+
# TODO: test other valid dtypes
252+
kw = data.draw(
253+
hh.kwargs(dtype=st.none() | st.just(x.dtype), copy=st.none() | st.booleans()),
254+
label="kw",
255+
)
256+
257+
out = xp.asarray(x, **kw)
258+
259+
dtype = kw.get("dtype", None)
260+
if dtype is None:
261+
ph.assert_dtype("asarray", x.dtype, out.dtype)
262+
else:
263+
ph.assert_kw_dtype("asarray", dtype, out.dtype)
264+
ph.assert_shape("asarray", out.shape, x.shape)
265+
if dtype is None or dtype == x.dtype:
266+
ph.assert_array("asarray", out, x, **kw)
267+
else:
268+
pass # TODO
269+
copy = kw.get("copy", None)
270+
if copy is not None:
271+
idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx")
272+
_dtype = x.dtype if dtype is None else dtype
273+
old_value = x[idx]
274+
value = data.draw(
275+
xps.arrays(dtype=_dtype, shape=()).filter(lambda y: y != old_value),
276+
label="mutating value",
277+
)
278+
x[idx] = value
279+
note(f"mutated {x=}")
280+
if copy:
281+
assert not xp.all(
282+
out == x
283+
), "xp.all(out == x)=True, but should be False after x was mutated\n{out=}"
284+
elif copy is False:
285+
pass # TODO
250286

251287

252288
@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes))

0 commit comments

Comments
 (0)