Skip to content

xp.asarray() testing #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def assert_fill(


def assert_array(func_name: str, out: Array, expected: Array, /, **kw):
assert_dtype(func_name, out.dtype, expected.dtype, **kw)
assert_dtype(func_name, out.dtype, expected.dtype)
assert_shape(func_name, out.shape, expected.shape, **kw)
msg = f"out not as expected [{func_name}({fmt_kw(kw)})]\n{out=}\n{expected=}"
if dh.is_float_dtype(out.dtype):
Expand Down
22 changes: 20 additions & 2 deletions array_api_tests/shape_helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import math
from itertools import product
from typing import Iterator, List, Optional, Tuple, Union

from .typing import Shape
from .typing import Scalar, Shape

__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex"]
__all__ = ["normalise_axis", "ndindex", "axis_ndindex", "axes_ndindex", "reshape"]


def normalise_axis(
Expand Down Expand Up @@ -57,3 +58,20 @@ def axes_ndindex(shape: Shape, axes: Tuple[int, ...]) -> Iterator[List[Shape]]:
idx = tuple(idx)
indices.append(idx)
yield list(indices)


def reshape(flat_seq: List[Scalar], shape: Shape) -> Union[Scalar, List[Scalar]]:
"""Reshape a flat sequence"""
if any(s == 0 for s in shape):
raise ValueError(
f"{shape=} contains 0-sided dimensions, "
f"but that's not representable in lists"
)
if len(shape) == 0:
assert len(flat_seq) == 1 # sanity check
return flat_seq[0]
elif len(shape) == 1:
return flat_seq
size = len(flat_seq)
n = math.prod(shape[1:])
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
38 changes: 10 additions & 28 deletions array_api_tests/test_array_object.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from itertools import product
from typing import Sequence, Union, get_args
from typing import List, get_args

import pytest
from hypothesis import assume, given, note
Expand All @@ -15,30 +15,18 @@
from .typing import DataType, Param, Scalar, ScalarType, Shape


def reshape(
flat_seq: Sequence[Scalar], shape: Shape
) -> Union[Scalar, Sequence[Scalar]]:
"""Reshape a flat sequence"""
if len(shape) == 0:
assert len(flat_seq) == 1 # sanity check
return flat_seq[0]
elif len(shape) == 1:
return flat_seq
size = len(flat_seq)
n = math.prod(shape[1:])
return [reshape(flat_seq[i * n : (i + 1) * n], shape[1:]) for i in range(size // n)]
def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scalar]]:
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
size = math.prod(shape)
return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
lambda l: sh.reshape(l, shape)
)


@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
def test_getitem(shape, data):
size = math.prod(shape)
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
obj = data.draw(
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
lambda l: reshape(l, shape)
),
label="obj",
)
obj = data.draw(scalar_objects(dtype, shape), label="obj")
x = xp.asarray(obj, dtype=dtype)
note(f"{x=}")
key = data.draw(xps.indices(shape=shape), label="key")
Expand Down Expand Up @@ -71,21 +59,15 @@ def test_getitem(shape, data):
for i in idx:
val = val[i]
out_obj.append(val)
out_obj = reshape(out_obj, out_shape)
out_obj = sh.reshape(out_obj, out_shape)
expected = xp.asarray(out_obj, dtype=dtype)
ph.assert_array("__getitem__", out, expected)


@given(hh.shapes(min_side=1), st.data()) # TODO: test 0-sided arrays
def test_setitem(shape, data):
size = math.prod(shape)
dtype = data.draw(xps.scalar_dtypes(), label="dtype")
obj = data.draw(
st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
lambda l: reshape(l, shape)
),
label="obj",
)
obj = data.draw(scalar_objects(dtype, shape), label="obj")
x = xp.asarray(obj, dtype=dtype)
note(f"{x=}")
# TODO: test setting non-0d arrays
Expand Down
101 changes: 100 additions & 1 deletion array_api_tests/test_creation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from itertools import count
from typing import Iterator, NamedTuple, Union

from hypothesis import assume, given
from hypothesis import assume, given, note
from hypothesis import strategies as st

from . import _array_module as xp
from . import array_helpers as ah
from . import dtype_helpers as dh
from . import hypothesis_helpers as hh
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from .typing import DataType, Scalar

Expand Down Expand Up @@ -186,6 +187,104 @@ def test_arange(dtype, data):
), f"out[0]={out[0]}, but should be {_start} {f_func}"


@given(
shape=hh.shapes(min_side=1),
data=st.data(),
)
def test_asarray_scalars(shape, data):
kw = data.draw(
hh.kwargs(dtype=st.none() | xps.scalar_dtypes(), copy=st.none()), label="kw"
)
dtype = kw.get("dtype", None)
if dtype is None:
dtype_family = data.draw(
st.sampled_from(
[(xp.bool,), (xp.int32, xp.int64), (xp.float32, xp.float64)]
),
label="expected out dtypes",
)
_dtype = dtype_family[0]
else:
_dtype = dtype
if dh.is_float_dtype(_dtype):
elements_strat = xps.from_dtype(_dtype) | xps.from_dtype(xp.int32)
elif dh.is_int_dtype(_dtype):
elements_strat = xps.from_dtype(_dtype) | st.booleans()
else:
elements_strat = xps.from_dtype(_dtype)
size = math.prod(shape)
obj_strat = st.lists(elements_strat, min_size=size, max_size=size)
scalar_type = dh.get_scalar_type(_dtype)
if dtype is None:
# For asarray to infer the dtype we're testing, obj requires at least
# one element to be the scalar equivalent of the inferred dtype, and so
# we filter out invalid examples. Note we use type() as Python booleans
# instance check with ints e.g. isinstance(False, int) == True.
obj_strat = obj_strat.filter(lambda l: any(type(e) == scalar_type for e in l))
_obj = data.draw(obj_strat, label="_obj")
obj = sh.reshape(_obj, shape)
note(f"{obj=}")

out = xp.asarray(obj, **kw)

if dtype is None:
msg = f"out.dtype={dh.dtype_to_name[out.dtype]}, should be "
if dtype_family == (xp.float32, xp.float64):
msg += "default floating-point dtype (float32 or float64)"
elif dtype_family == (xp.int32, xp.int64):
msg += "default integer dtype (int32 or int64)"
else:
msg += "boolean dtype"
msg += " [asarray()]"
assert out.dtype in dtype_family, msg
else:
assert kw["dtype"] == _dtype # sanity check
ph.assert_kw_dtype("asarray", _dtype, out.dtype)
ph.assert_shape("asarray", out.shape, shape)
for idx, v_expect in zip(sh.ndindex(out.shape), _obj):
v = scalar_type(out[idx])
ph.assert_scalar_equals("asarray", scalar_type, idx, v, v_expect, **kw)


@given(xps.arrays(dtype=xps.scalar_dtypes(), shape=hh.shapes()), st.data())
def test_asarray_arrays(x, data):
# TODO: test other valid dtypes
kw = data.draw(
hh.kwargs(dtype=st.none() | st.just(x.dtype), copy=st.none() | st.booleans()),
label="kw",
)

out = xp.asarray(x, **kw)

dtype = kw.get("dtype", None)
if dtype is None:
ph.assert_dtype("asarray", x.dtype, out.dtype)
else:
ph.assert_kw_dtype("asarray", dtype, out.dtype)
ph.assert_shape("asarray", out.shape, x.shape)
if dtype is None or dtype == x.dtype:
ph.assert_array("asarray", out, x, **kw)
else:
pass # TODO
copy = kw.get("copy", None)
if copy is not None:
idx = data.draw(xps.indices(x.shape, max_dims=0), label="mutating idx")
_dtype = x.dtype if dtype is None else dtype
old_value = x[idx]
value = data.draw(
xps.arrays(dtype=_dtype, shape=()).filter(lambda y: y != old_value),
label="mutating value",
)
x[idx] = value
note(f"mutated {x=}")
if copy:
assert not xp.all(
out == x
), "xp.all(out == x)=True, but should be False after x was mutated\n{out=}"
elif copy is False:
pass # TODO


@given(hh.shapes(), hh.kwargs(dtype=st.none() | hh.shared_dtypes))
def test_empty(shape, kw):
out = xp.empty(shape, **kw)
Expand Down