diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8875f2c2..87f0d766 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -472,10 +472,52 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] +# isdtype is a new function in the 2022.12 array API specification. + +def isdtype( + dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, + *, _tuple=True, # Disallow nested tuples +) -> bool: + """ + Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. + + Note that outside of this function, this compat library does not yet fully + support complex numbers. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + for more details + """ + if isinstance(kind, tuple) and _tuple: + return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) + elif isinstance(kind, str): + if kind == 'bool': + return dtype == xp.bool_ + elif kind == 'signed integer': + return xp.issubdtype(dtype, xp.signedinteger) + elif kind == 'unsigned integer': + return xp.issubdtype(dtype, xp.unsignedinteger) + elif kind == 'integral': + return xp.issubdtype(dtype, xp.integer) + elif kind == 'real floating': + return xp.issubdtype(dtype, xp.floating) + elif kind == 'complex floating': + return xp.issubdtype(dtype, xp.complexfloating) + elif kind == 'numeric': + return xp.issubdtype(dtype, xp.number) + else: + raise ValueError(f"Unrecognized data type kind: {kind!r}") + else: + # This will allow things that aren't required by the spec, like + # isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be + # more strict here to match the type annotation? Note that the + # numpy.array_api implementation will be very strict. + return dtype == kind + __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', 'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul', - 'matrix_transpose', 'tensordot', 'vecdot'] + 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index ce7f3780..b43c371f 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -61,6 +61,7 @@ matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) vecdot = get_xp(cp)(_aliases.vecdot) +isdtype = get_xp(cp)(_aliases.isdtype) __all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 2022b842..08f4de0b 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -61,6 +61,7 @@ matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) vecdot = get_xp(np)(_aliases.vecdot) +isdtype = get_xp(np)(_aliases.isdtype) __all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index efd86768..c23ba059 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import wraps -from builtins import all as builtin_all +from builtins import all as builtin_all, any as builtin_any from ..common._aliases import (UniqueAllResult, UniqueCountsResult, UniqueInverseResult, @@ -18,13 +18,17 @@ import torch array = torch.Tensor -_array_api_dtypes = { - torch.bool, +_int_dtypes = { torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, +} + +_array_api_dtypes = { + torch.bool, + *_int_dtypes, torch.float32, torch.float64, } @@ -602,6 +606,43 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.tensordot(x1, x2, dims=axes, **kwargs) + +def isdtype( + dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], + *, _tuple=True, # Disallow nested tuples +) -> bool: + """ + Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. + + Note that outside of this function, this compat library does not yet fully + support complex numbers. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + for more details + """ + if isinstance(kind, tuple) and _tuple: + return builtin_any(isdtype(dtype, k, _tuple=False) for k in kind) + elif isinstance(kind, str): + if kind == 'bool': + return dtype == torch.bool + elif kind == 'signed integer': + return dtype in _int_dtypes and dtype.is_signed + elif kind == 'unsigned integer': + return dtype in _int_dtypes and not dtype.is_signed + elif kind == 'integral': + return dtype in _int_dtypes + elif kind == 'real floating': + return dtype.is_floating_point + elif kind == 'complex floating': + return dtype.is_complex + elif kind == 'numeric': + return isdtype(dtype, ('integral', 'real floating', 'complex floating')) + else: + raise ValueError(f"Unrecognized data type kind: {kind!r}") + else: + return dtype == kind + __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal', @@ -612,4 +653,4 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], 'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'matmul', 'matrix_transpose', 'vecdot', 'tensordot'] + 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype'] diff --git a/tests/_helpers.py b/tests/_helpers.py new file mode 100644 index 00000000..4066d07a --- /dev/null +++ b/tests/_helpers.py @@ -0,0 +1,8 @@ +from importlib import import_module + +import pytest + +def import_(library): + if 'cupy' in library: + return pytest.importorskip(library) + return import_module(library) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 4b5bb07c..806b1192 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -1,12 +1,15 @@ import array_api_compat from array_api_compat import array_namespace + +from ._helpers import import_ + import pytest @pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) @pytest.mark.parametrize("api_version", [None, '2021.12']) def test_array_namespace(library, api_version): - lib = pytest.importorskip(library) + lib = import_(library) array = lib.asarray([1.0, 2.0, 3.0]) namespace = array_api_compat.array_namespace(array, api_version=api_version) diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py new file mode 100644 index 00000000..d164699e --- /dev/null +++ b/tests/test_isdtype.py @@ -0,0 +1,115 @@ +""" +isdtype is not yet tested in the test suite, and it should extend properly to +non-spec dtypes +""" + +from ._helpers import import_ + +import pytest + +# Check the known dtypes by their string names + +def _spec_dtypes(library): + if library == 'torch': + # torch does not have unsigned integer dtypes + return { + 'bool', + 'complex64', + 'complex128', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + 'float32', + 'float64', + } + else: + return { + 'bool', + 'complex64', + 'complex128', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + 'int8', + 'uint16', + 'uint32', + 'uint64', + 'uint8', + } + +dtype_categories = { + 'bool': lambda d: d == 'bool', + 'signed integer': lambda d: d.startswith('int'), + 'unsigned integer': lambda d: d.startswith('uint'), + 'integral': lambda d: dtype_categories['signed integer'](d) or + dtype_categories['unsigned integer'](d), + 'real floating': lambda d: 'float' in d, + 'complex floating': lambda d: d.startswith('complex'), + 'numeric': lambda d: dtype_categories['integral'](d) or + dtype_categories['real floating'](d) or + dtype_categories['complex floating'](d), +} + +def isdtype_(dtype_, kind): + # Check a dtype_ string against kind. Note that 'bool' technically has two + # meanings here but they are both the same. + if kind in dtype_categories: + res = dtype_categories[kind](dtype_) + else: + res = dtype_ == kind + assert type(res) is bool + return res + +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +def test_isdtype_spec_dtypes(library): + xp = import_('array_api_compat.' + library) + + isdtype = xp.isdtype + + for dtype_ in _spec_dtypes(library): + for dtype2_ in _spec_dtypes(library): + dtype = getattr(xp, dtype_) + dtype2 = getattr(xp, dtype2_) + res = isdtype_(dtype_, dtype2_) + assert isdtype(dtype, dtype2) is res, (dtype_, dtype2_) + + for cat in dtype_categories: + res = isdtype_(dtype_, cat) + assert isdtype(dtype, cat) == res, (dtype_, cat) + + # Basic tuple testing (the array-api testsuite will be more complete here) + for kind1_ in [*_spec_dtypes(library), *dtype_categories]: + for kind2_ in [*_spec_dtypes(library), *dtype_categories]: + kind1 = kind1_ if kind1_ in dtype_categories else getattr(xp, kind1_) + kind2 = kind2_ if kind2_ in dtype_categories else getattr(xp, kind2_) + kind = (kind1, kind2) + + res = isdtype_(dtype_, kind1_) or isdtype_(dtype_, kind2_) + assert isdtype(dtype, kind) == res, (dtype_, (kind1_, kind2_)) + +additional_dtypes = [ + 'float16', + 'float128', + 'complex256', + 'bfloat16', +] + +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("dtype_", additional_dtypes) +def test_isdtype_additional_dtypes(library, dtype_): + xp = import_('array_api_compat.' + library) + + isdtype = xp.isdtype + + if not hasattr(xp, dtype_): + return + # pytest.skip(f"{library} doesn't have dtype {dtype_}") + + dtype = getattr(xp, dtype_) + for cat in dtype_categories: + res = isdtype_(dtype_, cat) + assert isdtype(dtype, cat) == res, (dtype_, cat)