Skip to content

Some fixes for v2022.12 #189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,19 +157,18 @@ def is_int_dtype(dtype):
return dtype in all_int_dtypes


def is_float_dtype(dtype):
def is_float_dtype(dtype, *, include_complex=True):
# None equals NumPy's xp.float64 object, so we specifically check it here.
# xp.float64 is in fact an alias of np.dtype('float64'), and its equality
# with None is meant to be deprecated at some point.
# See https://github.com/numpy/numpy/issues/18434
if dtype is None:
return False
valid_dtypes = real_float_dtypes
if api_version > "2021.12":
if api_version > "2021.12" and include_complex:
valid_dtypes += complex_dtypes
return dtype in valid_dtypes


def get_scalar_type(dtype: DataType) -> ScalarType:
if dtype in all_int_dtypes:
return int
Expand Down
4 changes: 2 additions & 2 deletions array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,8 +464,8 @@ def assert_array_elements(
f"{sh.fmt_idx(out_repr, idx)}={at_out}, should be {at_expected} "
f"{f_func}"
)
_assert_float_element(at_out.real, at_expected.real, msg)
_assert_float_element(at_out.imag, at_expected.imag, msg)
_assert_float_element(xp.real(at_out), xp.real(at_expected), msg)
_assert_float_element(xp.imag(at_out), xp.imag(at_expected), msg)
else:
assert xp.all(
out == expected
Expand Down
2 changes: 2 additions & 0 deletions array_api_tests/test_data_type_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def test_finfo(dtype_name):
assert isinstance(
value, stype
), f"type(out.{attr})={type(value)!r}, but should be {stype.__name__} {f_func}"
assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}"
# TODO: test values


Expand All @@ -179,6 +180,7 @@ def test_iinfo(dtype_name):
assert isinstance(
value, int
), f"type(out.{attr})={type(value)!r}, but should be int {f_func}"
assert hasattr(out, "dtype"), f"out has no attribute 'dtype' {f_func}"
# TODO: test values


Expand Down
2 changes: 1 addition & 1 deletion array_api_tests/test_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_take(x, data):
f_axis_idx = sh.fmt_idx("x", axis_idx)
for i in _indices:
f_take_idx = sh.fmt_idx(f_axis_idx, i)
indexed_x = x[axis_idx][i]
indexed_x = x[axis_idx][i, ...]
for at_idx in sh.ndindex(indexed_x.shape):
out_idx = next(out_indices)
ph.assert_0d_equals(
Expand Down
8 changes: 4 additions & 4 deletions array_api_tests/test_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def test_unique_all(x):

if dh.is_float_dtype(out.values.dtype):
assume(math.prod(x.shape) <= 128) # may not be representable
expected = sum(v for k, v in counts.items() if math.isnan(k))
expected = sum(v for k, v in counts.items() if cmath.isnan(k))
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"


Expand All @@ -137,7 +137,7 @@ def test_unique_counts(x):
for idx in sh.ndindex(out.values.shape):
val = scalar_type(out.values[idx])
count = int(out.counts[idx])
if math.isnan(val):
if cmath.isnan(val):
nans += 1
assert count == 1, (
f"out.counts[{idx}]={count} for out.values[{idx}]={val}, "
Expand All @@ -159,7 +159,7 @@ def test_unique_counts(x):
vals_idx[val] = idx
if dh.is_float_dtype(out.values.dtype):
assume(math.prod(x.shape) <= 128) # may not be representable
expected = sum(v for k, v in counts.items() if math.isnan(k))
expected = sum(v for k, v in counts.items() if cmath.isnan(k))
assert nans == expected, f"{nans} NaNs in out, but should be {expected}"


Expand Down Expand Up @@ -188,7 +188,7 @@ def test_unique_inverse(x):
nans = 0
for idx in sh.ndindex(out.values.shape):
val = scalar_type(out.values[idx])
if math.isnan(val):
if cmath.isnan(val):
nans += 1
else:
assert (
Expand Down
22 changes: 19 additions & 3 deletions array_api_tests/test_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from . import hypothesis_helpers as hh
from . import pytest_helpers as ph
from . import shape_helpers as sh
from . import xps
from . import xps, api_version
from ._array_module import _UndefinedStub
from .typing import DataType

Expand Down Expand Up @@ -145,11 +145,19 @@ def test_prod(x, data):
_dtype = x.dtype
else:
_dtype = default_dtype
else:
elif dh.is_float_dtype(x.dtype, include_complex=False):
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
_dtype = x.dtype
else:
_dtype = dh.default_float
elif api_version > "2021.12":
# Complex dtype
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
_dtype = x.dtype
else:
_dtype = dh.default_complex
else:
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
else:
_dtype = dtype
if _dtype is None:
Expand Down Expand Up @@ -253,11 +261,19 @@ def test_sum(x, data):
_dtype = x.dtype
else:
_dtype = default_dtype
else:
elif dh.is_float_dtype(x.dtype, include_complex=False):
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_float]:
_dtype = x.dtype
else:
_dtype = dh.default_float
elif api_version > "2021.12":
# Complex dtype
if dh.dtype_nbits[x.dtype] > dh.dtype_nbits[dh.default_complex]:
_dtype = x.dtype
else:
_dtype = dh.default_complex
else:
raise RuntimeError("Unexpected dtype. This indicates a bug in the test suite.")
else:
_dtype = dtype
if _dtype is None:
Expand Down