Skip to content

Introduce _is_host_cpu utility predicate used in __dlpack__ #1784

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions dpctl/tensor/_usmarray.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,34 @@ cdef object _as_zero_dim_ndarray(object usm_ary):
view.shape = tuple()
return view


cdef int _copy_writable(int lhs_flags, int rhs_flags):
"Copy the WRITABLE flag to lhs_flags from rhs_flags"
return (lhs_flags & ~USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE)


cdef bint _is_host_cpu(object dl_device):
"Check if dl_device denotes (kDLCPU, 0)"
cdef object dl_type
cdef object dl_id
cdef Py_ssize_t n_elems = -1

try:
n_elems = len(dl_device)
except TypeError:
pass

if n_elems != 2:
return False

dl_type = dl_device[0]
dl_id = dl_device[1]
if isinstance(dl_type, str):
return (dl_type == "kDLCPU" and dl_id == 0)

return (dl_type == DLDeviceType.kDLCPU) and (dl_id == 0)


cdef class usm_ndarray:
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
offset=0, order="C", buffer_ctor_kwargs=dict(), \
Expand Down Expand Up @@ -1148,8 +1172,7 @@ cdef class usm_ndarray:
raise BufferError(
"array cannot be placed on the requested device without a copy"
)
if dl_device[0] == (DLDeviceType.kDLCPU):
assert dl_device[1] == 0
if _is_host_cpu(dl_device):
if stream is not None:
raise ValueError(
"`stream` must be `None` when `dl_device` is of type `kDLCPU`"
Expand Down
35 changes: 35 additions & 0 deletions dpctl/tests/test_usm_ndarray_dlpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import dpctl
import dpctl.tensor as dpt
import dpctl.tensor._dlpack as _dlp
import dpctl.tensor._usmarray as dpt_arr

device_oneAPI = 14 # DLDeviceType.kDLOneAPI

Expand Down Expand Up @@ -470,3 +471,37 @@ def test_dlpack_kwargs():
assert y._pointer != x2._pointer
del x2, y
del cap


def _is_capsule(o):
t = type(o)
return t.__module__ == "builtins" and t.__name__ == "PyCapsule"


def test_dlpack_dl_device():
try:
x = dpt.arange(100, dtype="i4")
except dpctl.SyclDeviceCreationError:
pytest.skip("No SYCL devices available")
max_supported_ver = _dlp.get_build_dlpack_version()
cap1 = x.__dlpack__(
dl_device=x.__dlpack_device__(), max_version=max_supported_ver
)
assert _is_capsule(cap1)
cap2 = x.__dlpack__(dl_device=(1, 0), max_version=max_supported_ver)
assert _is_capsule(cap2)
cap3 = x.__dlpack__(
dl_device=(dpt_arr.DLDeviceType.kDLCPU, 0),
max_version=max_supported_ver,
)
assert _is_capsule(cap3)
cap4 = x.__dlpack__(dl_device=("kDLCPU", 0), max_version=max_supported_ver)
assert _is_capsule(cap4)
with pytest.raises(NotImplementedError):
# pass method instead of return of its __call__ invocation
x.__dlpack__(
dl_device=x.__dlpack_device__, max_version=max_supported_ver
)
with pytest.raises(NotImplementedError):
# exercise check for length
x.__dlpack__(dl_device=(3,), max_version=max_supported_ver)
Loading