Skip to content

Commit 96de932

Browse files
Introduce _is_host_cpu utility predicate used in __dlpack__
_is_host_cpu(dl_device) checks if user request export for host CPU device. Recognized inputs are (1, 0) (_usmarray.DLDeviceType.kDLCPU, 0) ("kDLCPU", 0) Add test to exercise __dlpack__ with non-default dl_device keyword arguments
1 parent 1ec9a76 commit 96de932

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,34 @@ cdef object _as_zero_dim_ndarray(object usm_ary):
8787
view.shape = tuple()
8888
return view
8989

90+
9091
cdef int _copy_writable(int lhs_flags, int rhs_flags):
9192
"Copy the WRITABLE flag to lhs_flags from rhs_flags"
9293
return (lhs_flags & ~USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE)
9394

95+
96+
cdef bint _is_host_cpu(object dl_device):
97+
"Check if dl_device denotes (kDLCPU, 0)"
98+
cdef object dl_type
99+
cdef object dl_id
100+
cdef Py_ssize_t n_elems = -1
101+
102+
try:
103+
n_elems = len(dl_device)
104+
except TypeError:
105+
pass
106+
107+
if n_elems != 2:
108+
return False
109+
110+
dl_type = dl_device[0]
111+
dl_id = dl_device[1]
112+
if isinstance(dl_type, str):
113+
return (dl_type == "kDLCPU" and dl_id == 0)
114+
115+
return (dl_type == DLDeviceType.kDLCPU) and (dl_id == 0)
116+
117+
94118
cdef class usm_ndarray:
95119
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
96120
offset=0, order="C", buffer_ctor_kwargs=dict(), \
@@ -1148,8 +1172,7 @@ cdef class usm_ndarray:
11481172
raise BufferError(
11491173
"array cannot be placed on the requested device without a copy"
11501174
)
1151-
if dl_device[0] == (DLDeviceType.kDLCPU):
1152-
assert dl_device[1] == 0
1175+
if _is_host_cpu(dl_device):
11531176
if stream is not None:
11541177
raise ValueError(
11551178
"`stream` must be `None` when `dl_device` is of type `kDLCPU`"

dpctl/tests/test_usm_ndarray_dlpack.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
import dpctl.tensor as dpt
2525
import dpctl.tensor._dlpack as _dlp
2626

27+
import dpctl.tensor._usmarray as dpt_arr
28+
2729
device_oneAPI = 14 # DLDeviceType.kDLOneAPI
2830

2931
_usm_types_list = ["shared", "device", "host"]
@@ -470,3 +472,21 @@ def test_dlpack_kwargs():
470472
assert y._pointer != x2._pointer
471473
del x2, y
472474
del cap
475+
476+
477+
def test_dlpack_dl_device():
478+
try:
479+
x = dpt.arange(100, dtype="i4")
480+
except dpctl.SyclDeviceCreationError:
481+
pytest.skip("No SYCL devices available")
482+
max_supported_ver = _dlp.get_build_dlpack_version()
483+
cap1 = x.__dlpack__(dl_device=x.__dlpack_device__(), max_version=max_supported_ver)
484+
cap2 = x.__dlpack__(dl_device=(1, 0), max_version=max_supported_ver)
485+
cap3 = x.__dlpack__(dl_device=(dpt_arr.DLDeviceType.kDLCPU, 0), max_version=max_supported_ver)
486+
cap4 = x.__dlpack__(dl_device=("kDLCPU", 0), max_version=max_supported_ver)
487+
with pytest.raises(NotImplementedError):
488+
# pass method instead of return of its __call__ invocation
489+
cap5 = x.__dlpack__(dl_device=x.__dlpack_device__, max_version=max_supported_ver)
490+
with pytest.raises(NotImplementedError):
491+
# exercise check for length
492+
cap6 = x.__dlpack__(dl_device=(3,), max_version=max_supported_ver)

0 commit comments

Comments
 (0)