1
1
from __future__ import annotations
2
2
3
3
from functools import wraps
4
- from builtins import all as builtin_all
4
+ from builtins import all as builtin_all , any as builtin_any
5
5
6
6
from ..common ._aliases import (UniqueAllResult , UniqueCountsResult ,
7
7
UniqueInverseResult ,
18
18
import torch
19
19
array = torch .Tensor
20
20
21
- _array_api_dtypes = {
22
- torch .bool ,
21
+ _int_dtypes = {
23
22
torch .uint8 ,
24
23
torch .int8 ,
25
24
torch .int16 ,
26
25
torch .int32 ,
27
26
torch .int64 ,
27
+ }
28
+
29
+ _array_api_dtypes = {
30
+ torch .bool ,
31
+ * _int_dtypes ,
28
32
torch .float32 ,
29
33
torch .float64 ,
30
34
}
@@ -602,6 +606,43 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
602
606
x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
603
607
return torch .tensordot (x1 , x2 , dims = axes , ** kwargs )
604
608
609
+
610
+ def isdtype (
611
+ dtype : Dtype , kind : Union [Dtype , str , Tuple [Union [Dtype , str ], ...]],
612
+ * , _tuple = True , # Disallow nested tuples
613
+ ) -> bool :
614
+ """
615
+ Returns a boolean indicating whether a provided dtype is of a specified data type ``kind``.
616
+
617
+ Note that outside of this function, this compat library does not yet fully
618
+ support complex numbers.
619
+
620
+ See
621
+ https://data-apis.org/array-api/latest/API_specification/generated/array_api.isdtype.html
622
+ for more details
623
+ """
624
+ if isinstance (kind , tuple ) and _tuple :
625
+ return builtin_any (isdtype (dtype , k , _tuple = False ) for k in kind )
626
+ elif isinstance (kind , str ):
627
+ if kind == 'bool' :
628
+ return dtype == torch .bool
629
+ elif kind == 'signed integer' :
630
+ return dtype in _int_dtypes and dtype .is_signed
631
+ elif kind == 'unsigned integer' :
632
+ return dtype in _int_dtypes and not dtype .is_signed
633
+ elif kind == 'integral' :
634
+ return dtype in _int_dtypes
635
+ elif kind == 'real floating' :
636
+ return dtype .is_floating_point
637
+ elif kind == 'complex floating' :
638
+ return dtype .is_complex
639
+ elif kind == 'numeric' :
640
+ return isdtype (dtype , ('integral' , 'real floating' , 'complex floating' ))
641
+ else :
642
+ raise ValueError (f"Unrecognized data type kind: { kind !r} " )
643
+ else :
644
+ return dtype == kind
645
+
605
646
__all__ = ['result_type' , 'can_cast' , 'permute_dims' , 'bitwise_invert' , 'add' ,
606
647
'atan2' , 'bitwise_and' , 'bitwise_left_shift' , 'bitwise_or' ,
607
648
'bitwise_right_shift' , 'bitwise_xor' , 'divide' , 'equal' ,
@@ -612,4 +653,4 @@ def tensordot(x1: array, x2: array, /, *, axes: Union[int, Tuple[Sequence[int],
612
653
'nonzero' , 'where' , 'arange' , 'eye' , 'linspace' , 'full' , 'ones' ,
613
654
'zeros' , 'empty' , 'expand_dims' , 'astype' , 'broadcast_arrays' ,
614
655
'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
615
- 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' ]
656
+ 'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ]
0 commit comments