Skip to content

Commit eb66b69

Browse files
authored
Merge branch 'main' into torch-linalg
2 parents 148a892 + e6007fa commit eb66b69

File tree

7 files changed

+217
-6
lines changed

7 files changed

+217
-6
lines changed

array_api_compat/common/_aliases.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,52 @@ def vecdot(x1: ndarray, x2: ndarray, /, xp, *, axis: int = -1) -> ndarray:
472472
res = x1_[..., None, :] @ x2_[..., None]
473473
return res[..., 0, 0]
474474

475+
# isdtype is a new function in the 2022.12 array API specification.
476+
477+
def isdtype(
478+
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]], xp,
479+
*, _tuple=True, # Disallow nested tuples
480+
) -> bool:
481+
"""
482+
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
483+
484+
Note that outside of this function, this compat library does not yet fully
485+
support complex numbers.
486+
487+
See
488+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
489+
for more details
490+
"""
491+
if isinstance(kind, tuple) and _tuple:
492+
return any(isdtype(dtype, k, xp, _tuple=False) for k in kind)
493+
elif isinstance(kind, str):
494+
if kind == 'bool':
495+
return dtype == xp.bool_
496+
elif kind == 'signed integer':
497+
return xp.issubdtype(dtype, xp.signedinteger)
498+
elif kind == 'unsigned integer':
499+
return xp.issubdtype(dtype, xp.unsignedinteger)
500+
elif kind == 'integral':
501+
return xp.issubdtype(dtype, xp.integer)
502+
elif kind == 'real floating':
503+
return xp.issubdtype(dtype, xp.floating)
504+
elif kind == 'complex floating':
505+
return xp.issubdtype(dtype, xp.complexfloating)
506+
elif kind == 'numeric':
507+
return xp.issubdtype(dtype, xp.number)
508+
else:
509+
raise ValueError(f"Unrecognized data type kind: {kind!r}")
510+
else:
511+
# This will allow things that aren't required by the spec, like
512+
# isdtype(np.float64, float) or isdtype(np.int64, 'l'). Should we be
513+
# more strict here to match the type annotation? Note that the
514+
# numpy.array_api implementation will be very strict.
515+
return dtype == kind
516+
475517
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like',
476518
'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like',
477519
'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult',
478520
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
479521
'astype', 'std', 'var', 'permute_dims', 'reshape', 'argsort',
480522
'sort', 'sum', 'prod', 'ceil', 'floor', 'trunc', 'matmul',
481-
'matrix_transpose', 'tensordot', 'vecdot']
523+
'matrix_transpose', 'tensordot', 'vecdot', 'isdtype']

array_api_compat/cupy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
matrix_transpose = get_xp(cp)(_aliases.matrix_transpose)
6262
tensordot = get_xp(cp)(_aliases.tensordot)
6363
vecdot = get_xp(cp)(_aliases.vecdot)
64+
isdtype = get_xp(cp)(_aliases.isdtype)
6465

6566
__all__ = _aliases.__all__ + ['asarray', 'asarray_cupy', 'bool', 'acos',
6667
'acosh', 'asin', 'asinh', 'atan', 'atan2',

array_api_compat/numpy/_aliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
matrix_transpose = get_xp(np)(_aliases.matrix_transpose)
6262
tensordot = get_xp(np)(_aliases.tensordot)
6363
vecdot = get_xp(np)(_aliases.vecdot)
64+
isdtype = get_xp(np)(_aliases.isdtype)
6465

6566
__all__ = _aliases.__all__ + ['asarray', 'asarray_numpy', 'bool', 'acos',
6667
'acosh', 'asin', 'asinh', 'atan', 'atan2',

array_api_compat/torch/_aliases.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from functools import wraps
4-
from builtins import all as builtin_all
4+
from builtins import all as builtin_all, any as builtin_any
55

66
from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
77
UniqueInverseResult,
@@ -19,13 +19,17 @@
1919

2020
array = torch.Tensor
2121

22-
_array_api_dtypes = {
23-
torch.bool,
22+
_int_dtypes = {
2423
torch.uint8,
2524
torch.int8,
2625
torch.int16,
2726
torch.int32,
2827
torch.int64,
28+
}
29+
30+
_array_api_dtypes = {
31+
torch.bool,
32+
*_int_dtypes,
2933
torch.float32,
3034
torch.float64,
3135
}
@@ -611,6 +615,43 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
611615
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
612616
return torch.tensordot(x1, x2, dims=axes, **kwargs)
613617

618+
619+
def isdtype(
620+
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]],
621+
*, _tuple=True, # Disallow nested tuples
622+
) -> bool:
623+
"""
624+
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
625+
626+
Note that outside of this function, this compat library does not yet fully
627+
support complex numbers.
628+
629+
See
630+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
631+
for more details
632+
"""
633+
if isinstance(kind, tuple) and _tuple:
634+
return builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
635+
elif isinstance(kind, str):
636+
if kind == 'bool':
637+
return dtype == torch.bool
638+
elif kind == 'signed integer':
639+
return dtype in _int_dtypes and dtype.is_signed
640+
elif kind == 'unsigned integer':
641+
return dtype in _int_dtypes and not dtype.is_signed
642+
elif kind == 'integral':
643+
return dtype in _int_dtypes
644+
elif kind == 'real floating':
645+
return dtype.is_floating_point
646+
elif kind == 'complex floating':
647+
return dtype.is_complex
648+
elif kind == 'numeric':
649+
return isdtype(dtype, ('integral', 'real floating', 'complex floating'))
650+
else:
651+
raise ValueError(f"Unrecognized data type kind: {kind!r}")
652+
else:
653+
return dtype == kind
654+
614655
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',
615656
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
616657
'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal',
@@ -622,4 +663,4 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
622663
'zeros', 'empty', 'tril', 'triu', 'expand_dims', 'astype',
623664
'broadcast_arrays', 'unique_all', 'unique_counts',
624665
'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose',
625-
'vecdot', 'tensordot']
666+
'vecdot', 'tensordot', 'isdtype']

tests/_helpers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from importlib import import_module
2+
3+
import pytest
4+
5+
def import_(library):
6+
if 'cupy' in library:
7+
return pytest.importorskip(library)
8+
return import_module(library)

tests/test_array_namespace.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import array_api_compat
22
from array_api_compat import array_namespace
3+
4+
from ._helpers import import_
5+
36
import pytest
47

58

69
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
710
@pytest.mark.parametrize("api_version", [None, '2021.12'])
811
def test_array_namespace(library, api_version):
9-
lib = pytest.importorskip(library)
12+
lib = import_(library)
1013

1114
array = lib.asarray([1.0, 2.0, 3.0])
1215
namespace = array_api_compat.array_namespace(array, api_version=api_version)

tests/test_isdtype.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
"""
2+
isdtype is not yet tested in the test suite, and it should extend properly to
3+
non-spec dtypes
4+
"""
5+
6+
from ._helpers import import_
7+
8+
import pytest
9+
10+
# Check the known dtypes by their string names
11+
12+
def _spec_dtypes(library):
13+
if library == 'torch':
14+
# torch does not have unsigned integer dtypes
15+
return {
16+
'bool',
17+
'complex64',
18+
'complex128',
19+
'uint8',
20+
'int8',
21+
'int16',
22+
'int32',
23+
'int64',
24+
'float32',
25+
'float64',
26+
}
27+
else:
28+
return {
29+
'bool',
30+
'complex64',
31+
'complex128',
32+
'float32',
33+
'float64',
34+
'int16',
35+
'int32',
36+
'int64',
37+
'int8',
38+
'uint16',
39+
'uint32',
40+
'uint64',
41+
'uint8',
42+
}
43+
44+
dtype_categories = {
45+
'bool': lambda d: d == 'bool',
46+
'signed integer': lambda d: d.startswith('int'),
47+
'unsigned integer': lambda d: d.startswith('uint'),
48+
'integral': lambda d: dtype_categories['signed integer'](d) or
49+
dtype_categories['unsigned integer'](d),
50+
'real floating': lambda d: 'float' in d,
51+
'complex floating': lambda d: d.startswith('complex'),
52+
'numeric': lambda d: dtype_categories['integral'](d) or
53+
dtype_categories['real floating'](d) or
54+
dtype_categories['complex floating'](d),
55+
}
56+
57+
def isdtype_(dtype_, kind):
58+
# Check a dtype_ string against kind. Note that 'bool' technically has two
59+
# meanings here but they are both the same.
60+
if kind in dtype_categories:
61+
res = dtype_categories[kind](dtype_)
62+
else:
63+
res = dtype_ == kind
64+
assert type(res) is bool
65+
return res
66+
67+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
68+
def test_isdtype_spec_dtypes(library):
69+
xp = import_('array_api_compat.' + library)
70+
71+
isdtype = xp.isdtype
72+
73+
for dtype_ in _spec_dtypes(library):
74+
for dtype2_ in _spec_dtypes(library):
75+
dtype = getattr(xp, dtype_)
76+
dtype2 = getattr(xp, dtype2_)
77+
res = isdtype_(dtype_, dtype2_)
78+
assert isdtype(dtype, dtype2) is res, (dtype_, dtype2_)
79+
80+
for cat in dtype_categories:
81+
res = isdtype_(dtype_, cat)
82+
assert isdtype(dtype, cat) == res, (dtype_, cat)
83+
84+
# Basic tuple testing (the array-api testsuite will be more complete here)
85+
for kind1_ in [*_spec_dtypes(library), *dtype_categories]:
86+
for kind2_ in [*_spec_dtypes(library), *dtype_categories]:
87+
kind1 = kind1_ if kind1_ in dtype_categories else getattr(xp, kind1_)
88+
kind2 = kind2_ if kind2_ in dtype_categories else getattr(xp, kind2_)
89+
kind = (kind1, kind2)
90+
91+
res = isdtype_(dtype_, kind1_) or isdtype_(dtype_, kind2_)
92+
assert isdtype(dtype, kind) == res, (dtype_, (kind1_, kind2_))
93+
94+
additional_dtypes = [
95+
'float16',
96+
'float128',
97+
'complex256',
98+
'bfloat16',
99+
]
100+
101+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
102+
@pytest.mark.parametrize("dtype_", additional_dtypes)
103+
def test_isdtype_additional_dtypes(library, dtype_):
104+
xp = import_('array_api_compat.' + library)
105+
106+
isdtype = xp.isdtype
107+
108+
if not hasattr(xp, dtype_):
109+
return
110+
# pytest.skip(f"{library} doesn't have dtype {dtype_}")
111+
112+
dtype = getattr(xp, dtype_)
113+
for cat in dtype_categories:
114+
res = isdtype_(dtype_, cat)
115+
assert isdtype(dtype, cat) == res, (dtype_, cat)

0 commit comments

Comments
 (0)