Skip to content

Commit 7c7c02e

Browse files
committed
Test additional dtypes in test_isdtype
1 parent a375e9d commit 7c7c02e

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

tests/test_isdtype.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _spec_dtypes(library):
4747
'unsigned integer': lambda d: d.startswith('uint'),
4848
'integral': lambda d: dtype_categories['signed integer'](d) or
4949
dtype_categories['unsigned integer'](d),
50-
'real floating': lambda d: d.startswith('float'),
50+
'real floating': lambda d: 'float' in d,
5151
'complex floating': lambda d: d.startswith('complex'),
5252
'numeric': lambda d: dtype_categories['integral'](d) or
5353
dtype_categories['real floating'](d) or
@@ -90,3 +90,25 @@ def test_isdtype_spec_dtypes(library):
9090

9191
res = isdtype_(dtype_, kind1_) or isdtype_(dtype_, kind2_)
9292
assert isdtype(dtype, kind) == res, (dtype_, (kind1_, kind2_))
93+
94+
additional_dtypes = [
95+
'float16',
96+
'float128',
97+
'complex256',
98+
'bfloat16',
99+
]
100+
101+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
102+
@pytest.mark.parametrize("dtype_", additional_dtypes)
103+
def test_isdtype_additional_dtypes(library, dtype_):
104+
xp = import_('array_api_compat.' + library)
105+
106+
isdtype = xp.isdtype
107+
108+
if not hasattr(xp, dtype_):
109+
pytest.skip(f"{library} doesn't have dtype {dtype_}")
110+
111+
dtype = getattr(xp, dtype_)
112+
for cat in dtype_categories:
113+
res = isdtype_(dtype_, cat)
114+
assert isdtype(dtype, cat) == res, (dtype_, cat)

0 commit comments

Comments
 (0)