Skip to content

Commit 7bc3d80

Browse files
authored
Merge pull request #1979 from IntelPython/fix-device-keyword-in-array-api-inspection
Fix array API inspection behavior with `device` keyword
2 parents c5cbb08 + 06f266c commit 7bc3d80

File tree

2 files changed

+86
-32
lines changed

2 files changed

+86
-32
lines changed

dpctl/tensor/_array_api.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,21 @@ def _isdtype_impl(dtype, kind):
4646
elif isinstance(kind, tuple):
4747
return any(_isdtype_impl(dtype, k) for k in kind)
4848
else:
49-
raise TypeError(f"Unsupported data type kind: {kind}")
49+
raise TypeError(f"Unsupported type for dtype kind: {type(kind)}")
50+
51+
52+
def _get_device_impl(d):
53+
if d is None:
54+
return dpctl.select_default_device()
55+
elif isinstance(d, dpctl.SyclDevice):
56+
return d
57+
elif isinstance(d, (dpt.Device, dpctl.SyclQueue)):
58+
return d.sycl_device
59+
else:
60+
try:
61+
return dpctl.SyclDevice(d)
62+
except TypeError:
63+
raise TypeError(f"Unsupported type for device argument: {type(d)}")
5064

5165

5266
__array_api_version__ = "2023.12"
@@ -117,13 +131,13 @@ def default_dtypes(self, *, device=None):
117131
Returns a dictionary of default data types for ``device``.
118132
119133
Args:
120-
device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`]):
134+
device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`, str]):
121135
array API concept of device used in getting default data types.
122136
``device`` can be ``None`` (in which case the default device
123-
is used), an instance of :class:`dpctl.SyclDevice` corresponding
124-
to a non-partitioned SYCL device, an instance of
125-
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device`
126-
object returned by :attr:`dpctl.tensor.usm_ndarray.device`.
137+
is used), an instance of :class:`dpctl.SyclDevice`, an instance
138+
of :class:`dpctl.SyclQueue`, a :class:`dpctl.tensor.Device`
139+
object returned by :attr:`dpctl.tensor.usm_ndarray.device`, or
140+
a filter selector string.
127141
Default: ``None``.
128142
129143
Returns:
@@ -135,10 +149,7 @@ def default_dtypes(self, *, device=None):
135149
- ``"integral"``: dtype
136150
- ``"indexing"``: dtype
137151
"""
138-
if device is None:
139-
device = dpctl.select_default_device()
140-
elif isinstance(device, dpt.Device):
141-
device = device.sycl_device
152+
device = _get_device_impl(device)
142153
return {
143154
"real floating": dpt.dtype(default_device_fp_type(device)),
144155
"complex floating": dpt.dtype(default_device_complex_type(device)),
@@ -161,10 +172,10 @@ def dtypes(self, *, device=None, kind=None):
161172
device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`, str]):
162173
array API concept of device used in getting default data types.
163174
``device`` can be ``None`` (in which case the default device is
164-
used), an instance of :class:`dpctl.SyclDevice` corresponding
165-
to a non-partitioned SYCL device, an instance of
166-
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device`
167-
object returned by :attr:`dpctl.tensor.usm_ndarray.device`.
175+
used), an instance of :class:`dpctl.SyclDevice`, an instance of
176+
:class:`dpctl.SyclQueue`, a :class:`dpctl.tensor.Device`
177+
object returned by :attr:`dpctl.tensor.usm_ndarray.device`, or
178+
a filter selector string.
168179
Default: ``None``.
169180
170181
kind (Optional[str, Tuple[str, ...]]):
@@ -196,22 +207,20 @@ def dtypes(self, *, device=None, kind=None):
196207
a dictionary of the supported data types of the specified
197208
``kind``
198209
"""
199-
if device is None:
200-
device = dpctl.select_default_device()
201-
elif isinstance(device, dpt.Device):
202-
device = device.sycl_device
210+
device = _get_device_impl(device)
203211
_fp64 = device.has_aspect_fp64
204212
if kind is None:
205213
return {
206214
key: val
207215
for key, val in self._all_dtypes.items()
208-
if (key != "float64" or _fp64)
216+
if _fp64 or (key != "float64" and key != "complex128")
209217
}
210218
else:
211219
return {
212220
key: val
213221
for key, val in self._all_dtypes.items()
214-
if (key != "float64" or _fp64) and _isdtype_impl(val, kind)
222+
if (_fp64 or (key != "float64" and key != "complex128"))
223+
and _isdtype_impl(val, kind)
215224
}
216225

217226
def devices(self):

dpctl/tests/test_tensor_array_api_inspection.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
"bool": dpt.bool,
3030
"float32": dpt.float32,
3131
"complex64": dpt.complex64,
32-
"complex128": dpt.complex128,
3332
"int8": dpt.int8,
3433
"int16": dpt.int16,
3534
"int32": dpt.int32,
@@ -41,12 +40,6 @@
4140
}
4241

4342

44-
class MockDevice:
45-
def __init__(self, fp16: bool, fp64: bool):
46-
self.has_aspect_fp16 = fp16
47-
self.has_aspect_fp64 = fp64
48-
49-
5043
def test_array_api_inspection_methods():
5144
info = dpt.__array_namespace_info__()
5245
assert info.capabilities()
@@ -125,17 +118,21 @@ def test_array_api_inspection_default_device_dtypes():
125118
dtypes = _dtypes_no_fp16_fp64.copy()
126119
if dev.has_aspect_fp64:
127120
dtypes["float64"] = dpt.float64
121+
dtypes["complex128"] = dpt.complex128
128122

129123
assert dtypes == dpt.__array_namespace_info__().dtypes()
130124

131125

132-
@pytest.mark.parametrize("fp16", [True, False])
133-
@pytest.mark.parametrize("fp64", [True, False])
134-
def test_array_api_inspection_device_dtypes(fp16, fp64):
135-
dev = MockDevice(fp16, fp64)
126+
def test_array_api_inspection_device_dtypes():
127+
info = dpt.__array_namespace_info__()
128+
try:
129+
dev = info.default_device()
130+
except dpctl.SyclDeviceCreationError:
131+
pytest.skip("No default device available")
136132
dtypes = _dtypes_no_fp16_fp64.copy()
137-
if fp64:
133+
if dev.has_aspect_fp64:
138134
dtypes["float64"] = dpt.float64
135+
dtypes["complex128"] = dpt.complex128
139136

140137
assert dtypes == dpt.__array_namespace_info__().dtypes(device=dev)
141138

@@ -179,3 +176,51 @@ def test_array_api_inspection_dtype_kind():
179176
)
180177
== info.dtypes()
181178
)
179+
assert info.dtypes(
180+
kind=("integral", "real floating", "complex floating")
181+
) == info.dtypes(kind="numeric")
182+
183+
184+
def test_array_api_inspection_dtype_kind_errors():
185+
info = dpt.__array_namespace_info__()
186+
try:
187+
info.default_device()
188+
except dpctl.SyclDeviceCreationError:
189+
pytest.skip("No default device available")
190+
191+
with pytest.raises(ValueError):
192+
info.dtypes(kind="error")
193+
194+
with pytest.raises(TypeError):
195+
info.dtypes(kind={0: "real floating"})
196+
197+
198+
def test_array_api_inspection_device_types():
199+
info = dpt.__array_namespace_info__()
200+
try:
201+
dev = info.default_device()
202+
except dpctl.SyclDeviceCreationError:
203+
pytest.skip("No default device available")
204+
205+
q = dpctl.SyclQueue(dev)
206+
assert info.default_dtypes(device=q)
207+
assert info.dtypes(device=q)
208+
209+
dev_dpt = dpt.Device.create_device(dev)
210+
assert info.default_dtypes(device=dev_dpt)
211+
assert info.dtypes(device=dev_dpt)
212+
213+
filter = dev.get_filter_string()
214+
assert info.default_dtypes(device=filter)
215+
assert info.dtypes(device=filter)
216+
217+
218+
def test_array_api_inspection_device_errors():
219+
info = dpt.__array_namespace_info__()
220+
221+
bad_dev = dict()
222+
with pytest.raises(TypeError):
223+
info.dtypes(device=bad_dev)
224+
225+
with pytest.raises(TypeError):
226+
info.default_dtypes(device=bad_dev)

0 commit comments

Comments
 (0)