Skip to content

Versioning support, bulk of complex testing #165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
82e6312
Rudimentary complex testing for unary elwise functions
honno Dec 7, 2022
7a1e48e
Define elwise filters only for component dtypes
honno Dec 7, 2022
5889a7e
Complex testing for all elwise funcs
honno Dec 8, 2022
ff865bc
TODOs for `test_divide`
honno Dec 10, 2022
9f03c92
Loose assertion of infinities to very large floats
honno Dec 11, 2022
1ba2efd
Complex-related updates
honno Dec 12, 2022
2d6b2d8
`test_conj`
honno Dec 12, 2022
f5723ad
`min_version()` marker and other versioning nicities
honno Jan 27, 2023
f34bbf5
Try inferring `api_version` from `xp.__array_api_version__`
honno Jan 27, 2023
e1d56a0
Bump Hypothesis to `>=6.68.0`
honno Feb 14, 2023
03b735f
Remove `COMPLEX_VER`
honno Feb 14, 2023
78c57d0
`hypothesis_helper` dtype strats just alias `xps` strats
honno Feb 24, 2023
84bd3ef
Move oneway strategies to `hypothesis_helpers.py`
honno Feb 27, 2023
a9506a8
Use `cmath` where obvious
honno Feb 27, 2023
12c3aa2
Stop testing complex in `test_arange`
honno Feb 27, 2023
74101de
Remove unnecessary use of `hh.shared_dtypes()` in `test_empty`
honno Feb 27, 2023
b41d447
Support complex in `test_full`, complex dtype utilities
honno Feb 27, 2023
90e7837
Skip even not-so-very-large distances in `test_linspace`
honno Feb 27, 2023
05f2cf9
Skip testing complex dtypes in `test_data_type_functions.py` for now
honno Feb 27, 2023
8922b83
Skip testing complex numbers in `test_linalg.py` for now
honno Feb 27, 2023
33a0f6c
Remove debug `print` statement in `test_manipulation_functions.py`
honno Feb 27, 2023
d2267e4
Change dtype helpers behaviour depending on `api_version`
honno Feb 27, 2023
ef0e3b1
Update type hints relating to `complex`
honno Feb 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,13 @@ library to fail.

### Configuration

#### API version

You can specify the API version to use when testing via the
`ARRAY_API_TESTS_VERSION` environment variable. Currently this defaults to the
array module's `__array_api_version__` value, and if that attribute doesn't
exist then we fallback to `"2021.12"`.

#### CI flag

Use the `--ci` flag to run only the primary and special cases tests. You can
Expand Down
12 changes: 7 additions & 5 deletions array_api_tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from functools import wraps
from os import getenv

from hypothesis import strategies as st
from hypothesis.extra import array_api

from . import _version
from ._array_module import mod as _xp

__all__ = ["xps"]
__all__ = ["api_version", "xps"]


# We monkey patch floats() to always disable subnormals as they are out-of-scope
Expand Down Expand Up @@ -41,9 +43,9 @@ def _from_dtype(*a, **kw):
pass


xps = array_api.make_strategies_namespace(_xp, api_version="2021.12")


from . import _version
api_version = getenv(
"ARRAY_API_TESTS_VERSION", getattr(_xp, "__array_api_version__", "2021.12")
)
xps = array_api.make_strategies_namespace(_xp, api_version=api_version)

__version__ = _version.get_versions()["version"]
1 change: 1 addition & 0 deletions array_api_tests/_array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __repr__(self):
"uint8", "uint16", "uint32", "uint64",
"int8", "int16", "int32", "int64",
"float32", "float64",
"complex64", "complex128",
]
_constants = ["e", "inf", "nan", "pi"]
_funcs = [f.__name__ for funcs in stubs.category_to_funcs.values() for f in funcs]
Expand Down
50 changes: 45 additions & 5 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union
from warnings import warn

from . import api_version
from . import _array_module as xp
from ._array_module import _UndefinedStub
from .stubs import name_to_func
Expand All @@ -15,10 +16,12 @@
"uint_dtypes",
"all_int_dtypes",
"float_dtypes",
"real_dtypes",
"numeric_dtypes",
"all_dtypes",
"dtype_to_name",
"all_float_dtypes",
"bool_and_all_int_dtypes",
"dtype_to_name",
"dtype_to_scalars",
"is_int_dtype",
"is_float_dtype",
Expand All @@ -27,9 +30,11 @@
"default_int",
"default_uint",
"default_float",
"default_complex",
"promotion_table",
"dtype_nbits",
"dtype_signed",
"dtype_components",
"func_in_dtypes",
"func_returns_bool",
"binary_op_to_symbol",
Expand Down Expand Up @@ -86,15 +91,25 @@ def __repr__(self):
_uint_names = ("uint8", "uint16", "uint32", "uint64")
_int_names = ("int8", "int16", "int32", "int64")
_float_names = ("float32", "float64")
_dtype_names = ("bool",) + _uint_names + _int_names + _float_names
_real_names = _uint_names + _int_names + _float_names
_complex_names = ("complex64", "complex128")
_numeric_names = _real_names + _complex_names
_dtype_names = ("bool",) + _numeric_names


uint_dtypes = tuple(getattr(xp, name) for name in _uint_names)
int_dtypes = tuple(getattr(xp, name) for name in _int_names)
float_dtypes = tuple(getattr(xp, name) for name in _float_names)
all_int_dtypes = uint_dtypes + int_dtypes
numeric_dtypes = all_int_dtypes + float_dtypes
real_dtypes = all_int_dtypes + float_dtypes
complex_dtypes = tuple(getattr(xp, name) for name in _complex_names)
numeric_dtypes = real_dtypes
if api_version > "2021.12":
numeric_dtypes += complex_dtypes
all_dtypes = (xp.bool,) + numeric_dtypes
all_float_dtypes = float_dtypes
if api_version > "2021.12":
all_float_dtypes += complex_dtypes
bool_and_all_int_dtypes = (xp.bool,) + all_int_dtypes


Expand All @@ -121,14 +136,19 @@ def is_float_dtype(dtype):
# See https://github.com/numpy/numpy/issues/18434
if dtype is None:
return False
return dtype in float_dtypes
valid_dtypes = float_dtypes
if api_version > "2021.12":
valid_dtypes += complex_dtypes
return dtype in valid_dtypes


def get_scalar_type(dtype: DataType) -> ScalarType:
if is_int_dtype(dtype):
return int
elif is_float_dtype(dtype):
return float
elif dtype in complex_dtypes:
return complex
else:
return bool

Expand Down Expand Up @@ -157,7 +177,8 @@ class MinMax(NamedTuple):
[(d, 8) for d in [xp.int8, xp.uint8]]
+ [(d, 16) for d in [xp.int16, xp.uint16]]
+ [(d, 32) for d in [xp.int32, xp.uint32, xp.float32]]
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64]]
+ [(d, 64) for d in [xp.int64, xp.uint64, xp.float64, xp.complex64]]
+ [(xp.complex128, 128)]
)


Expand All @@ -166,6 +187,11 @@ class MinMax(NamedTuple):
)


dtype_components = EqualityMapping(
[(xp.complex64, xp.float32), (xp.complex128, xp.float64)]
)


if isinstance(xp.asarray, _UndefinedStub):
default_int = xp.int32
default_float = xp.float32
Expand All @@ -180,6 +206,15 @@ class MinMax(NamedTuple):
default_float = xp.asarray(float()).dtype
if default_float not in float_dtypes:
warn(f"inferred default float is {default_float!r}, which is not a float")
if api_version > "2021.12":
default_complex = xp.asarray(complex()).dtype
if default_complex not in complex_dtypes:
warn(
f"inferred default complex is {default_complex!r}, "
"which is not a complex"
)
else:
default_complex = None
if dtype_nbits[default_int] == 32:
default_uint = xp.uint32
else:
Expand Down Expand Up @@ -226,6 +261,11 @@ class MinMax(NamedTuple):
((xp.float32, xp.float32), xp.float32),
((xp.float32, xp.float64), xp.float64),
((xp.float64, xp.float64), xp.float64),
# complex
((xp.complex64, xp.complex64), xp.complex64),
((xp.complex64, xp.complex128), xp.complex128),
((xp.complex128, xp.complex128), xp.complex128),

]
_numeric_promotions += [((d2, d1), res) for (d1, d2), res in _numeric_promotions]
_promotion_table = list(set(_numeric_promotions))
Expand Down
67 changes: 50 additions & 17 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from operator import mul
from typing import Any, List, NamedTuple, Optional, Sequence, Tuple, Union

from hypothesis import assume
from hypothesis import assume, reject
from hypothesis.strategies import (SearchStrategy, booleans, composite, floats,
integers, just, lists, none, one_of,
sampled_from, shared)
Expand All @@ -26,27 +26,20 @@
# work for floating point dtypes as those are assumed to be defined in other
# places in the tests.
FILTER_UNDEFINED_DTYPES = True
# TODO: currently we assume this to be true - we probably can remove this completely
assert FILTER_UNDEFINED_DTYPES

integer_dtypes = sampled_from(dh.all_int_dtypes)
floating_dtypes = sampled_from(dh.float_dtypes)
numeric_dtypes = sampled_from(dh.numeric_dtypes)
integer_or_boolean_dtypes = sampled_from(dh.bool_and_all_int_dtypes)
boolean_dtypes = just(xp.bool)
dtypes = sampled_from(dh.all_dtypes)

if FILTER_UNDEFINED_DTYPES:
integer_dtypes = integer_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
floating_dtypes = floating_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
numeric_dtypes = numeric_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
integer_or_boolean_dtypes = integer_or_boolean_dtypes.filter(lambda x: not
isinstance(x, _UndefinedStub))
boolean_dtypes = boolean_dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
dtypes = dtypes.filter(lambda x: not isinstance(x, _UndefinedStub))
integer_dtypes = xps.integer_dtypes() | xps.unsigned_integer_dtypes()
floating_dtypes = xps.floating_dtypes()
numeric_dtypes = xps.numeric_dtypes()
integer_or_boolean_dtypes = xps.boolean_dtypes() | integer_dtypes
boolean_dtypes = xps.boolean_dtypes()
dtypes = xps.scalar_dtypes()

shared_dtypes = shared(dtypes, key="dtype")
shared_floating_dtypes = shared(floating_dtypes, key="dtype")

_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes]
_dtype_categories = [(xp.bool,), dh.uint_dtypes, dh.int_dtypes, dh.float_dtypes, dh.complex_dtypes]
_sorted_dtypes = [d for category in _dtype_categories for d in category]

def _dtypes_sorter(dtype_pair: Tuple[DataType, DataType]):
Expand Down Expand Up @@ -106,6 +99,46 @@ def mutually_promotable_dtypes(
return one_of(strats).map(tuple)


class OnewayPromotableDtypes(NamedTuple):
input_dtype: DataType
result_dtype: DataType


@composite
def oneway_promotable_dtypes(
draw, dtypes: Sequence[DataType]
) -> SearchStrategy[OnewayPromotableDtypes]:
"""Return a strategy for input dtypes that promote to result dtypes."""
d1, d2 = draw(mutually_promotable_dtypes(dtypes=dtypes))
result_dtype = dh.result_type(d1, d2)
if d1 == result_dtype:
return OnewayPromotableDtypes(d2, d1)
elif d2 == result_dtype:
return OnewayPromotableDtypes(d1, d2)
else:
reject()


class OnewayBroadcastableShapes(NamedTuple):
input_shape: Shape
result_shape: Shape


@composite
def oneway_broadcastable_shapes(draw) -> SearchStrategy[OnewayBroadcastableShapes]:
"""Return a strategy for input shapes that broadcast to result shapes."""
result_shape = draw(shapes(min_side=1))
input_shape = draw(
xps.broadcastable_shapes(
result_shape,
# Override defaults so bad shapes are less likely to be generated.
max_side=None if result_shape == () else max(result_shape),
max_dims=len(result_shape),
).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape)
)
return OnewayBroadcastableShapes(input_shape, result_shape)


# shared() allows us to draw either the function or the function name and they
# will both correspond to the same function.

Expand Down
11 changes: 4 additions & 7 deletions array_api_tests/meta/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@

from .. import _array_module as xp
from .. import dtype_helpers as dh
from .. import hypothesis_helpers as hh
from .. import shape_helpers as sh
from .. import xps
from ..test_creation_functions import frange
from ..test_manipulation_functions import roll_ndindex
from ..test_operators_and_elementwise_functions import (
mock_int_dtype,
oneway_broadcastable_shapes,
oneway_promotable_dtypes,
)
from ..test_operators_and_elementwise_functions import mock_int_dtype


@pytest.mark.parametrize(
Expand Down Expand Up @@ -115,11 +112,11 @@ def test_int_to_dtype(x, dtype):
assert mock_int_dtype(x, dtype) == d


@given(oneway_promotable_dtypes(dh.all_dtypes))
@given(hh.oneway_promotable_dtypes(dh.all_dtypes))
def test_oneway_promotable_dtypes(D):
assert D.result_dtype == dh.result_type(*D)


@given(oneway_broadcastable_shapes())
@given(hh.oneway_broadcastable_shapes())
def test_oneway_broadcastable_shapes(S):
assert S.result_shape == sh.broadcast_shapes(*S)
Loading