Skip to content

Commit f125be4

Browse files
committed
Adds factory functions to convert between dlpack devices and dpctl.SyclDevice
1 parent 23fcd62 commit f125be4

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

dpctl/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,10 @@
5959
uint64,
6060
)
6161
from dpctl.tensor._device import Device
62+
from dpctl.tensor._dldevice_conversions import (
63+
dldevice_to_sycldevice,
64+
sycldevice_to_dldevice,
65+
)
6266
from dpctl.tensor._dlpack import from_dlpack
6367
from dpctl.tensor._indexing_functions import (
6468
extract,
@@ -388,4 +392,6 @@
388392
"take_along_axis",
389393
"put_along_axis",
390394
"top_k",
395+
"dldevice_to_sycldevice",
396+
"sycldevice_to_dldevice",
391397
]

dpctl/tensor/_dldevice_conversions.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2024 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
import dpctl
18+
19+
from ._usmarray import DLDeviceType
20+
21+
22+
def dldevice_to_sycldevice(dl_dev: tuple):
23+
if isinstance(dl_dev, tuple):
24+
if len(dl_dev) != 2:
25+
raise ValueError("dldevice tuple must have length 2")
26+
else:
27+
raise TypeError(
28+
f"dl_dev is expected to be a 2-tuple, got " f"{type(dl_dev)}"
29+
)
30+
if dl_dev[0] != DLDeviceType.kDLOneAPI:
31+
raise ValueError("dldevice type must be kDLOneAPI")
32+
return dpctl.SyclDevice(str(dl_dev[1]))
33+
34+
35+
def sycldevice_to_dldevice(dev: dpctl.SyclDevice):
36+
if not isinstance(dev, dpctl.SyclDevice):
37+
raise TypeError(
38+
"dev is expected to be a dpctl.SyclDevice, got " f"{type(dev)}"
39+
)
40+
return (DLDeviceType.kDLOneAPI, dev.get_device_id())

0 commit comments

Comments
 (0)