@@ -47,7 +47,7 @@ def _spec_dtypes(library):
47
47
'unsigned integer' : lambda d : d .startswith ('uint' ),
48
48
'integral' : lambda d : dtype_categories ['signed integer' ](d ) or
49
49
dtype_categories ['unsigned integer' ](d ),
50
- 'real floating' : lambda d : d . startswith ( 'float' ) ,
50
+ 'real floating' : lambda d : 'float' in d ,
51
51
'complex floating' : lambda d : d .startswith ('complex' ),
52
52
'numeric' : lambda d : dtype_categories ['integral' ](d ) or
53
53
dtype_categories ['real floating' ](d ) or
@@ -90,3 +90,25 @@ def test_isdtype_spec_dtypes(library):
90
90
91
91
res = isdtype_ (dtype_ , kind1_ ) or isdtype_ (dtype_ , kind2_ )
92
92
assert isdtype (dtype , kind ) == res , (dtype_ , (kind1_ , kind2_ ))
93
+
94
+ additional_dtypes = [
95
+ 'float16' ,
96
+ 'float128' ,
97
+ 'complex256' ,
98
+ 'bfloat16' ,
99
+ ]
100
+
101
+ @pytest .mark .parametrize ("library" , ["cupy" , "numpy" , "torch" ])
102
+ @pytest .mark .parametrize ("dtype_" , additional_dtypes )
103
+ def test_isdtype_additional_dtypes (library , dtype_ ):
104
+ xp = import_ ('array_api_compat.' + library )
105
+
106
+ isdtype = xp .isdtype
107
+
108
+ if not hasattr (xp , dtype_ ):
109
+ pytest .skip (f"{ library } doesn't have dtype { dtype_ } " )
110
+
111
+ dtype = getattr (xp , dtype_ )
112
+ for cat in dtype_categories :
113
+ res = isdtype_ (dtype_ , cat )
114
+ assert isdtype (dtype , cat ) == res , (dtype_ , cat )
0 commit comments