diff --git a/dpctl/tensor/_usmarray.pyx b/dpctl/tensor/_usmarray.pyx index 7266efa3e6..6a9b775e80 100644 --- a/dpctl/tensor/_usmarray.pyx +++ b/dpctl/tensor/_usmarray.pyx @@ -87,10 +87,34 @@ cdef object _as_zero_dim_ndarray(object usm_ary): view.shape = tuple() return view + cdef int _copy_writable(int lhs_flags, int rhs_flags): "Copy the WRITABLE flag to lhs_flags from rhs_flags" return (lhs_flags & ~USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE) + +cdef bint _is_host_cpu(object dl_device): + "Check if dl_device denotes (kDLCPU, 0)" + cdef object dl_type + cdef object dl_id + cdef Py_ssize_t n_elems = -1 + + try: + n_elems = len(dl_device) + except TypeError: + pass + + if n_elems != 2: + return False + + dl_type = dl_device[0] + dl_id = dl_device[1] + if isinstance(dl_type, str): + return (dl_type == "kDLCPU" and dl_id == 0) + + return (dl_type == DLDeviceType.kDLCPU) and (dl_id == 0) + + cdef class usm_ndarray: """ usm_ndarray(shape, dtype=None, strides=None, buffer="device", \ offset=0, order="C", buffer_ctor_kwargs=dict(), \ @@ -1148,8 +1172,7 @@ cdef class usm_ndarray: raise BufferError( "array cannot be placed on the requested device without a copy" ) - if dl_device[0] == (DLDeviceType.kDLCPU): - assert dl_device[1] == 0 + if _is_host_cpu(dl_device): if stream is not None: raise ValueError( "`stream` must be `None` when `dl_device` is of type `kDLCPU`" diff --git a/dpctl/tests/test_usm_ndarray_dlpack.py b/dpctl/tests/test_usm_ndarray_dlpack.py index 527116b1d0..e8d7885b1a 100644 --- a/dpctl/tests/test_usm_ndarray_dlpack.py +++ b/dpctl/tests/test_usm_ndarray_dlpack.py @@ -23,6 +23,7 @@ import dpctl import dpctl.tensor as dpt import dpctl.tensor._dlpack as _dlp +import dpctl.tensor._usmarray as dpt_arr device_oneAPI = 14 # DLDeviceType.kDLOneAPI @@ -470,3 +471,37 @@ def test_dlpack_kwargs(): assert y._pointer != x2._pointer del x2, y del cap + + +def _is_capsule(o): + t = type(o) + return t.__module__ == "builtins" and t.__name__ == "PyCapsule" + + +def test_dlpack_dl_device(): + try: + x = dpt.arange(100, dtype="i4") + except dpctl.SyclDeviceCreationError: + pytest.skip("No SYCL devices available") + max_supported_ver = _dlp.get_build_dlpack_version() + cap1 = x.__dlpack__( + dl_device=x.__dlpack_device__(), max_version=max_supported_ver + ) + assert _is_capsule(cap1) + cap2 = x.__dlpack__(dl_device=(1, 0), max_version=max_supported_ver) + assert _is_capsule(cap2) + cap3 = x.__dlpack__( + dl_device=(dpt_arr.DLDeviceType.kDLCPU, 0), + max_version=max_supported_ver, + ) + assert _is_capsule(cap3) + cap4 = x.__dlpack__(dl_device=("kDLCPU", 0), max_version=max_supported_ver) + assert _is_capsule(cap4) + with pytest.raises(NotImplementedError): + # pass method instead of return of its __call__ invocation + x.__dlpack__( + dl_device=x.__dlpack_device__, max_version=max_supported_ver + ) + with pytest.raises(NotImplementedError): + # exercise check for length + x.__dlpack__(dl_device=(3,), max_version=max_supported_ver)