Skip to content

Commit 0bef982

Browse files
committed
Add isdtype implementation for numpy/cupy with some initial testing
1 parent 58cd3b9 commit 0bef982

File tree

6 files changed

+142
-2
lines changed

6 files changed

+142
-2
lines changed

array_api_compat/common/_aliases.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,49 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
472472
res = x1_[..., None, :] @ x2_[..., None]
473473
return res[..., 0, 0]
474474

475+
# isdtype is a new function in the 2022.12 array API specification.
476+
477+
def isdtype(
478+
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp,
479+
*, _tuple=True, # Disallow nested tuples
480+
) -> bool:
481+
"""
482+
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
483+
484+
See
485+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
486+
for more details
487+
"""
488+
if isinstance(kind, tuple) and _tuple:
489+
return any(isdtype(dtype, k, xp, _tuple=False) for k in kind)
490+
elif isinstance(kind, str):
491+
if kind == 'bool':
492+
return dtype == xp.bool_
493+
elif kind == 'signed integer':
494+
return xp.issubdtype(dtype, xp.signedinteger)
495+
elif kind == 'unsigned integer':
496+
return xp.issubdtype(dtype, xp.unsignedinteger)
497+
elif kind == 'integral':
498+
return xp.issubdtype(dtype, xp.integer)
499+
elif kind == 'real floating':
500+
return xp.issubdtype(dtype, xp.floating)
501+
elif kind == 'complex floating':
502+
return xp.issubdtype(dtype, xp.complexfloating)
503+
elif kind == 'numeric':
504+
return xp.issubdtype(dtype, xp.number)
505+
else:
506+
raise ValueError(f"Unrecognized data type kind: {kind!r}")
507+
else:
508+
# This will allow things that aren't required by the spec, like
509+
# isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
510+
# more strict here to match the type annotation? Note that the
511+
# numpy.array_api implementation will be very strict.
512+
return dtype == kind
513+
475514
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
476515
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
477516
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
478517
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
479518
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
480519
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul',
481-
'matrix_transpose', 'tensordot', 'vecdot']
520+
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/cupy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6262
tensordot = get_xp(cp)(_aliases.tensordot)
6363
vecdot = get_xp(cp)(_aliases.vecdot)
64+
isdtype = get_xp(cp)(_aliases.isdtype)
6465

6566
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
6667
'acosh', 'asin', 'asinh', 'atan', 'atan2',

array_api_compat/numpy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6262
tensordot = get_xp(np)(_aliases.tensordot)
6363
vecdot = get_xp(np)(_aliases.vecdot)
64+
isdtype = get_xp(np)(_aliases.isdtype)
6465

6566
__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
6667
'acosh', 'asin', 'asinh', 'atan', 'atan2',

tests/_helpers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from importlib import import_module
2+
3+
import pytest
4+
5+
def import_(library):
6+
if library == 'cupy':
7+
return pytest.importorskip(library)
8+
return import_module(library)

tests/test_array_namespace.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import array_api_compat
22
from array_api_compat import array_namespace
3+
4+
from ._helpers import import_
5+
36
import pytest
47

58

69
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
710
@pytest.mark.parametrize("api_version", [None, '2021.12'])
811
def test_array_namespace(library, api_version):
9-
lib = pytest.importorskip(library)
12+
lib = import_(library)
1013

1114
array = lib.asarray([1.0, 2.0, 3.0])
1215
namespace = array_api_compat.array_namespace(array, api_version=api_version)

tests/test_isdtype.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
"""
2+
isdtype is not yet tested in the test suite, and it should extend properly to
3+
non-spec dtypes
4+
"""
5+
6+
from ._helpers import import_
7+
8+
import pytest
9+
10+
# Check the known dtypes by their string names
11+
12+
def _spec_dtypes(library):
13+
if library == 'torch':
14+
# torch does not have unsigned integer dtypes
15+
return {
16+
'bool',
17+
'uint8',
18+
'int8',
19+
'int16',
20+
'int32',
21+
'int64',
22+
'float32',
23+
'float64',
24+
}
25+
else:
26+
return {
27+
'bool',
28+
'float32',
29+
'float64',
30+
'int16',
31+
'int32',
32+
'int64',
33+
'int8',
34+
'uint16',
35+
'uint32',
36+
'uint64',
37+
'uint8',
38+
}
39+
40+
dtype_categories = {
41+
'bool': lambda d: d == 'bool',
42+
'signed integer': lambda d: d.startswith('int'),
43+
'unsigned integer': lambda d: d.startswith('uint'),
44+
'integral': lambda d: dtype_categories['signed integer'](d) or
45+
dtype_categories['unsigned integer'](d),
46+
'real floating': lambda d: d.startswith('float'),
47+
'complex floating': lambda d: d.startswith('complex'),
48+
'numeric': lambda d: dtype_categories['integral'](d) or
49+
dtype_categories['real floating'](d) or
50+
dtype_categories['complex floating'](d),
51+
}
52+
53+
def isdtype_(dtype_, kind):
54+
# Check a dtype_ string against kind. Note that 'bool' technically has two
55+
# meanings here but they are both the same.
56+
if kind in dtype_categories:
57+
res = dtype_categories[kind](dtype_)
58+
else:
59+
res = dtype_ == kind
60+
assert type(res) is bool
61+
return res
62+
63+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
64+
def test_isdtype_spec_dtypes(library):
65+
xp = import_('array_api_compat.' + library)
66+
67+
isdtype = xp.isdtype
68+
69+
for dtype_ in _spec_dtypes(library):
70+
for dtype2_ in _spec_dtypes(library):
71+
dtype = getattr(xp, dtype_)
72+
dtype2 = getattr(xp, dtype2_)
73+
res = isdtype_(dtype_, dtype2_)
74+
assert isdtype(dtype, dtype2) is res, (dtype_, dtype2_)
75+
76+
for cat in dtype_categories:
77+
res = isdtype_(dtype_, cat)
78+
assert isdtype(dtype, cat) == res, (dtype_, cat)
79+
80+
# Basic tuple testing (the array-api testsuite will be more complete here)
81+
for kind1_ in [*_spec_dtypes(library), *dtype_categories]:
82+
for kind2_ in [*_spec_dtypes(library), *dtype_categories]:
83+
kind1 = kind1_ if kind1_ in dtype_categories else getattr(xp, kind1_)
84+
kind2 = kind2_ if kind2_ in dtype_categories else getattr(xp, kind2_)
85+
kind = (kind1, kind2)
86+
87+
res = isdtype_(dtype_, kind1_) or isdtype_(dtype_, kind2_)
88+
assert isdtype(dtype, kind) == res, (dtype_, (kind1_, kind2_))

0 commit comments

Comments
 (0)