diff --git a/array_api_tests/pytest_helpers.py b/array_api_tests/pytest_helpers.py index 64e39aa4..d1b48830 100644 --- a/array_api_tests/pytest_helpers.py +++ b/array_api_tests/pytest_helpers.py @@ -231,7 +231,7 @@ def assert_fill( def assert_array(func_name: str, out: Array, expected: Array, /, **kw): - assert_dtype(func_name, out.dtype, expected.dtype, **kw) + assert_dtype(func_name, out.dtype, expected.dtype) assert_shape(func_name, out.shape, expected.shape, **kw) msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}" if dh.is_float_dtype(out.dtype): diff --git a/array_api_tests/shape_helpers.py b/array_api_tests/shape_helpers.py index 751b8d49..17dd7f6e 100644 --- a/array_api_tests/shape_helpers.py +++ b/array_api_tests/shape_helpers.py @@ -1,9 +1,10 @@ +import math from itertools import product from typing import Iterator, List, Optional, Tuple, Union -from .typing import Shape +from .typing import Scalar, Shape -__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex"] +__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"] def normalise_axis( @@ -57,3 +58,20 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]: idx = tuple(idx) indices.append(idx) yield list(indices) + + +def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]: + """Reshape a flat sequence""" + if any(s == 0 for s in shape): + raise ValueError( + f"{shape=} contains 0-sided dimensions, " + f"but that's not representable in lists" + ) + if len(shape) == 0: + assert len(flat_seq) == 1 # sanity check + return flat_seq[0] + elif len(shape) == 1: + return flat_seq + size = len(flat_seq) + n = math.prod(shape[1:]) + return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)] diff --git a/array_api_tests/test_array_object.py b/array_api_tests/test_array_object.py index 397d823d..ce6ae596 100644 --- a/array_api_tests/test_array_object.py +++ b/array_api_tests/test_array_object.py @@ -1,6 +1,6 @@ import math from itertools import product -from typing import Sequence, Union, get_args +from typing import List, get_args import pytest from hypothesis import assume, given, note @@ -15,30 +15,18 @@ from .typing import DataType, Param, Scalar, ScalarType, Shape -def reshape( - flat_seq: Sequence[Scalar], shape: Shape -) -> Union[Scalar, Sequence[Scalar]]: - """Reshape a flat sequence""" - if len(shape) == 0: - assert len(flat_seq) == 1 # sanity check - return flat_seq[0] - elif len(shape) == 1: - return flat_seq - size = len(flat_seq) - n = math.prod(shape[1:]) - return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)] +def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scalar]]: + """Generates scalars or nested sequences which are valid for xp.asarray()""" + size = math.prod(shape) + return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map( + lambda l: sh.reshape(l, shape) + ) @given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays def test_getitem(shape, data): - size = math.prod(shape) dtype = data.draw(xps.scalar_dtypes(), label="dtype") - obj = data.draw( - st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map( - lambda l: reshape(l, shape) - ), - label="obj", - ) + obj = data.draw(scalar_objects(dtype, shape), label="obj") x = xp.asarray(obj, dtype=dtype) note(f"{x=}") key = data.draw(xps.indices(shape=shape), label="key") @@ -71,21 +59,15 @@ def test_getitem(shape, data): for i in idx: val = val[i] out_obj.append(val) - out_obj = reshape(out_obj, out_shape) + out_obj = sh.reshape(out_obj, out_shape) expected = xp.asarray(out_obj, dtype=dtype) ph.assert_array("__getitem__", out, expected) @given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays def test_setitem(shape, data): - size = math.prod(shape) dtype = data.draw(xps.scalar_dtypes(), label="dtype") - obj = data.draw( - st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map( - lambda l: reshape(l, shape) - ), - label="obj", - ) + obj = data.draw(scalar_objects(dtype, shape), label="obj") x = xp.asarray(obj, dtype=dtype) note(f"{x=}") # TODO: test setting non-0d arrays diff --git a/array_api_tests/test_creation_functions.py b/array_api_tests/test_creation_functions.py index 1098f1e1..4b96151a 100644 --- a/array_api_tests/test_creation_functions.py +++ b/array_api_tests/test_creation_functions.py @@ -2,7 +2,7 @@ from itertools import count from typing import Iterator, NamedTuple, Union -from hypothesis import assume, given +from hypothesis import assume, given, note from hypothesis import strategies as st from . import _array_module as xp @@ -10,6 +10,7 @@ from . import dtype_helpers as dh from . import hypothesis_helpers as hh from . import pytest_helpers as ph +from . import shape_helpers as sh from . import xps from .typing import DataType, Scalar @@ -186,6 +187,104 @@ def test_arange(dtype, data): ), f"out[0]={out[0]}, but should be {_start} {f_func}" +@given( + shape=hh.shapes(min_side=1), + data=st.data(), +) +def test_asarray_scalars(shape, data): + kw = data.draw( + hh.kwargs(dtype=st.none() | xps.scalar_dtypes(), copy=st.none()), label="kw" + ) + dtype = kw.get("dtype", None) + if dtype is None: + dtype_family = data.draw( + st.sampled_from( + [(xp.bool,), (xp.int32, xp.int64), (xp.float32, xp.float64)] + ), + label="expected out dtypes", + ) + _dtype = dtype_family[0] + else: + _dtype = dtype + if dh.is_float_dtype(_dtype): + elements_strat = xps.from_dtype(_dtype) | xps.from_dtype(xp.int32) + elif dh.is_int_dtype(_dtype): + elements_strat = xps.from_dtype(_dtype) | st.booleans() + else: + elements_strat = xps.from_dtype(_dtype) + size = math.prod(shape) + obj_strat = st.lists(elements_strat, min_size=size, max_size=size) + scalar_type = dh.get_scalar_type(_dtype) + if dtype is None: + # For asarray to infer the dtype we're testing, obj requires at least + # one element to be the scalar equivalent of the inferred dtype, and so + # we filter out invalid examples. Note we use type() as Python booleans + # instance check with ints e.g. isinstance(False, int) == True. + obj_strat = obj_strat.filter(lambda l: any(type(e) == scalar_type for e in l)) + _obj = data.draw(obj_strat, label="_obj") + obj = sh.reshape(_obj, shape) + note(f"{obj=}") + + out = xp.asarray(obj, **kw) + + if dtype is None: + msg = f"out.dtype={dh.dtype_to_name[out.dtype]}, should be " + if dtype_family == (xp.float32, xp.float64): + msg += "default floating-point dtype (float32 or float64)" + elif dtype_family == (xp.int32, xp.int64): + msg += "default integer dtype (int32 or int64)" + else: + msg += "boolean dtype" + msg += " [asarray()]" + assert out.dtype in dtype_family, msg + else: + assert kw["dtype"] == _dtype # sanity check + ph.assert_kw_dtype("asarray", _dtype, out.dtype) + ph.assert_shape("asarray", out.shape, shape) + for idx, v_expect in zip(sh.ndindex(out.shape), _obj): + v = scalar_type(out[idx]) + ph.assert_scalar_equals("asarray", scalar_type, idx, v, v_expect, **kw) + + +@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), st.data()) +def test_asarray_arrays(x, data): + # TODO: test other valid dtypes + kw = data.draw( + hh.kwargs(dtype=st.none() | st.just(x.dtype), copy=st.none() | st.booleans()), + label="kw", + ) + + out = xp.asarray(x, **kw) + + dtype = kw.get("dtype", None) + if dtype is None: + ph.assert_dtype("asarray", x.dtype, out.dtype) + else: + ph.assert_kw_dtype("asarray", dtype, out.dtype) + ph.assert_shape("asarray", out.shape, x.shape) + if dtype is None or dtype == x.dtype: + ph.assert_array("asarray", out, x, **kw) + else: + pass # TODO + copy = kw.get("copy", None) + if copy is not None: + idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx") + _dtype = x.dtype if dtype is None else dtype + old_value = x[idx] + value = data.draw( + xps.arrays(dtype=_dtype, shape=()).filter(lambda y: y != old_value), + label="mutating value", + ) + x[idx] = value + note(f"mutated {x=}") + if copy: + assert not xp.all( + out == x + ), "xp.all(out == x)=True, but should be False after x was mutated\n{out=}" + elif copy is False: + pass # TODO + + @given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes)) def test_empty(shape, kw): out = xp.empty(shape, **kw)