Skip to content

Commit fc90e1b

Browse files
committed
Add isdtype support for pytorch
1 parent 7c7c02e commit fc90e1b

File tree

1 file changed

+45
-4
lines changed

1 file changed

+45
-4
lines changed

array_api_compat/torch/_aliases.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from functools import wraps
4-
from builtins import all as builtin_all
4+
from builtins import all as builtin_all, any as builtin_any
55

66
from ..common._aliases import (UniqueAllResult, UniqueCountsResult,
77
UniqueInverseResult,
@@ -18,13 +18,17 @@
1818
import torch
1919
array = torch.Tensor
2020

21-
_array_api_dtypes = {
22-
torch.bool,
21+
_int_dtypes = {
2322
torch.uint8,
2423
torch.int8,
2524
torch.int16,
2625
torch.int32,
2726
torch.int64,
27+
}
28+
29+
_array_api_dtypes = {
30+
torch.bool,
31+
*_int_dtypes,
2832
torch.float32,
2933
torch.float64,
3034
}
@@ -602,6 +606,43 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
602606
x1, x2 = _fix_promotion(x1, x2, only_scalar=False)
603607
return torch.tensordot(x1, x2, dims=axes, **kwargs)
604608

609+
610+
def isdtype(
611+
dtype: Dtype, kind: Union[Dtype, str, Tuple[Union[Dtype, str], ...]],
612+
*, _tuple=True, # Disallow nested tuples
613+
) -> bool:
614+
"""
615+
Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
616+
617+
Note that outside of this function, this compat library does not yet fully
618+
support complex numbers.
619+
620+
See
621+
https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
622+
for more details
623+
"""
624+
if isinstance(kind, tuple) and _tuple:
625+
return builtin_any(isdtype(dtype, k, _tuple=False) for k in kind)
626+
elif isinstance(kind, str):
627+
if kind == 'bool':
628+
return dtype == torch.bool
629+
elif kind == 'signed integer':
630+
return dtype in _int_dtypes and dtype.is_signed
631+
elif kind == 'unsigned integer':
632+
return dtype in _int_dtypes and not dtype.is_signed
633+
elif kind == 'integral':
634+
return dtype in _int_dtypes
635+
elif kind == 'real floating':
636+
return dtype.is_floating_point
637+
elif kind == 'complex floating':
638+
return dtype.is_complex
639+
elif kind == 'numeric':
640+
return isdtype(dtype, ('integral', 'real floating', 'complex floating'))
641+
else:
642+
raise ValueError(f"Unrecognized data type kind: {kind!r}")
643+
else:
644+
return dtype == kind
645+
605646
__all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'add',
606647
'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or',
607648
'bitwise_right_shift', 'bitwise_xor', 'divide', 'equal',
@@ -612,4 +653,4 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
612653
'nonzero', 'where', 'arange', 'eye', 'linspace', 'full', 'ones',
613654
'zeros', 'empty', 'expand_dims', 'astype', 'broadcast_arrays',
614655
'unique_all', 'unique_counts', 'unique_inverse', 'unique_values',
615-
'matmul', 'matrix_transpose', 'vecdot', 'tensordot']
656+
'matmul', 'matrix_transpose', 'vecdot', 'tensordot', 'isdtype']

0 commit comments

Comments
 (0)