Skip to content

Commit 348dd3d

Browse files
committed
Array API inspection utilities interfacing with devices fixed
* Fixes a bug where `"complex128"` was present in the list of dtypes on devices without fp64 support * Fixes functionality of device keyword for SyclQueue and strings * Improves error message when device keyword cannot be used to construct a SyclDevice instance * Tweaks docstring for device keywords in inspection utilities
1 parent c5cbb08 commit 348dd3d

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

dpctl/tensor/_array_api.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,20 @@ def _isdtype_impl(dtype, kind):
4949
raise TypeError(f"Unsupported data type kind: {kind}")
5050

5151

52+
def _get_device_impl(d):
53+
if d is None:
54+
return dpctl.select_default_device()
55+
elif isinstance(d, dpctl.SyclDevice):
56+
return d
57+
elif isinstance(d, (dpt.Device, dpctl.SyclQueue)):
58+
return d.sycl_device
59+
else:
60+
try:
61+
return dpctl.SyclDevice(d)
62+
except TypeError:
63+
raise TypeError(f"Unsupported type for device argument: {type(d)}")
64+
65+
5266
__array_api_version__ = "2023.12"
5367

5468

@@ -117,13 +131,13 @@ def default_dtypes(self, *, device=None):
117131
Returns a dictionary of default data types for ``device``.
118132
119133
Args:
120-
device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`]):
134+
device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`, str]):
121135
array API concept of device used in getting default data types.
122136
``device`` can be ``None`` (in which case the default device
123-
is used), an instance of :class:`dpctl.SyclDevice` corresponding
124-
to a non-partitioned SYCL device, an instance of
125-
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device`
126-
object returned by :attr:`dpctl.tensor.usm_ndarray.device`.
137+
is used), an instance of :class:`dpctl.SyclDevice`, an instance
138+
of :class:`dpctl.SyclQueue`, a :class:`dpctl.tensor.Device`
139+
object returned by :attr:`dpctl.tensor.usm_ndarray.device`, or
140+
a filter selector string.
127141
Default: ``None``.
128142
129143
Returns:
@@ -135,10 +149,7 @@ def default_dtypes(self, *, device=None):
135149
- ``"integral"``: dtype
136150
- ``"indexing"``: dtype
137151
"""
138-
if device is None:
139-
device = dpctl.select_default_device()
140-
elif isinstance(device, dpt.Device):
141-
device = device.sycl_device
152+
device = _get_device_impl(device)
142153
return {
143154
"real floating": dpt.dtype(default_device_fp_type(device)),
144155
"complex floating": dpt.dtype(default_device_complex_type(device)),
@@ -161,10 +172,10 @@ def dtypes(self, *, device=None, kind=None):
161172
device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`, str]):
162173
array API concept of device used in getting default data types.
163174
``device`` can be ``None`` (in which case the default device is
164-
used), an instance of :class:`dpctl.SyclDevice` corresponding
165-
to a non-partitioned SYCL device, an instance of
166-
:class:`dpctl.SyclQueue`, or a :class:`dpctl.tensor.Device`
167-
object returned by :attr:`dpctl.tensor.usm_ndarray.device`.
175+
used), an instance of :class:`dpctl.SyclDevice`, an instance of
176+
:class:`dpctl.SyclQueue`, a :class:`dpctl.tensor.Device`
177+
object returned by :attr:`dpctl.tensor.usm_ndarray.device`, or
178+
a filter selector string.
168179
Default: ``None``.
169180
170181
kind (Optional[str, Tuple[str, ...]]):
@@ -196,22 +207,20 @@ def dtypes(self, *, device=None, kind=None):
196207
a dictionary of the supported data types of the specified
197208
``kind``
198209
"""
199-
if device is None:
200-
device = dpctl.select_default_device()
201-
elif isinstance(device, dpt.Device):
202-
device = device.sycl_device
210+
device = _get_device_impl(device)
203211
_fp64 = device.has_aspect_fp64
204212
if kind is None:
205213
return {
206214
key: val
207215
for key, val in self._all_dtypes.items()
208-
if (key != "float64" or _fp64)
216+
if _fp64 or (key != "float64" and key != "complex128")
209217
}
210218
else:
211219
return {
212220
key: val
213221
for key, val in self._all_dtypes.items()
214-
if (key != "float64" or _fp64) and _isdtype_impl(val, kind)
222+
if (_fp64 or (key != "float64" and key != "complex128"))
223+
and _isdtype_impl(val, kind)
215224
}
216225

217226
def devices(self):

0 commit comments

Comments
 (0)