From 0bef9825c4035ec306459043d7f0eafa34446bfa Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 13 Mar 2023 18:08:26 -0600 Subject: [PATCH 1/6] Add isdtype implementation for numpy/cupy with some initial testing --- array_api_compat/common/_aliases.py | 41 +++++++++++++- array_api_compat/cupy/_aliases.py | 1 + array_api_compat/numpy/_aliases.py | 1 + tests/_helpers.py | 8 +++ tests/test_array_namespace.py | 5 +- tests/test_isdtype.py | 88 +++++++++++++++++++++++++++++ 6 files changed, 142 insertions(+), 2 deletions(-) create mode 100644 tests/_helpers.py create mode 100644 tests/test_isdtype.py diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 8875f2c2..e89be621 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -472,10 +472,49 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray: res = x1_[..., None, :] @ x2_[..., None] return res[..., 0, 0] +# isdtype is a new function in the 2022.12 array API specification. + +def isdtype( + dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp, + *, _tuple=True, # Disallow nested tuples +) -> bool: + """ + Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + for more details + """ + if isinstance(kind, tuple) and _tuple: + return any(isdtype(dtype, k, xp, _tuple=False) for k in kind) + elif isinstance(kind, str): + if kind == 'bool': + return dtype == xp.bool_ + elif kind == 'signed integer': + return xp.issubdtype(dtype, xp.signedinteger) + elif kind == 'unsigned integer': + return xp.issubdtype(dtype, xp.unsignedinteger) + elif kind == 'integral': + return xp.issubdtype(dtype, xp.integer) + elif kind == 'real floating': + return xp.issubdtype(dtype, xp.floating) + elif kind == 'complex floating': + return xp.issubdtype(dtype, xp.complexfloating) + elif kind == 'numeric': + return xp.issubdtype(dtype, xp.number) + else: + raise ValueError(f"Unrecognized data type kind: {kind!r}") + else: + # This will allow things that aren't required by the spec, like + # isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be + # more strict here to match the type annotation? Note that the + # numpy.array_api implementation will be very strict. + return dtype == kind + __all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort', 'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul', - 'matrix_transpose', 'tensordot', 'vecdot'] + 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype'] diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index ce7f3780..b43c371f 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -61,6 +61,7 @@ matrix_transpose = get_xp(cp)(_aliases.matrix_transpose) tensordot = get_xp(cp)(_aliases.tensordot) vecdot = get_xp(cp)(_aliases.vecdot) +isdtype = get_xp(cp)(_aliases.isdtype) __all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', diff --git a/array_api_compat/numpy/_aliases.py b/array_api_compat/numpy/_aliases.py index 2022b842..08f4de0b 100644 --- a/array_api_compat/numpy/_aliases.py +++ b/array_api_compat/numpy/_aliases.py @@ -61,6 +61,7 @@ matrix_transpose = get_xp(np)(_aliases.matrix_transpose) tensordot = get_xp(np)(_aliases.tensordot) vecdot = get_xp(np)(_aliases.vecdot) +isdtype = get_xp(np)(_aliases.isdtype) __all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', diff --git a/tests/_helpers.py b/tests/_helpers.py new file mode 100644 index 00000000..bd0685c8 --- /dev/null +++ b/tests/_helpers.py @@ -0,0 +1,8 @@ +from importlib import import_module + +import pytest + +def import_(library): + if library == 'cupy': + return pytest.importorskip(library) + return import_module(library) diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 4b5bb07c..806b1192 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -1,12 +1,15 @@ import array_api_compat from array_api_compat import array_namespace + +from ._helpers import import_ + import pytest @pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) @pytest.mark.parametrize("api_version", [None, '2021.12']) def test_array_namespace(library, api_version): - lib = pytest.importorskip(library) + lib = import_(library) array = lib.asarray([1.0, 2.0, 3.0]) namespace = array_api_compat.array_namespace(array, api_version=api_version) diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py new file mode 100644 index 00000000..058930b2 --- /dev/null +++ b/tests/test_isdtype.py @@ -0,0 +1,88 @@ +""" +isdtype is not yet tested in the test suite, and it should extend properly to +non-spec dtypes +""" + +from ._helpers import import_ + +import pytest + +# Check the known dtypes by their string names + +def _spec_dtypes(library): + if library == 'torch': + # torch does not have unsigned integer dtypes + return { + 'bool', + 'uint8', + 'int8', + 'int16', + 'int32', + 'int64', + 'float32', + 'float64', + } + else: + return { + 'bool', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + 'int8', + 'uint16', + 'uint32', + 'uint64', + 'uint8', + } + +dtype_categories = { + 'bool': lambda d: d == 'bool', + 'signed integer': lambda d: d.startswith('int'), + 'unsigned integer': lambda d: d.startswith('uint'), + 'integral': lambda d: dtype_categories['signed integer'](d) or + dtype_categories['unsigned integer'](d), + 'real floating': lambda d: d.startswith('float'), + 'complex floating': lambda d: d.startswith('complex'), + 'numeric': lambda d: dtype_categories['integral'](d) or + dtype_categories['real floating'](d) or + dtype_categories['complex floating'](d), +} + +def isdtype_(dtype_, kind): + # Check a dtype_ string against kind. Note that 'bool' technically has two + # meanings here but they are both the same. + if kind in dtype_categories: + res = dtype_categories[kind](dtype_) + else: + res = dtype_ == kind + assert type(res) is bool + return res + +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +def test_isdtype_spec_dtypes(library): + xp = import_('array_api_compat.' + library) + + isdtype = xp.isdtype + + for dtype_ in _spec_dtypes(library): + for dtype2_ in _spec_dtypes(library): + dtype = getattr(xp, dtype_) + dtype2 = getattr(xp, dtype2_) + res = isdtype_(dtype_, dtype2_) + assert isdtype(dtype, dtype2) is res, (dtype_, dtype2_) + + for cat in dtype_categories: + res = isdtype_(dtype_, cat) + assert isdtype(dtype, cat) == res, (dtype_, cat) + + # Basic tuple testing (the array-api testsuite will be more complete here) + for kind1_ in [*_spec_dtypes(library), *dtype_categories]: + for kind2_ in [*_spec_dtypes(library), *dtype_categories]: + kind1 = kind1_ if kind1_ in dtype_categories else getattr(xp, kind1_) + kind2 = kind2_ if kind2_ in dtype_categories else getattr(xp, kind2_) + kind = (kind1, kind2) + + res = isdtype_(dtype_, kind1_) or isdtype_(dtype_, kind2_) + assert isdtype(dtype, kind) == res, (dtype_, (kind1_, kind2_)) From fa772de34fc906ca650e7e8bdeec2d192c9dc5b5 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 13 Mar 2023 18:25:36 -0600 Subject: [PATCH 2/6] Fix test skipping for cupy --- tests/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/_helpers.py b/tests/_helpers.py index bd0685c8..4066d07a 100644 --- a/tests/_helpers.py +++ b/tests/_helpers.py @@ -3,6 +3,6 @@ import pytest def import_(library): - if library == 'cupy': + if 'cupy' in library: return pytest.importorskip(library) return import_module(library) From a375e9df8c3e6d6b8139c4bbf66f85499e287b4e Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 13 Mar 2023 18:30:08 -0600 Subject: [PATCH 3/6] Test complex dtypes in test_isdtype --- array_api_compat/common/_aliases.py | 3 +++ tests/test_isdtype.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index e89be621..87f0d766 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -481,6 +481,9 @@ def isdtype( """ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. + Note that outside of this function, this compat library does not yet fully + support complex numbers. + See https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html for more details diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 058930b2..36fc3941 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -14,6 +14,8 @@ def _spec_dtypes(library): # torch does not have unsigned integer dtypes return { 'bool', + 'complex64', + 'complex128', 'uint8', 'int8', 'int16', @@ -25,6 +27,8 @@ def _spec_dtypes(library): else: return { 'bool', + 'complex64', + 'complex128', 'float32', 'float64', 'int16', From 7c7c02e394505e4fc888a2a615f50b79a79388aa Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 13 Mar 2023 18:30:22 -0600 Subject: [PATCH 4/6] Test additional dtypes in test_isdtype --- tests/test_isdtype.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index 36fc3941..aa937c3f 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -47,7 +47,7 @@ def _spec_dtypes(library): 'unsigned integer': lambda d: d.startswith('uint'), 'integral': lambda d: dtype_categories['signed integer'](d) or dtype_categories['unsigned integer'](d), - 'real floating': lambda d: d.startswith('float'), + 'real floating': lambda d: 'float' in d, 'complex floating': lambda d: d.startswith('complex'), 'numeric': lambda d: dtype_categories['integral'](d) or dtype_categories['real floating'](d) or @@ -90,3 +90,25 @@ def test_isdtype_spec_dtypes(library): res = isdtype_(dtype_, kind1_) or isdtype_(dtype_, kind2_) assert isdtype(dtype, kind) == res, (dtype_, (kind1_, kind2_)) + +additional_dtypes = [ + 'float16', + 'float128', + 'complex256', + 'bfloat16', +] + +@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"]) +@pytest.mark.parametrize("dtype_", additional_dtypes) +def test_isdtype_additional_dtypes(library, dtype_): + xp = import_('array_api_compat.' + library) + + isdtype = xp.isdtype + + if not hasattr(xp, dtype_): + pytest.skip(f"{library} doesn't have dtype {dtype_}") + + dtype = getattr(xp, dtype_) + for cat in dtype_categories: + res = isdtype_(dtype_, cat) + assert isdtype(dtype, cat) == res, (dtype_, cat) From fc90e1b6103812d4ee30a52e7f6711ebf3b6b275 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 13 Mar 2023 18:34:59 -0600 Subject: [PATCH 5/6] Add isdtype support for pytorch --- array_api_compat/torch/_aliases.py | 49 +++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index efd86768..c23ba059 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -1,7 +1,7 @@ from __future__ import annotations from functools import wraps -from builtins import all as builtin_all +from builtins import all as builtin_all, any as builtin_any from ..common._aliases import (UniqueAllResult, UniqueCountsResult, UniqueInverseResult, @@ -18,13 +18,17 @@ import torch array = torch.Tensor -_array_api_dtypes = { - torch.bool, +_int_dtypes = { torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, +} + +_array_api_dtypes = { + torch.bool, + *_int_dtypes, torch.float32, torch.float64, } @@ -602,6 +606,43 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], x1, x2 = _fix_promotion(x1, x2, only_scalar=False) return torch.tensordot(x1, x2, dims=axes, **kwargs) + +def isdtype( + dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], + *, _tuple=True, # Disallow nested tuples +) -> bool: + """ + Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``. + + Note that outside of this function, this compat library does not yet fully + support complex numbers. + + See + https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html + for more details + """ + if isinstance(kind, tuple) and _tuple: + return builtin_any(isdtype(dtype, k, _tuple=False) for k in kind) + elif isinstance(kind, str): + if kind == 'bool': + return dtype == torch.bool + elif kind == 'signed integer': + return dtype in _int_dtypes and dtype.is_signed + elif kind == 'unsigned integer': + return dtype in _int_dtypes and not dtype.is_signed + elif kind == 'integral': + return dtype in _int_dtypes + elif kind == 'real floating': + return dtype.is_floating_point + elif kind == 'complex floating': + return dtype.is_complex + elif kind == 'numeric': + return isdtype(dtype, ('integral', 'real floating', 'complex floating')) + else: + raise ValueError(f"Unrecognized data type kind: {kind!r}") + else: + return dtype == kind + __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal', @@ -612,4 +653,4 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int], 'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones', 'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', - 'matmul', 'matrix_transpose', 'vecdot', 'tensordot'] + 'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype'] From a8c92e50e075a09cf02b976de216bffa8a61c19f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Mon, 13 Mar 2023 18:47:20 -0600 Subject: [PATCH 6/6] Return instead of skipping (so there are no skips in the test output) --- tests/test_isdtype.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_isdtype.py b/tests/test_isdtype.py index aa937c3f..d164699e 100644 --- a/tests/test_isdtype.py +++ b/tests/test_isdtype.py @@ -106,7 +106,8 @@ def test_isdtype_additional_dtypes(library, dtype_): isdtype = xp.isdtype if not hasattr(xp, dtype_): - pytest.skip(f"{library} doesn't have dtype {dtype_}") + return + # pytest.skip(f"{library} doesn't have dtype {dtype_}") dtype = getattr(xp, dtype_) for cat in dtype_categories: