Skip to content

Implement isdtype() #32

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,10 +472,52 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
res = x1_[..., None, :] @ x2_[..., None]
return res[..., 0, 0]

# isdtype is a new function in the 2022.12 array API specification.

def isdtype(
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp,
*, _tuple=True, # Disallow nested tuples
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.

Note that outside of this function, this compat library does not yet fully
support complex numbers.

See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
for more details
"""
if isinstance(kind, tuple) and _tuple:
return any(isdtype(dtype, k, xp, _tuple=False) for k in kind)
elif isinstance(kind, str):
if kind == 'bool':
return dtype == xp.bool_
elif kind == 'signed integer':
return xp.issubdtype(dtype, xp.signedinteger)
elif kind == 'unsigned integer':
return xp.issubdtype(dtype, xp.unsignedinteger)
elif kind == 'integral':
return xp.issubdtype(dtype, xp.integer)
elif kind == 'real floating':
return xp.issubdtype(dtype, xp.floating)
elif kind == 'complex floating':
return xp.issubdtype(dtype, xp.complexfloating)
elif kind == 'numeric':
return xp.issubdtype(dtype, xp.number)
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
else:
# This will allow things that aren't required by the spec, like
# isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
# more strict here to match the type annotation? Note that the
# numpy.array_api implementation will be very strict.
return dtype == kind

__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul',
'matrix_transpose', 'tensordot', 'vecdot']
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']
1 change: 1 addition & 0 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
tensordot = get_xp(cp)(_aliases.tensordot)
vecdot = get_xp(cp)(_aliases.vecdot)
isdtype = get_xp(cp)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
Expand Down
1 change: 1 addition & 0 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
tensordot = get_xp(np)(_aliases.tensordot)
vecdot = get_xp(np)(_aliases.vecdot)
isdtype = get_xp(np)(_aliases.isdtype)

__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
'acosh', 'asin', 'asinh', 'atan', 'atan2',
Expand Down
49 changes: 45 additions & 4 deletions array_api_compat/torch/_aliases.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import wraps
from builtins import all as builtin_all
from builtins import all as builtin_all, any as builtin_any

from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
UniqueInverseResult,
Expand All @@ -18,13 +18,17 @@
import torch
array = torch.Tensor

_array_api_dtypes = {
torch.bool,
_int_dtypes = {
torch.uint8,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
}

_array_api_dtypes = {
torch.bool,
*_int_dtypes,
torch.float32,
torch.float64,
}
Expand Down Expand Up @@ -602,6 +606,43 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
return torch.tensordot(x1, x2, dims=axes, **kwargs)


def isdtype(
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]],
*, _tuple=True, # Disallow nested tuples
) -> bool:
"""
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.

Note that outside of this function, this compat library does not yet fully
support complex numbers.

See
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
for more details
"""
if isinstance(kind, tuple) and _tuple:
return builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
elif isinstance(kind, str):
if kind == 'bool':
return dtype == torch.bool
elif kind == 'signed integer':
return dtype in _int_dtypes and dtype.is_signed
Copy link
Member Author

@asmeurer asmeurer Mar 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to test if a torch dtype is integral?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There isn't. Based on dtype properties it would be:

not dtype.is_floating_point and not dtype.is_complex and dtype.is_signed and dtype != dtype.bool

which looks worse and may also not be foolproof. The check you have will need explicit updating if for example uint16 is added to PyTorch, but that's okay (nothing like that is in the pipeline AFAIK).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully this function itself would be added to pytorch by then.

elif kind == 'unsigned integer':
return dtype in _int_dtypes and not dtype.is_signed
elif kind == 'integral':
return dtype in _int_dtypes
elif kind == 'real floating':
return dtype.is_floating_point
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the correct way to check for torch dtype categories?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be good, yes.

elif kind == 'complex floating':
return dtype.is_complex
elif kind == 'numeric':
return isdtype(dtype, ('integral', 'real floating', 'complex floating'))
else:
raise ValueError(f"Unrecognized data type kind: {kind!r}")
else:
return dtype == kind

__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal',
Expand All @@ -612,4 +653,4 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones',
'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays',
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
'matmul', 'matrix_transpose', 'vecdot', 'tensordot']
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype']
8 changes: 8 additions & 0 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from importlib import import_module

import pytest

def import_(library):
if 'cupy' in library:
return pytest.importorskip(library)
return import_module(library)
5 changes: 4 additions & 1 deletion tests/test_array_namespace.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import array_api_compat
from array_api_compat import array_namespace

from ._helpers import import_

import pytest


@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
@pytest.mark.parametrize("api_version", [None, '2021.12'])
def test_array_namespace(library, api_version):
lib = pytest.importorskip(library)
lib = import_(library)

array = lib.asarray([1.0, 2.0, 3.0])
namespace = array_api_compat.array_namespace(array, api_version=api_version)
Expand Down
115 changes: 115 additions & 0 deletions tests/test_isdtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""
isdtype is not yet tested in the test suite, and it should extend properly to
non-spec dtypes
"""

from ._helpers import import_

import pytest

# Check the known dtypes by their string names

def _spec_dtypes(library):
if library == 'torch':
# torch does not have unsigned integer dtypes
return {
'bool',
'complex64',
'complex128',
'uint8',
'int8',
'int16',
'int32',
'int64',
'float32',
'float64',
}
else:
return {
'bool',
'complex64',
'complex128',
'float32',
'float64',
'int16',
'int32',
'int64',
'int8',
'uint16',
'uint32',
'uint64',
'uint8',
}

dtype_categories = {
'bool': lambda d: d == 'bool',
'signed integer': lambda d: d.startswith('int'),
'unsigned integer': lambda d: d.startswith('uint'),
'integral': lambda d: dtype_categories['signed integer'](d) or
dtype_categories['unsigned integer'](d),
'real floating': lambda d: 'float' in d,
'complex floating': lambda d: d.startswith('complex'),
'numeric': lambda d: dtype_categories['integral'](d) or
dtype_categories['real floating'](d) or
dtype_categories['complex floating'](d),
}

def isdtype_(dtype_, kind):
# Check a dtype_ string against kind. Note that 'bool' technically has two
# meanings here but they are both the same.
if kind in dtype_categories:
res = dtype_categories[kind](dtype_)
else:
res = dtype_ == kind
assert type(res) is bool
return res

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
def test_isdtype_spec_dtypes(library):
xp = import_('array_api_compat.' + library)

isdtype = xp.isdtype

for dtype_ in _spec_dtypes(library):
for dtype2_ in _spec_dtypes(library):
dtype = getattr(xp, dtype_)
dtype2 = getattr(xp, dtype2_)
res = isdtype_(dtype_, dtype2_)
assert isdtype(dtype, dtype2) is res, (dtype_, dtype2_)

for cat in dtype_categories:
res = isdtype_(dtype_, cat)
assert isdtype(dtype, cat) == res, (dtype_, cat)

# Basic tuple testing (the array-api testsuite will be more complete here)
for kind1_ in [*_spec_dtypes(library), *dtype_categories]:
for kind2_ in [*_spec_dtypes(library), *dtype_categories]:
kind1 = kind1_ if kind1_ in dtype_categories else getattr(xp, kind1_)
kind2 = kind2_ if kind2_ in dtype_categories else getattr(xp, kind2_)
kind = (kind1, kind2)

res = isdtype_(dtype_, kind1_) or isdtype_(dtype_, kind2_)
assert isdtype(dtype, kind) == res, (dtype_, (kind1_, kind2_))

additional_dtypes = [
'float16',
'float128',
'complex256',
'bfloat16',
]

@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
@pytest.mark.parametrize("dtype_", additional_dtypes)
def test_isdtype_additional_dtypes(library, dtype_):
xp = import_('array_api_compat.' + library)

isdtype = xp.isdtype

if not hasattr(xp, dtype_):
return
# pytest.skip(f"{library} doesn't have dtype {dtype_}")

dtype = getattr(xp, dtype_)
for cat in dtype_categories:
res = isdtype_(dtype_, cat)
assert isdtype(dtype, cat) == res, (dtype_, cat)