@@ -46,7 +46,21 @@ def _isdtype_impl(dtype, kind):
46
46
elif isinstance (kind , tuple ):
47
47
return any (_isdtype_impl (dtype , k ) for k in kind )
48
48
else :
49
- raise TypeError (f"Unsupported data type kind: { kind } " )
49
+ raise TypeError (f"Unsupported type for dtype kind: { type (kind )} " )
50
+
51
+
52
+ def _get_device_impl (d ):
53
+ if d is None :
54
+ return dpctl .select_default_device ()
55
+ elif isinstance (d , dpctl .SyclDevice ):
56
+ return d
57
+ elif isinstance (d , (dpt .Device , dpctl .SyclQueue )):
58
+ return d .sycl_device
59
+ else :
60
+ try :
61
+ return dpctl .SyclDevice (d )
62
+ except TypeError :
63
+ raise TypeError (f"Unsupported type for device argument: { type (d )} " )
50
64
51
65
52
66
__array_api_version__ = "2023.12"
@@ -117,13 +131,13 @@ def default_dtypes(self, *, device=None):
117
131
Returns a dictionary of default data types for ``device``.
118
132
119
133
Args:
120
- device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`]):
134
+ device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`, str ]):
121
135
array API concept of device used in getting default data types.
122
136
``device`` can be ``None`` (in which case the default device
123
- is used), an instance of :class:`dpctl.SyclDevice` corresponding
124
- to a non-partitioned SYCL device, an instance of
125
- :class:`dpctl.SyclQueue`, or a :class :`dpctl.tensor.Device`
126
- object returned by :attr:`dpctl.tensor.usm_ndarray.device` .
137
+ is used), an instance of :class:`dpctl.SyclDevice`, an instance
138
+ of :class:`dpctl.SyclQueue`, a :class:`dpctl.tensor.Device`
139
+ object returned by :attr :`dpctl.tensor.usm_ndarray.device`, or
140
+ a filter selector string .
127
141
Default: ``None``.
128
142
129
143
Returns:
@@ -135,10 +149,7 @@ def default_dtypes(self, *, device=None):
135
149
- ``"integral"``: dtype
136
150
- ``"indexing"``: dtype
137
151
"""
138
- if device is None :
139
- device = dpctl .select_default_device ()
140
- elif isinstance (device , dpt .Device ):
141
- device = device .sycl_device
152
+ device = _get_device_impl (device )
142
153
return {
143
154
"real floating" : dpt .dtype (default_device_fp_type (device )),
144
155
"complex floating" : dpt .dtype (default_device_complex_type (device )),
@@ -161,10 +172,10 @@ def dtypes(self, *, device=None, kind=None):
161
172
device (Optional[:class:`dpctl.SyclDevice`, :class:`dpctl.SyclQueue`, :class:`dpctl.tensor.Device`, str]):
162
173
array API concept of device used in getting default data types.
163
174
``device`` can be ``None`` (in which case the default device is
164
- used), an instance of :class:`dpctl.SyclDevice` corresponding
165
- to a non-partitioned SYCL device, an instance of
166
- :class:`dpctl.SyclQueue`, or a :class :`dpctl.tensor.Device`
167
- object returned by :attr:`dpctl.tensor.usm_ndarray.device` .
175
+ used), an instance of :class:`dpctl.SyclDevice`, an instance of
176
+ :class:`dpctl.SyclQueue`, a :class:`dpctl.tensor.Device`
177
+ object returned by :attr :`dpctl.tensor.usm_ndarray.device`, or
178
+ a filter selector string .
168
179
Default: ``None``.
169
180
170
181
kind (Optional[str, Tuple[str, ...]]):
@@ -196,22 +207,20 @@ def dtypes(self, *, device=None, kind=None):
196
207
a dictionary of the supported data types of the specified
197
208
``kind``
198
209
"""
199
- if device is None :
200
- device = dpctl .select_default_device ()
201
- elif isinstance (device , dpt .Device ):
202
- device = device .sycl_device
210
+ device = _get_device_impl (device )
203
211
_fp64 = device .has_aspect_fp64
204
212
if kind is None :
205
213
return {
206
214
key : val
207
215
for key , val in self ._all_dtypes .items ()
208
- if (key != "float64" or _fp64 )
216
+ if _fp64 or (key != "float64" and key != "complex128" )
209
217
}
210
218
else :
211
219
return {
212
220
key : val
213
221
for key , val in self ._all_dtypes .items ()
214
- if (key != "float64" or _fp64 ) and _isdtype_impl (val , kind )
222
+ if (_fp64 or (key != "float64" and key != "complex128" ))
223
+ and _isdtype_impl (val , kind )
215
224
}
216
225
217
226
def devices (self ):
0 commit comments