Skip to content

Commit a86da4b

Browse files
authored
Merge pull request #1953 from IntelPython/dlpack-enum-sycl-device-interop
Improve interoperability between `SyclDevice` and DLPack devices
2 parents 2e8c9c0 + b1a5ecd commit a86da4b

File tree

9 files changed

+210
-55
lines changed

9 files changed

+210
-55
lines changed

dpctl/_sycl_device.pyx

Lines changed: 73 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,8 @@ cdef class SyclDevice(_SyclDevice):
283283
284284
Args:
285285
arg (str, optional):
286-
The argument can be a selector string or ``None``.
286+
The argument can be a selector string, another
287+
:class:`dpctl.SyclDevice`, or ``None``.
287288
Defaults to ``None``.
288289
289290
Raises:
@@ -293,9 +294,7 @@ cdef class SyclDevice(_SyclDevice):
293294
SyclDeviceCreationError:
294295
If the :class:`dpctl.SyclDevice` object creation failed.
295296
TypeError:
296-
If the list of :class:`dpctl.SyclDevice` objects was empty,
297-
or the input capsule contained a null pointer or could not
298-
be renamed.
297+
If the argument is not a :class:`dpctl.SyclDevice` or string.
299298
"""
300299
@staticmethod
301300
cdef SyclDevice _create(DPCTLSyclDeviceRef dref):
@@ -363,9 +362,9 @@ cdef class SyclDevice(_SyclDevice):
363362
"Could not create a SyclDevice from default selector"
364363
)
365364
else:
366-
raise ValueError(
365+
raise TypeError(
367366
"Invalid argument. Argument should be a str object specifying "
368-
"a SYCL filter selector string."
367+
"a SYCL filter selector string or another SyclDevice."
369368
)
370369

371370
def print_device_info(self):
@@ -1557,7 +1556,7 @@ cdef class SyclDevice(_SyclDevice):
15571556
cdef int i
15581557

15591558
if ncounts == 0:
1560-
raise TypeError(
1559+
raise ValueError(
15611560
"Non-empty object representing list of counts is expected."
15621561
)
15631562
counts_buff = <size_t *> malloc((<size_t> ncounts) * sizeof(size_t))
@@ -1659,7 +1658,7 @@ cdef class SyclDevice(_SyclDevice):
16591658
Created sub-devices.
16601659
16611660
Raises:
1662-
TypeError:
1661+
ValueError:
16631662
If the ``partition`` keyword argument is not specified or
16641663
the affinity domain string is not legal or is not one of the
16651664
three supported options.
@@ -1695,7 +1694,7 @@ cdef class SyclDevice(_SyclDevice):
16951694
_partition_affinity_domain_type._next_partitionable
16961695
)
16971696
else:
1698-
raise TypeError(
1697+
raise ValueError(
16991698
"Partition affinity domain {} is not understood.".format(
17001699
partition
17011700
)
@@ -1708,11 +1707,11 @@ cdef class SyclDevice(_SyclDevice):
17081707
else:
17091708
try:
17101709
partition = int(partition)
1711-
return self.create_sub_devices_equally(partition)
17121710
except Exception as e:
17131711
raise TypeError(
17141712
"Unsupported type of sub-device argument"
17151713
) from e
1714+
return self.create_sub_devices_equally(partition)
17161715

17171716
@property
17181717
def parent_device(self):
@@ -1877,7 +1876,7 @@ cdef class SyclDevice(_SyclDevice):
18771876
A Python string representing a filter selector string.
18781877
18791878
Raises:
1880-
TypeError:
1879+
ValueError:
18811880
If the device is a sub-device.
18821881
18831882
:Example:
@@ -1902,7 +1901,7 @@ cdef class SyclDevice(_SyclDevice):
19021901
else:
19031902
# this a sub-device, free it, and raise an exception
19041903
DPCTLDevice_Delete(pDRef)
1905-
raise TypeError("This SyclDevice is not a root device")
1904+
raise ValueError("This SyclDevice is not a root device")
19061905

19071906
cdef int get_backend_and_device_type_ordinal(self):
19081907
""" If this device is a root ``sycl::device``, returns the ordinal
@@ -1950,9 +1949,7 @@ cdef class SyclDevice(_SyclDevice):
19501949

19511950
cdef int get_overall_ordinal(self):
19521951
""" If this device is a root ``sycl::device``, returns the ordinal
1953-
position of this device in the vector ``sycl::device::get_devices()``
1954-
filtered to contain only devices with the same backend as this
1955-
device.
1952+
position of this device in the vector ``sycl::device::get_devices()``.
19561953
19571954
Returns -1 if the device is a sub-device, or the device could not
19581955
be found in the vector.
@@ -1985,9 +1982,9 @@ cdef class SyclDevice(_SyclDevice):
19851982
A Python string representing a filter selector string.
19861983
19871984
Raises:
1988-
TypeError:
1989-
If the device is a sub-device.
19901985
ValueError:
1986+
If the device is a sub-device.
1987+
19911988
If no match for the device was found in the vector
19921989
returned by ``sycl::device::get_devices()``
19931990
@@ -2026,7 +2023,7 @@ cdef class SyclDevice(_SyclDevice):
20262023
else:
20272024
# this a sub-device, free it, and raise an exception
20282025
DPCTLDevice_Delete(pDRef)
2029-
raise TypeError("This SyclDevice is not a root device")
2026+
raise ValueError("This SyclDevice is not a root device")
20302027
else:
20312028
if include_backend:
20322029
BTy = DPCTLDevice_GetBackend(self._device_ref)
@@ -2045,6 +2042,64 @@ cdef class SyclDevice(_SyclDevice):
20452042
else:
20462043
return str(relId)
20472044

2045+
def get_unpartitioned_parent_device(self):
2046+
""" get_unpartitioned_parent_device()
2047+
2048+
Returns the unpartitioned parent device of this device.
2049+
2050+
If this device is already an unpartitioned, root device,
2051+
the same device is returned.
2052+
2053+
Returns:
2054+
dpctl.SyclDevice:
2055+
A parent, unpartitioned :class:`dpctl.SyclDevice` instance, or
2056+
``self`` if already a root device.
2057+
"""
2058+
cdef DPCTLSyclDeviceRef pDRef = NULL
2059+
cdef DPCTLSyclDeviceRef tDRef = NULL
2060+
pDRef = DPCTLDevice_GetParentDevice(self._device_ref)
2061+
if pDRef is NULL:
2062+
return self
2063+
else:
2064+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
2065+
while tDRef is not NULL:
2066+
DPCTLDevice_Delete(pDRef)
2067+
pDRef = tDRef
2068+
tDRef = DPCTLDevice_GetParentDevice(pDRef)
2069+
return SyclDevice._create(pDRef)
2070+
2071+
def get_device_id(self):
2072+
""" get_device_id()
2073+
For an unpartitioned device, returns the canonical index of this device
2074+
in the list of devices visible to dpctl.
2075+
2076+
Returns:
2077+
int:
2078+
The index of the device.
2079+
2080+
Raises:
2081+
ValueError:
2082+
If the device could not be found.
2083+
2084+
:Example:
2085+
.. code-block:: python
2086+
2087+
import dpctl
2088+
gpu_dev = dpctl.SyclDevice("gpu")
2089+
i = gpu_dev.get_device_id
2090+
devs = dpctl.get_devices()
2091+
assert devs[i] == gpu_dev
2092+
"""
2093+
cdef int dev_id = -1
2094+
cdef SyclDevice dev
2095+
2096+
dev = self.get_unpartitioned_parent_device()
2097+
dev_id = dev.get_overall_ordinal()
2098+
if dev_id < 0:
2099+
raise ValueError("device could not be found")
2100+
return dev_id
2101+
2102+
20482103
cdef api DPCTLSyclDeviceRef SyclDevice_GetDeviceRef(SyclDevice dev):
20492104
"""
20502105
C-API function to get opaque device reference from

dpctl/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@
5959
uint64,
6060
)
6161
from dpctl.tensor._device import Device
62+
from dpctl.tensor._dldevice_conversions import (
63+
dldevice_to_sycl_device,
64+
sycl_device_to_dldevice,
65+
)
6266
from dpctl.tensor._dlpack import from_dlpack
6367
from dpctl.tensor._indexing_functions import (
6468
extract,
@@ -388,4 +392,6 @@
388392
"take_along_axis",
389393
"put_along_axis",
390394
"top_k",
395+
"dldevice_to_sycl_device",
396+
"sycl_device_to_dldevice",
391397
]

dpctl/tensor/_dldevice_conversions.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2025 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from .._sycl_device import SyclDevice
18+
from ._usmarray import DLDeviceType
19+
20+
21+
def dldevice_to_sycl_device(dl_dev: tuple):
22+
if isinstance(dl_dev, tuple):
23+
if len(dl_dev) != 2:
24+
raise ValueError("dldevice tuple must have length 2")
25+
else:
26+
raise TypeError(
27+
f"dl_dev is expected to be a 2-tuple, got " f"{type(dl_dev)}"
28+
)
29+
if dl_dev[0] != DLDeviceType.kDLOneAPI:
30+
raise ValueError("dldevice type must be kDLOneAPI")
31+
return SyclDevice(str(dl_dev[1]))
32+
33+
34+
def sycl_device_to_dldevice(dev: SyclDevice):
35+
if not isinstance(dev, SyclDevice):
36+
raise TypeError(
37+
"dev is expected to be a SyclDevice, got " f"{type(dev)}"
38+
)
39+
return (DLDeviceType.kDLOneAPI, dev.get_device_id())

dpctl/tensor/_dlpack.pxd

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ cpdef object to_dlpack_versioned_capsule(usm_ndarray array, bint copied) except
4747
cpdef object numpy_to_dlpack_versioned_capsule(ndarray array, bint copied) except +
4848
cpdef object from_dlpack_capsule(object dltensor) except +
4949

50-
cdef int get_parent_device_ordinal_id(SyclDevice dev) except *
51-
5250
cdef class DLPackCreationError(Exception):
5351
"""
5452
A DLPackCreateError exception is raised when constructing

dpctl/tensor/_dlpack.pyx

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -177,28 +177,6 @@ cdef object _get_default_context(c_dpctl.SyclDevice dev):
177177

178178
return default_context
179179

180-
181-
cdef int get_parent_device_ordinal_id(c_dpctl.SyclDevice dev) except -1:
182-
cdef DPCTLSyclDeviceRef pDRef = NULL
183-
cdef DPCTLSyclDeviceRef tDRef = NULL
184-
cdef c_dpctl.SyclDevice p_dev
185-
186-
pDRef = DPCTLDevice_GetParentDevice(dev.get_device_ref())
187-
if pDRef is not NULL:
188-
# if dev is a sub-device, find its parent
189-
# and return its overall ordinal id
190-
tDRef = DPCTLDevice_GetParentDevice(pDRef)
191-
while tDRef is not NULL:
192-
DPCTLDevice_Delete(pDRef)
193-
pDRef = tDRef
194-
tDRef = DPCTLDevice_GetParentDevice(pDRef)
195-
p_dev = c_dpctl.SyclDevice._create(pDRef)
196-
return p_dev.get_overall_ordinal()
197-
198-
# return overall ordinal id of argument device
199-
return dev.get_overall_ordinal()
200-
201-
202180
cdef int get_array_dlpack_device_id(
203181
usm_ndarray usm_ary
204182
) except -1:
@@ -224,14 +202,13 @@ cdef int get_array_dlpack_device_id(
224202
"on non-partitioned SYCL devices on platforms where "
225203
"default_context oneAPI extension is not supported."
226204
)
227-
device_id = ary_sycl_device.get_overall_ordinal()
228205
else:
229206
if not usm_ary.sycl_context == default_context:
230207
raise DLPackCreationError(
231208
"to_dlpack_capsule: DLPack can only export arrays based on USM "
232209
"allocations bound to a default platform SYCL context"
233210
)
234-
device_id = get_parent_device_ordinal_id(ary_sycl_device)
211+
device_id = ary_sycl_device.get_device_id()
235212

236213
if device_id < 0:
237214
raise DLPackCreationError(
@@ -1086,7 +1063,7 @@ def from_dlpack(x, /, *, device=None, copy=None):
10861063
d = device.sycl_device
10871064
else:
10881065
d = device
1089-
dl_device = (device_OneAPI, get_parent_device_ordinal_id(<c_dpctl.SyclDevice>d))
1066+
dl_device = (device_OneAPI, d.get_device_id())
10901067
if dl_device is not None:
10911068
if (dl_device[0] not in [device_OneAPI, device_CPU]):
10921069
raise ValueError(

dpctl/tensor/_usmarray.pyx

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,16 +1304,16 @@ cdef class usm_ndarray:
13041304
DLPackCreationError:
13051305
when the ``device_id`` could not be determined.
13061306
"""
1307-
cdef int dev_id = c_dlpack.get_parent_device_ordinal_id(<c_dpctl.SyclDevice>self.sycl_device)
1308-
if dev_id < 0:
1307+
try:
1308+
dev_id = self.sycl_device.get_device_id()
1309+
except ValueError as e:
13091310
raise c_dlpack.DLPackCreationError(
13101311
"Could not determine id of the device where array was allocated."
13111312
)
1312-
else:
1313-
return (
1314-
DLDeviceType.kDLOneAPI,
1315-
dev_id,
1316-
)
1313+
return (
1314+
DLDeviceType.kDLOneAPI,
1315+
dev_id,
1316+
)
13171317

13181318
def __eq__(self, other):
13191319
return dpctl.tensor.equal(self, other)

dpctl/tests/test_sycl_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def test_cpython_api_SyclContext_Make():
263263

264264
def test_invalid_capsule():
265265
cap = create_invalid_capsule()
266-
with pytest.raises(ValueError):
266+
with pytest.raises(TypeError):
267267
dpctl.SyclContext(cap)
268268

269269

0 commit comments

Comments
 (0)